Refactor network stack to use * instead of &

Reviewers: teon.banek, buda

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1587
This commit is contained in:
Matej Ferencevic 2018-09-03 15:29:06 +02:00
parent 4e996d2667
commit 240472e7cb
23 changed files with 175 additions and 175 deletions

View File

@ -41,8 +41,8 @@ class Session {
public: public:
using TEncoder = Encoder<ChunkedEncoderBuffer<TOutputStream>>; using TEncoder = Encoder<ChunkedEncoderBuffer<TOutputStream>>;
Session(TInputStream &input_stream, TOutputStream &output_stream) Session(TInputStream *input_stream, TOutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {} : input_stream_(*input_stream), output_stream_(*output_stream) {}
virtual ~Session() {} virtual ~Session() {}

View File

@ -5,35 +5,35 @@
namespace communication { namespace communication {
Buffer::Buffer() Buffer::Buffer()
: data_(kBufferInitialSize, 0), read_end_(*this), write_end_(*this) {} : data_(kBufferInitialSize, 0), read_end_(this), write_end_(this) {}
Buffer::ReadEnd::ReadEnd(Buffer &buffer) : buffer_(buffer) {} Buffer::ReadEnd::ReadEnd(Buffer *buffer) : buffer_(buffer) {}
uint8_t *Buffer::ReadEnd::data() { return buffer_.data(); } uint8_t *Buffer::ReadEnd::data() { return buffer_->data(); }
size_t Buffer::ReadEnd::size() const { return buffer_.size(); } size_t Buffer::ReadEnd::size() const { return buffer_->size(); }
void Buffer::ReadEnd::Shift(size_t len) { buffer_.Shift(len); } void Buffer::ReadEnd::Shift(size_t len) { buffer_->Shift(len); }
void Buffer::ReadEnd::Resize(size_t len) { buffer_.Resize(len); } void Buffer::ReadEnd::Resize(size_t len) { buffer_->Resize(len); }
void Buffer::ReadEnd::Clear() { buffer_.Clear(); } void Buffer::ReadEnd::Clear() { buffer_->Clear(); }
Buffer::WriteEnd::WriteEnd(Buffer &buffer) : buffer_(buffer) {} Buffer::WriteEnd::WriteEnd(Buffer *buffer) : buffer_(buffer) {}
io::network::StreamBuffer Buffer::WriteEnd::Allocate() { io::network::StreamBuffer Buffer::WriteEnd::Allocate() {
return buffer_.Allocate(); return buffer_->Allocate();
} }
void Buffer::WriteEnd::Written(size_t len) { buffer_.Written(len); } void Buffer::WriteEnd::Written(size_t len) { buffer_->Written(len); }
void Buffer::WriteEnd::Resize(size_t len) { buffer_.Resize(len); } void Buffer::WriteEnd::Resize(size_t len) { buffer_->Resize(len); }
void Buffer::WriteEnd::Clear() { buffer_.Clear(); } void Buffer::WriteEnd::Clear() { buffer_->Clear(); }
Buffer::ReadEnd &Buffer::read_end() { return read_end_; } Buffer::ReadEnd *Buffer::read_end() { return &read_end_; }
Buffer::WriteEnd &Buffer::write_end() { return write_end_; } Buffer::WriteEnd *Buffer::write_end() { return &write_end_; }
uint8_t *Buffer::data() { return data_.data(); } uint8_t *Buffer::data() { return data_.data(); }

View File

@ -39,7 +39,7 @@ class Buffer final {
*/ */
class ReadEnd { class ReadEnd {
public: public:
ReadEnd(Buffer &buffer); ReadEnd(Buffer *buffer);
ReadEnd(const ReadEnd &) = delete; ReadEnd(const ReadEnd &) = delete;
ReadEnd(ReadEnd &&) = delete; ReadEnd(ReadEnd &&) = delete;
@ -57,7 +57,7 @@ class Buffer final {
void Clear(); void Clear();
private: private:
Buffer &buffer_; Buffer *buffer_;
}; };
/** /**
@ -66,7 +66,7 @@ class Buffer final {
*/ */
class WriteEnd { class WriteEnd {
public: public:
WriteEnd(Buffer &buffer); WriteEnd(Buffer *buffer);
WriteEnd(const WriteEnd &) = delete; WriteEnd(const WriteEnd &) = delete;
WriteEnd(WriteEnd &&) = delete; WriteEnd(WriteEnd &&) = delete;
@ -82,20 +82,20 @@ class Buffer final {
void Clear(); void Clear();
private: private:
Buffer &buffer_; Buffer *buffer_;
}; };
/** /**
* This function returns a reference to the associated ReadEnd object for this * This function returns a pointer to the associated ReadEnd object for this
* buffer. * buffer.
*/ */
ReadEnd &read_end(); ReadEnd *read_end();
/** /**
* This function returns a reference to the associated WriteEnd object for * This function returns a pointer to the associated WriteEnd object for
* this buffer. * this buffer.
*/ */
WriteEnd &write_end(); WriteEnd *write_end();
private: private:
/** /**

View File

@ -83,9 +83,9 @@ void Client::Close() {
bool Client::Read(size_t len) { bool Client::Read(size_t len) {
size_t received = 0; size_t received = 0;
buffer_.write_end().Resize(buffer_.read_end().size() + len); buffer_.write_end()->Resize(buffer_.read_end()->size() + len);
while (received < len) { while (received < len) {
auto buff = buffer_.write_end().Allocate(); auto buff = buffer_.write_end()->Allocate();
if (ssl_) { if (ssl_) {
// We clear errors here to prevent errors piling up in the internal // We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this: // OpenSSL error queue. To see when could that be an issue read this:
@ -120,7 +120,7 @@ bool Client::Read(size_t len) {
} }
// Notify the buffer that it has new data. // Notify the buffer that it has new data.
buffer_.write_end().Written(got); buffer_.write_end()->Written(got);
received += got; received += got;
} else { } else {
// Read raw data from the socket. // Read raw data from the socket.
@ -135,20 +135,20 @@ bool Client::Read(size_t len) {
} }
// Notify the buffer that it has new data. // Notify the buffer that it has new data.
buffer_.write_end().Written(got); buffer_.write_end()->Written(got);
received += got; received += got;
} }
} }
return true; return true;
} }
uint8_t *Client::GetData() { return buffer_.read_end().data(); } uint8_t *Client::GetData() { return buffer_.read_end()->data(); }
size_t Client::GetDataSize() { return buffer_.read_end().size(); } size_t Client::GetDataSize() { return buffer_.read_end()->size(); }
void Client::ShiftData(size_t len) { buffer_.read_end().Shift(len); } void Client::ShiftData(size_t len) { buffer_.read_end()->Shift(len); }
void Client::ClearData() { buffer_.read_end().Clear(); } void Client::ClearData() { buffer_.read_end()->Clear(); }
bool Client::Write(const uint8_t *data, size_t len, bool have_more) { bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
if (ssl_) { if (ssl_) {

View File

@ -40,7 +40,7 @@ class Listener final {
using SessionHandler = Session<TSession, TSessionData>; using SessionHandler = Session<TSession, TSessionData>;
public: public:
Listener(TSessionData &data, ServerContext *context, Listener(TSessionData *data, ServerContext *context,
int inactivity_timeout_sec, const std::string &service_name, int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count) size_t workers_count)
: data_(data), : data_(data),
@ -227,7 +227,7 @@ class Listener final {
io::network::Epoll epoll_; io::network::Epoll epoll_;
TSessionData &data_; TSessionData *data_;
utils::SpinLock lock_; utils::SpinLock lock_;
std::vector<std::unique_ptr<SessionHandler>> sessions_; std::vector<std::unique_ptr<SessionHandler>> sessions_;

View File

@ -12,25 +12,25 @@
namespace communication::rpc { namespace communication::rpc {
Session::Session(Server &server, const io::network::Endpoint &endpoint, Session::Session(Server *server, const io::network::Endpoint &endpoint,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream) communication::OutputStream *output_stream)
: server_(server), : server_(server),
endpoint_(endpoint), endpoint_(endpoint),
input_stream_(input_stream), input_stream_(input_stream),
output_stream_(output_stream) {} output_stream_(output_stream) {}
void Session::Execute() { void Session::Execute() {
if (input_stream_.size() < sizeof(MessageSize)) return; if (input_stream_->size() < sizeof(MessageSize)) return;
MessageSize request_len = MessageSize request_len =
*reinterpret_cast<MessageSize *>(input_stream_.data()); *reinterpret_cast<MessageSize *>(input_stream_->data());
uint64_t request_size = sizeof(MessageSize) + request_len; uint64_t request_size = sizeof(MessageSize) + request_len;
input_stream_.Resize(request_size); input_stream_->Resize(request_size);
if (input_stream_.size() < request_size) return; if (input_stream_->size() < request_size) return;
// Read the request message. // Read the request message.
auto data = auto data =
::kj::arrayPtr(input_stream_.data() + sizeof(request_len), request_len); ::kj::arrayPtr(input_stream_->data() + sizeof(request_len), request_len);
// Our data is word aligned and padded to 64bit because we use regular // Our data is word aligned and padded to 64bit because we use regular
// (non-packed) serialization of Cap'n Proto. So we can use reinterpret_cast. // (non-packed) serialization of Cap'n Proto. So we can use reinterpret_cast.
auto data_words = auto data_words =
@ -38,18 +38,18 @@ void Session::Execute() {
reinterpret_cast<::capnp::word *>(data.end())); reinterpret_cast<::capnp::word *>(data.end()));
::capnp::FlatArrayMessageReader request_message(data_words.asConst()); ::capnp::FlatArrayMessageReader request_message(data_words.asConst());
auto request = request_message.getRoot<capnp::Message>(); auto request = request_message.getRoot<capnp::Message>();
input_stream_.Shift(sizeof(MessageSize) + request_len); input_stream_->Shift(sizeof(MessageSize) + request_len);
::capnp::MallocMessageBuilder response_message; ::capnp::MallocMessageBuilder response_message;
// callback fills the message data // callback fills the message data
auto response_builder = response_message.initRoot<capnp::Message>(); auto response_builder = response_message.initRoot<capnp::Message>();
auto callbacks_accessor = server_.callbacks_.access(); auto callbacks_accessor = server_->callbacks_.access();
auto it = callbacks_accessor.find(request.getTypeId()); auto it = callbacks_accessor.find(request.getTypeId());
if (it == callbacks_accessor.end()) { if (it == callbacks_accessor.end()) {
// We couldn't find a regular callback to call, try to find an extended // We couldn't find a regular callback to call, try to find an extended
// callback to call. // callback to call.
auto extended_callbacks_accessor = server_.extended_callbacks_.access(); auto extended_callbacks_accessor = server_->extended_callbacks_.access();
auto extended_it = extended_callbacks_accessor.find(request.getTypeId()); auto extended_it = extended_callbacks_accessor.find(request.getTypeId());
if (extended_it == extended_callbacks_accessor.end()) { if (extended_it == extended_callbacks_accessor.end()) {
// Throw exception to close the socket and cleanup the session. // Throw exception to close the socket and cleanup the session.
@ -73,11 +73,11 @@ void Session::Execute() {
} }
MessageSize input_stream_size = response_bytes.size(); MessageSize input_stream_size = response_bytes.size();
if (!output_stream_.Write(reinterpret_cast<uint8_t *>(&input_stream_size), if (!output_stream_->Write(reinterpret_cast<uint8_t *>(&input_stream_size),
sizeof(MessageSize), true)) { sizeof(MessageSize), true)) {
throw SessionException("Couldn't send response size!"); throw SessionException("Couldn't send response size!");
} }
if (!output_stream_.Write(response_bytes.begin(), response_bytes.size())) { if (!output_stream_->Write(response_bytes.begin(), response_bytes.size())) {
throw SessionException("Couldn't send response data!"); throw SessionException("Couldn't send response data!");
} }

View File

@ -36,9 +36,9 @@ class SessionException : public utils::BasicException {
*/ */
class Session { class Session {
public: public:
Session(Server &server, const io::network::Endpoint &endpoint, Session(Server *server, const io::network::Endpoint &endpoint,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream); communication::OutputStream *output_stream);
/** /**
* Executes the protocol after data has been read into the stream. * Executes the protocol after data has been read into the stream.
@ -48,10 +48,10 @@ class Session {
void Execute(); void Execute();
private: private:
Server &server_; Server *server_;
io::network::Endpoint endpoint_; io::network::Endpoint endpoint_;
communication::InputStream &input_stream_; communication::InputStream *input_stream_;
communication::OutputStream &output_stream_; communication::OutputStream *output_stream_;
}; };
} // namespace communication::rpc } // namespace communication::rpc

View File

@ -4,7 +4,7 @@ namespace communication::rpc {
Server::Server(const io::network::Endpoint &endpoint, Server::Server(const io::network::Endpoint &endpoint,
size_t workers_count) size_t workers_count)
: server_(endpoint, *this, &context_, -1, "RPC", workers_count) {} : server_(endpoint, this, &context_, -1, "RPC", workers_count) {}
void Server::StopProcessingCalls() { void Server::StopProcessingCalls() {
server_.Shutdown(); server_.Shutdown();

View File

@ -46,7 +46,7 @@ class Server final {
* Constructs and binds server to endpoint, operates on session data and * Constructs and binds server to endpoint, operates on session data and
* invokes workers_count workers * invokes workers_count workers
*/ */
Server(const io::network::Endpoint &endpoint, TSessionData &session_data, Server(const io::network::Endpoint &endpoint, TSessionData *session_data,
ServerContext *context, int inactivity_timeout_sec, ServerContext *context, int inactivity_timeout_sec,
const std::string &service_name, const std::string &service_name,
size_t workers_count = std::thread::hardware_concurrency()) size_t workers_count = std::thread::hardware_concurrency())

View File

@ -65,14 +65,14 @@ class OutputStream final {
template <class TSession, class TSessionData> template <class TSession, class TSessionData>
class Session final { class Session final {
public: public:
Session(io::network::Socket &&socket, TSessionData &data, Session(io::network::Socket &&socket, TSessionData *data,
ServerContext *context, int inactivity_timeout_sec) ServerContext *context, int inactivity_timeout_sec)
: socket_(std::move(socket)), : socket_(std::move(socket)),
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
return Write(data, len, have_more); return Write(data, len, have_more);
}), }),
session_(data, socket_.endpoint(), input_buffer_.read_end(), session_(data, socket_.endpoint(), input_buffer_.read_end(),
output_stream_), &output_stream_),
inactivity_timeout_sec_(inactivity_timeout_sec) { inactivity_timeout_sec_(inactivity_timeout_sec) {
// Set socket options. // Set socket options.
// The socket is set to be a non-blocking socket. We use the socket in a // The socket is set to be a non-blocking socket. We use the socket in a
@ -145,7 +145,7 @@ class Session final {
RefreshLastEventTime(); RefreshLastEventTime();
// Allocate the buffer to fill the data. // Allocate the buffer to fill the data.
auto buf = input_buffer_.write_end().Allocate(); auto buf = input_buffer_.write_end()->Allocate();
if (ssl_) { if (ssl_) {
// We clear errors here to prevent errors piling up in the internal // We clear errors here to prevent errors piling up in the internal
@ -181,7 +181,7 @@ class Session final {
return false; return false;
} else { } else {
// Notify the input buffer that it has new data. // Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len); input_buffer_.write_end()->Written(len);
} }
} else { } else {
// Read from the buffer at most buf.len bytes in a non-blocking fashion. // Read from the buffer at most buf.len bytes in a non-blocking fashion.
@ -205,7 +205,7 @@ class Session final {
throw SessionClosedException("Session was closed by client."); throw SessionClosedException("Session was closed by client.");
} else { } else {
// Notify the input buffer that it has new data. // Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len); input_buffer_.write_end()->Written(len);
} }
} }

View File

@ -79,7 +79,7 @@ void SingleNodeMain() {
} }
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)}, ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, &context, FLAGS_session_inactivity_timeout, &session_data, &context, FLAGS_session_inactivity_timeout,
service_name, FLAGS_num_workers); service_name, FLAGS_num_workers);
// Setup telemetry // Setup telemetry
@ -160,7 +160,7 @@ void MasterMain() {
} }
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)}, ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, &context, FLAGS_session_inactivity_timeout, &session_data, &context, FLAGS_session_inactivity_timeout,
service_name, FLAGS_num_workers); service_name, FLAGS_num_workers);
// Handler for regular termination signals // Handler for regular termination signals

View File

@ -22,14 +22,14 @@ DEFINE_uint64(memory_warning_threshold, 1024,
"less available RAM it will log a warning. Set to 0 to " "less available RAM it will log a warning. Set to 0 to "
"disable."); "disable.");
BoltSession::BoltSession(SessionData &data, const io::network::Endpoint &, BoltSession::BoltSession(SessionData *data, const io::network::Endpoint &,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream) communication::OutputStream *output_stream)
: communication::bolt::Session<communication::InputStream, : communication::bolt::Session<communication::InputStream,
communication::OutputStream>(input_stream, communication::OutputStream>(input_stream,
output_stream), output_stream),
transaction_engine_(data.db, data.interpreter), transaction_engine_(data->db, data->interpreter),
auth_(&data.auth) {} auth_(&data->auth) {}
using TEncoder = using TEncoder =
communication::bolt::Session<communication::InputStream, communication::bolt::Session<communication::InputStream,

View File

@ -31,9 +31,9 @@ class BoltSession final
: public communication::bolt::Session<communication::InputStream, : public communication::bolt::Session<communication::InputStream,
communication::OutputStream> { communication::OutputStream> {
public: public:
BoltSession(SessionData &data, const io::network::Endpoint &endpoint, BoltSession(SessionData *data, const io::network::Endpoint &endpoint,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream); communication::OutputStream *output_stream);
using communication::bolt::Session<communication::InputStream, using communication::bolt::Session<communication::InputStream,
communication::OutputStream>::TEncoder; communication::OutputStream>::TEncoder;

View File

@ -23,28 +23,28 @@ class TestData {};
class TestSession { class TestSession {
public: public:
TestSession(TestData &, const io::network::Endpoint &, TestSession(TestData *, const io::network::Endpoint &,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream) communication::OutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {} : input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() { void Execute() {
if (input_stream_.size() < 2) return; if (input_stream_->size() < 2) return;
const uint8_t *data = input_stream_.data(); const uint8_t *data = input_stream_->data();
size_t size = data[0]; size_t size = data[0];
size <<= 8; size <<= 8;
size += data[1]; size += data[1];
input_stream_.Resize(size + 2); input_stream_->Resize(size + 2);
if (input_stream_.size() < size + 2) return; if (input_stream_->size() < size + 2) return;
for (int i = 0; i < REPLY; ++i) for (int i = 0; i < REPLY; ++i)
ASSERT_TRUE(output_stream_.Write(data + 2, size)); ASSERT_TRUE(output_stream_->Write(data + 2, size));
input_stream_.Shift(size + 2); input_stream_->Shift(size + 2);
} }
communication::InputStream &input_stream_; communication::InputStream *input_stream_;
communication::OutputStream &output_stream_; communication::OutputStream *output_stream_;
}; };
using ContextT = communication::ServerContext; using ContextT = communication::ServerContext;

View File

@ -24,17 +24,17 @@ class TestData {};
class TestSession { class TestSession {
public: public:
TestSession(TestData &, const io::network::Endpoint &, TestSession(TestData *, const io::network::Endpoint &,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream) communication::OutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {} : input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() { void Execute() {
output_stream_.Write(input_stream_.data(), input_stream_.size()); output_stream_->Write(input_stream_->data(), input_stream_->size());
} }
communication::InputStream &input_stream_; communication::InputStream *input_stream_;
communication::OutputStream &output_stream_; communication::OutputStream *output_stream_;
}; };
std::atomic<bool> run{true}; std::atomic<bool> run{true};
@ -65,7 +65,7 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
int N = (std::thread::hardware_concurrency() + 1) / 2; int N = (std::thread::hardware_concurrency() + 1) / 2;
int Nc = N * 3; int Nc = N * 3;
communication::ServerContext context; communication::ServerContext context;
communication::Server<TestSession, TestData> server(endpoint, data, &context, communication::Server<TestSession, TestData> server(endpoint, &data, &context,
-1, "Test", N); -1, "Test", N);
const auto &ep = server.endpoint(); const auto &ep = server.endpoint();

View File

@ -22,7 +22,7 @@ TEST(Network, Server) {
TestData session_data; TestData session_data;
int N = (std::thread::hardware_concurrency() + 1) / 2; int N = (std::thread::hardware_concurrency() + 1) / 2;
ContextT context; ContextT context;
ServerT server(endpoint, session_data, &context, -1, "Test", N); ServerT server(endpoint, &session_data, &context, -1, "Test", N);
const auto &ep = server.endpoint(); const auto &ep = server.endpoint();
// start clients // start clients

View File

@ -23,7 +23,7 @@ TEST(Network, SessionLeak) {
// initialize server // initialize server
TestData session_data; TestData session_data;
ContextT context; ContextT context;
ServerT server(endpoint, session_data, &context, -1, "Test", 2); ServerT server(endpoint, &session_data, &context, -1, "Test", 2);
// start clients // start clients
int N = 50; int N = 50;

View File

@ -21,24 +21,24 @@ struct EchoData {};
class EchoSession { class EchoSession {
public: public:
EchoSession(EchoData &, const io::network::Endpoint &, EchoSession(EchoData *, const io::network::Endpoint &,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream) communication::OutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {} : input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() { void Execute() {
if (input_stream_.size() < message.size()) return; if (input_stream_->size() < message.size()) return;
LOG(INFO) << "Server received message."; LOG(INFO) << "Server received message.";
if (!output_stream_.Write(input_stream_.data(), message.size())) { if (!output_stream_->Write(input_stream_->data(), message.size())) {
throw utils::BasicException("Output stream write failed!"); throw utils::BasicException("Output stream write failed!");
} }
LOG(INFO) << "Server sent message."; LOG(INFO) << "Server sent message.";
input_stream_.Shift(message.size()); input_stream_->Shift(message.size());
} }
private: private:
communication::InputStream &input_stream_; communication::InputStream *input_stream_;
communication::OutputStream &output_stream_; communication::OutputStream *output_stream_;
}; };
int main(int argc, char **argv) { int main(int argc, char **argv) {
@ -54,7 +54,7 @@ int main(int argc, char **argv) {
FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file, FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file,
FLAGS_server_verify_peer); FLAGS_server_verify_peer);
communication::Server<EchoSession, EchoData> server( communication::Server<EchoSession, EchoData> server(
{"127.0.0.1", 0}, echo_data, &server_context, -1, "SSL", 1); {"127.0.0.1", 0}, &echo_data, &server_context, -1, "SSL", 1);
// Initialize the client. // Initialize the client.
communication::ClientContext client_context(FLAGS_client_key_file, communication::ClientContext client_context(FLAGS_client_key_file,

View File

@ -19,37 +19,37 @@ struct EchoData {
class EchoSession { class EchoSession {
public: public:
EchoSession(EchoData &data, const io::network::Endpoint &, EchoSession(EchoData *data, const io::network::Endpoint &,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream) communication::OutputStream *output_stream)
: data_(data), : data_(data),
input_stream_(input_stream), input_stream_(input_stream),
output_stream_(output_stream) {} output_stream_(output_stream) {}
void Execute() { void Execute() {
if (input_stream_.size() < 2) return; if (input_stream_->size() < 2) return;
const uint8_t *data = input_stream_.data(); const uint8_t *data = input_stream_->data();
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_.data()); uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_->data());
input_stream_.Resize(size + 2); input_stream_->Resize(size + 2);
if (input_stream_.size() < size + 2) return; if (input_stream_->size() < size + 2) return;
if (size == 0) { if (size == 0) {
LOG(INFO) << "Server received EOF message"; LOG(INFO) << "Server received EOF message";
data_.alive.store(false); data_->alive.store(false);
return; return;
} }
LOG(INFO) << "Server received '" LOG(INFO) << "Server received '"
<< std::string(reinterpret_cast<const char *>(data + 2), size) << std::string(reinterpret_cast<const char *>(data + 2), size)
<< "'"; << "'";
if (!output_stream_.Write(data + 2, size)) { if (!output_stream_->Write(data + 2, size)) {
throw utils::BasicException("Output stream write failed!"); throw utils::BasicException("Output stream write failed!");
} }
input_stream_.Shift(size + 2); input_stream_->Shift(size + 2);
} }
private: private:
EchoData &data_; EchoData *data_;
communication::InputStream &input_stream_; communication::InputStream *input_stream_;
communication::OutputStream &output_stream_; communication::OutputStream *output_stream_;
}; };
int main(int argc, char **argv) { int main(int argc, char **argv) {
@ -63,7 +63,7 @@ int main(int argc, char **argv) {
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port); io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file, communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file,
FLAGS_ca_file, FLAGS_verify_peer); FLAGS_ca_file, FLAGS_verify_peer);
communication::Server<EchoSession, EchoData> server(endpoint, echo_data, communication::Server<EchoSession, EchoData> server(endpoint, &echo_data,
&context, -1, "SSL", 1); &context, -1, "SSL", 1);
while (echo_data.alive) { while (echo_data.alive) {

View File

@ -13,15 +13,15 @@ using ChunkStateT = communication::bolt::ChunkState;
TEST(BoltBuffer, CorrectChunk) { TEST(BoltBuffer, CorrectChunk) {
uint8_t tmp[2000]; uint8_t tmp[2000];
BufferT buffer; BufferT buffer;
DecoderBufferT decoder_buffer(buffer.read_end()); DecoderBufferT decoder_buffer(*buffer.read_end());
StreamBufferT sb = buffer.write_end().Allocate(); StreamBufferT sb = buffer.write_end()->Allocate();
sb.data[0] = 0x03; sb.data[0] = 0x03;
sb.data[1] = 0xe8; sb.data[1] = 0xe8;
memcpy(sb.data + 2, data, 1000); memcpy(sb.data + 2, data, 1000);
sb.data[1002] = 0; sb.data[1002] = 0;
sb.data[1003] = 0; sb.data[1003] = 0;
buffer.write_end().Written(1004); buffer.write_end()->Written(1004);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
@ -29,21 +29,21 @@ TEST(BoltBuffer, CorrectChunk) {
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true); ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
ASSERT_EQ(buffer.read_end().size(), 0); ASSERT_EQ(buffer.read_end()->size(), 0);
} }
TEST(BoltBuffer, CorrectChunkTrailingData) { TEST(BoltBuffer, CorrectChunkTrailingData) {
uint8_t tmp[2000]; uint8_t tmp[2000];
BufferT buffer; BufferT buffer;
DecoderBufferT decoder_buffer(buffer.read_end()); DecoderBufferT decoder_buffer(*buffer.read_end());
StreamBufferT sb = buffer.write_end().Allocate(); StreamBufferT sb = buffer.write_end()->Allocate();
sb.data[0] = 0x03; sb.data[0] = 0x03;
sb.data[1] = 0xe8; sb.data[1] = 0xe8;
memcpy(sb.data + 2, data, 2002); memcpy(sb.data + 2, data, 2002);
sb.data[1002] = 0; sb.data[1002] = 0;
sb.data[1003] = 0; sb.data[1003] = 0;
buffer.write_end().Written(2004); buffer.write_end()->Written(2004);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
@ -51,66 +51,66 @@ TEST(BoltBuffer, CorrectChunkTrailingData) {
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true); ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
uint8_t *leftover = buffer.read_end().data(); uint8_t *leftover = buffer.read_end()->data();
ASSERT_EQ(buffer.read_end().size(), 1000); ASSERT_EQ(buffer.read_end()->size(), 1000);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i + 1002], leftover[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i + 1002], leftover[i]);
} }
TEST(BoltBuffer, GraduallyPopulatedChunk) { TEST(BoltBuffer, GraduallyPopulatedChunk) {
uint8_t tmp[2000]; uint8_t tmp[2000];
BufferT buffer; BufferT buffer;
DecoderBufferT decoder_buffer(buffer.read_end()); DecoderBufferT decoder_buffer(*buffer.read_end());
StreamBufferT sb = buffer.write_end().Allocate(); StreamBufferT sb = buffer.write_end()->Allocate();
sb.data[0] = 0x03; sb.data[0] = 0x03;
sb.data[1] = 0xe8; sb.data[1] = 0xe8;
buffer.write_end().Written(2); buffer.write_end()->Written(2);
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
sb = buffer.write_end().Allocate(); sb = buffer.write_end()->Allocate();
memcpy(sb.data, data + 200 * i, 200); memcpy(sb.data, data + 200 * i, 200);
buffer.write_end().Written(200); buffer.write_end()->Written(200);
} }
sb = buffer.write_end().Allocate(); sb = buffer.write_end()->Allocate();
sb.data[0] = 0; sb.data[0] = 0;
sb.data[1] = 0; sb.data[1] = 0;
buffer.write_end().Written(2); buffer.write_end()->Written(2);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true); ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
ASSERT_EQ(buffer.read_end().size(), 0); ASSERT_EQ(buffer.read_end()->size(), 0);
} }
TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) { TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) {
uint8_t tmp[2000]; uint8_t tmp[2000];
BufferT buffer; BufferT buffer;
DecoderBufferT decoder_buffer(buffer.read_end()); DecoderBufferT decoder_buffer(*buffer.read_end());
StreamBufferT sb = buffer.write_end().Allocate(); StreamBufferT sb = buffer.write_end()->Allocate();
sb.data[0] = 0x03; sb.data[0] = 0x03;
sb.data[1] = 0xe8; sb.data[1] = 0xe8;
buffer.write_end().Written(2); buffer.write_end()->Written(2);
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
sb = buffer.write_end().Allocate(); sb = buffer.write_end()->Allocate();
memcpy(sb.data, data + 200 * i, 200); memcpy(sb.data, data + 200 * i, 200);
buffer.write_end().Written(200); buffer.write_end()->Written(200);
} }
sb = buffer.write_end().Allocate(); sb = buffer.write_end()->Allocate();
sb.data[0] = 0; sb.data[0] = 0;
sb.data[1] = 0; sb.data[1] = 0;
buffer.write_end().Written(2); buffer.write_end()->Written(2);
sb = buffer.write_end().Allocate(); sb = buffer.write_end()->Allocate();
memcpy(sb.data, data, 1000); memcpy(sb.data, data, 1000);
buffer.write_end().Written(1000); buffer.write_end()->Written(1000);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done); ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
@ -118,8 +118,8 @@ TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) {
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true); ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
uint8_t *leftover = buffer.read_end().data(); uint8_t *leftover = buffer.read_end()->data();
ASSERT_EQ(buffer.read_end().size(), 1000); ASSERT_EQ(buffer.read_end()->size(), 1000);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], leftover[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], leftover[i]);
} }

