Refactor network layer to use streams

Summary:
The network layer now has a `Session` that handles all things that should be
done before the `Execute` method is called on sessions. Also, all sessions
now communicate using streams instead of holding the input buffer and writing
to the `Socket`. This design will allow implementation of a SSL middleware.

Reviewers: buda, dgleich

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1314
This commit is contained in:
Matej Ferencevic 2018-03-23 16:32:17 +01:00
parent b2f3bf9709
commit f1a8d7cd3d
27 changed files with 778 additions and 521 deletions

View File

@ -2,8 +2,8 @@
# all memgraph src files # all memgraph src files
set(memgraph_src_files set(memgraph_src_files
communication/buffer.cpp
communication/bolt/v1/decoder/decoded_value.cpp communication/bolt/v1/decoder/decoded_value.cpp
communication/bolt/v1/session.cpp
communication/rpc/buffer.cpp communication/rpc/buffer.cpp
communication/rpc/client.cpp communication/rpc/client.cpp
communication/rpc/protocol.cpp communication/rpc/protocol.cpp

View File

@ -263,8 +263,8 @@ class Client {
// decoder objects // decoder objects
Buffer<> buffer_; Buffer<> buffer_;
ChunkedDecoderBuffer decoder_buffer_{buffer_}; ChunkedDecoderBuffer<Buffer<>> decoder_buffer_{buffer_};
Decoder<ChunkedDecoderBuffer> decoder_{decoder_buffer_}; Decoder<ChunkedDecoderBuffer<Buffer<>>> decoder_{decoder_buffer_};
// encoder objects // encoder objects
ChunkedEncoderBuffer<Socket> encoder_buffer_{socket_}; ChunkedEncoderBuffer<Socket> encoder_buffer_{socket_};

View File

@ -9,7 +9,6 @@
#include <glog/logging.h> #include <glog/logging.h>
#include "communication/bolt/v1/constants.hpp" #include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/decoder/buffer.hpp"
namespace communication::bolt { namespace communication::bolt {
@ -39,12 +38,10 @@ enum class ChunkState : uint8_t {
* chunk for validity and then copies only data from the chunk. The headers * chunk for validity and then copies only data from the chunk. The headers
* aren't copied so that the decoder can read only the raw encoded data. * aren't copied so that the decoder can read only the raw encoded data.
*/ */
template <typename TBuffer>
class ChunkedDecoderBuffer { class ChunkedDecoderBuffer {
private:
using StreamBufferT = io::network::StreamBuffer;
public: public:
ChunkedDecoderBuffer(Buffer<> &buffer) : buffer_(buffer) { ChunkedDecoderBuffer(TBuffer &buffer) : buffer_(buffer) {
data_.reserve(MAX_CHUNK_SIZE); data_.reserve(MAX_CHUNK_SIZE);
} }
@ -130,8 +127,8 @@ class ChunkedDecoderBuffer {
size_t Size() { return data_.size() - pos_; } size_t Size() { return data_.size() - pos_; }
private: private:
Buffer<> &buffer_; TBuffer &buffer_;
std::vector<uint8_t> data_; std::vector<uint8_t> data_;
size_t pos_{0}; size_t pos_{0};
}; };
} } // namespace communication::bolt

View File

@ -1,7 +0,0 @@
#include "communication/bolt/v1/session.hpp"
// Danger: If multiple sessions are associated with one worker nonactive
// sessions could become blocked for FLAGS_session_inactivity_timeout time.
// TODO: We should never associate more sessions with one worker.
DEFINE_int32(session_inactivity_timeout, 1800,
"Time in seconds after which inactive sessions will be closed");

View File

@ -14,17 +14,13 @@
#include "communication/bolt/v1/states/executing.hpp" #include "communication/bolt/v1/states/executing.hpp"
#include "communication/bolt/v1/states/handshake.hpp" #include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/states/init.hpp" #include "communication/bolt/v1/states/init.hpp"
#include "communication/buffer.hpp"
#include "database/graph_db.hpp" #include "database/graph_db.hpp"
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp"
#include "query/interpreter.hpp" #include "query/interpreter.hpp"
#include "threading/sync/spinlock.hpp" #include "threading/sync/spinlock.hpp"
#include "transactions/transaction.hpp" #include "transactions/transaction.hpp"
#include "utils/exceptions.hpp" #include "utils/exceptions.hpp"
DECLARE_int32(session_inactivity_timeout);
namespace communication::bolt { namespace communication::bolt {
/** Encapsulates Dbms and Interpreter that are passed through the network server /** Encapsulates Dbms and Interpreter that are passed through the network server
@ -49,18 +45,21 @@ class SessionException : public utils::BasicException {
* *
* This class is responsible for handling a single client connection. * This class is responsible for handling a single client connection.
* *
* @tparam TSocket type of socket (could be a network socket or test socket) * @tparam TInputStream type of input stream that will be used
* @tparam TOutputStream type of output stream that will be used
*/ */
template <typename TSocket> template <typename TInputStream, typename TOutputStream>
class Session { class Session {
public: public:
using ResultStreamT = ResultStream<Encoder<ChunkedEncoderBuffer<TSocket>>>; using ResultStreamT =
using StreamBuffer = io::network::StreamBuffer; ResultStream<Encoder<ChunkedEncoderBuffer<TOutputStream>>>;
Session(TSocket &&socket, SessionData &data) Session(SessionData &data, TInputStream &input_stream,
: socket_(std::move(socket)), TOutputStream &output_stream)
db_(data.db), : db_(data.db),
interpreter_(data.interpreter) {} interpreter_(data.interpreter),
input_stream_(input_stream),
output_stream_(output_stream) {}
~Session() { ~Session() {
if (db_accessor_) { if (db_accessor_) {
@ -68,24 +67,19 @@ class Session {
} }
} }
/**
* @return the socket id
*/
int Id() const { return socket_.fd(); }
/** /**
* Executes the session after data has been read into the buffer. * Executes the session after data has been read into the buffer.
* Goes through the bolt states in order to execute commands from the client. * Goes through the bolt states in order to execute commands from the client.
*/ */
void Execute() { void Execute() {
if (UNLIKELY(!handshake_done_)) { if (UNLIKELY(!handshake_done_)) {
if (buffer_.size() < HANDSHAKE_SIZE) { if (input_stream_.size() < HANDSHAKE_SIZE) {
DLOG(WARNING) << fmt::format("Received partial handshake of size {}", DLOG(WARNING) << fmt::format("Received partial handshake of size {}",
buffer_.size()); input_stream_.size());
return; return;
} }
DLOG(WARNING) << fmt::format("Decoding handshake of size {}", DLOG(WARNING) << fmt::format("Decoding handshake of size {}",
buffer_.size()); input_stream_.size());
state_ = StateHandshakeRun(*this); state_ = StateHandshakeRun(*this);
if (UNLIKELY(state_ == State::Close)) { if (UNLIKELY(state_ == State::Close)) {
ClientFailureInvalidData(); ClientFailureInvalidData();
@ -129,43 +123,12 @@ class Session {
return; return;
} }
DLOG(INFO) << fmt::format("Buffer size: {}", buffer_.size()); DLOG(INFO) << fmt::format("Input stream size: {}", input_stream_.size());
DLOG(INFO) << fmt::format("Decoder buffer size: {}", DLOG(INFO) << fmt::format("Decoder buffer size: {}",
decoder_buffer_.Size()); decoder_buffer_.Size());
} }
} }
/**
* Allocates data from the internal buffer.
* Used in the underlying network stack to asynchronously read data
* from the client.
* @returns a StreamBuffer to the allocated internal data buffer
*/
StreamBuffer Allocate() { return buffer_.Allocate(); }
/**
* Notifies the internal buffer of written data.
* Used in the underlying network stack to notify the internal buffer
* how many bytes of data have been written.
* @param len how many data was written to the buffer
*/
void Written(size_t len) { buffer_.Written(len); }
/**
* Returns true if session has timed out. Session times out if there was no
* activity in FLAGS_sessions_inactivity_timeout seconds or if there is a
* active transaction with shoul_abort flag set to true.
* This function must be thread safe because this function and
* `RefreshLastEventTime` are called from different threads in the
* network stack.
*/
bool TimedOut() {
std::unique_lock<SpinLock> guard(lock_);
return last_event_time_ +
std::chrono::seconds(FLAGS_session_inactivity_timeout) <
std::chrono::steady_clock::now();
}
/** /**
* Commits associated transaction. * Commits associated transaction.
*/ */
@ -184,34 +147,19 @@ class Session {
db_accessor_ = nullptr; db_accessor_ = nullptr;
} }
TSocket &socket() { return socket_; }
/**
* Function that is called by the network stack to set the last event time.
* It is used to determine whether the session has timed out.
* This function must be thread safe because this function and
* `TimedOut` are called from different threads in the network stack.
*/
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
std::unique_lock<SpinLock> guard(lock_);
last_event_time_ = last_event_time;
}
// TODO: Rethink if there is a way to hide some members. At the momement all // TODO: Rethink if there is a way to hide some members. At the momement all
// of them are public. // of them are public.
TSocket socket_;
database::MasterBase &db_; database::MasterBase &db_;
query::Interpreter &interpreter_; query::Interpreter &interpreter_;
TInputStream &input_stream_;
TOutputStream &output_stream_;
ChunkedEncoderBuffer<TSocket> encoder_buffer_{socket_}; ChunkedEncoderBuffer<TOutputStream> encoder_buffer_{output_stream_};
Encoder<ChunkedEncoderBuffer<TSocket>> encoder_{encoder_buffer_}; Encoder<ChunkedEncoderBuffer<TOutputStream>> encoder_{encoder_buffer_};
ResultStreamT output_stream_{encoder_}; ResultStreamT result_stream_{encoder_};
Buffer<> buffer_; ChunkedDecoderBuffer<TInputStream> decoder_buffer_{input_stream_};
ChunkedDecoderBuffer decoder_buffer_{buffer_}; Decoder<ChunkedDecoderBuffer<TInputStream>> decoder_{decoder_buffer_};
Decoder<ChunkedDecoderBuffer> decoder_{decoder_buffer_};
bool handshake_done_{false}; bool handshake_done_{false};
State state_{State::Handshake}; State state_{State::Handshake};
@ -219,11 +167,6 @@ class Session {
// there is no associated transaction. // there is no associated transaction.
std::unique_ptr<database::GraphDbAccessor> db_accessor_; std::unique_ptr<database::GraphDbAccessor> db_accessor_;
// Time of the last event and associated lock.
std::chrono::time_point<std::chrono::steady_clock> last_event_time_ =
std::chrono::steady_clock::now();
SpinLock lock_;
private: private:
void ClientFailureInvalidData() { void ClientFailureInvalidData() {
// Set the state to Close. // Set the state to Close.
@ -237,11 +180,7 @@ class Session {
"Check the server logs for more details."}}); "Check the server logs for more details."}});
// Throw an exception to indicate that something went wrong with execution // Throw an exception to indicate that something went wrong with execution
// of the session to trigger session cleanup and socket close. // of the session to trigger session cleanup and socket close.
if (TimedOut()) { throw SessionException("Something went wrong during session execution!");
throw SessionException("The session has timed out!");
} else {
throw SessionException("The client has sent invalid data!");
}
} }
}; };
} // namespace communication::bolt } // namespace communication::bolt

View File

@ -128,7 +128,7 @@ State HandleRun(TSession &session, State state, Marker marker) {
session session
.interpreter_(query.ValueString(), *session.db_accessor_, params_tv, .interpreter_(query.ValueString(), *session.db_accessor_, params_tv,
in_explicit_transaction) in_explicit_transaction)
.PullAll(session.output_stream_); .PullAll(session.result_stream_);
if (!in_explicit_transaction) { if (!in_explicit_transaction) {
session.Commit(); session.Commit();

View File

@ -15,7 +15,7 @@ namespace communication::bolt {
*/ */
template <typename TSession> template <typename TSession>
State StateHandshakeRun(TSession &session) { State StateHandshakeRun(TSession &session) {
auto precmp = memcmp(session.buffer_.data(), kPreamble, sizeof(kPreamble)); auto precmp = memcmp(session.input_stream_.data(), kPreamble, sizeof(kPreamble));
if (UNLIKELY(precmp != 0)) { if (UNLIKELY(precmp != 0)) {
DLOG(WARNING) << "Received a wrong preamble!"; DLOG(WARNING) << "Received a wrong preamble!";
return State::Close; return State::Close;
@ -25,14 +25,14 @@ State StateHandshakeRun(TSession &session) {
// make sense to check which version the client prefers this will change in // make sense to check which version the client prefers this will change in
// the future. // the future.
if (!session.socket_.Write(kProtocol, sizeof(kProtocol))) { if (!session.output_stream_.Write(kProtocol, sizeof(kProtocol))) {
DLOG(WARNING) << "Couldn't write handshake response!"; DLOG(WARNING) << "Couldn't write handshake response!";
return State::Close; return State::Close;
} }
// Delete data from buffer. It is guaranteed that there will more than, or // Delete data from the input stream. It is guaranteed that there will more
// equal to 20 bytes (HANDSHAKE_SIZE) in the buffer. // than, or equal to 20 bytes (HANDSHAKE_SIZE) in the buffer.
session.buffer_.Shift(HANDSHAKE_SIZE); session.input_stream_.Shift(HANDSHAKE_SIZE);
return State::Init; return State::Init;
} }

View File

@ -0,0 +1,71 @@
#include "glog/logging.h"
#include "communication/buffer.hpp"
namespace communication {
Buffer::Buffer()
: data_(kBufferInitialSize, 0), read_end_(*this), write_end_(*this) {}
Buffer::ReadEnd::ReadEnd(Buffer &buffer) : buffer_(buffer) {}
uint8_t *Buffer::ReadEnd::data() { return buffer_.data(); }
size_t Buffer::ReadEnd::size() const { return buffer_.size(); }
void Buffer::ReadEnd::Shift(size_t len) { buffer_.Shift(len); }
void Buffer::ReadEnd::Resize(size_t len) { buffer_.Resize(len); }
void Buffer::ReadEnd::Clear() { buffer_.Clear(); }
Buffer::WriteEnd::WriteEnd(Buffer &buffer) : buffer_(buffer) {}
io::network::StreamBuffer Buffer::WriteEnd::Allocate() {
return buffer_.Allocate();
}
void Buffer::WriteEnd::Written(size_t len) { buffer_.Written(len); }
void Buffer::WriteEnd::Resize(size_t len) { buffer_.Resize(len); }
void Buffer::WriteEnd::Clear() { buffer_.Clear(); }
Buffer::ReadEnd &Buffer::read_end() { return read_end_; }
Buffer::WriteEnd &Buffer::write_end() { return write_end_; }
uint8_t *Buffer::data() { return data_.data(); }
size_t Buffer::size() const { return have_; }
void Buffer::Shift(size_t len) {
DCHECK(len <= have_) << "Tried to shift more data than the buffer has!";
if (len == have_) {
have_ = 0;
} else {
memmove(data_.data(), data_.data() + len, have_ - len);
have_ -= len;
}
}
io::network::StreamBuffer Buffer::Allocate() {
DCHECK(data_.size() > have_) << "The buffer thinks that there is more data "
"in the buffer than there is underlying "
"storage space!";
return {data_.data() + have_, data_.size() - have_};
}
void Buffer::Written(size_t len) {
have_ += len;
DCHECK(have_ <= data_.size()) << "Written more than storage has space!";
}
void Buffer::Resize(size_t len) {
if (len < data_.size()) return;
data_.resize(len, 0);
}
void Buffer::Clear() { have_ = 0; }
} // namespace communication

View File

@ -0,0 +1,148 @@
#pragma once
#include <vector>
#include "io/network/stream_buffer.hpp"
namespace communication {
/**
* @brief Buffer
*
* Has methods for writing and reading raw data.
*
* Allocating, writing and written stores data in the buffer. The stored
* data can then be read using the pointer returned with the data function.
* This implementation stores data in a variable sized array (a vector).
* The internal array can only grow in size.
*
* This buffer is NOT thread safe. It is intended to be used in the network
* stack where all execution when it is being done is being done on a single
* thread.
*/
class Buffer {
private:
// Initial capacity of the internal buffer.
const size_t kBufferInitialSize = 65536;
public:
Buffer();
/**
* This class provides all functions from the buffer that are needed to allow
* reading data from the buffer.
*/
class ReadEnd {
public:
ReadEnd(Buffer &buffer);
uint8_t *data();
size_t size() const;
void Shift(size_t len);
void Resize(size_t len);
void Clear();
private:
Buffer &buffer_;
};
/**
* This class provides all functions from the buffer that are needed to allow
* writing data to the buffer.
*/
class WriteEnd {
public:
WriteEnd(Buffer &buffer);
io::network::StreamBuffer Allocate();
void Written(size_t len);
void Resize(size_t len);
void Clear();
private:
Buffer &buffer_;
};
/**
* This function returns a reference to the associated ReadEnd object for this
* buffer.
*/
ReadEnd &read_end();
/**
* This function returns a reference to the associated WriteEnd object for
* this buffer.
*/
WriteEnd &write_end();
private:
/**
* This function returns a pointer to the internal buffer. It is used for
* reading data from the buffer.
*/
uint8_t *data();
/**
* This function returns the size of available data for reading.
*/
size_t size() const;
/**
* This method shifts the available data for len. It is used when you read
* some data from the buffer and you want to remove it from the buffer.
*
* @param len the length of data that has to be removed from the start of
* the buffer
*/
void Shift(size_t len);
/**
* Allocates a new StreamBuffer from the internal buffer.
* This function returns a pointer to the first currently free memory
* location in the internal buffer. Also, it returns the size of the
* available memory.
*/
io::network::StreamBuffer Allocate();
/**
* This method is used to notify the buffer that the data has been written.
* To write data to this buffer you should do this:
* Call Allocate(), then write to the returned data pointer.
* IMPORTANT: Don't write more data then the returned size, you will cause
* a memory overflow. Then call Written(size) with the length of data that
* you have written into the buffer.
*
* @param len the size of data that has been written into the buffer
*/
void Written(size_t len);
/**
* This method resizes the internal data buffer.
* It is used to notify the buffer of the incoming message size.
* If the requested size is larger than the buffer size then the buffer is
* resized, if the requested size is smaller than the buffer size then
* nothing is done.
*
* @param len the desired size of the buffer
*/
void Resize(size_t len);
/**
* This method clears the buffer. It doesn't release the underlying storage
* space.
*/
void Clear();
std::vector<uint8_t> data_;
size_t have_{0};
ReadEnd read_end_;
WriteEnd write_end_;
};
} // namespace communication

View File

@ -10,6 +10,7 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include "communication/session.hpp"
#include "io/network/epoll.hpp" #include "io/network/epoll.hpp"
#include "io/network/socket.hpp" #include "io/network/socket.hpp"
#include "threading/sync/spinlock.hpp" #include "threading/sync/spinlock.hpp"
@ -35,11 +36,16 @@ class Listener {
// can take a long time. // can take a long time.
static const int kMaxEvents = 1; static const int kMaxEvents = 1;
using SessionHandler = Session<TSession, TSessionData>;
public: public:
Listener(TSessionData &data, bool check_for_timeouts, Listener(TSessionData &data, int inactivity_timeout_sec,
const std::string &service_name) const std::string &service_name)
: data_(data), alive_(true) { : data_(data),
if (check_for_timeouts) { alive_(true),
inactivity_timeout_sec_(inactivity_timeout_sec),
service_name_(service_name) {
if (inactivity_timeout_sec_ > 0) {
thread_ = std::thread([this, service_name]() { thread_ = std::thread([this, service_name]() {
utils::ThreadSetName(fmt::format("{} timeout", service_name)); utils::ThreadSetName(fmt::format("{} timeout", service_name));
while (alive_) { while (alive_) {
@ -47,7 +53,7 @@ class Listener {
std::unique_lock<SpinLock> guard(lock_); std::unique_lock<SpinLock> guard(lock_);
for (auto &session : sessions_) { for (auto &session : sessions_) {
if (session->TimedOut()) { if (session->TimedOut()) {
LOG(WARNING) << "Session associated with " LOG(WARNING) << service_name << " session associated with "
<< session->socket().endpoint() << " timed out."; << session->socket().endpoint() << " timed out.";
// Here we shutdown the socket to terminate any leftover // Here we shutdown the socket to terminate any leftover
// blocking `Write` calls and to signal an event that the // blocking `Write` calls and to signal an event that the
@ -94,8 +100,8 @@ class Listener {
int fd = connection.fd(); int fd = connection.fd();
// Create a new Session for the connection. // Create a new Session for the connection.
sessions_.push_back( sessions_.push_back(std::make_unique<SessionHandler>(
std::make_unique<TSession>(std::move(connection), data_)); std::move(connection), data_, inactivity_timeout_sec_));
// Register the connection in Epoll. // Register the connection in Epoll.
// We want to listen to an incoming event which is edge triggered and // We want to listen to an incoming event which is edge triggered and
@ -130,7 +136,8 @@ class Listener {
// dereference it here. It is safe to dereference the pointer because // dereference it here. It is safe to dereference the pointer because
// this design guarantees that there will never be an event that has // this design guarantees that there will never be an event that has
// a stale Session pointer. // a stale Session pointer.
TSession &session = *reinterpret_cast<TSession *>(event.data.ptr); SessionHandler &session =
*reinterpret_cast<SessionHandler *>(event.data.ptr);
// Process epoll events. We use epoll in edge-triggered mode so we process // Process epoll events. We use epoll in edge-triggered mode so we process
// all events here. Only one of the `if` statements must be executed // all events here. Only one of the `if` statements must be executed
@ -139,82 +146,56 @@ class Listener {
// segfault. // segfault.
if (event.events & EPOLLIN) { if (event.events & EPOLLIN) {
// Read and process all incoming data. // Read and process all incoming data.
while (ReadAndProcessSession(session)) while (ExecuteSession(session))
; ;
} else if (event.events & EPOLLRDHUP) { } else if (event.events & EPOLLRDHUP) {
// The client closed the connection. // The client closed the connection.
LOG(INFO) << "Client " << session.socket().endpoint() LOG(INFO) << service_name_ << " client " << session.socket().endpoint()
<< " closed the connection."; << " closed the connection.";
CloseSession(session); CloseSession(session);
} else if (!(event.events & EPOLLIN) || } else if (!(event.events & EPOLLIN) ||
event.events & (EPOLLHUP | EPOLLERR)) { event.events & (EPOLLHUP | EPOLLERR)) {
// There was an error on the server side. // There was an error on the server side.
LOG(ERROR) << "Error occured in session associated with " LOG(ERROR) << "Error occured in " << service_name_
<< session.socket().endpoint(); << " session associated with " << session.socket().endpoint();
CloseSession(session); CloseSession(session);
} else { } else {
// Unhandled epoll event. // Unhandled epoll event.
LOG(ERROR) << "Unhandled event occured in session associated with " LOG(ERROR) << "Unhandled event occured in " << service_name_
<< session.socket().endpoint() << " events: " << event.events; << " session associated with " << session.socket().endpoint()
<< " events: " << event.events;
CloseSession(session); CloseSession(session);
} }
} }
private: private:
bool ReadAndProcessSession(TSession &session) { bool ExecuteSession(SessionHandler &session) {
// Refresh the last event time in the session. try {
// This function must be implemented thread safe. if (session.Execute()) {
session.RefreshLastEventTime(std::chrono::steady_clock::now()); // Session execution done, rearm epoll to send events for this
// socket.
// Allocate the buffer to fill the data.
auto buf = session.Allocate();
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
int len = session.socket().Read(buf.data, buf.len, true);
// Check for read errors.
if (len == -1) {
// This means read would block or read was interrupted by signal, we
// return `false` to stop reading of data.
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
// Rearm epoll to send events from this socket.
epoll_.Modify(session.socket().fd(), epoll_.Modify(session.socket().fd(),
EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, &session); EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, &session);
return false; return false;
} }
// Some other error occurred, close the session. } catch (const SessionClosedException &e) {
CloseSession(session); LOG(INFO) << service_name_ << " client " << session.socket().endpoint()
return false;
}
// The client has closed the connection.
if (len == 0) {
LOG(INFO) << "Client " << session.socket().endpoint()
<< " closed the connection."; << " closed the connection.";
CloseSession(session); CloseSession(session);
return false; return false;
}
// Notify the session that it has new data.
session.Written(len);
// Execute the session.
try {
session.Execute();
session.RefreshLastEventTime(std::chrono::steady_clock::now());
} catch (const std::exception &e) { } catch (const std::exception &e) {
// Catch all exceptions. // Catch all exceptions.
LOG(ERROR) << "Exception was thrown while processing event in session " LOG(ERROR) << "Exception was thrown while processing event in "
"associated with " << service_name_ << " session associated with "
<< session.socket().endpoint() << session.socket().endpoint()
<< " with message: " << e.what(); << " with message: " << e.what();
CloseSession(session); CloseSession(session);
return false; return false;
} }
return true; return true;
} }
void CloseSession(TSession &session) { void CloseSession(SessionHandler &session) {
// Deregister the Session's socket from epoll to disable further events. For // Deregister the Session's socket from epoll to disable further events. For
// a detailed description why this is necessary before destroying (closing) // a detailed description why this is necessary before destroying (closing)
// the socket, see: // the socket, see:
@ -222,9 +203,8 @@ class Listener {
epoll_.Delete(session.socket().fd()); epoll_.Delete(session.socket().fd());
std::unique_lock<SpinLock> guard(lock_); std::unique_lock<SpinLock> guard(lock_);
auto it = auto it = std::find_if(sessions_.begin(), sessions_.end(),
std::find_if(sessions_.begin(), sessions_.end(), [&](const auto &l) { return l.get() == &session; });
[&](const auto &l) { return l->Id() == session.Id(); });
CHECK(it != sessions_.end()) CHECK(it != sessions_.end())
<< "Trying to remove session that is not found in sessions!"; << "Trying to remove session that is not found in sessions!";
@ -241,9 +221,11 @@ class Listener {
TSessionData &data_; TSessionData &data_;
SpinLock lock_; SpinLock lock_;
std::vector<std::unique_ptr<TSession>> sessions_; std::vector<std::unique_ptr<SessionHandler>> sessions_;
std::thread thread_; std::thread thread_;
std::atomic<bool> alive_; std::atomic<bool> alive_;
const int inactivity_timeout_sec_;
const std::string service_name_;
}; };
} // namespace communication } // namespace communication

View File

@ -13,29 +13,33 @@
namespace communication::rpc { namespace communication::rpc {
Session::Session(Socket &&socket, Server &server) Session::Session(Server &server, communication::InputStream &input_stream,
: socket_(std::move(socket)), server_(server) {} communication::OutputStream &output_stream)
: server_(server),
input_stream_(input_stream),
output_stream_(output_stream) {}
void Session::Execute() { void Session::Execute() {
if (buffer_.size() < sizeof(MessageSize)) return; if (input_stream_.size() < sizeof(MessageSize)) return;
MessageSize request_len = *reinterpret_cast<MessageSize *>(buffer_.data()); MessageSize request_len =
*reinterpret_cast<MessageSize *>(input_stream_.data());
uint64_t request_size = sizeof(MessageSize) + request_len; uint64_t request_size = sizeof(MessageSize) + request_len;
buffer_.Resize(request_size); input_stream_.Resize(request_size);
if (buffer_.size() < request_size) return; if (input_stream_.size() < request_size) return;
// Read the request message. // Read the request message.
std::unique_ptr<Message> request([this, request_len]() { std::unique_ptr<Message> request([this, request_len]() {
Message *req_ptr = nullptr; Message *req_ptr = nullptr;
std::stringstream stream(std::ios_base::in | std::ios_base::binary); std::stringstream stream(std::ios_base::in | std::ios_base::binary);
stream.str(std::string( stream.str(std::string(
reinterpret_cast<char *>(buffer_.data() + sizeof(MessageSize)), reinterpret_cast<char *>(input_stream_.data() + sizeof(MessageSize)),
request_len)); request_len));
boost::archive::binary_iarchive archive(stream); boost::archive::binary_iarchive archive(stream);
// Sent from client.cpp // Sent from client.cpp
archive >> req_ptr; archive >> req_ptr;
return req_ptr; return req_ptr;
}()); }());
buffer_.Shift(sizeof(MessageSize) + request_len); input_stream_.Shift(sizeof(MessageSize) + request_len);
auto callbacks_accessor = server_.callbacks_.access(); auto callbacks_accessor = server_.callbacks_.access();
auto it = callbacks_accessor.find(request->type_index()); auto it = callbacks_accessor.find(request->type_index());
@ -71,12 +75,12 @@ void Session::Execute() {
buffer.size(), std::numeric_limits<MessageSize>::max())); buffer.size(), std::numeric_limits<MessageSize>::max()));
} }
MessageSize buffer_size = buffer.size(); MessageSize input_stream_size = buffer.size();
if (!socket_.Write(reinterpret_cast<uint8_t *>(&buffer_size), if (!output_stream_.Write(reinterpret_cast<uint8_t *>(&input_stream_size),
sizeof(MessageSize), true)) { sizeof(MessageSize), true)) {
throw SessionException("Couldn't send response size!"); throw SessionException("Couldn't send response size!");
} }
if (!socket_.Write(buffer)) { if (!output_stream_.Write(buffer)) {
throw SessionException("Couldn't send response data!"); throw SessionException("Couldn't send response data!");
} }
@ -85,8 +89,4 @@ void Session::Execute() {
LOG(INFO) << "[RpcServer] sent " << (res_type ? res_type.value() : ""); LOG(INFO) << "[RpcServer] sent " << (res_type ? res_type.value() : "");
} }
} }
StreamBuffer Session::Allocate() { return buffer_.Allocate(); }
void Session::Written(size_t len) { buffer_.Written(len); }
} // namespace communication::rpc } // namespace communication::rpc

View File

@ -4,8 +4,8 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include "communication/rpc/buffer.hpp"
#include "communication/rpc/messages.hpp" #include "communication/rpc/messages.hpp"
#include "communication/session.hpp"
#include "io/network/endpoint.hpp" #include "io/network/endpoint.hpp"
#include "io/network/socket.hpp" #include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp" #include "io/network/stream_buffer.hpp"
@ -43,50 +43,20 @@ class SessionException : public utils::BasicException {
*/ */
class Session { class Session {
public: public:
Session(Socket &&socket, Server &server); Session(Server &server, communication::InputStream &input_stream,
communication::OutputStream &output_stream);
int Id() const { return socket_.fd(); }
/** /**
* Executes the protocol after data has been read into the buffer. * Executes the protocol after data has been read into the stream.
* Goes through the protocol states in order to execute commands from the * Goes through the protocol states in order to execute commands from the
* client. * client.
*/ */
void Execute(); void Execute();
/**
* Allocates data from the internal buffer.
* Used in the underlying network stack to asynchronously read data
* from the client.
* @returns a StreamBuffer to the allocated internal data buffer
*/
StreamBuffer Allocate();
/**
* Notifies the internal buffer of written data.
* Used in the underlying network stack to notify the internal buffer
* how many bytes of data have been written.
* @param len how many data was written to the buffer
*/
void Written(size_t len);
bool TimedOut() { return false; }
Socket &socket() { return socket_; }
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
last_event_time_ = last_event_time;
}
private: private:
Socket socket_;
std::chrono::time_point<std::chrono::steady_clock> last_event_time_ =
std::chrono::steady_clock::now();
Server &server_; Server &server_;
communication::InputStream &input_stream_;
Buffer buffer_; communication::OutputStream &output_stream_;
}; };
} // namespace communication::rpc } // namespace communication::rpc

View File

@ -11,7 +11,7 @@ namespace communication::rpc {
Server::Server(const io::network::Endpoint &endpoint, Server::Server(const io::network::Endpoint &endpoint,
size_t workers_count) size_t workers_count)
: server_(endpoint, *this, false, "RPC", workers_count) {} : server_(endpoint, *this, -1, "RPC", workers_count) {}
void Server::StopProcessingCalls() { void Server::StopProcessingCalls() {
server_.Shutdown(); server_.Shutdown();

View File

@ -43,9 +43,10 @@ class Server {
* invokes workers_count workers * invokes workers_count workers
*/ */
Server(const io::network::Endpoint &endpoint, TSessionData &session_data, Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
bool check_for_timeouts, const std::string &service_name, int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count = std::thread::hardware_concurrency()) size_t workers_count = std::thread::hardware_concurrency())
: listener_(session_data, check_for_timeouts, service_name) { : listener_(session_data, inactivity_timeout_sec, service_name),
service_name_(service_name) {
// Without server we can't continue with application so we can just // Without server we can't continue with application so we can just
// terminate here. // terminate here.
if (!socket_.Bind(endpoint)) { if (!socket_.Bind(endpoint)) {
@ -120,7 +121,8 @@ class Server {
// Connection is not available anymore or configuration failed. // Connection is not available anymore or configuration failed.
return; return;
} }
LOG(INFO) << "Accepted a connection from " << s->endpoint(); LOG(INFO) << "Accepted a " << service_name_ << " connection from "
<< s->endpoint();
listener_.AddConnection(std::move(*s)); listener_.AddConnection(std::move(*s));
} }
@ -130,6 +132,8 @@ class Server {
Socket socket_; Socket socket_;
Listener<TSession, TSessionData> listener_; Listener<TSession, TSessionData> listener_;
const std::string service_name_;
}; };
} // namespace communication } // namespace communication

View File

@ -0,0 +1,161 @@
#pragma once
#include <algorithm>
#include <atomic>
#include <chrono>
#include <memory>
#include <mutex>
#include <thread>
#include <glog/logging.h>
#include "communication/buffer.hpp"
#include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp"
#include "threading/sync/spinlock.hpp"
#include "utils/exceptions.hpp"
namespace communication {
/**
* This exception is thrown to indicate to the communication stack that the
* session is closed and that cleanup should be performed.
*/
class SessionClosedException : public utils::BasicException {
using utils::BasicException::BasicException;
};
/**
* This is used to provide input to user sessions. All sessions used with the
* network stack should use this class as their input stream.
*/
using InputStream = Buffer::ReadEnd;
/**
* This is used to provide output from user sessions. All sessions used with the
* network stack should use this class for their output stream.
*/
class OutputStream {
public:
OutputStream(io::network::Socket &socket) : socket_(socket) {}
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
return socket_.Write(data, len, have_more);
}
bool Write(const std::string &str, bool have_more = false) {
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
have_more);
}
private:
io::network::Socket &socket_;
};
/**
* This class is used internally in the communication stack to handle all user
* sessions. It handles socket ownership, inactivity timeout and protocol
* wrapping.
*/
template <class TSession, class TSessionData>
class Session {
public:
Session(io::network::Socket &&socket, TSessionData &data,
int inactivity_timeout_sec)
: socket_(std::move(socket)),
output_stream_(socket_),
session_(data, input_buffer_.read_end(), output_stream_),
inactivity_timeout_sec_(inactivity_timeout_sec) {}
Session(const Session &) = delete;
Session(Session &&) = delete;
Session &operator=(const Session &) = delete;
Session &operator=(Session &&) = delete;
/**
* This function is called from the communication stack when an event occurs
* indicating that there is data waiting to be read. This function calls the
* `Execute` method from the supplied `TSession` and handles all things
* necessary before the execution (eg. reading data from network, protocol
* encapsulation, etc.). This function returns `true` if the session is done
* with execution (when all data is read and all processing is done). It
* returns `false` when there is more data that should be read and processed.
*/
bool Execute() {
// Refresh the last event time in the session.
RefreshLastEventTime();
// Allocate the buffer to fill the data.
auto buf = input_buffer_.write_end().Allocate();
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
int len = socket_.Read(buf.data, buf.len, true);
// Check for read errors.
if (len == -1) {
// This means read would block or read was interrupted by signal, we
// return `true` to indicate that all data is processad and to stop
// reading of data.
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
return true;
}
// Some other error occurred, throw an exception to start session cleanup.
throw utils::BasicException("Couldn't read data from socket!");
}
// The client has closed the connection.
if (len == 0) {
throw SessionClosedException("Session was closed by client.");
}
// Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len);
// Execute the session.
session_.Execute();
// Refresh the last event time.
RefreshLastEventTime();
return false;
}
/**
* Returns true if session has timed out. Session times out if there was no
* activity in inactivity_timeout_sec seconds. This function must be thread
* safe because this function and `RefreshLastEventTime` are called from
* different threads in the network stack.
*/
bool TimedOut() {
std::unique_lock<SpinLock> guard(lock_);
return last_event_time_ + std::chrono::seconds(inactivity_timeout_sec_) <
std::chrono::steady_clock::now();
}
/**
* Returns a reference to the internal socket.
*/
io::network::Socket &socket() { return socket_; }
private:
void RefreshLastEventTime() {
std::unique_lock<SpinLock> guard(lock_);
last_event_time_ = std::chrono::steady_clock::now();
}
// We own the socket.
io::network::Socket socket_;
// Input and output buffers/streams.
Buffer input_buffer_;
OutputStream output_stream_;
// Session that will be executed.
TSession session_;
// Time of the last event and associated lock.
std::chrono::time_point<std::chrono::steady_clock> last_event_time_{
std::chrono::steady_clock::now()};
SpinLock lock_;
const int inactivity_timeout_sec_;
};
} // namespace communication

View File

@ -28,7 +28,8 @@ namespace fs = std::experimental::filesystem;
using communication::bolt::SessionData; using communication::bolt::SessionData;
using io::network::Endpoint; using io::network::Endpoint;
using io::network::Socket; using io::network::Socket;
using SessionT = communication::bolt::Session<Socket>; using SessionT = communication::bolt::Session<communication::InputStream,
communication::OutputStream>;
using ServerT = communication::Server<SessionT, SessionData>; using ServerT = communication::Server<SessionT, SessionData>;
// General purpose flags. // General purpose flags.
@ -39,6 +40,10 @@ DEFINE_VALIDATED_int32(port, 7687, "Communication port on which to listen.",
DEFINE_VALIDATED_int32(num_workers, DEFINE_VALIDATED_int32(num_workers,
std::max(std::thread::hardware_concurrency(), 1U), std::max(std::thread::hardware_concurrency(), 1U),
"Number of workers (Bolt)", FLAG_IN_RANGE(1, INT32_MAX)); "Number of workers (Bolt)", FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800,
"Time in seconds after which inactive sessions will be "
"closed.",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_string(log_file, "", "Path to where the log should be stored."); DEFINE_string(log_file, "", "Path to where the log should be stored.");
DEFINE_HIDDEN_string( DEFINE_HIDDEN_string(
log_link_basename, "", log_link_basename, "",
@ -95,7 +100,8 @@ void MasterMain() {
database::Master db; database::Master db;
SessionData session_data{db}; SessionData session_data{db};
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)}, ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, true, "Bolt", FLAGS_num_workers); session_data, FLAGS_session_inactivity_timeout, "Bolt",
FLAGS_num_workers);
// Handler for regular termination signals // Handler for regular termination signals
auto shutdown = [&server] { auto shutdown = [&server] {
@ -121,7 +127,8 @@ void SingleNodeMain() {
database::SingleNode db; database::SingleNode db;
SessionData session_data{db}; SessionData session_data{db};
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)}, ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, true, "Bolt", FLAGS_num_workers); session_data, FLAGS_session_inactivity_timeout, "Bolt",
FLAGS_num_workers);
// Handler for regular termination signals // Handler for regular termination signals
auto shutdown = [&server] { auto shutdown = [&server] {

View File

@ -10,11 +10,8 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/server.hpp" #include "communication/server.hpp"
#include "database/graph_db_accessor.hpp" #include "database/graph_db_accessor.hpp"
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
static constexpr const int SIZE = 60000; static constexpr const int SIZE = 60000;
static constexpr const int REPLY = 10; static constexpr const int REPLY = 10;
@ -26,44 +23,27 @@ class TestData {};
class TestSession { class TestSession {
public: public:
TestSession(Socket &&socket, TestData &) : socket_(std::move(socket)) { TestSession(TestData &, communication::InputStream &input_stream,
event_.data.ptr = this; communication::OutputStream &output_stream)
} : input_stream_(input_stream), output_stream_(output_stream) {}
bool TimedOut() const { return false; }
int Id() const { return socket_.fd(); }
void Execute() { void Execute() {
if (buffer_.size() < 2) return; if (input_stream_.size() < 2) return;
const uint8_t *data = buffer_.data(); const uint8_t *data = input_stream_.data();
size_t size = data[0]; size_t size = data[0];
size <<= 8; size <<= 8;
size += data[1]; size += data[1];
if (buffer_.size() < size + 2) return; input_stream_.Resize(size + 2);
if (input_stream_.size() < size + 2) return;
for (int i = 0; i < REPLY; ++i) for (int i = 0; i < REPLY; ++i)
ASSERT_TRUE(this->socket_.Write(data + 2, size)); ASSERT_TRUE(output_stream_.Write(data + 2, size));
buffer_.Shift(size + 2); input_stream_.Shift(size + 2);
} }
io::network::StreamBuffer Allocate() { return buffer_.Allocate(); } communication::InputStream input_stream_;
communication::OutputStream output_stream_;
void Written(size_t len) { buffer_.Written(len); }
Socket &socket() { return socket_; }
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
last_event_time_ = last_event_time;
}
communication::bolt::Buffer<SIZE * 2> buffer_;
Socket socket_;
io::network::Epoll::Event event_;
std::chrono::time_point<std::chrono::steady_clock> last_event_time_;
}; };
using ServerT = communication::Server<TestSession, TestData>; using ServerT = communication::Server<TestSession, TestData>;

View File

@ -12,11 +12,8 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/server.hpp" #include "communication/server.hpp"
#include "database/graph_db_accessor.hpp" #include "database/graph_db_accessor.hpp"
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
static constexpr const char interface[] = "127.0.0.1"; static constexpr const char interface[] = "127.0.0.1";
@ -27,32 +24,16 @@ class TestData {};
class TestSession { class TestSession {
public: public:
TestSession(Socket &&socket, TestData &) : socket_(std::move(socket)) { TestSession(TestData &, communication::InputStream &input_stream,
event_.data.ptr = this; communication::OutputStream &output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() {
output_stream_.Write(input_stream_.data(), input_stream_.size());
} }
bool TimedOut() const { return false; } communication::InputStream input_stream_;
communication::OutputStream output_stream_;
int Id() const { return socket_.fd(); }
void Execute() { this->socket_.Write(buffer_.data(), buffer_.size()); }
io::network::StreamBuffer Allocate() { return buffer_.Allocate(); }
void Written(size_t len) { buffer_.Written(len); }
Socket &socket() { return socket_; }
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
last_event_time_ = last_event_time;
}
Socket socket_;
communication::bolt::Buffer<> buffer_;
io::network::Epoll::Event event_;
std::chrono::time_point<std::chrono::steady_clock> last_event_time_;
}; };
std::atomic<bool> run{true}; std::atomic<bool> run{true};
@ -82,7 +63,8 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
TestData data; TestData data;
int N = (std::thread::hardware_concurrency() + 1) / 2; int N = (std::thread::hardware_concurrency() + 1) / 2;
int Nc = N * 3; int Nc = N * 3;
communication::Server<TestSession, TestData> server(endpoint, data, false, "Test", N); communication::Server<TestSession, TestData> server(endpoint, data, -1,
"Test", N);
const auto &ep = server.endpoint(); const auto &ep = server.endpoint();
// start clients // start clients

View File

@ -21,7 +21,7 @@ TEST(Network, Server) {
// initialize server // initialize server
TestData session_data; TestData session_data;
int N = (std::thread::hardware_concurrency() + 1) / 2; int N = (std::thread::hardware_concurrency() + 1) / 2;
ServerT server(endpoint, session_data, false, "Test", N); ServerT server(endpoint, session_data, -1, "Test", N);
const auto &ep = server.endpoint(); const auto &ep = server.endpoint();
// start clients // start clients

View File

@ -22,7 +22,7 @@ TEST(Network, SessionLeak) {
// initialize server // initialize server
TestData session_data; TestData session_data;
ServerT server(endpoint, session_data, false, "Test", 2); ServerT server(endpoint, session_data, -1, "Test", 2);
// start clients // start clients
int N = 50; int N = 50;

View File

@ -7,7 +7,7 @@ uint8_t data[SIZE];
using BufferT = communication::bolt::Buffer<>; using BufferT = communication::bolt::Buffer<>;
using StreamBufferT = io::network::StreamBuffer; using StreamBufferT = io::network::StreamBuffer;
using DecoderBufferT = communication::bolt::ChunkedDecoderBuffer; using DecoderBufferT = communication::bolt::ChunkedDecoderBuffer<BufferT>;
using ChunkStateT = communication::bolt::ChunkState; using ChunkStateT = communication::bolt::ChunkState;
TEST(BoltBuffer, CorrectChunk) { TEST(BoltBuffer, CorrectChunk) {

View File

@ -2,12 +2,11 @@
#include "communication/bolt/v1/encoder/chunked_encoder_buffer.hpp" #include "communication/bolt/v1/encoder/chunked_encoder_buffer.hpp"
// aliases // aliases
using SocketT = TestSocket; using BufferT = communication::bolt::ChunkedEncoderBuffer<TestOutputStream>;
using BufferT = communication::bolt::ChunkedEncoderBuffer<SocketT>;
// constants // constants
using communication::bolt::CHUNK_HEADER_SIZE;
using communication::bolt::CHUNK_END_MARKER_SIZE; using communication::bolt::CHUNK_END_MARKER_SIZE;
using communication::bolt::CHUNK_HEADER_SIZE;
using communication::bolt::MAX_CHUNK_SIZE; using communication::bolt::MAX_CHUNK_SIZE;
using communication::bolt::WHOLE_CHUNK_SIZE; using communication::bolt::WHOLE_CHUNK_SIZE;
@ -53,8 +52,8 @@ TEST(BoltChunkedEncoderBuffer, OneSmallChunk) {
int size = 100; int size = 100;
// initialize tested buffer // initialize tested buffer
SocketT socket(10); TestOutputStream output_stream;
BufferT buffer(socket); BufferT buffer(output_stream);
// write into buffer // write into buffer
buffer.Write(test_data, size); buffer.Write(test_data, size);
@ -62,7 +61,7 @@ TEST(BoltChunkedEncoderBuffer, OneSmallChunk) {
// check the output array // check the output array
// the array should look like: [0, 100, first 100 bytes of test data, 0, 0] // the array should look like: [0, 100, first 100 bytes of test data, 0, 0]
VerifyChunkOfTestData(socket.output.data(), size); VerifyChunkOfTestData(output_stream.output.data(), size);
} }
TEST(BoltChunkedEncoderBuffer, TwoSmallChunks) { TEST(BoltChunkedEncoderBuffer, TwoSmallChunks) {
@ -70,8 +69,8 @@ TEST(BoltChunkedEncoderBuffer, TwoSmallChunks) {
int size2 = 200; int size2 = 200;
// initialize tested buffer // initialize tested buffer
SocketT socket(10); TestOutputStream output_stream;
BufferT buffer(socket); BufferT buffer(output_stream);
// write into buffer // write into buffer
buffer.Write(test_data, size1); buffer.Write(test_data, size1);
@ -83,7 +82,7 @@ TEST(BoltChunkedEncoderBuffer, TwoSmallChunks) {
// the output array should look like this: // the output array should look like this:
// [0, 100, first 100 bytes of test data, 0, 0] + // [0, 100, first 100 bytes of test data, 0, 0] +
// [0, 100, second 100 bytes of test data, 0, 0] // [0, 100, second 100 bytes of test data, 0, 0]
auto data = socket.output.data(); auto data = output_stream.output.data();
VerifyChunkOfTestData(data, size1); VerifyChunkOfTestData(data, size1);
VerifyChunkOfTestData( VerifyChunkOfTestData(
data + CHUNK_HEADER_SIZE + size1 + CHUNK_END_MARKER_SIZE, size2, size1); data + CHUNK_HEADER_SIZE + size1 + CHUNK_END_MARKER_SIZE, size2, size1);
@ -91,8 +90,8 @@ TEST(BoltChunkedEncoderBuffer, TwoSmallChunks) {
TEST(BoltChunkedEncoderBuffer, OneAndAHalfOfMaxChunk) { TEST(BoltChunkedEncoderBuffer, OneAndAHalfOfMaxChunk) {
// initialize tested buffer // initialize tested buffer
SocketT socket(10); TestOutputStream output_stream;
BufferT buffer(socket); BufferT buffer(output_stream);
// write into buffer // write into buffer
buffer.Write(test_data, TEST_DATA_SIZE); buffer.Write(test_data, TEST_DATA_SIZE);
@ -102,7 +101,7 @@ TEST(BoltChunkedEncoderBuffer, OneAndAHalfOfMaxChunk) {
// the output array should look like this: // the output array should look like this:
// [0xFF, 0xFF, first 65535 bytes of test data, // [0xFF, 0xFF, first 65535 bytes of test data,
// 0x86, 0xA1, 34465 bytes of test data after the first 65535 bytes, 0, 0] // 0x86, 0xA1, 34465 bytes of test data after the first 65535 bytes, 0, 0]
auto output = socket.output.data(); auto output = output_stream.output.data();
VerifyChunkOfTestData(output, MAX_CHUNK_SIZE, 0, false); VerifyChunkOfTestData(output, MAX_CHUNK_SIZE, 0, false);
VerifyChunkOfTestData(output + WHOLE_CHUNK_SIZE, VerifyChunkOfTestData(output + WHOLE_CHUNK_SIZE,
TEST_DATA_SIZE - MAX_CHUNK_SIZE, MAX_CHUNK_SIZE); TEST_DATA_SIZE - MAX_CHUNK_SIZE, MAX_CHUNK_SIZE);

View File

@ -12,18 +12,39 @@
/** /**
* TODO (mferencevic): document * TODO (mferencevic): document
*/ */
class TestSocket { class TestInputStream {
public: public:
explicit TestSocket(int socket) : socket_(socket) {} uint8_t *data() { return data_.data(); }
TestSocket(const TestSocket &) = default;
TestSocket &operator=(const TestSocket &) = default;
TestSocket(TestSocket &&) = default;
TestSocket &operator=(TestSocket &&) = default;
int id() const { return socket_; } size_t size() { return data_.size(); }
bool Write(const uint8_t *data, size_t len, bool have_more = false, void Clear() { data_.clear(); }
const std::function<bool()> & = [] { return false; }) {
void Write(const uint8_t *data, size_t len) {
for (size_t i = 0; i < len; ++i) {
data_.push_back(data[i]);
}
}
void Write(const char *data, size_t len) {
Write(reinterpret_cast<const uint8_t *>(data), len);
}
void Shift(size_t count) {
CHECK(count <= data_.size());
data_.erase(data_.begin(), data_.begin() + count);
}
private:
std::vector<uint8_t> data_;
};
/**
* TODO (mferencevic): document
*/
class TestOutputStream {
public:
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
if (!write_success_) return false; if (!write_success_) return false;
for (size_t i = 0; i < len; ++i) output.push_back(data[i]); for (size_t i = 0; i < len; ++i) output.push_back(data[i]);
return true; return true;
@ -34,7 +55,6 @@ class TestSocket {
std::vector<uint8_t> output; std::vector<uint8_t> output;
protected: protected:
int socket_;
bool write_success_{true}; bool write_success_{true};
}; };
@ -43,14 +63,14 @@ class TestSocket {
*/ */
class TestBuffer { class TestBuffer {
public: public:
TestBuffer(TestSocket &socket) : socket_(socket) {} TestBuffer(TestOutputStream &output_stream) : output_stream_(output_stream) {}
void Write(const uint8_t *data, size_t n) { socket_.Write(data, n); } void Write(const uint8_t *data, size_t n) { output_stream_.Write(data, n); }
void Chunk() {} void Chunk() {}
bool Flush() { return true; } bool Flush() { return true; }
private: private:
TestSocket &socket_; TestOutputStream &output_stream_;
}; };
/** /**

View File

@ -45,10 +45,10 @@ void CheckTypeSize(std::vector<uint8_t> &v, int typ, uint64_t size) {
} }
void CheckInt(std::vector<uint8_t> &output, int64_t value) { void CheckInt(std::vector<uint8_t> &output, int64_t value) {
TestSocket test_socket(20); TestOutputStream output_stream;
TestBuffer encoder_buffer(test_socket); TestBuffer encoder_buffer(output_stream);
communication::bolt::BaseEncoder<TestBuffer> bolt_encoder(encoder_buffer); communication::bolt::BaseEncoder<TestBuffer> bolt_encoder(encoder_buffer);
std::vector<uint8_t> &encoded = test_socket.output; std::vector<uint8_t> &encoded = output_stream.output;
bolt_encoder.WriteInt(value); bolt_encoder.WriteInt(value);
CheckOutput(output, encoded.data(), encoded.size(), false); CheckOutput(output, encoded.data(), encoded.size(), false);
} }
@ -58,10 +58,10 @@ void CheckRecordHeader(std::vector<uint8_t> &v, uint64_t size) {
CheckTypeSize(v, LIST, size); CheckTypeSize(v, LIST, size);
} }
TestSocket test_socket(10); TestOutputStream output_stream;
TestBuffer encoder_buffer(test_socket); TestBuffer encoder_buffer(output_stream);
communication::bolt::Encoder<TestBuffer> bolt_encoder(encoder_buffer); communication::bolt::Encoder<TestBuffer> bolt_encoder(encoder_buffer);
std::vector<uint8_t> &output = test_socket.output; std::vector<uint8_t> &output = output_stream.output;
TEST(BoltEncoder, NullAndBool) { TEST(BoltEncoder, NullAndBool) {
output.clear(); output.clear();

View File

@ -7,7 +7,7 @@
using query::TypedValue; using query::TypedValue;
using BufferT = communication::bolt::ChunkedEncoderBuffer<TestSocket>; using BufferT = communication::bolt::ChunkedEncoderBuffer<TestOutputStream>;
using EncoderT = communication::bolt::Encoder<BufferT>; using EncoderT = communication::bolt::Encoder<BufferT>;
using ResultStreamT = communication::bolt::ResultStream<EncoderT>; using ResultStreamT = communication::bolt::ResultStream<EncoderT>;
@ -21,11 +21,11 @@ const uint8_t summary_output[] =
"\x00\x0C\xB1\x70\xA1\x87\x63\x68\x61\x6E\x67\x65\x64\x0A\x00\x00"; "\x00\x0C\xB1\x70\xA1\x87\x63\x68\x61\x6E\x67\x65\x64\x0A\x00\x00";
TEST(Bolt, ResultStream) { TEST(Bolt, ResultStream) {
TestSocket socket(10); TestOutputStream output_stream;
BufferT buffer(socket); BufferT buffer(output_stream);
EncoderT encoder(buffer); EncoderT encoder(buffer);
ResultStreamT result_stream(encoder); ResultStreamT result_stream(encoder);
std::vector<uint8_t> &output = socket.output; std::vector<uint8_t> &output = output_stream.output;
std::vector<std::string> headers; std::vector<std::string> headers;
for (int i = 0; i < 10; ++i) for (int i = 0; i < 10; ++i)

View File

@ -9,16 +9,18 @@
// TODO: This could be done in fixture. // TODO: This could be done in fixture.
// Shortcuts for writing variable initializations in tests // Shortcuts for writing variable initializations in tests
#define INIT_VARS \ #define INIT_VARS \
TestSocket socket(10); \ TestInputStream input_stream; \
TestOutputStream output_stream; \
database::SingleNode db; \ database::SingleNode db; \
SessionData session_data{db}; \ SessionData session_data{db}; \
SessionT session(std::move(socket), session_data); \ SessionT session(session_data, input_stream, output_stream); \
std::vector<uint8_t> &output = session.socket().output; std::vector<uint8_t> &output = output_stream.output;
using communication::bolt::SessionData; using communication::bolt::SessionData;
using communication::bolt::SessionException; using communication::bolt::SessionException;
using communication::bolt::State; using communication::bolt::State;
using SessionT = communication::bolt::Session<TestSocket>; using SessionT =
communication::bolt::Session<TestInputStream, TestOutputStream>;
using ResultStreamT = SessionT::ResultStreamT; using ResultStreamT = SessionT::ResultStreamT;
// Sample testdata that has correct inputs and outputs. // Sample testdata that has correct inputs and outputs.
@ -43,15 +45,15 @@ const uint8_t success_resp[] = {0x00, 0x03, 0xb1, 0x70, 0xa0, 0x00, 0x00};
const uint8_t ignored_resp[] = {0x00, 0x02, 0xb0, 0x7e, 0x00, 0x00}; const uint8_t ignored_resp[] = {0x00, 0x02, 0xb0, 0x7e, 0x00, 0x00};
// Write bolt chunk header (length) // Write bolt chunk header (length)
void WriteChunkHeader(SessionT &session, uint16_t len) { void WriteChunkHeader(TestInputStream &input_stream, uint16_t len) {
len = bswap(len); len = bswap(len);
auto buff = session.Allocate(); input_stream.Write(reinterpret_cast<uint8_t *>(&len), sizeof(len));
memcpy(buff.data, reinterpret_cast<uint8_t *>(&len), sizeof(len));
session.Written(sizeof(len));
} }
// Write bolt chunk tail (two zeros) // Write bolt chunk tail (two zeros)
void WriteChunkTail(SessionT &session) { WriteChunkHeader(session, 0); } void WriteChunkTail(TestInputStream &input_stream) {
WriteChunkHeader(input_stream, 0);
}
// Check that the server responded with a failure message. // Check that the server responded with a failure message.
void CheckFailureMessage(std::vector<uint8_t> &output) { void CheckFailureMessage(std::vector<uint8_t> &output) {
@ -83,71 +85,60 @@ void CheckIgnoreMessage(std::vector<uint8_t> &output) {
} }
// Execute and check a correct handshake // Execute and check a correct handshake
void ExecuteHandshake(SessionT &session, std::vector<uint8_t> &output) { void ExecuteHandshake(TestInputStream &input_stream, SessionT &session,
auto buff = session.Allocate(); std::vector<uint8_t> &output) {
memcpy(buff.data, handshake_req, 20); input_stream.Write(handshake_req, 20);
session.Written(20);
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Init); ASSERT_EQ(session.state_, State::Init);
PrintOutput(output); PrintOutput(output);
CheckOutput(output, handshake_resp, 4); CheckOutput(output, handshake_resp, 4);
} }
// Write bolt chunk and execute command // Write bolt chunk and execute command
void ExecuteCommand(SessionT &session, const uint8_t *data, size_t len, void ExecuteCommand(TestInputStream &input_stream, SessionT &session,
bool chunk = true) { const uint8_t *data, size_t len, bool chunk = true) {
if (chunk) WriteChunkHeader(session, len); if (chunk) WriteChunkHeader(input_stream, len);
auto buff = session.Allocate(); input_stream.Write(data, len);
memcpy(buff.data, data, len); if (chunk) WriteChunkTail(input_stream);
session.Written(len);
if (chunk) WriteChunkTail(session);
session.Execute(); session.Execute();
} }
// Execute and check a correct init // Execute and check a correct init
void ExecuteInit(SessionT &session, std::vector<uint8_t> &output) { void ExecuteInit(TestInputStream &input_stream, SessionT &session,
ExecuteCommand(session, init_req, sizeof(init_req)); std::vector<uint8_t> &output) {
ExecuteCommand(input_stream, session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, State::Idle); ASSERT_EQ(session.state_, State::Idle);
PrintOutput(output); PrintOutput(output);
CheckOutput(output, init_resp, 7); CheckOutput(output, init_resp, 7);
} }
// Write bolt encoded run request // Write bolt encoded run request
void WriteRunRequest(SessionT &session, const char *str) { void WriteRunRequest(TestInputStream &input_stream, const char *str) {
// write chunk header // write chunk header
auto len = strlen(str); auto len = strlen(str);
WriteChunkHeader(session, 3 + 2 + len + 1); WriteChunkHeader(input_stream, 3 + 2 + len + 1);
// write string header // write string header
auto buff = session.Allocate(); input_stream.Write(run_req_header, 3);
memcpy(buff.data, run_req_header, 3);
session.Written(3);
// write string length // write string length
WriteChunkHeader(session, len); WriteChunkHeader(input_stream, len);
// write string // write string
buff = session.Allocate(); input_stream.Write(str, len);
memcpy(buff.data, str, len);
session.Written(len);
// write empty map for parameters // write empty map for parameters
buff = session.Allocate(); input_stream.Write("\xA0", 1); // TinyMap0
buff.data[0] = 0xA0; // TinyMap0
session.Written(1);
// write chunk tail // write chunk tail
WriteChunkTail(session); WriteChunkTail(input_stream);
} }
TEST(BoltSession, HandshakeWrongPreamble) { TEST(BoltSession, HandshakeWrongPreamble) {
INIT_VARS; INIT_VARS;
auto buff = session.Allocate(); // write 0x00000001 five times
// copy 0x00000001 four times for (int i = 0; i < 5; ++i) input_stream.Write(handshake_req + 4, 4);
for (int i = 0; i < 4; ++i) memcpy(buff.data + i * 4, handshake_req + 4, 4);
session.Written(20);
ASSERT_THROW(session.Execute(), SessionException); ASSERT_THROW(session.Execute(), SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -158,15 +149,12 @@ TEST(BoltSession, HandshakeWrongPreamble) {
TEST(BoltSession, HandshakeInTwoPackets) { TEST(BoltSession, HandshakeInTwoPackets) {
INIT_VARS; INIT_VARS;
auto buff = session.Allocate(); input_stream.Write(handshake_req, 10);
memcpy(buff.data, handshake_req, 10);
session.Written(10);
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Handshake); ASSERT_EQ(session.state_, State::Handshake);
memcpy(buff.data + 10, handshake_req + 10, 10); input_stream.Write(handshake_req + 10, 10);
session.Written(10);
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Init); ASSERT_EQ(session.state_, State::Init);
@ -176,9 +164,9 @@ TEST(BoltSession, HandshakeInTwoPackets) {
TEST(BoltSession, HandshakeWriteFail) { TEST(BoltSession, HandshakeWriteFail) {
INIT_VARS; INIT_VARS;
session.socket().SetWriteSuccess(false); output_stream.SetWriteSuccess(false);
ASSERT_THROW( ASSERT_THROW(ExecuteCommand(input_stream, session, handshake_req,
ExecuteCommand(session, handshake_req, sizeof(handshake_req), false), sizeof(handshake_req), false),
SessionException); SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -187,13 +175,14 @@ TEST(BoltSession, HandshakeWriteFail) {
TEST(BoltSession, HandshakeOK) { TEST(BoltSession, HandshakeOK) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
} }
TEST(BoltSession, InitWrongSignature) { TEST(BoltSession, InitWrongSignature) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ASSERT_THROW(ExecuteCommand(session, run_req_header, sizeof(run_req_header)), ASSERT_THROW(ExecuteCommand(input_stream, session, run_req_header,
sizeof(run_req_header)),
SessionException); SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -202,11 +191,12 @@ TEST(BoltSession, InitWrongSignature) {
TEST(BoltSession, InitWrongMarker) { TEST(BoltSession, InitWrongMarker) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
// wrong marker, good signature // wrong marker, good signature
uint8_t data[2] = {0x00, init_req[1]}; uint8_t data[2] = {0x00, init_req[1]};
ASSERT_THROW(ExecuteCommand(session, data, 2), SessionException); ASSERT_THROW(ExecuteCommand(input_stream, session, data, 2),
SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
CheckFailureMessage(output); CheckFailureMessage(output);
@ -219,8 +209,9 @@ TEST(BoltSession, InitMissingData) {
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ASSERT_THROW(ExecuteCommand(session, init_req, len[i]), SessionException); ASSERT_THROW(ExecuteCommand(input_stream, session, init_req, len[i]),
SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
CheckFailureMessage(output); CheckFailureMessage(output);
@ -229,9 +220,10 @@ TEST(BoltSession, InitMissingData) {
TEST(BoltSession, InitWriteFail) { TEST(BoltSession, InitWriteFail) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
session.socket().SetWriteSuccess(false); output_stream.SetWriteSuccess(false);
ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)), ASSERT_THROW(
ExecuteCommand(input_stream, session, init_req, sizeof(init_req)),
SessionException); SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -240,19 +232,20 @@ TEST(BoltSession, InitWriteFail) {
TEST(BoltSession, InitOK) { TEST(BoltSession, InitOK) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
} }
TEST(BoltSession, ExecuteRunWrongMarker) { TEST(BoltSession, ExecuteRunWrongMarker) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
// wrong marker, good signature // wrong marker, good signature
uint8_t data[2] = {0x00, run_req_header[1]}; uint8_t data[2] = {0x00, run_req_header[1]};
ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException); ASSERT_THROW(ExecuteCommand(input_stream, session, data, sizeof(data)),
SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
CheckFailureMessage(output); CheckFailureMessage(output);
@ -265,9 +258,9 @@ TEST(BoltSession, ExecuteRunMissingData) {
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
ASSERT_THROW(ExecuteCommand(session, run_req_header, len[i]), ASSERT_THROW(ExecuteCommand(input_stream, session, run_req_header, len[i]),
SessionException); SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -280,11 +273,11 @@ TEST(BoltSession, ExecuteRunBasicException) {
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
session.socket().SetWriteSuccess(i == 0); output_stream.SetWriteSuccess(i == 0);
WriteRunRequest(session, "MATCH (omnom"); WriteRunRequest(input_stream, "MATCH (omnom");
if (i == 0) { if (i == 0) {
session.Execute(); session.Execute();
} else { } else {
@ -304,10 +297,10 @@ TEST(BoltSession, ExecuteRunBasicException) {
TEST(BoltSession, ExecuteRunWithoutPullAll) { TEST(BoltSession, ExecuteRunWithoutPullAll) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "RETURN 2"); WriteRunRequest(input_stream, "RETURN 2");
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Result); ASSERT_EQ(session.state_, State::Result);
@ -321,12 +314,13 @@ TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) {
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
// wrong marker, good signature // wrong marker, good signature
uint8_t data[2] = {0x00, dataset[i][1]}; uint8_t data[2] = {0x00, dataset[i][1]};
ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException); ASSERT_THROW(ExecuteCommand(input_stream, session, data, sizeof(data)),
SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
CheckFailureMessage(output); CheckFailureMessage(output);
@ -338,11 +332,12 @@ TEST(BoltSession, ExecutePullAllBufferEmpty) {
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
session.socket().SetWriteSuccess(i == 0); output_stream.SetWriteSuccess(i == 0);
ASSERT_THROW(ExecuteCommand(session, pullall_req, sizeof(pullall_req)), ASSERT_THROW(
ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)),
SessionException); SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -364,18 +359,19 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) {
for (int j = 0; j < 2; ++j) { for (int j = 0; j < 2; ++j) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "CREATE (n) RETURN n"); WriteRunRequest(input_stream, "CREATE (n) RETURN n");
session.Execute(); session.Execute();
if (j == 1) output.clear(); if (j == 1) output.clear();
session.socket().SetWriteSuccess(j == 0); output_stream.SetWriteSuccess(j == 0);
if (j == 0) { if (j == 0) {
ExecuteCommand(session, dataset[i], 2); ExecuteCommand(input_stream, session, dataset[i], 2);
} else { } else {
ASSERT_THROW(ExecuteCommand(session, dataset[i], 2), SessionException); ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2),
SessionException);
} }
if (j == 0) { if (j == 0) {
@ -393,9 +389,10 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) {
TEST(BoltSession, ExecuteInvalidMessage) { TEST(BoltSession, ExecuteInvalidMessage) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)), ASSERT_THROW(
ExecuteCommand(input_stream, session, init_req, sizeof(init_req)),
SessionException); SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -407,19 +404,20 @@ TEST(BoltSession, ErrorIgnoreMessage) {
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "MATCH (omnom"); WriteRunRequest(input_stream, "MATCH (omnom");
session.Execute(); session.Execute();
output.clear(); output.clear();
session.socket().SetWriteSuccess(i == 0); output_stream.SetWriteSuccess(i == 0);
if (i == 0) { if (i == 0) {
ExecuteCommand(session, init_req, sizeof(init_req)); ExecuteCommand(input_stream, session, init_req, sizeof(init_req));
} else { } else {
ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)), ASSERT_THROW(
ExecuteCommand(input_stream, session, init_req, sizeof(init_req)),
SessionException); SessionException);
} }
@ -440,21 +438,21 @@ TEST(BoltSession, ErrorRunAfterRun) {
// first test with socket write success, then with socket write fail // first test with socket write success, then with socket write fail
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "MATCH (n) RETURN n"); WriteRunRequest(input_stream, "MATCH (n) RETURN n");
session.Execute(); session.Execute();
output.clear(); output.clear();
session.socket().SetWriteSuccess(true); output_stream.SetWriteSuccess(true);
// Session holds results of last run. // Session holds results of last run.
ASSERT_EQ(session.state_, State::Result); ASSERT_EQ(session.state_, State::Result);
// New run request. // New run request.
WriteRunRequest(session, "MATCH (n) RETURN n"); WriteRunRequest(input_stream, "MATCH (n) RETURN n");
ASSERT_THROW(session.Execute(), SessionException); ASSERT_THROW(session.Execute(), SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -463,16 +461,17 @@ TEST(BoltSession, ErrorRunAfterRun) {
TEST(BoltSession, ErrorCantCleanup) { TEST(BoltSession, ErrorCantCleanup) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "MATCH (omnom"); WriteRunRequest(input_stream, "MATCH (omnom");
session.Execute(); session.Execute();
output.clear(); output.clear();
// there is data missing in the request, cleanup should fail // there is data missing in the request, cleanup should fail
ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req) - 10), ASSERT_THROW(
ExecuteCommand(input_stream, session, init_req, sizeof(init_req) - 10),
SessionException); SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
@ -482,17 +481,18 @@ TEST(BoltSession, ErrorCantCleanup) {
TEST(BoltSession, ErrorWrongMarker) { TEST(BoltSession, ErrorWrongMarker) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "MATCH (omnom"); WriteRunRequest(input_stream, "MATCH (omnom");
session.Execute(); session.Execute();
output.clear(); output.clear();
// wrong marker, good signature // wrong marker, good signature
uint8_t data[2] = {0x00, init_req[1]}; uint8_t data[2] = {0x00, init_req[1]};
ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException); ASSERT_THROW(ExecuteCommand(input_stream, session, data, sizeof(data)),
SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
CheckFailureMessage(output); CheckFailureMessage(output);
@ -507,19 +507,20 @@ TEST(BoltSession, ErrorOK) {
for (int j = 0; j < 2; ++j) { for (int j = 0; j < 2; ++j) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "MATCH (omnom"); WriteRunRequest(input_stream, "MATCH (omnom");
session.Execute(); session.Execute();
output.clear(); output.clear();
session.socket().SetWriteSuccess(j == 0); output_stream.SetWriteSuccess(j == 0);
if (j == 0) { if (j == 0) {
ExecuteCommand(session, dataset[i], 2); ExecuteCommand(input_stream, session, dataset[i], 2);
} else { } else {
ASSERT_THROW(ExecuteCommand(session, dataset[i], 2), SessionException); ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2),
SessionException);
} }
// assert that all data from the init message was cleaned up // assert that all data from the init message was cleaned up
@ -539,17 +540,18 @@ TEST(BoltSession, ErrorOK) {
TEST(BoltSession, ErrorMissingData) { TEST(BoltSession, ErrorMissingData) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "MATCH (omnom"); WriteRunRequest(input_stream, "MATCH (omnom");
session.Execute(); session.Execute();
output.clear(); output.clear();
// some marker, missing signature // some marker, missing signature
uint8_t data[1] = {0x00}; uint8_t data[1] = {0x00};
ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException); ASSERT_THROW(ExecuteCommand(input_stream, session, data, sizeof(data)),
SessionException);
ASSERT_EQ(session.state_, State::Close); ASSERT_EQ(session.state_, State::Close);
CheckFailureMessage(output); CheckFailureMessage(output);
@ -558,11 +560,11 @@ TEST(BoltSession, ErrorMissingData) {
TEST(BoltSession, MultipleChunksInOneExecute) { TEST(BoltSession, MultipleChunksInOneExecute) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "CREATE (n) RETURN n"); WriteRunRequest(input_stream, "CREATE (n) RETURN n");
ExecuteCommand(session, pullall_req, sizeof(pullall_req)); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req));
ASSERT_EQ(session.state_, State::Idle); ASSERT_EQ(session.state_, State::Idle);
PrintOutput(output); PrintOutput(output);
@ -584,13 +586,11 @@ TEST(BoltSession, MultipleChunksInOneExecute) {
TEST(BoltSession, PartialChunk) { TEST(BoltSession, PartialChunk) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteChunkHeader(session, sizeof(discardall_req)); WriteChunkHeader(input_stream, sizeof(discardall_req));
auto buff = session.Allocate(); input_stream.Write(discardall_req, sizeof(discardall_req));
memcpy(buff.data, discardall_req, sizeof(discardall_req));
session.Written(2);
// missing chunk tail // missing chunk tail
session.Execute(); session.Execute();
@ -598,7 +598,7 @@ TEST(BoltSession, PartialChunk) {
ASSERT_EQ(session.state_, State::Idle); ASSERT_EQ(session.state_, State::Idle);
ASSERT_EQ(output.size(), 0); ASSERT_EQ(output.size(), 0);
WriteChunkTail(session); WriteChunkTail(input_stream);
ASSERT_THROW(session.Execute(), SessionException); ASSERT_THROW(session.Execute(), SessionException);
@ -615,40 +615,40 @@ TEST(BoltSession, ExplicitTransactionValidQueries) {
for (const auto &transaction_end : transaction_ends) { for (const auto &transaction_end : transaction_ends) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "BEGIN"); WriteRunRequest(input_stream, "BEGIN");
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Result); ASSERT_EQ(session.state_, State::Result);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req)); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req));
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Idle); ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
WriteRunRequest(session, "MATCH (n) RETURN n"); WriteRunRequest(input_stream, "MATCH (n) RETURN n");
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Result); ASSERT_EQ(session.state_, State::Result);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req)); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req));
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Idle); ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
WriteRunRequest(session, transaction_end.c_str()); WriteRunRequest(input_stream, transaction_end.c_str());
session.Execute(); session.Execute();
ASSERT_FALSE(session.db_accessor_); ASSERT_FALSE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
ASSERT_EQ(session.state_, State::Result); ASSERT_EQ(session.state_, State::Result);
ExecuteCommand(session, pullall_req, sizeof(pullall_req)); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req));
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Idle); ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.db_accessor_); ASSERT_FALSE(session.db_accessor_);
@ -662,40 +662,41 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) {
for (const auto &transaction_end : transaction_ends) { for (const auto &transaction_end : transaction_ends) {
INIT_VARS; INIT_VARS;
ExecuteHandshake(session, output); ExecuteHandshake(input_stream, session, output);
ExecuteInit(session, output); ExecuteInit(input_stream, session, output);
WriteRunRequest(session, "BEGIN"); WriteRunRequest(input_stream, "BEGIN");
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Result); ASSERT_EQ(session.state_, State::Result);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req)); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req));
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Idle); ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
WriteRunRequest(session, "MATCH ("); WriteRunRequest(input_stream, "MATCH (");
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::ErrorWaitForRollback); ASSERT_EQ(session.state_, State::ErrorWaitForRollback);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckFailureMessage(output); CheckFailureMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req)); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req));
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::ErrorWaitForRollback); ASSERT_EQ(session.state_, State::ErrorWaitForRollback);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckIgnoreMessage(output); CheckIgnoreMessage(output);
ExecuteCommand(session, ackfailure_req, sizeof(ackfailure_req)); ExecuteCommand(input_stream, session, ackfailure_req,
sizeof(ackfailure_req));
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::WaitForRollback); ASSERT_EQ(session.state_, State::WaitForRollback);
ASSERT_TRUE(session.db_accessor_); ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
WriteRunRequest(session, transaction_end.c_str()); WriteRunRequest(input_stream, transaction_end.c_str());
if (transaction_end == "ROLLBACK") { if (transaction_end == "ROLLBACK") {
session.Execute(); session.Execute();
@ -703,7 +704,7 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) {
ASSERT_FALSE(session.db_accessor_); ASSERT_FALSE(session.db_accessor_);
CheckSuccessMessage(output); CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req)); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req));
session.Execute(); session.Execute();
ASSERT_EQ(session.state_, State::Idle); ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.db_accessor_); ASSERT_FALSE(session.db_accessor_);

View File

@ -1,74 +1,77 @@
#include <chrono> #include <chrono>
#include <experimental/filesystem>
#include <iostream> #include <iostream>
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "communication/bolt/client.hpp"
#include "communication/bolt/v1/session.hpp"
#include "communication/server.hpp" #include "communication/server.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/socket.hpp" #include "io/network/socket.hpp"
DECLARE_int32(query_execution_time_sec);
DECLARE_int32(session_inactivity_timeout);
using namespace std::chrono_literals; using namespace std::chrono_literals;
class TestClientSocket;
using communication::bolt::ClientException;
using communication::bolt::SessionData;
using io::network::Endpoint;
using io::network::Socket;
using SessionT = communication::bolt::Session<Socket>;
using ResultStreamT = SessionT::ResultStreamT;
using ServerT = communication::Server<SessionT, SessionData>;
using ClientT = communication::bolt::Client<Socket>;
class RunningServer { class TestData {};
class TestSession {
public: public:
database::SingleNode db_; TestSession(TestData &, communication::InputStream &input_stream,
SessionData session_data_{db_}; communication::OutputStream &output_stream)
Endpoint endpoint_{"127.0.0.1", 0}; : input_stream_(input_stream), output_stream_(output_stream) {}
ServerT server_{endpoint_, session_data_, true, "Test", 1};
void Execute() {
LOG(INFO) << "Received data: '"
<< std::string(
reinterpret_cast<const char *>(input_stream_.data()),
input_stream_.size())
<< "'";
output_stream_.Write(input_stream_.data(), input_stream_.size());
input_stream_.Shift(input_stream_.size());
}
private:
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
}; };
class TestClient : public ClientT { const std::string query("timeout test");
public:
TestClient(Endpoint endpoint) bool QueryServer(io::network::Socket &socket) {
: ClientT( if (!socket.Write(query)) return false;
[&] { char response[105];
Socket socket; int len = 0;
socket.Connect(endpoint); while (len < query.size()) {
return socket; int got = socket.Read(response + len, query.size() - len);
}(), if (got <= 0) return false;
"", "") {} len += got;
}; }
if (std::string(response, strlen(response)) != query) return false;
return true;
}
TEST(NetworkTimeouts, InactiveSession) { TEST(NetworkTimeouts, InactiveSession) {
FLAGS_session_inactivity_timeout = 2; // Instantiate the server and set the session timeout to 2 seconds.
RunningServer rs; TestData test_data;
communication::Server<TestSession, TestData> server{
{"127.0.0.1", 0}, test_data, 2, "Test", 1};
TestClient client(rs.server_.endpoint()); // Create the client and connect to the server.
// Check that we can execute first query. io::network::Socket client;
client.Execute("RETURN 1", {}); ASSERT_TRUE(client.Connect(server.endpoint()));
// After sleep, session should still be alive. // Send some data to the server.
ASSERT_TRUE(QueryServer(client));
for (int i = 0; i < 3; ++i) {
// After this sleep the session should still be alive.
std::this_thread::sleep_for(500ms); std::this_thread::sleep_for(500ms);
client.Execute("RETURN 1", {});
// After sleep, session should still be alive. // Send some data to the server.
std::this_thread::sleep_for(500ms); ASSERT_TRUE(QueryServer(client));
client.Execute("RETURN 1", {}); }
// After sleep, session should still be alive. // After this sleep the session should have timed out.
std::this_thread::sleep_for(500ms);
client.Execute("RETURN 1", {});
// After sleep, session should have timed out.
std::this_thread::sleep_for(3500ms); std::this_thread::sleep_for(3500ms);
EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException); ASSERT_FALSE(QueryServer(client));
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {