Refactor network stack to use *
instead of &
Reviewers: teon.banek, buda Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1587
This commit is contained in:
parent
4e996d2667
commit
240472e7cb
@ -41,8 +41,8 @@ class Session {
|
||||
public:
|
||||
using TEncoder = Encoder<ChunkedEncoderBuffer<TOutputStream>>;
|
||||
|
||||
Session(TInputStream &input_stream, TOutputStream &output_stream)
|
||||
: input_stream_(input_stream), output_stream_(output_stream) {}
|
||||
Session(TInputStream *input_stream, TOutputStream *output_stream)
|
||||
: input_stream_(*input_stream), output_stream_(*output_stream) {}
|
||||
|
||||
virtual ~Session() {}
|
||||
|
||||
|
@ -5,35 +5,35 @@
|
||||
namespace communication {
|
||||
|
||||
Buffer::Buffer()
|
||||
: data_(kBufferInitialSize, 0), read_end_(*this), write_end_(*this) {}
|
||||
: data_(kBufferInitialSize, 0), read_end_(this), write_end_(this) {}
|
||||
|
||||
Buffer::ReadEnd::ReadEnd(Buffer &buffer) : buffer_(buffer) {}
|
||||
Buffer::ReadEnd::ReadEnd(Buffer *buffer) : buffer_(buffer) {}
|
||||
|
||||
uint8_t *Buffer::ReadEnd::data() { return buffer_.data(); }
|
||||
uint8_t *Buffer::ReadEnd::data() { return buffer_->data(); }
|
||||
|
||||
size_t Buffer::ReadEnd::size() const { return buffer_.size(); }
|
||||
size_t Buffer::ReadEnd::size() const { return buffer_->size(); }
|
||||
|
||||
void Buffer::ReadEnd::Shift(size_t len) { buffer_.Shift(len); }
|
||||
void Buffer::ReadEnd::Shift(size_t len) { buffer_->Shift(len); }
|
||||
|
||||
void Buffer::ReadEnd::Resize(size_t len) { buffer_.Resize(len); }
|
||||
void Buffer::ReadEnd::Resize(size_t len) { buffer_->Resize(len); }
|
||||
|
||||
void Buffer::ReadEnd::Clear() { buffer_.Clear(); }
|
||||
void Buffer::ReadEnd::Clear() { buffer_->Clear(); }
|
||||
|
||||
Buffer::WriteEnd::WriteEnd(Buffer &buffer) : buffer_(buffer) {}
|
||||
Buffer::WriteEnd::WriteEnd(Buffer *buffer) : buffer_(buffer) {}
|
||||
|
||||
io::network::StreamBuffer Buffer::WriteEnd::Allocate() {
|
||||
return buffer_.Allocate();
|
||||
return buffer_->Allocate();
|
||||
}
|
||||
|
||||
void Buffer::WriteEnd::Written(size_t len) { buffer_.Written(len); }
|
||||
void Buffer::WriteEnd::Written(size_t len) { buffer_->Written(len); }
|
||||
|
||||
void Buffer::WriteEnd::Resize(size_t len) { buffer_.Resize(len); }
|
||||
void Buffer::WriteEnd::Resize(size_t len) { buffer_->Resize(len); }
|
||||
|
||||
void Buffer::WriteEnd::Clear() { buffer_.Clear(); }
|
||||
void Buffer::WriteEnd::Clear() { buffer_->Clear(); }
|
||||
|
||||
Buffer::ReadEnd &Buffer::read_end() { return read_end_; }
|
||||
Buffer::ReadEnd *Buffer::read_end() { return &read_end_; }
|
||||
|
||||
Buffer::WriteEnd &Buffer::write_end() { return write_end_; }
|
||||
Buffer::WriteEnd *Buffer::write_end() { return &write_end_; }
|
||||
|
||||
uint8_t *Buffer::data() { return data_.data(); }
|
||||
|
||||
|
@ -39,7 +39,7 @@ class Buffer final {
|
||||
*/
|
||||
class ReadEnd {
|
||||
public:
|
||||
ReadEnd(Buffer &buffer);
|
||||
ReadEnd(Buffer *buffer);
|
||||
|
||||
ReadEnd(const ReadEnd &) = delete;
|
||||
ReadEnd(ReadEnd &&) = delete;
|
||||
@ -57,7 +57,7 @@ class Buffer final {
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
Buffer &buffer_;
|
||||
Buffer *buffer_;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -66,7 +66,7 @@ class Buffer final {
|
||||
*/
|
||||
class WriteEnd {
|
||||
public:
|
||||
WriteEnd(Buffer &buffer);
|
||||
WriteEnd(Buffer *buffer);
|
||||
|
||||
WriteEnd(const WriteEnd &) = delete;
|
||||
WriteEnd(WriteEnd &&) = delete;
|
||||
@ -82,20 +82,20 @@ class Buffer final {
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
Buffer &buffer_;
|
||||
Buffer *buffer_;
|
||||
};
|
||||
|
||||
/**
|
||||
* This function returns a reference to the associated ReadEnd object for this
|
||||
* This function returns a pointer to the associated ReadEnd object for this
|
||||
* buffer.
|
||||
*/
|
||||
ReadEnd &read_end();
|
||||
ReadEnd *read_end();
|
||||
|
||||
/**
|
||||
* This function returns a reference to the associated WriteEnd object for
|
||||
* This function returns a pointer to the associated WriteEnd object for
|
||||
* this buffer.
|
||||
*/
|
||||
WriteEnd &write_end();
|
||||
WriteEnd *write_end();
|
||||
|
||||
private:
|
||||
/**
|
||||
|
@ -83,9 +83,9 @@ void Client::Close() {
|
||||
|
||||
bool Client::Read(size_t len) {
|
||||
size_t received = 0;
|
||||
buffer_.write_end().Resize(buffer_.read_end().size() + len);
|
||||
buffer_.write_end()->Resize(buffer_.read_end()->size() + len);
|
||||
while (received < len) {
|
||||
auto buff = buffer_.write_end().Allocate();
|
||||
auto buff = buffer_.write_end()->Allocate();
|
||||
if (ssl_) {
|
||||
// We clear errors here to prevent errors piling up in the internal
|
||||
// OpenSSL error queue. To see when could that be an issue read this:
|
||||
@ -120,7 +120,7 @@ bool Client::Read(size_t len) {
|
||||
}
|
||||
|
||||
// Notify the buffer that it has new data.
|
||||
buffer_.write_end().Written(got);
|
||||
buffer_.write_end()->Written(got);
|
||||
received += got;
|
||||
} else {
|
||||
// Read raw data from the socket.
|
||||
@ -135,20 +135,20 @@ bool Client::Read(size_t len) {
|
||||
}
|
||||
|
||||
// Notify the buffer that it has new data.
|
||||
buffer_.write_end().Written(got);
|
||||
buffer_.write_end()->Written(got);
|
||||
received += got;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint8_t *Client::GetData() { return buffer_.read_end().data(); }
|
||||
uint8_t *Client::GetData() { return buffer_.read_end()->data(); }
|
||||
|
||||
size_t Client::GetDataSize() { return buffer_.read_end().size(); }
|
||||
size_t Client::GetDataSize() { return buffer_.read_end()->size(); }
|
||||
|
||||
void Client::ShiftData(size_t len) { buffer_.read_end().Shift(len); }
|
||||
void Client::ShiftData(size_t len) { buffer_.read_end()->Shift(len); }
|
||||
|
||||
void Client::ClearData() { buffer_.read_end().Clear(); }
|
||||
void Client::ClearData() { buffer_.read_end()->Clear(); }
|
||||
|
||||
bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
|
||||
if (ssl_) {
|
||||
|
@ -40,7 +40,7 @@ class Listener final {
|
||||
using SessionHandler = Session<TSession, TSessionData>;
|
||||
|
||||
public:
|
||||
Listener(TSessionData &data, ServerContext *context,
|
||||
Listener(TSessionData *data, ServerContext *context,
|
||||
int inactivity_timeout_sec, const std::string &service_name,
|
||||
size_t workers_count)
|
||||
: data_(data),
|
||||
@ -227,7 +227,7 @@ class Listener final {
|
||||
|
||||
io::network::Epoll epoll_;
|
||||
|
||||
TSessionData &data_;
|
||||
TSessionData *data_;
|
||||
|
||||
utils::SpinLock lock_;
|
||||
std::vector<std::unique_ptr<SessionHandler>> sessions_;
|
||||
|
@ -12,25 +12,25 @@
|
||||
|
||||
namespace communication::rpc {
|
||||
|
||||
Session::Session(Server &server, const io::network::Endpoint &endpoint,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
Session::Session(Server *server, const io::network::Endpoint &endpoint,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream)
|
||||
: server_(server),
|
||||
endpoint_(endpoint),
|
||||
input_stream_(input_stream),
|
||||
output_stream_(output_stream) {}
|
||||
|
||||
void Session::Execute() {
|
||||
if (input_stream_.size() < sizeof(MessageSize)) return;
|
||||
if (input_stream_->size() < sizeof(MessageSize)) return;
|
||||
MessageSize request_len =
|
||||
*reinterpret_cast<MessageSize *>(input_stream_.data());
|
||||
*reinterpret_cast<MessageSize *>(input_stream_->data());
|
||||
uint64_t request_size = sizeof(MessageSize) + request_len;
|
||||
input_stream_.Resize(request_size);
|
||||
if (input_stream_.size() < request_size) return;
|
||||
input_stream_->Resize(request_size);
|
||||
if (input_stream_->size() < request_size) return;
|
||||
|
||||
// Read the request message.
|
||||
auto data =
|
||||
::kj::arrayPtr(input_stream_.data() + sizeof(request_len), request_len);
|
||||
::kj::arrayPtr(input_stream_->data() + sizeof(request_len), request_len);
|
||||
// Our data is word aligned and padded to 64bit because we use regular
|
||||
// (non-packed) serialization of Cap'n Proto. So we can use reinterpret_cast.
|
||||
auto data_words =
|
||||
@ -38,18 +38,18 @@ void Session::Execute() {
|
||||
reinterpret_cast<::capnp::word *>(data.end()));
|
||||
::capnp::FlatArrayMessageReader request_message(data_words.asConst());
|
||||
auto request = request_message.getRoot<capnp::Message>();
|
||||
input_stream_.Shift(sizeof(MessageSize) + request_len);
|
||||
input_stream_->Shift(sizeof(MessageSize) + request_len);
|
||||
|
||||
::capnp::MallocMessageBuilder response_message;
|
||||
// callback fills the message data
|
||||
auto response_builder = response_message.initRoot<capnp::Message>();
|
||||
|
||||
auto callbacks_accessor = server_.callbacks_.access();
|
||||
auto callbacks_accessor = server_->callbacks_.access();
|
||||
auto it = callbacks_accessor.find(request.getTypeId());
|
||||
if (it == callbacks_accessor.end()) {
|
||||
// We couldn't find a regular callback to call, try to find an extended
|
||||
// callback to call.
|
||||
auto extended_callbacks_accessor = server_.extended_callbacks_.access();
|
||||
auto extended_callbacks_accessor = server_->extended_callbacks_.access();
|
||||
auto extended_it = extended_callbacks_accessor.find(request.getTypeId());
|
||||
if (extended_it == extended_callbacks_accessor.end()) {
|
||||
// Throw exception to close the socket and cleanup the session.
|
||||
@ -73,11 +73,11 @@ void Session::Execute() {
|
||||
}
|
||||
|
||||
MessageSize input_stream_size = response_bytes.size();
|
||||
if (!output_stream_.Write(reinterpret_cast<uint8_t *>(&input_stream_size),
|
||||
if (!output_stream_->Write(reinterpret_cast<uint8_t *>(&input_stream_size),
|
||||
sizeof(MessageSize), true)) {
|
||||
throw SessionException("Couldn't send response size!");
|
||||
}
|
||||
if (!output_stream_.Write(response_bytes.begin(), response_bytes.size())) {
|
||||
if (!output_stream_->Write(response_bytes.begin(), response_bytes.size())) {
|
||||
throw SessionException("Couldn't send response data!");
|
||||
}
|
||||
|
||||
|
@ -36,9 +36,9 @@ class SessionException : public utils::BasicException {
|
||||
*/
|
||||
class Session {
|
||||
public:
|
||||
Session(Server &server, const io::network::Endpoint &endpoint,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream);
|
||||
Session(Server *server, const io::network::Endpoint &endpoint,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream);
|
||||
|
||||
/**
|
||||
* Executes the protocol after data has been read into the stream.
|
||||
@ -48,10 +48,10 @@ class Session {
|
||||
void Execute();
|
||||
|
||||
private:
|
||||
Server &server_;
|
||||
Server *server_;
|
||||
io::network::Endpoint endpoint_;
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
communication::InputStream *input_stream_;
|
||||
communication::OutputStream *output_stream_;
|
||||
};
|
||||
|
||||
} // namespace communication::rpc
|
||||
|
@ -4,7 +4,7 @@ namespace communication::rpc {
|
||||
|
||||
Server::Server(const io::network::Endpoint &endpoint,
|
||||
size_t workers_count)
|
||||
: server_(endpoint, *this, &context_, -1, "RPC", workers_count) {}
|
||||
: server_(endpoint, this, &context_, -1, "RPC", workers_count) {}
|
||||
|
||||
void Server::StopProcessingCalls() {
|
||||
server_.Shutdown();
|
||||
|
@ -46,7 +46,7 @@ class Server final {
|
||||
* Constructs and binds server to endpoint, operates on session data and
|
||||
* invokes workers_count workers
|
||||
*/
|
||||
Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
|
||||
Server(const io::network::Endpoint &endpoint, TSessionData *session_data,
|
||||
ServerContext *context, int inactivity_timeout_sec,
|
||||
const std::string &service_name,
|
||||
size_t workers_count = std::thread::hardware_concurrency())
|
||||
|
@ -65,14 +65,14 @@ class OutputStream final {
|
||||
template <class TSession, class TSessionData>
|
||||
class Session final {
|
||||
public:
|
||||
Session(io::network::Socket &&socket, TSessionData &data,
|
||||
Session(io::network::Socket &&socket, TSessionData *data,
|
||||
ServerContext *context, int inactivity_timeout_sec)
|
||||
: socket_(std::move(socket)),
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
|
||||
return Write(data, len, have_more);
|
||||
}),
|
||||
session_(data, socket_.endpoint(), input_buffer_.read_end(),
|
||||
output_stream_),
|
||||
&output_stream_),
|
||||
inactivity_timeout_sec_(inactivity_timeout_sec) {
|
||||
// Set socket options.
|
||||
// The socket is set to be a non-blocking socket. We use the socket in a
|
||||
@ -145,7 +145,7 @@ class Session final {
|
||||
RefreshLastEventTime();
|
||||
|
||||
// Allocate the buffer to fill the data.
|
||||
auto buf = input_buffer_.write_end().Allocate();
|
||||
auto buf = input_buffer_.write_end()->Allocate();
|
||||
|
||||
if (ssl_) {
|
||||
// We clear errors here to prevent errors piling up in the internal
|
||||
@ -181,7 +181,7 @@ class Session final {
|
||||
return false;
|
||||
} else {
|
||||
// Notify the input buffer that it has new data.
|
||||
input_buffer_.write_end().Written(len);
|
||||
input_buffer_.write_end()->Written(len);
|
||||
}
|
||||
} else {
|
||||
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
|
||||
@ -205,7 +205,7 @@ class Session final {
|
||||
throw SessionClosedException("Session was closed by client.");
|
||||
} else {
|
||||
// Notify the input buffer that it has new data.
|
||||
input_buffer_.write_end().Written(len);
|
||||
input_buffer_.write_end()->Written(len);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,7 @@ void SingleNodeMain() {
|
||||
}
|
||||
|
||||
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
|
||||
session_data, &context, FLAGS_session_inactivity_timeout,
|
||||
&session_data, &context, FLAGS_session_inactivity_timeout,
|
||||
service_name, FLAGS_num_workers);
|
||||
|
||||
// Setup telemetry
|
||||
@ -160,7 +160,7 @@ void MasterMain() {
|
||||
}
|
||||
|
||||
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
|
||||
session_data, &context, FLAGS_session_inactivity_timeout,
|
||||
&session_data, &context, FLAGS_session_inactivity_timeout,
|
||||
service_name, FLAGS_num_workers);
|
||||
|
||||
// Handler for regular termination signals
|
||||
|
@ -22,14 +22,14 @@ DEFINE_uint64(memory_warning_threshold, 1024,
|
||||
"less available RAM it will log a warning. Set to 0 to "
|
||||
"disable.");
|
||||
|
||||
BoltSession::BoltSession(SessionData &data, const io::network::Endpoint &,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
BoltSession::BoltSession(SessionData *data, const io::network::Endpoint &,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream)
|
||||
: communication::bolt::Session<communication::InputStream,
|
||||
communication::OutputStream>(input_stream,
|
||||
output_stream),
|
||||
transaction_engine_(data.db, data.interpreter),
|
||||
auth_(&data.auth) {}
|
||||
transaction_engine_(data->db, data->interpreter),
|
||||
auth_(&data->auth) {}
|
||||
|
||||
using TEncoder =
|
||||
communication::bolt::Session<communication::InputStream,
|
||||
|
@ -31,9 +31,9 @@ class BoltSession final
|
||||
: public communication::bolt::Session<communication::InputStream,
|
||||
communication::OutputStream> {
|
||||
public:
|
||||
BoltSession(SessionData &data, const io::network::Endpoint &endpoint,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream);
|
||||
BoltSession(SessionData *data, const io::network::Endpoint &endpoint,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream);
|
||||
|
||||
using communication::bolt::Session<communication::InputStream,
|
||||
communication::OutputStream>::TEncoder;
|
||||
|
@ -23,28 +23,28 @@ class TestData {};
|
||||
|
||||
class TestSession {
|
||||
public:
|
||||
TestSession(TestData &, const io::network::Endpoint &,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
TestSession(TestData *, const io::network::Endpoint &,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream)
|
||||
: input_stream_(input_stream), output_stream_(output_stream) {}
|
||||
|
||||
void Execute() {
|
||||
if (input_stream_.size() < 2) return;
|
||||
const uint8_t *data = input_stream_.data();
|
||||
if (input_stream_->size() < 2) return;
|
||||
const uint8_t *data = input_stream_->data();
|
||||
size_t size = data[0];
|
||||
size <<= 8;
|
||||
size += data[1];
|
||||
input_stream_.Resize(size + 2);
|
||||
if (input_stream_.size() < size + 2) return;
|
||||
input_stream_->Resize(size + 2);
|
||||
if (input_stream_->size() < size + 2) return;
|
||||
|
||||
for (int i = 0; i < REPLY; ++i)
|
||||
ASSERT_TRUE(output_stream_.Write(data + 2, size));
|
||||
ASSERT_TRUE(output_stream_->Write(data + 2, size));
|
||||
|
||||
input_stream_.Shift(size + 2);
|
||||
input_stream_->Shift(size + 2);
|
||||
}
|
||||
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
communication::InputStream *input_stream_;
|
||||
communication::OutputStream *output_stream_;
|
||||
};
|
||||
|
||||
using ContextT = communication::ServerContext;
|
||||
|
@ -24,17 +24,17 @@ class TestData {};
|
||||
|
||||
class TestSession {
|
||||
public:
|
||||
TestSession(TestData &, const io::network::Endpoint &,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
TestSession(TestData *, const io::network::Endpoint &,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream)
|
||||
: input_stream_(input_stream), output_stream_(output_stream) {}
|
||||
|
||||
void Execute() {
|
||||
output_stream_.Write(input_stream_.data(), input_stream_.size());
|
||||
output_stream_->Write(input_stream_->data(), input_stream_->size());
|
||||
}
|
||||
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
communication::InputStream *input_stream_;
|
||||
communication::OutputStream *output_stream_;
|
||||
};
|
||||
|
||||
std::atomic<bool> run{true};
|
||||
@ -65,7 +65,7 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
|
||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||
int Nc = N * 3;
|
||||
communication::ServerContext context;
|
||||
communication::Server<TestSession, TestData> server(endpoint, data, &context,
|
||||
communication::Server<TestSession, TestData> server(endpoint, &data, &context,
|
||||
-1, "Test", N);
|
||||
|
||||
const auto &ep = server.endpoint();
|
||||
|
@ -22,7 +22,7 @@ TEST(Network, Server) {
|
||||
TestData session_data;
|
||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||
ContextT context;
|
||||
ServerT server(endpoint, session_data, &context, -1, "Test", N);
|
||||
ServerT server(endpoint, &session_data, &context, -1, "Test", N);
|
||||
|
||||
const auto &ep = server.endpoint();
|
||||
// start clients
|
||||
|
@ -23,7 +23,7 @@ TEST(Network, SessionLeak) {
|
||||
// initialize server
|
||||
TestData session_data;
|
||||
ContextT context;
|
||||
ServerT server(endpoint, session_data, &context, -1, "Test", 2);
|
||||
ServerT server(endpoint, &session_data, &context, -1, "Test", 2);
|
||||
|
||||
// start clients
|
||||
int N = 50;
|
||||
|
@ -21,24 +21,24 @@ struct EchoData {};
|
||||
|
||||
class EchoSession {
|
||||
public:
|
||||
EchoSession(EchoData &, const io::network::Endpoint &,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
EchoSession(EchoData *, const io::network::Endpoint &,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream)
|
||||
: input_stream_(input_stream), output_stream_(output_stream) {}
|
||||
|
||||
void Execute() {
|
||||
if (input_stream_.size() < message.size()) return;
|
||||
if (input_stream_->size() < message.size()) return;
|
||||
LOG(INFO) << "Server received message.";
|
||||
if (!output_stream_.Write(input_stream_.data(), message.size())) {
|
||||
if (!output_stream_->Write(input_stream_->data(), message.size())) {
|
||||
throw utils::BasicException("Output stream write failed!");
|
||||
}
|
||||
LOG(INFO) << "Server sent message.";
|
||||
input_stream_.Shift(message.size());
|
||||
input_stream_->Shift(message.size());
|
||||
}
|
||||
|
||||
private:
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
communication::InputStream *input_stream_;
|
||||
communication::OutputStream *output_stream_;
|
||||
};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
@ -54,7 +54,7 @@ int main(int argc, char **argv) {
|
||||
FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file,
|
||||
FLAGS_server_verify_peer);
|
||||
communication::Server<EchoSession, EchoData> server(
|
||||
{"127.0.0.1", 0}, echo_data, &server_context, -1, "SSL", 1);
|
||||
{"127.0.0.1", 0}, &echo_data, &server_context, -1, "SSL", 1);
|
||||
|
||||
// Initialize the client.
|
||||
communication::ClientContext client_context(FLAGS_client_key_file,
|
||||
|
@ -19,37 +19,37 @@ struct EchoData {
|
||||
|
||||
class EchoSession {
|
||||
public:
|
||||
EchoSession(EchoData &data, const io::network::Endpoint &,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
EchoSession(EchoData *data, const io::network::Endpoint &,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream)
|
||||
: data_(data),
|
||||
input_stream_(input_stream),
|
||||
output_stream_(output_stream) {}
|
||||
|
||||
void Execute() {
|
||||
if (input_stream_.size() < 2) return;
|
||||
const uint8_t *data = input_stream_.data();
|
||||
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_.data());
|
||||
input_stream_.Resize(size + 2);
|
||||
if (input_stream_.size() < size + 2) return;
|
||||
if (input_stream_->size() < 2) return;
|
||||
const uint8_t *data = input_stream_->data();
|
||||
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_->data());
|
||||
input_stream_->Resize(size + 2);
|
||||
if (input_stream_->size() < size + 2) return;
|
||||
if (size == 0) {
|
||||
LOG(INFO) << "Server received EOF message";
|
||||
data_.alive.store(false);
|
||||
data_->alive.store(false);
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Server received '"
|
||||
<< std::string(reinterpret_cast<const char *>(data + 2), size)
|
||||
<< "'";
|
||||
if (!output_stream_.Write(data + 2, size)) {
|
||||
if (!output_stream_->Write(data + 2, size)) {
|
||||
throw utils::BasicException("Output stream write failed!");
|
||||
}
|
||||
input_stream_.Shift(size + 2);
|
||||
input_stream_->Shift(size + 2);
|
||||
}
|
||||
|
||||
private:
|
||||
EchoData &data_;
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
EchoData *data_;
|
||||
communication::InputStream *input_stream_;
|
||||
communication::OutputStream *output_stream_;
|
||||
};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
@ -63,7 +63,7 @@ int main(int argc, char **argv) {
|
||||
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file,
|
||||
FLAGS_ca_file, FLAGS_verify_peer);
|
||||
communication::Server<EchoSession, EchoData> server(endpoint, echo_data,
|
||||
communication::Server<EchoSession, EchoData> server(endpoint, &echo_data,
|
||||
&context, -1, "SSL", 1);
|
||||
|
||||
while (echo_data.alive) {
|
||||
|
@ -13,15 +13,15 @@ using ChunkStateT = communication::bolt::ChunkState;
|
||||
TEST(BoltBuffer, CorrectChunk) {
|
||||
uint8_t tmp[2000];
|
||||
BufferT buffer;
|
||||
DecoderBufferT decoder_buffer(buffer.read_end());
|
||||
StreamBufferT sb = buffer.write_end().Allocate();
|
||||
DecoderBufferT decoder_buffer(*buffer.read_end());
|
||||
StreamBufferT sb = buffer.write_end()->Allocate();
|
||||
|
||||
sb.data[0] = 0x03;
|
||||
sb.data[1] = 0xe8;
|
||||
memcpy(sb.data + 2, data, 1000);
|
||||
sb.data[1002] = 0;
|
||||
sb.data[1003] = 0;
|
||||
buffer.write_end().Written(1004);
|
||||
buffer.write_end()->Written(1004);
|
||||
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
|
||||
@ -29,21 +29,21 @@ TEST(BoltBuffer, CorrectChunk) {
|
||||
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
|
||||
|
||||
ASSERT_EQ(buffer.read_end().size(), 0);
|
||||
ASSERT_EQ(buffer.read_end()->size(), 0);
|
||||
}
|
||||
|
||||
TEST(BoltBuffer, CorrectChunkTrailingData) {
|
||||
uint8_t tmp[2000];
|
||||
BufferT buffer;
|
||||
DecoderBufferT decoder_buffer(buffer.read_end());
|
||||
StreamBufferT sb = buffer.write_end().Allocate();
|
||||
DecoderBufferT decoder_buffer(*buffer.read_end());
|
||||
StreamBufferT sb = buffer.write_end()->Allocate();
|
||||
|
||||
sb.data[0] = 0x03;
|
||||
sb.data[1] = 0xe8;
|
||||
memcpy(sb.data + 2, data, 2002);
|
||||
sb.data[1002] = 0;
|
||||
sb.data[1003] = 0;
|
||||
buffer.write_end().Written(2004);
|
||||
buffer.write_end()->Written(2004);
|
||||
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
|
||||
@ -51,66 +51,66 @@ TEST(BoltBuffer, CorrectChunkTrailingData) {
|
||||
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
|
||||
|
||||
uint8_t *leftover = buffer.read_end().data();
|
||||
ASSERT_EQ(buffer.read_end().size(), 1000);
|
||||
uint8_t *leftover = buffer.read_end()->data();
|
||||
ASSERT_EQ(buffer.read_end()->size(), 1000);
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i + 1002], leftover[i]);
|
||||
}
|
||||
|
||||
TEST(BoltBuffer, GraduallyPopulatedChunk) {
|
||||
uint8_t tmp[2000];
|
||||
BufferT buffer;
|
||||
DecoderBufferT decoder_buffer(buffer.read_end());
|
||||
StreamBufferT sb = buffer.write_end().Allocate();
|
||||
DecoderBufferT decoder_buffer(*buffer.read_end());
|
||||
StreamBufferT sb = buffer.write_end()->Allocate();
|
||||
|
||||
sb.data[0] = 0x03;
|
||||
sb.data[1] = 0xe8;
|
||||
buffer.write_end().Written(2);
|
||||
buffer.write_end()->Written(2);
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
|
||||
sb = buffer.write_end().Allocate();
|
||||
sb = buffer.write_end()->Allocate();
|
||||
memcpy(sb.data, data + 200 * i, 200);
|
||||
buffer.write_end().Written(200);
|
||||
buffer.write_end()->Written(200);
|
||||
}
|
||||
|
||||
sb = buffer.write_end().Allocate();
|
||||
sb = buffer.write_end()->Allocate();
|
||||
sb.data[0] = 0;
|
||||
sb.data[1] = 0;
|
||||
buffer.write_end().Written(2);
|
||||
buffer.write_end()->Written(2);
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
|
||||
|
||||
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
|
||||
|
||||
ASSERT_EQ(buffer.read_end().size(), 0);
|
||||
ASSERT_EQ(buffer.read_end()->size(), 0);
|
||||
}
|
||||
|
||||
TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) {
|
||||
uint8_t tmp[2000];
|
||||
BufferT buffer;
|
||||
DecoderBufferT decoder_buffer(buffer.read_end());
|
||||
StreamBufferT sb = buffer.write_end().Allocate();
|
||||
DecoderBufferT decoder_buffer(*buffer.read_end());
|
||||
StreamBufferT sb = buffer.write_end()->Allocate();
|
||||
|
||||
sb.data[0] = 0x03;
|
||||
sb.data[1] = 0xe8;
|
||||
buffer.write_end().Written(2);
|
||||
buffer.write_end()->Written(2);
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
|
||||
sb = buffer.write_end().Allocate();
|
||||
sb = buffer.write_end()->Allocate();
|
||||
memcpy(sb.data, data + 200 * i, 200);
|
||||
buffer.write_end().Written(200);
|
||||
buffer.write_end()->Written(200);
|
||||
}
|
||||
|
||||
sb = buffer.write_end().Allocate();
|
||||
sb = buffer.write_end()->Allocate();
|
||||
sb.data[0] = 0;
|
||||
sb.data[1] = 0;
|
||||
buffer.write_end().Written(2);
|
||||
buffer.write_end()->Written(2);
|
||||
|
||||
sb = buffer.write_end().Allocate();
|
||||
sb = buffer.write_end()->Allocate();
|
||||
memcpy(sb.data, data, 1000);
|
||||
buffer.write_end().Written(1000);
|
||||
buffer.write_end()->Written(1000);
|
||||
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
|
||||
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
|
||||
@ -118,8 +118,8 @@ TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) {
|
||||
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
|
||||
|
||||
uint8_t *leftover = buffer.read_end().data();
|
||||
ASSERT_EQ(buffer.read_end().size(), 1000);
|
||||
uint8_t *leftover = buffer.read_end()->data();
|
||||
ASSERT_EQ(buffer.read_end()->size(), 1000);
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], leftover[i]);
|
||||
}
|
||||
|
||||
|
@ -20,8 +20,8 @@ class TestSession : public Session<TestInputStream, TestOutputStream> {
|
||||
public:
|
||||
using Session<TestInputStream, TestOutputStream>::TEncoder;
|
||||
|
||||
TestSession(TestSessionData &data, TestInputStream &input_stream,
|
||||
TestOutputStream &output_stream)
|
||||
TestSession(TestSessionData *data, TestInputStream *input_stream,
|
||||
TestOutputStream *output_stream)
|
||||
: Session<TestInputStream, TestOutputStream>(input_stream,
|
||||
output_stream) {}
|
||||
|
||||
@ -61,11 +61,11 @@ class TestSession : public Session<TestInputStream, TestOutputStream> {
|
||||
|
||||
// TODO: This could be done in fixture.
|
||||
// Shortcuts for writing variable initializations in tests
|
||||
#define INIT_VARS \
|
||||
TestInputStream input_stream; \
|
||||
TestOutputStream output_stream; \
|
||||
TestSessionData session_data; \
|
||||
TestSession session(session_data, input_stream, output_stream); \
|
||||
#define INIT_VARS \
|
||||
TestInputStream input_stream; \
|
||||
TestOutputStream output_stream; \
|
||||
TestSessionData session_data; \
|
||||
TestSession session(&session_data, &input_stream, &output_stream); \
|
||||
std::vector<uint8_t> &output = output_stream.output;
|
||||
|
||||
// Sample testdata that has correct inputs and outputs.
|
||||
|
@ -8,47 +8,47 @@ using communication::Buffer;
|
||||
|
||||
TEST(CommunicationBuffer, AllocateAndWritten) {
|
||||
Buffer buffer;
|
||||
auto sb = buffer.write_end().Allocate();
|
||||
auto sb = buffer.write_end()->Allocate();
|
||||
|
||||
memcpy(sb.data, data, 1000);
|
||||
buffer.write_end().Written(1000);
|
||||
buffer.write_end()->Written(1000);
|
||||
|
||||
ASSERT_EQ(buffer.read_end().size(), 1000);
|
||||
ASSERT_EQ(buffer.read_end()->size(), 1000);
|
||||
|
||||
uint8_t *tmp = buffer.read_end().data();
|
||||
uint8_t *tmp = buffer.read_end()->data();
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
|
||||
}
|
||||
|
||||
TEST(CommunicationBuffer, Shift) {
|
||||
Buffer buffer;
|
||||
auto sb = buffer.write_end().Allocate();
|
||||
auto sb = buffer.write_end()->Allocate();
|
||||
|
||||
memcpy(sb.data, data, 1000);
|
||||
buffer.write_end().Written(1000);
|
||||
buffer.write_end()->Written(1000);
|
||||
|
||||
sb = buffer.write_end().Allocate();
|
||||
sb = buffer.write_end()->Allocate();
|
||||
memcpy(sb.data, data + 1000, 1000);
|
||||
buffer.write_end().Written(1000);
|
||||
buffer.write_end()->Written(1000);
|
||||
|
||||
ASSERT_EQ(buffer.read_end().size(), 2000);
|
||||
ASSERT_EQ(buffer.read_end()->size(), 2000);
|
||||
|
||||
uint8_t *tmp = buffer.read_end().data();
|
||||
uint8_t *tmp = buffer.read_end()->data();
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
|
||||
|
||||
buffer.read_end().Shift(1000);
|
||||
ASSERT_EQ(buffer.read_end().size(), 1000);
|
||||
tmp = buffer.read_end().data();
|
||||
buffer.read_end()->Shift(1000);
|
||||
ASSERT_EQ(buffer.read_end()->size(), 1000);
|
||||
tmp = buffer.read_end()->data();
|
||||
|
||||
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i + 1000], tmp[i]);
|
||||
}
|
||||
|
||||
TEST(CommunicationBuffer, Resize) {
|
||||
Buffer buffer;
|
||||
auto sb = buffer.write_end().Allocate();
|
||||
auto sb = buffer.write_end()->Allocate();
|
||||
|
||||
buffer.read_end().Resize(sb.len + 1000);
|
||||
buffer.read_end()->Resize(sb.len + 1000);
|
||||
|
||||
auto sbn = buffer.write_end().Allocate();
|
||||
auto sbn = buffer.write_end()->Allocate();
|
||||
ASSERT_EQ(sb.len + 1000, sbn.len);
|
||||
}
|
||||
|
||||
|
@ -14,24 +14,24 @@ class TestData {};
|
||||
|
||||
class TestSession {
|
||||
public:
|
||||
TestSession(TestData &, const io::network::Endpoint &,
|
||||
communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
TestSession(TestData *, const io::network::Endpoint &,
|
||||
communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream)
|
||||
: input_stream_(input_stream), output_stream_(output_stream) {}
|
||||
|
||||
void Execute() {
|
||||
LOG(INFO) << "Received data: '"
|
||||
<< std::string(
|
||||
reinterpret_cast<const char *>(input_stream_.data()),
|
||||
input_stream_.size())
|
||||
reinterpret_cast<const char *>(input_stream_->data()),
|
||||
input_stream_->size())
|
||||
<< "'";
|
||||
output_stream_.Write(input_stream_.data(), input_stream_.size());
|
||||
input_stream_.Shift(input_stream_.size());
|
||||
output_stream_->Write(input_stream_->data(), input_stream_->size());
|
||||
input_stream_->Shift(input_stream_->size());
|
||||
}
|
||||
|
||||
private:
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
communication::InputStream *input_stream_;
|
||||
communication::OutputStream *output_stream_;
|
||||
};
|
||||
|
||||
const std::string query("timeout test");
|
||||
@ -55,7 +55,7 @@ TEST(NetworkTimeouts, InactiveSession) {
|
||||
TestData test_data;
|
||||
communication::ServerContext context;
|
||||
communication::Server<TestSession, TestData> server{
|
||||
{"127.0.0.1", 0}, test_data, &context, 2, "Test", 1};
|
||||
{"127.0.0.1", 0}, &test_data, &context, 2, "Test", 1};
|
||||
|
||||
// Create the client and connect to the server.
|
||||
io::network::Socket client;
|
||||
|
Loading…
Reference in New Issue
Block a user