View File

@ -20,8 +20,8 @@ class TestSession : public Session<TestInputStream, TestOutputStream> {
public: public:
using Session<TestInputStream, TestOutputStream>::TEncoder; using Session<TestInputStream, TestOutputStream>::TEncoder;
TestSession(TestSessionData &data, TestInputStream &input_stream, TestSession(TestSessionData *data, TestInputStream *input_stream,
TestOutputStream &output_stream) TestOutputStream *output_stream)
: Session<TestInputStream, TestOutputStream>(input_stream, : Session<TestInputStream, TestOutputStream>(input_stream,
output_stream) {} output_stream) {}
@ -61,11 +61,11 @@ class TestSession : public Session<TestInputStream, TestOutputStream> {
// TODO: This could be done in fixture. // TODO: This could be done in fixture.
// Shortcuts for writing variable initializations in tests // Shortcuts for writing variable initializations in tests
#define INIT_VARS \ #define INIT_VARS \
TestInputStream input_stream; \ TestInputStream input_stream; \
TestOutputStream output_stream; \ TestOutputStream output_stream; \
TestSessionData session_data; \ TestSessionData session_data; \
TestSession session(session_data, input_stream, output_stream); \ TestSession session(&session_data, &input_stream, &output_stream); \
std::vector<uint8_t> &output = output_stream.output; std::vector<uint8_t> &output = output_stream.output;
// Sample testdata that has correct inputs and outputs. // Sample testdata that has correct inputs and outputs.

View File

@ -8,47 +8,47 @@ using communication::Buffer;
TEST(CommunicationBuffer, AllocateAndWritten) { TEST(CommunicationBuffer, AllocateAndWritten) {
Buffer buffer; Buffer buffer;
auto sb = buffer.write_end().Allocate(); auto sb = buffer.write_end()->Allocate();
memcpy(sb.data, data, 1000); memcpy(sb.data, data, 1000);
buffer.write_end().Written(1000); buffer.write_end()->Written(1000);
ASSERT_EQ(buffer.read_end().size(), 1000); ASSERT_EQ(buffer.read_end()->size(), 1000);
uint8_t *tmp = buffer.read_end().data(); uint8_t *tmp = buffer.read_end()->data();
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
} }
TEST(CommunicationBuffer, Shift) { TEST(CommunicationBuffer, Shift) {
Buffer buffer; Buffer buffer;
auto sb = buffer.write_end().Allocate(); auto sb = buffer.write_end()->Allocate();
memcpy(sb.data, data, 1000); memcpy(sb.data, data, 1000);
buffer.write_end().Written(1000); buffer.write_end()->Written(1000);
sb = buffer.write_end().Allocate(); sb = buffer.write_end()->Allocate();
memcpy(sb.data, data + 1000, 1000); memcpy(sb.data, data + 1000, 1000);
buffer.write_end().Written(1000); buffer.write_end()->Written(1000);
ASSERT_EQ(buffer.read_end().size(), 2000); ASSERT_EQ(buffer.read_end()->size(), 2000);
uint8_t *tmp = buffer.read_end().data(); uint8_t *tmp = buffer.read_end()->data();
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
buffer.read_end().Shift(1000); buffer.read_end()->Shift(1000);
ASSERT_EQ(buffer.read_end().size(), 1000); ASSERT_EQ(buffer.read_end()->size(), 1000);
tmp = buffer.read_end().data(); tmp = buffer.read_end()->data();
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i + 1000], tmp[i]); for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i + 1000], tmp[i]);
} }
TEST(CommunicationBuffer, Resize) { TEST(CommunicationBuffer, Resize) {
Buffer buffer; Buffer buffer;
auto sb = buffer.write_end().Allocate(); auto sb = buffer.write_end()->Allocate();
buffer.read_end().Resize(sb.len + 1000); buffer.read_end()->Resize(sb.len + 1000);
auto sbn = buffer.write_end().Allocate(); auto sbn = buffer.write_end()->Allocate();
ASSERT_EQ(sb.len + 1000, sbn.len); ASSERT_EQ(sb.len + 1000, sbn.len);
} }

