Add support for large messages to RPC

Reviewers: buda, teon.banek, mculinovic

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1196
This commit is contained in:
Matej Ferencevic 2018-02-15 13:48:03 +01:00
parent d15464e181
commit 763bdaac0f
9 changed files with 184 additions and 27 deletions

View File

@ -4,6 +4,7 @@
set(memgraph_src_files
communication/bolt/v1/decoder/decoded_value.cpp
communication/bolt/v1/session.cpp
communication/rpc/buffer.cpp
communication/rpc/client.cpp
communication/rpc/protocol.cpp
communication/rpc/server.cpp

View File

@ -0,0 +1,39 @@
#include "glog/logging.h"
#include "communication/rpc/buffer.hpp"
namespace communication::rpc {
Buffer::Buffer() : data_(kBufferInitialSize, 0) {}
io::network::StreamBuffer Buffer::Allocate() {
return {data_.data() + have_, data_.size() - have_};
}
void Buffer::Written(size_t len) {
have_ += len;
DCHECK(have_ <= data_.size()) << "Written more than storage has space!";
}
void Buffer::Shift(size_t len) {
DCHECK(len <= have_) << "Tried to shift more data than the buffer has!";
if (len == have_) {
have_ = 0;
} else {
data_.erase(data_.begin(), data_.begin() + len);
have_ -= len;
}
}
void Buffer::Resize(size_t len) {
if (len < data_.size()) return;
data_.resize(len, 0);
}
void Buffer::Clear() { have_ = 0; }
uint8_t *Buffer::data() { return data_.data(); }
size_t Buffer::size() { return have_; }
} // namespace communication::rpc

View File

@ -0,0 +1,86 @@
#pragma once
#include <vector>
#include "io/network/stream_buffer.hpp"
namespace communication::rpc {
// Initial capacity of the internal buffer.
const size_t kBufferInitialSize = 65536;
/**
* @brief Buffer
*
* Has methods for writing and reading raw data.
*
* Allocating, writing and written stores data in the buffer. The stored
* data can then be read using the pointer returned with the data function.
* This implementation stores data in a variable sized array (a vector).
* The internal array can only grow in size.
*/
class Buffer {
public:
Buffer();
/**
* Allocates a new StreamBuffer from the internal buffer.
* This function returns a pointer to the first currently free memory
* location in the internal buffer. Also, it returns the size of the
* available memory.
*/
io::network::StreamBuffer Allocate();
/**
* This method is used to notify the buffer that the data has been written.
* To write data to this buffer you should do this:
* Call Allocate(), then write to the returned data pointer.
* IMPORTANT: Don't write more data then the returned size, you will cause
* a memory overflow. Then call Written(size) with the length of data that
* you have written into the buffer.
*
* @param len the size of data that has been written into the buffer
*/
void Written(size_t len);
/**
* This method shifts the available data for len. It is used when you read
* some data from the buffer and you want to remove it from the buffer.
*
* @param len the length of data that has to be removed from the start of
* the buffer
*/
void Shift(size_t len);
/**
* This method resizes the internal data buffer.
* It is used to notify the buffer of the incoming message size.
* If the requested size is larger than the buffer size then the buffer is
* resized, if the requested size is smaller than the buffer size then
* nothing is done.
*
* @param len the desired size of the buffer
*/
void Resize(size_t len);
/**
* This method clears the buffer.
*/
void Clear();
/**
* This function returns a pointer to the internal buffer. It is used for
* reading data from the buffer.
*/
uint8_t *data();
/**
* This function returns the size of available data for reading.
*/
size_t size();
private:
std::vector<uint8_t> data_;
size_t have_{0};
};
} // namespace communication::rpc

View File

@ -27,7 +27,7 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
// Connect to the remote server.
if (!socket_) {
socket_.emplace();
received_bytes_ = 0;
buffer_.Clear();
if (!socket_->Connect(endpoint_)) {
LOG(ERROR) << "Couldn't connect to remote address: " << endpoint_;
socket_ = std::experimental::nullopt;
@ -70,12 +70,12 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
}
const std::string &request_buffer = request_stream.str();
MessageSize request_data_size = request_buffer.size();
int64_t request_size = sizeof(uint32_t) + request_data_size;
CHECK(request_size <= kMaxMessageSize) << fmt::format(
"Trying to send message of size {}, max message size is {}", request_size,
kMaxMessageSize);
CHECK(request_buffer.size() <= std::numeric_limits<MessageSize>::max())
<< fmt::format(
"Trying to send message of size {}, max message size is {}",
request_buffer.size(), std::numeric_limits<MessageSize>::max());
MessageSize request_data_size = request_buffer.size();
if (!socket_->Write(reinterpret_cast<uint8_t *>(&request_data_size),
sizeof(MessageSize), true)) {
LOG(ERROR) << "Couldn't send request size!";
@ -91,21 +91,22 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
// Receive response.
while (true) {
auto received = socket_->Read(buffer_.data() + received_bytes_,
buffer_.size() - received_bytes_);
auto buff = buffer_.Allocate();
auto received = socket_->Read(buff.data, buff.len);
if (received <= 0) {
socket_ = std::experimental::nullopt;
return nullptr;
}
received_bytes_ += received;
buffer_.Written(received);
if (received_bytes_ < sizeof(uint32_t) + sizeof(MessageSize)) continue;
if (buffer_.size() < sizeof(uint32_t) + sizeof(MessageSize)) continue;
uint32_t response_id = *reinterpret_cast<uint32_t *>(buffer_.data());
MessageSize response_data_size =
*reinterpret_cast<MessageSize *>(buffer_.data() + sizeof(uint32_t));
size_t response_size =
sizeof(uint32_t) + sizeof(MessageSize) + response_data_size;
if (received_bytes_ < response_size) continue;
buffer_.Resize(response_size);
if (buffer_.size() < response_size) continue;
std::unique_ptr<Message> response;
{
@ -119,9 +120,7 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
response_archive >> response;
}
std::copy(buffer_.begin() + response_size,
buffer_.begin() + received_bytes_, buffer_.begin());
received_bytes_ -= response_size;
buffer_.Shift(response_size);
if (response_id != request_id) {
// This can happen if some stale response arrives after we issued a new

View File

@ -6,6 +6,7 @@
#include <glog/logging.h>
#include "communication/rpc/buffer.hpp"
#include "communication/rpc/messages.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/socket.hpp"
@ -53,8 +54,7 @@ class Client {
std::experimental::optional<io::network::Socket> socket_;
uint32_t next_message_id_{0};
std::array<uint8_t, kMaxMessageSize> buffer_;
size_t received_bytes_{0};
Buffer buffer_;
std::mutex mutex_;
};

View File

@ -9,9 +9,7 @@
namespace communication::rpc {
// This buffer should be larger than the largest serialized message.
const uint64_t kMaxMessageSize = 262144;
using MessageSize = uint16_t;
using MessageSize = uint32_t;
/**
* Base class for messages.

View File

@ -22,6 +22,7 @@ void Session::Execute() {
if (!handshake_done_) {
if (buffer_.size() < sizeof(MessageSize)) return;
MessageSize service_len = *reinterpret_cast<MessageSize *>(buffer_.data());
buffer_.Resize(sizeof(MessageSize) + service_len);
if (buffer_.size() < sizeof(MessageSize) + service_len) return;
service_name_ = std::string(
reinterpret_cast<char *>(buffer_.data() + sizeof(MessageSize)),
@ -34,8 +35,9 @@ void Session::Execute() {
uint32_t message_id = *reinterpret_cast<uint32_t *>(buffer_.data());
MessageSize message_len =
*reinterpret_cast<MessageSize *>(buffer_.data() + sizeof(uint32_t));
if (buffer_.size() < sizeof(uint32_t) + sizeof(MessageSize) + message_len)
return;
uint64_t request_size = sizeof(uint32_t) + sizeof(MessageSize) + message_len;
buffer_.Resize(request_size);
if (buffer_.size() < request_size) return;
// TODO (mferencevic): check for exceptions
std::unique_ptr<Message> message;
@ -78,10 +80,10 @@ void SendMessage(Socket &socket, uint32_t message_id,
}
const std::string &buffer = stream.str();
uint64_t message_size = sizeof(MessageSize) + buffer.size();
CHECK(message_size <= kMaxMessageSize) << fmt::format(
"Trying to send message of size {}, max message size is {}", message_size,
kMaxMessageSize);
CHECK(buffer.size() <= std::numeric_limits<MessageSize>::max())
<< fmt::format(
"Trying to send message of size {}, max message size is {}",
buffer.size(), std::numeric_limits<MessageSize>::max());
if (!socket.Write(reinterpret_cast<uint8_t *>(&message_id), sizeof(uint32_t),
true)) {

View File

@ -4,7 +4,7 @@
#include <cstdint>
#include <memory>
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/rpc/buffer.hpp"
#include "communication/rpc/messages.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/socket.hpp"
@ -26,7 +26,6 @@ namespace communication::rpc {
using Endpoint = io::network::Endpoint;
using Socket = io::network::Socket;
using StreamBuffer = io::network::StreamBuffer;
using Buffer = bolt::Buffer<kMaxMessageSize>;
// Forward declaration of class System
class System;

View File

@ -54,6 +54,24 @@ struct SumRes : public Message {
BOOST_CLASS_EXPORT(SumRes);
using Sum = RequestResponse<SumReq, SumRes>;
struct EchoMessage : public Message {
EchoMessage(const std::string &data) : data(data) {}
std::string data;
private:
friend class boost::serialization::access;
EchoMessage() {} // Needed for serialization.
template <class TArchive>
void serialize(TArchive &ar, unsigned int) {
ar &boost::serialization::base_object<Message>(*this);
ar &data;
}
};
BOOST_CLASS_EXPORT(EchoMessage);
using Echo = RequestResponse<EchoMessage, EchoMessage>;
TEST(Rpc, Call) {
System server_system({"127.0.0.1", 0});
Server server(server_system, "main");
@ -143,3 +161,18 @@ TEST(Rpc, ClientPool) {
}
EXPECT_LE(t2.Elapsed(), 200ms);
}
TEST(Rpc, LargeMessage) {
System server_system({"127.0.0.1", 0});
Server server(server_system, "main");
server.Register<Echo>([](const EchoMessage &request) {
return std::make_unique<EchoMessage>(request.data);
});
std::this_thread::sleep_for(100ms);
std::string testdata(100000, 'a');
Client client(server_system.endpoint(), "main");
auto echo = client.Call<Echo>(testdata);
EXPECT_EQ(echo->data, testdata);
}