Refactor network stack to use * instead of &

Reviewers: teon.banek, buda

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1587
This commit is contained in:
Matej Ferencevic 2018-09-03 15:29:06 +02:00
parent 4e996d2667
commit 240472e7cb
23 changed files with 175 additions and 175 deletions

View File

@ -41,8 +41,8 @@ class Session {
public:
using TEncoder = Encoder<ChunkedEncoderBuffer<TOutputStream>>;
Session(TInputStream &input_stream, TOutputStream &output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {}
Session(TInputStream *input_stream, TOutputStream *output_stream)
: input_stream_(*input_stream), output_stream_(*output_stream) {}
virtual ~Session() {}

View File

@ -5,35 +5,35 @@
namespace communication {
Buffer::Buffer()
: data_(kBufferInitialSize, 0), read_end_(*this), write_end_(*this) {}
: data_(kBufferInitialSize, 0), read_end_(this), write_end_(this) {}
Buffer::ReadEnd::ReadEnd(Buffer &buffer) : buffer_(buffer) {}
Buffer::ReadEnd::ReadEnd(Buffer *buffer) : buffer_(buffer) {}
uint8_t *Buffer::ReadEnd::data() { return buffer_.data(); }
uint8_t *Buffer::ReadEnd::data() { return buffer_->data(); }
size_t Buffer::ReadEnd::size() const { return buffer_.size(); }
size_t Buffer::ReadEnd::size() const { return buffer_->size(); }
void Buffer::ReadEnd::Shift(size_t len) { buffer_.Shift(len); }
void Buffer::ReadEnd::Shift(size_t len) { buffer_->Shift(len); }
void Buffer::ReadEnd::Resize(size_t len) { buffer_.Resize(len); }
void Buffer::ReadEnd::Resize(size_t len) { buffer_->Resize(len); }
void Buffer::ReadEnd::Clear() { buffer_.Clear(); }
void Buffer::ReadEnd::Clear() { buffer_->Clear(); }
Buffer::WriteEnd::WriteEnd(Buffer &buffer) : buffer_(buffer) {}
Buffer::WriteEnd::WriteEnd(Buffer *buffer) : buffer_(buffer) {}
io::network::StreamBuffer Buffer::WriteEnd::Allocate() {
return buffer_.Allocate();
return buffer_->Allocate();
}
void Buffer::WriteEnd::Written(size_t len) { buffer_.Written(len); }
void Buffer::WriteEnd::Written(size_t len) { buffer_->Written(len); }
void Buffer::WriteEnd::Resize(size_t len) { buffer_.Resize(len); }
void Buffer::WriteEnd::Resize(size_t len) { buffer_->Resize(len); }
void Buffer::WriteEnd::Clear() { buffer_.Clear(); }
void Buffer::WriteEnd::Clear() { buffer_->Clear(); }
Buffer::ReadEnd &Buffer::read_end() { return read_end_; }
Buffer::ReadEnd *Buffer::read_end() { return &read_end_; }
Buffer::WriteEnd &Buffer::write_end() { return write_end_; }
Buffer::WriteEnd *Buffer::write_end() { return &write_end_; }
uint8_t *Buffer::data() { return data_.data(); }

View File

@ -39,7 +39,7 @@ class Buffer final {
*/
class ReadEnd {
public:
ReadEnd(Buffer &buffer);
ReadEnd(Buffer *buffer);
ReadEnd(const ReadEnd &) = delete;
ReadEnd(ReadEnd &&) = delete;
@ -57,7 +57,7 @@ class Buffer final {
void Clear();
private:
Buffer &buffer_;
Buffer *buffer_;
};
/**
@ -66,7 +66,7 @@ class Buffer final {
*/
class WriteEnd {
public:
WriteEnd(Buffer &buffer);
WriteEnd(Buffer *buffer);
WriteEnd(const WriteEnd &) = delete;
WriteEnd(WriteEnd &&) = delete;
@ -82,20 +82,20 @@ class Buffer final {
void Clear();
private:
Buffer &buffer_;
Buffer *buffer_;
};
/**
* This function returns a reference to the associated ReadEnd object for this
* This function returns a pointer to the associated ReadEnd object for this
* buffer.
*/
ReadEnd &read_end();
ReadEnd *read_end();
/**
* This function returns a reference to the associated WriteEnd object for
* This function returns a pointer to the associated WriteEnd object for
* this buffer.
*/
WriteEnd &write_end();
WriteEnd *write_end();
private:
/**

View File

@ -83,9 +83,9 @@ void Client::Close() {
bool Client::Read(size_t len) {
size_t received = 0;
buffer_.write_end().Resize(buffer_.read_end().size() + len);
buffer_.write_end()->Resize(buffer_.read_end()->size() + len);
while (received < len) {
auto buff = buffer_.write_end().Allocate();
auto buff = buffer_.write_end()->Allocate();
if (ssl_) {
// We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this:
@ -120,7 +120,7 @@ bool Client::Read(size_t len) {
}
// Notify the buffer that it has new data.
buffer_.write_end().Written(got);
buffer_.write_end()->Written(got);
received += got;
} else {
// Read raw data from the socket.
@ -135,20 +135,20 @@ bool Client::Read(size_t len) {
}
// Notify the buffer that it has new data.
buffer_.write_end().Written(got);
buffer_.write_end()->Written(got);
received += got;
}
}
return true;
}
uint8_t *Client::GetData() { return buffer_.read_end().data(); }
uint8_t *Client::GetData() { return buffer_.read_end()->data(); }
size_t Client::GetDataSize() { return buffer_.read_end().size(); }
size_t Client::GetDataSize() { return buffer_.read_end()->size(); }
void Client::ShiftData(size_t len) { buffer_.read_end().Shift(len); }
void Client::ShiftData(size_t len) { buffer_.read_end()->Shift(len); }
void Client::ClearData() { buffer_.read_end().Clear(); }
void Client::ClearData() { buffer_.read_end()->Clear(); }
bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
if (ssl_) {

View File

@ -40,7 +40,7 @@ class Listener final {
using SessionHandler = Session<TSession, TSessionData>;
public:
Listener(TSessionData &data, ServerContext *context,
Listener(TSessionData *data, ServerContext *context,
int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count)
: data_(data),
@ -227,7 +227,7 @@ class Listener final {
io::network::Epoll epoll_;
TSessionData &data_;
TSessionData *data_;
utils::SpinLock lock_;
std::vector<std::unique_ptr<SessionHandler>> sessions_;

View File

@ -12,25 +12,25 @@
namespace communication::rpc {
Session::Session(Server &server, const io::network::Endpoint &endpoint,
communication::InputStream &input_stream,
communication::OutputStream &output_stream)
Session::Session(Server *server, const io::network::Endpoint &endpoint,
communication::InputStream *input_stream,
communication::OutputStream *output_stream)
: server_(server),
endpoint_(endpoint),
input_stream_(input_stream),
output_stream_(output_stream) {}
void Session::Execute() {
if (input_stream_.size() < sizeof(MessageSize)) return;
if (input_stream_->size() < sizeof(MessageSize)) return;
MessageSize request_len =
*reinterpret_cast<MessageSize *>(input_stream_.data());
*reinterpret_cast<MessageSize *>(input_stream_->data());
uint64_t request_size = sizeof(MessageSize) + request_len;
input_stream_.Resize(request_size);
if (input_stream_.size() < request_size) return;
input_stream_->Resize(request_size);
if (input_stream_->size() < request_size) return;
// Read the request message.
auto data =
::kj::arrayPtr(input_stream_.data() + sizeof(request_len), request_len);
::kj::arrayPtr(input_stream_->data() + sizeof(request_len), request_len);
// Our data is word aligned and padded to 64bit because we use regular
// (non-packed) serialization of Cap'n Proto. So we can use reinterpret_cast.
auto data_words =
@ -38,18 +38,18 @@ void Session::Execute() {
reinterpret_cast<::capnp::word *>(data.end()));
::capnp::FlatArrayMessageReader request_message(data_words.asConst());
auto request = request_message.getRoot<capnp::Message>();
input_stream_.Shift(sizeof(MessageSize) + request_len);
input_stream_->Shift(sizeof(MessageSize) + request_len);
::capnp::MallocMessageBuilder response_message;
// callback fills the message data
auto response_builder = response_message.initRoot<capnp::Message>();
auto callbacks_accessor = server_.callbacks_.access();
auto callbacks_accessor = server_->callbacks_.access();
auto it = callbacks_accessor.find(request.getTypeId());
if (it == callbacks_accessor.end()) {
// We couldn't find a regular callback to call, try to find an extended
// callback to call.
auto extended_callbacks_accessor = server_.extended_callbacks_.access();
auto extended_callbacks_accessor = server_->extended_callbacks_.access();
auto extended_it = extended_callbacks_accessor.find(request.getTypeId());
if (extended_it == extended_callbacks_accessor.end()) {
// Throw exception to close the socket and cleanup the session.
@ -73,11 +73,11 @@ void Session::Execute() {
}
MessageSize input_stream_size = response_bytes.size();
if (!output_stream_.Write(reinterpret_cast<uint8_t *>(&input_stream_size),
if (!output_stream_->Write(reinterpret_cast<uint8_t *>(&input_stream_size),
sizeof(MessageSize), true)) {
throw SessionException("Couldn't send response size!");
}
if (!output_stream_.Write(response_bytes.begin(), response_bytes.size())) {
if (!output_stream_->Write(response_bytes.begin(), response_bytes.size())) {
throw SessionException("Couldn't send response data!");
}

View File

@ -36,9 +36,9 @@ class SessionException : public utils::BasicException {
*/
class Session {
public:
Session(Server &server, const io::network::Endpoint &endpoint,
communication::InputStream &input_stream,
communication::OutputStream &output_stream);
Session(Server *server, const io::network::Endpoint &endpoint,
communication::InputStream *input_stream,
communication::OutputStream *output_stream);
/**
* Executes the protocol after data has been read into the stream.
@ -48,10 +48,10 @@ class Session {
void Execute();
private:
Server &server_;
Server *server_;
io::network::Endpoint endpoint_;
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
communication::InputStream *input_stream_;
communication::OutputStream *output_stream_;
};
} // namespace communication::rpc

View File

@ -4,7 +4,7 @@ namespace communication::rpc {
Server::Server(const io::network::Endpoint &endpoint,
size_t workers_count)
: server_(endpoint, *this, &context_, -1, "RPC", workers_count) {}
: server_(endpoint, this, &context_, -1, "RPC", workers_count) {}
void Server::StopProcessingCalls() {
server_.Shutdown();

View File

@ -46,7 +46,7 @@ class Server final {
* Constructs and binds server to endpoint, operates on session data and
* invokes workers_count workers
*/
Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
Server(const io::network::Endpoint &endpoint, TSessionData *session_data,
ServerContext *context, int inactivity_timeout_sec,
const std::string &service_name,
size_t workers_count = std::thread::hardware_concurrency())

View File

@ -65,14 +65,14 @@ class OutputStream final {
template <class TSession, class TSessionData>
class Session final {
public:
Session(io::network::Socket &&socket, TSessionData &data,
Session(io::network::Socket &&socket, TSessionData *data,
ServerContext *context, int inactivity_timeout_sec)
: socket_(std::move(socket)),
output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
return Write(data, len, have_more);
}),
session_(data, socket_.endpoint(), input_buffer_.read_end(),
output_stream_),
&output_stream_),
inactivity_timeout_sec_(inactivity_timeout_sec) {
// Set socket options.
// The socket is set to be a non-blocking socket. We use the socket in a
@ -145,7 +145,7 @@ class Session final {
RefreshLastEventTime();
// Allocate the buffer to fill the data.
auto buf = input_buffer_.write_end().Allocate();
auto buf = input_buffer_.write_end()->Allocate();
if (ssl_) {
// We clear errors here to prevent errors piling up in the internal
@ -181,7 +181,7 @@ class Session final {
return false;
} else {
// Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len);
input_buffer_.write_end()->Written(len);
}
} else {
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
@ -205,7 +205,7 @@ class Session final {
throw SessionClosedException("Session was closed by client.");
} else {
// Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len);
input_buffer_.write_end()->Written(len);
}
}

View File

@ -79,7 +79,7 @@ void SingleNodeMain() {
}
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, &context, FLAGS_session_inactivity_timeout,
&session_data, &context, FLAGS_session_inactivity_timeout,
service_name, FLAGS_num_workers);
// Setup telemetry
@ -160,7 +160,7 @@ void MasterMain() {
}
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, &context, FLAGS_session_inactivity_timeout,
&session_data, &context, FLAGS_session_inactivity_timeout,
service_name, FLAGS_num_workers);
// Handler for regular termination signals

View File

@ -22,14 +22,14 @@ DEFINE_uint64(memory_warning_threshold, 1024,
"less available RAM it will log a warning. Set to 0 to "
"disable.");
BoltSession::BoltSession(SessionData &data, const io::network::Endpoint &,
communication::InputStream &input_stream,
communication::OutputStream &output_stream)
BoltSession::BoltSession(SessionData *data, const io::network::Endpoint &,
communication::InputStream *input_stream,
communication::OutputStream *output_stream)
: communication::bolt::Session<communication::InputStream,
communication::OutputStream>(input_stream,
output_stream),
transaction_engine_(data.db, data.interpreter),
auth_(&data.auth) {}
transaction_engine_(data->db, data->interpreter),
auth_(&data->auth) {}
using TEncoder =
communication::bolt::Session<communication::InputStream,

View File

@ -31,9 +31,9 @@ class BoltSession final
: public communication::bolt::Session<communication::InputStream,
communication::OutputStream> {
public:
BoltSession(SessionData &data, const io::network::Endpoint &endpoint,
communication::InputStream &input_stream,
communication::OutputStream &output_stream);
BoltSession(SessionData *data, const io::network::Endpoint &endpoint,
communication::InputStream *input_stream,
communication::OutputStream *output_stream);
using communication::bolt::Session<communication::InputStream,
communication::OutputStream>::TEncoder;

View File

@ -23,28 +23,28 @@ class TestData {};
class TestSession {
public:
TestSession(TestData &, const io::network::Endpoint &,
communication::InputStream &input_stream,
communication::OutputStream &output_stream)
TestSession(TestData *, const io::network::Endpoint &,
communication::InputStream *input_stream,
communication::OutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() {
if (input_stream_.size() < 2) return;
const uint8_t *data = input_stream_.data();
if (input_stream_->size() < 2) return;
const uint8_t *data = input_stream_->data();
size_t size = data[0];
size <<= 8;
size += data[1];
input_stream_.Resize(size + 2);
if (input_stream_.size() < size + 2) return;
input_stream_->Resize(size + 2);
if (input_stream_->size() < size + 2) return;
for (int i = 0; i < REPLY; ++i)
ASSERT_TRUE(output_stream_.Write(data + 2, size));
ASSERT_TRUE(output_stream_->Write(data + 2, size));
input_stream_.Shift(size + 2);
input_stream_->Shift(size + 2);
}
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
communication::InputStream *input_stream_;
communication::OutputStream *output_stream_;
};
using ContextT = communication::ServerContext;

View File

@ -24,17 +24,17 @@ class TestData {};
class TestSession {
public:
TestSession(TestData &, const io::network::Endpoint &,
communication::InputStream &input_stream,
communication::OutputStream &output_stream)
TestSession(TestData *, const io::network::Endpoint &,
communication::InputStream *input_stream,
communication::OutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() {
output_stream_.Write(input_stream_.data(), input_stream_.size());
output_stream_->Write(input_stream_->data(), input_stream_->size());
}
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
communication::InputStream *input_stream_;
communication::OutputStream *output_stream_;
};
std::atomic<bool> run{true};
@ -65,7 +65,7 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
int N = (std::thread::hardware_concurrency() + 1) / 2;
int Nc = N * 3;
communication::ServerContext context;
communication::Server<TestSession, TestData> server(endpoint, data, &context,
communication::Server<TestSession, TestData> server(endpoint, &data, &context,
-1, "Test", N);
const auto &ep = server.endpoint();

View File

@ -22,7 +22,7 @@ TEST(Network, Server) {
TestData session_data;
int N = (std::thread::hardware_concurrency() + 1) / 2;
ContextT context;
ServerT server(endpoint, session_data, &context, -1, "Test", N);
ServerT server(endpoint, &session_data, &context, -1, "Test", N);
const auto &ep = server.endpoint();
// start clients

View File

@ -23,7 +23,7 @@ TEST(Network, SessionLeak) {
// initialize server
TestData session_data;
ContextT context;
ServerT server(endpoint, session_data, &context, -1, "Test", 2);
ServerT server(endpoint, &session_data, &context, -1, "Test", 2);
// start clients
int N = 50;

View File

@ -21,24 +21,24 @@ struct EchoData {};
class EchoSession {
public:
EchoSession(EchoData &, const io::network::Endpoint &,
communication::InputStream &input_stream,
communication::OutputStream &output_stream)
EchoSession(EchoData *, const io::network::Endpoint &,
communication::InputStream *input_stream,
communication::OutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() {
if (input_stream_.size() < message.size()) return;
if (input_stream_->size() < message.size()) return;
LOG(INFO) << "Server received message.";
if (!output_stream_.Write(input_stream_.data(), message.size())) {
if (!output_stream_->Write(input_stream_->data(), message.size())) {
throw utils::BasicException("Output stream write failed!");
}
LOG(INFO) << "Server sent message.";
input_stream_.Shift(message.size());
input_stream_->Shift(message.size());
}
private:
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
communication::InputStream *input_stream_;
communication::OutputStream *output_stream_;
};
int main(int argc, char **argv) {
@ -54,7 +54,7 @@ int main(int argc, char **argv) {
FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file,
FLAGS_server_verify_peer);
communication::Server<EchoSession, EchoData> server(
{"127.0.0.1", 0}, echo_data, &server_context, -1, "SSL", 1);
{"127.0.0.1", 0}, &echo_data, &server_context, -1, "SSL", 1);
// Initialize the client.
communication::ClientContext client_context(FLAGS_client_key_file,

View File

@ -19,37 +19,37 @@ struct EchoData {
class EchoSession {
public:
EchoSession(EchoData &data, const io::network::Endpoint &,
communication::InputStream &input_stream,
communication::OutputStream &output_stream)
EchoSession(EchoData *data, const io::network::Endpoint &,
communication::InputStream *input_stream,
communication::OutputStream *output_stream)
: data_(data),
input_stream_(input_stream),
output_stream_(output_stream) {}
void Execute() {
if (input_stream_.size() < 2) return;
const uint8_t *data = input_stream_.data();
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_.data());
input_stream_.Resize(size + 2);
if (input_stream_.size() < size + 2) return;
if (input_stream_->size() < 2) return;
const uint8_t *data = input_stream_->data();
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_->data());
input_stream_->Resize(size + 2);
if (input_stream_->size() < size + 2) return;
if (size == 0) {
LOG(INFO) << "Server received EOF message";
data_.alive.store(false);
data_->alive.store(false);
return;
}
LOG(INFO) << "Server received '"
<< std::string(reinterpret_cast<const char *>(data + 2), size)
<< "'";
if (!output_stream_.Write(data + 2, size)) {
if (!output_stream_->Write(data + 2, size)) {
throw utils::BasicException("Output stream write failed!");
}
input_stream_.Shift(size + 2);
input_stream_->Shift(size + 2);
}
private:
EchoData &data_;
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
EchoData *data_;
communication::InputStream *input_stream_;
communication::OutputStream *output_stream_;
};
int main(int argc, char **argv) {
@ -63,7 +63,7 @@ int main(int argc, char **argv) {
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file,
FLAGS_ca_file, FLAGS_verify_peer);
communication::Server<EchoSession, EchoData> server(endpoint, echo_data,
communication::Server<EchoSession, EchoData> server(endpoint, &echo_data,
&context, -1, "SSL", 1);
while (echo_data.alive) {

View File

@ -13,15 +13,15 @@ using ChunkStateT = communication::bolt::ChunkState;
TEST(BoltBuffer, CorrectChunk) {
uint8_t tmp[2000];
BufferT buffer;
DecoderBufferT decoder_buffer(buffer.read_end());
StreamBufferT sb = buffer.write_end().Allocate();
DecoderBufferT decoder_buffer(*buffer.read_end());
StreamBufferT sb = buffer.write_end()->Allocate();
sb.data[0] = 0x03;
sb.data[1] = 0xe8;
memcpy(sb.data + 2, data, 1000);
sb.data[1002] = 0;
sb.data[1003] = 0;
buffer.write_end().Written(1004);
buffer.write_end()->Written(1004);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
@ -29,21 +29,21 @@ TEST(BoltBuffer, CorrectChunk) {
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
ASSERT_EQ(buffer.read_end().size(), 0);
ASSERT_EQ(buffer.read_end()->size(), 0);
}
TEST(BoltBuffer, CorrectChunkTrailingData) {
uint8_t tmp[2000];
BufferT buffer;
DecoderBufferT decoder_buffer(buffer.read_end());
StreamBufferT sb = buffer.write_end().Allocate();
DecoderBufferT decoder_buffer(*buffer.read_end());
StreamBufferT sb = buffer.write_end()->Allocate();
sb.data[0] = 0x03;
sb.data[1] = 0xe8;
memcpy(sb.data + 2, data, 2002);
sb.data[1002] = 0;
sb.data[1003] = 0;
buffer.write_end().Written(2004);
buffer.write_end()->Written(2004);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
@ -51,66 +51,66 @@ TEST(BoltBuffer, CorrectChunkTrailingData) {
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
uint8_t *leftover = buffer.read_end().data();
ASSERT_EQ(buffer.read_end().size(), 1000);
uint8_t *leftover = buffer.read_end()->data();
ASSERT_EQ(buffer.read_end()->size(), 1000);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i + 1002], leftover[i]);
}
TEST(BoltBuffer, GraduallyPopulatedChunk) {
uint8_t tmp[2000];
BufferT buffer;
DecoderBufferT decoder_buffer(buffer.read_end());
StreamBufferT sb = buffer.write_end().Allocate();
DecoderBufferT decoder_buffer(*buffer.read_end());
StreamBufferT sb = buffer.write_end()->Allocate();
sb.data[0] = 0x03;
sb.data[1] = 0xe8;
buffer.write_end().Written(2);
buffer.write_end()->Written(2);
for (int i = 0; i < 5; ++i) {
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
sb = buffer.write_end().Allocate();
sb = buffer.write_end()->Allocate();
memcpy(sb.data, data + 200 * i, 200);
buffer.write_end().Written(200);
buffer.write_end()->Written(200);
}
sb = buffer.write_end().Allocate();
sb = buffer.write_end()->Allocate();
sb.data[0] = 0;
sb.data[1] = 0;
buffer.write_end().Written(2);
buffer.write_end()->Written(2);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
ASSERT_EQ(buffer.read_end().size(), 0);
ASSERT_EQ(buffer.read_end()->size(), 0);
}
TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) {
uint8_t tmp[2000];
BufferT buffer;
DecoderBufferT decoder_buffer(buffer.read_end());
StreamBufferT sb = buffer.write_end().Allocate();
DecoderBufferT decoder_buffer(*buffer.read_end());
StreamBufferT sb = buffer.write_end()->Allocate();
sb.data[0] = 0x03;
sb.data[1] = 0xe8;
buffer.write_end().Written(2);
buffer.write_end()->Written(2);
for (int i = 0; i < 5; ++i) {
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Partial);
sb = buffer.write_end().Allocate();
sb = buffer.write_end()->Allocate();
memcpy(sb.data, data + 200 * i, 200);
buffer.write_end().Written(200);
buffer.write_end()->Written(200);
}
sb = buffer.write_end().Allocate();
sb = buffer.write_end()->Allocate();
sb.data[0] = 0;
sb.data[1] = 0;
buffer.write_end().Written(2);
buffer.write_end()->Written(2);
sb = buffer.write_end().Allocate();
sb = buffer.write_end()->Allocate();
memcpy(sb.data, data, 1000);
buffer.write_end().Written(1000);
buffer.write_end()->Written(1000);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Whole);
ASSERT_EQ(decoder_buffer.GetChunk(), ChunkStateT::Done);
@ -118,8 +118,8 @@ TEST(BoltBuffer, GraduallyPopulatedChunkTrailingData) {
ASSERT_EQ(decoder_buffer.Read(tmp, 1000), true);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
uint8_t *leftover = buffer.read_end().data();
ASSERT_EQ(buffer.read_end().size(), 1000);
uint8_t *leftover = buffer.read_end()->data();
ASSERT_EQ(buffer.read_end()->size(), 1000);
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], leftover[i]);
}

View File

@ -20,8 +20,8 @@ class TestSession : public Session<TestInputStream, TestOutputStream> {
public:
using Session<TestInputStream, TestOutputStream>::TEncoder;
TestSession(TestSessionData &data, TestInputStream &input_stream,
TestOutputStream &output_stream)
TestSession(TestSessionData *data, TestInputStream *input_stream,
TestOutputStream *output_stream)
: Session<TestInputStream, TestOutputStream>(input_stream,
output_stream) {}
@ -61,11 +61,11 @@ class TestSession : public Session<TestInputStream, TestOutputStream> {
// TODO: This could be done in fixture.
// Shortcuts for writing variable initializations in tests
#define INIT_VARS \
TestInputStream input_stream; \
TestOutputStream output_stream; \
TestSessionData session_data; \
TestSession session(session_data, input_stream, output_stream); \
#define INIT_VARS \
TestInputStream input_stream; \
TestOutputStream output_stream; \
TestSessionData session_data; \
TestSession session(&session_data, &input_stream, &output_stream); \
std::vector<uint8_t> &output = output_stream.output;
// Sample testdata that has correct inputs and outputs.

View File

@ -8,47 +8,47 @@ using communication::Buffer;
TEST(CommunicationBuffer, AllocateAndWritten) {
Buffer buffer;
auto sb = buffer.write_end().Allocate();
auto sb = buffer.write_end()->Allocate();
memcpy(sb.data, data, 1000);
buffer.write_end().Written(1000);
buffer.write_end()->Written(1000);
ASSERT_EQ(buffer.read_end().size(), 1000);
ASSERT_EQ(buffer.read_end()->size(), 1000);
uint8_t *tmp = buffer.read_end().data();
uint8_t *tmp = buffer.read_end()->data();
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
}
TEST(CommunicationBuffer, Shift) {
Buffer buffer;
auto sb = buffer.write_end().Allocate();
auto sb = buffer.write_end()->Allocate();
memcpy(sb.data, data, 1000);
buffer.write_end().Written(1000);
buffer.write_end()->Written(1000);
sb = buffer.write_end().Allocate();
sb = buffer.write_end()->Allocate();
memcpy(sb.data, data + 1000, 1000);
buffer.write_end().Written(1000);
buffer.write_end()->Written(1000);
ASSERT_EQ(buffer.read_end().size(), 2000);
ASSERT_EQ(buffer.read_end()->size(), 2000);
uint8_t *tmp = buffer.read_end().data();
uint8_t *tmp = buffer.read_end()->data();
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i], tmp[i]);
buffer.read_end().Shift(1000);
ASSERT_EQ(buffer.read_end().size(), 1000);
tmp = buffer.read_end().data();
buffer.read_end()->Shift(1000);
ASSERT_EQ(buffer.read_end()->size(), 1000);
tmp = buffer.read_end()->data();
for (int i = 0; i < 1000; ++i) EXPECT_EQ(data[i + 1000], tmp[i]);
}
TEST(CommunicationBuffer, Resize) {
Buffer buffer;
auto sb = buffer.write_end().Allocate();
auto sb = buffer.write_end()->Allocate();
buffer.read_end().Resize(sb.len + 1000);
buffer.read_end()->Resize(sb.len + 1000);
auto sbn = buffer.write_end().Allocate();
auto sbn = buffer.write_end()->Allocate();
ASSERT_EQ(sb.len + 1000, sbn.len);
}

View File

@ -14,24 +14,24 @@ class TestData {};
class TestSession {
public:
TestSession(TestData &, const io::network::Endpoint &,
communication::InputStream &input_stream,
communication::OutputStream &output_stream)
TestSession(TestData *, const io::network::Endpoint &,
communication::InputStream *input_stream,
communication::OutputStream *output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() {
LOG(INFO) << "Received data: '"
<< std::string(
reinterpret_cast<const char *>(input_stream_.data()),
input_stream_.size())
reinterpret_cast<const char *>(input_stream_->data()),
input_stream_->size())
<< "'";
output_stream_.Write(input_stream_.data(), input_stream_.size());
input_stream_.Shift(input_stream_.size());
output_stream_->Write(input_stream_->data(), input_stream_->size());
input_stream_->Shift(input_stream_->size());
}
private:
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
communication::InputStream *input_stream_;
communication::OutputStream *output_stream_;
};
const std::string query("timeout test");
@ -55,7 +55,7 @@ TEST(NetworkTimeouts, InactiveSession) {
TestData test_data;
communication::ServerContext context;
communication::Server<TestSession, TestData> server{
{"127.0.0.1", 0}, test_data, &context, 2, "Test", 1};
{"127.0.0.1", 0}, &test_data, &context, 2, "Test", 1};
// Create the client and connect to the server.
io::network::Socket client;