View File

@ -14,24 +14,24 @@ class TestData {};
class TestSession { class TestSession {
public: public:
TestSession(TestData &, const io::network::Endpoint &, TestSession(TestData *, const io::network::Endpoint &,
communication::InputStream &input_stream, communication::InputStream *input_stream,
communication::OutputStream &output_stream) communication::OutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {} : input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() { void Execute() {
LOG(INFO) << "Received data: '" LOG(INFO) << "Received data: '"
<< std::string( << std::string(
reinterpret_cast<const char *>(input_stream_.data()), reinterpret_cast<const char *>(input_stream_->data()),
input_stream_.size()) input_stream_->size())
<< "'"; << "'";
output_stream_.Write(input_stream_.data(), input_stream_.size()); output_stream_->Write(input_stream_->data(), input_stream_->size());
input_stream_.Shift(input_stream_.size()); input_stream_->Shift(input_stream_->size());
} }
private: private:
communication::InputStream &input_stream_; communication::InputStream *input_stream_;
communication::OutputStream &output_stream_; communication::OutputStream *output_stream_;
}; };
const std::string query("timeout test"); const std::string query("timeout test");
@ -55,7 +55,7 @@ TEST(NetworkTimeouts, InactiveSession) {
TestData test_data; TestData test_data;
communication::ServerContext context; communication::ServerContext context;
communication::Server<TestSession, TestData> server{ communication::Server<TestSession, TestData> server{
{"127.0.0.1", 0}, test_data, &context, 2, "Test", 1}; {"127.0.0.1", 0}, &test_data, &context, 2, "Test", 1};
// Create the client and connect to the server. // Create the client and connect to the server.
io::network::Socket client; io::network::Socket client;