Add support for large messages to RPC
Reviewers: buda, teon.banek, mculinovic Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1196
This commit is contained in:
parent
d15464e181
commit
763bdaac0f
@ -4,6 +4,7 @@
|
||||
set(memgraph_src_files
|
||||
communication/bolt/v1/decoder/decoded_value.cpp
|
||||
communication/bolt/v1/session.cpp
|
||||
communication/rpc/buffer.cpp
|
||||
communication/rpc/client.cpp
|
||||
communication/rpc/protocol.cpp
|
||||
communication/rpc/server.cpp
|
||||
|
39
src/communication/rpc/buffer.cpp
Normal file
39
src/communication/rpc/buffer.cpp
Normal file
@ -0,0 +1,39 @@
|
||||
#include "glog/logging.h"
|
||||
|
||||
#include "communication/rpc/buffer.hpp"
|
||||
|
||||
namespace communication::rpc {
|
||||
|
||||
Buffer::Buffer() : data_(kBufferInitialSize, 0) {}
|
||||
|
||||
io::network::StreamBuffer Buffer::Allocate() {
|
||||
return {data_.data() + have_, data_.size() - have_};
|
||||
}
|
||||
|
||||
void Buffer::Written(size_t len) {
|
||||
have_ += len;
|
||||
DCHECK(have_ <= data_.size()) << "Written more than storage has space!";
|
||||
}
|
||||
|
||||
void Buffer::Shift(size_t len) {
|
||||
DCHECK(len <= have_) << "Tried to shift more data than the buffer has!";
|
||||
if (len == have_) {
|
||||
have_ = 0;
|
||||
} else {
|
||||
data_.erase(data_.begin(), data_.begin() + len);
|
||||
have_ -= len;
|
||||
}
|
||||
}
|
||||
|
||||
void Buffer::Resize(size_t len) {
|
||||
if (len < data_.size()) return;
|
||||
data_.resize(len, 0);
|
||||
}
|
||||
|
||||
void Buffer::Clear() { have_ = 0; }
|
||||
|
||||
uint8_t *Buffer::data() { return data_.data(); }
|
||||
|
||||
size_t Buffer::size() { return have_; }
|
||||
|
||||
} // namespace communication::rpc
|
86
src/communication/rpc/buffer.hpp
Normal file
86
src/communication/rpc/buffer.hpp
Normal file
@ -0,0 +1,86 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "io/network/stream_buffer.hpp"
|
||||
|
||||
namespace communication::rpc {
|
||||
|
||||
// Initial capacity of the internal buffer.
|
||||
const size_t kBufferInitialSize = 65536;
|
||||
|
||||
/**
|
||||
* @brief Buffer
|
||||
*
|
||||
* Has methods for writing and reading raw data.
|
||||
*
|
||||
* Allocating, writing and written stores data in the buffer. The stored
|
||||
* data can then be read using the pointer returned with the data function.
|
||||
* This implementation stores data in a variable sized array (a vector).
|
||||
* The internal array can only grow in size.
|
||||
*/
|
||||
class Buffer {
|
||||
public:
|
||||
Buffer();
|
||||
|
||||
/**
|
||||
* Allocates a new StreamBuffer from the internal buffer.
|
||||
* This function returns a pointer to the first currently free memory
|
||||
* location in the internal buffer. Also, it returns the size of the
|
||||
* available memory.
|
||||
*/
|
||||
io::network::StreamBuffer Allocate();
|
||||
|
||||
/**
|
||||
* This method is used to notify the buffer that the data has been written.
|
||||
* To write data to this buffer you should do this:
|
||||
* Call Allocate(), then write to the returned data pointer.
|
||||
* IMPORTANT: Don't write more data then the returned size, you will cause
|
||||
* a memory overflow. Then call Written(size) with the length of data that
|
||||
* you have written into the buffer.
|
||||
*
|
||||
* @param len the size of data that has been written into the buffer
|
||||
*/
|
||||
void Written(size_t len);
|
||||
|
||||
/**
|
||||
* This method shifts the available data for len. It is used when you read
|
||||
* some data from the buffer and you want to remove it from the buffer.
|
||||
*
|
||||
* @param len the length of data that has to be removed from the start of
|
||||
* the buffer
|
||||
*/
|
||||
void Shift(size_t len);
|
||||
|
||||
/**
|
||||
* This method resizes the internal data buffer.
|
||||
* It is used to notify the buffer of the incoming message size.
|
||||
* If the requested size is larger than the buffer size then the buffer is
|
||||
* resized, if the requested size is smaller than the buffer size then
|
||||
* nothing is done.
|
||||
*
|
||||
* @param len the desired size of the buffer
|
||||
*/
|
||||
void Resize(size_t len);
|
||||
|
||||
/**
|
||||
* This method clears the buffer.
|
||||
*/
|
||||
void Clear();
|
||||
|
||||
/**
|
||||
* This function returns a pointer to the internal buffer. It is used for
|
||||
* reading data from the buffer.
|
||||
*/
|
||||
uint8_t *data();
|
||||
|
||||
/**
|
||||
* This function returns the size of available data for reading.
|
||||
*/
|
||||
size_t size();
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> data_;
|
||||
size_t have_{0};
|
||||
};
|
||||
} // namespace communication::rpc
|
@ -27,7 +27,7 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
|
||||
// Connect to the remote server.
|
||||
if (!socket_) {
|
||||
socket_.emplace();
|
||||
received_bytes_ = 0;
|
||||
buffer_.Clear();
|
||||
if (!socket_->Connect(endpoint_)) {
|
||||
LOG(ERROR) << "Couldn't connect to remote address: " << endpoint_;
|
||||
socket_ = std::experimental::nullopt;
|
||||
@ -70,12 +70,12 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
|
||||
}
|
||||
|
||||
const std::string &request_buffer = request_stream.str();
|
||||
MessageSize request_data_size = request_buffer.size();
|
||||
int64_t request_size = sizeof(uint32_t) + request_data_size;
|
||||
CHECK(request_size <= kMaxMessageSize) << fmt::format(
|
||||
"Trying to send message of size {}, max message size is {}", request_size,
|
||||
kMaxMessageSize);
|
||||
CHECK(request_buffer.size() <= std::numeric_limits<MessageSize>::max())
|
||||
<< fmt::format(
|
||||
"Trying to send message of size {}, max message size is {}",
|
||||
request_buffer.size(), std::numeric_limits<MessageSize>::max());
|
||||
|
||||
MessageSize request_data_size = request_buffer.size();
|
||||
if (!socket_->Write(reinterpret_cast<uint8_t *>(&request_data_size),
|
||||
sizeof(MessageSize), true)) {
|
||||
LOG(ERROR) << "Couldn't send request size!";
|
||||
@ -91,21 +91,22 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
|
||||
|
||||
// Receive response.
|
||||
while (true) {
|
||||
auto received = socket_->Read(buffer_.data() + received_bytes_,
|
||||
buffer_.size() - received_bytes_);
|
||||
auto buff = buffer_.Allocate();
|
||||
auto received = socket_->Read(buff.data, buff.len);
|
||||
if (received <= 0) {
|
||||
socket_ = std::experimental::nullopt;
|
||||
return nullptr;
|
||||
}
|
||||
received_bytes_ += received;
|
||||
buffer_.Written(received);
|
||||
|
||||
if (received_bytes_ < sizeof(uint32_t) + sizeof(MessageSize)) continue;
|
||||
if (buffer_.size() < sizeof(uint32_t) + sizeof(MessageSize)) continue;
|
||||
uint32_t response_id = *reinterpret_cast<uint32_t *>(buffer_.data());
|
||||
MessageSize response_data_size =
|
||||
*reinterpret_cast<MessageSize *>(buffer_.data() + sizeof(uint32_t));
|
||||
size_t response_size =
|
||||
sizeof(uint32_t) + sizeof(MessageSize) + response_data_size;
|
||||
if (received_bytes_ < response_size) continue;
|
||||
buffer_.Resize(response_size);
|
||||
if (buffer_.size() < response_size) continue;
|
||||
|
||||
std::unique_ptr<Message> response;
|
||||
{
|
||||
@ -119,9 +120,7 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
|
||||
response_archive >> response;
|
||||
}
|
||||
|
||||
std::copy(buffer_.begin() + response_size,
|
||||
buffer_.begin() + received_bytes_, buffer_.begin());
|
||||
received_bytes_ -= response_size;
|
||||
buffer_.Shift(response_size);
|
||||
|
||||
if (response_id != request_id) {
|
||||
// This can happen if some stale response arrives after we issued a new
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "communication/rpc/buffer.hpp"
|
||||
#include "communication/rpc/messages.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
@ -53,8 +54,7 @@ class Client {
|
||||
std::experimental::optional<io::network::Socket> socket_;
|
||||
|
||||
uint32_t next_message_id_{0};
|
||||
std::array<uint8_t, kMaxMessageSize> buffer_;
|
||||
size_t received_bytes_{0};
|
||||
Buffer buffer_;
|
||||
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
@ -9,9 +9,7 @@
|
||||
|
||||
namespace communication::rpc {
|
||||
|
||||
// This buffer should be larger than the largest serialized message.
|
||||
const uint64_t kMaxMessageSize = 262144;
|
||||
using MessageSize = uint16_t;
|
||||
using MessageSize = uint32_t;
|
||||
|
||||
/**
|
||||
* Base class for messages.
|
||||
|
@ -22,6 +22,7 @@ void Session::Execute() {
|
||||
if (!handshake_done_) {
|
||||
if (buffer_.size() < sizeof(MessageSize)) return;
|
||||
MessageSize service_len = *reinterpret_cast<MessageSize *>(buffer_.data());
|
||||
buffer_.Resize(sizeof(MessageSize) + service_len);
|
||||
if (buffer_.size() < sizeof(MessageSize) + service_len) return;
|
||||
service_name_ = std::string(
|
||||
reinterpret_cast<char *>(buffer_.data() + sizeof(MessageSize)),
|
||||
@ -34,8 +35,9 @@ void Session::Execute() {
|
||||
uint32_t message_id = *reinterpret_cast<uint32_t *>(buffer_.data());
|
||||
MessageSize message_len =
|
||||
*reinterpret_cast<MessageSize *>(buffer_.data() + sizeof(uint32_t));
|
||||
if (buffer_.size() < sizeof(uint32_t) + sizeof(MessageSize) + message_len)
|
||||
return;
|
||||
uint64_t request_size = sizeof(uint32_t) + sizeof(MessageSize) + message_len;
|
||||
buffer_.Resize(request_size);
|
||||
if (buffer_.size() < request_size) return;
|
||||
|
||||
// TODO (mferencevic): check for exceptions
|
||||
std::unique_ptr<Message> message;
|
||||
@ -78,10 +80,10 @@ void SendMessage(Socket &socket, uint32_t message_id,
|
||||
}
|
||||
|
||||
const std::string &buffer = stream.str();
|
||||
uint64_t message_size = sizeof(MessageSize) + buffer.size();
|
||||
CHECK(message_size <= kMaxMessageSize) << fmt::format(
|
||||
"Trying to send message of size {}, max message size is {}", message_size,
|
||||
kMaxMessageSize);
|
||||
CHECK(buffer.size() <= std::numeric_limits<MessageSize>::max())
|
||||
<< fmt::format(
|
||||
"Trying to send message of size {}, max message size is {}",
|
||||
buffer.size(), std::numeric_limits<MessageSize>::max());
|
||||
|
||||
if (!socket.Write(reinterpret_cast<uint8_t *>(&message_id), sizeof(uint32_t),
|
||||
true)) {
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "communication/bolt/v1/decoder/buffer.hpp"
|
||||
#include "communication/rpc/buffer.hpp"
|
||||
#include "communication/rpc/messages.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
@ -26,7 +26,6 @@ namespace communication::rpc {
|
||||
using Endpoint = io::network::Endpoint;
|
||||
using Socket = io::network::Socket;
|
||||
using StreamBuffer = io::network::StreamBuffer;
|
||||
using Buffer = bolt::Buffer<kMaxMessageSize>;
|
||||
|
||||
// Forward declaration of class System
|
||||
class System;
|
||||
|
@ -54,6 +54,24 @@ struct SumRes : public Message {
|
||||
BOOST_CLASS_EXPORT(SumRes);
|
||||
using Sum = RequestResponse<SumReq, SumRes>;
|
||||
|
||||
struct EchoMessage : public Message {
|
||||
EchoMessage(const std::string &data) : data(data) {}
|
||||
std::string data;
|
||||
|
||||
private:
|
||||
friend class boost::serialization::access;
|
||||
EchoMessage() {} // Needed for serialization.
|
||||
|
||||
template <class TArchive>
|
||||
void serialize(TArchive &ar, unsigned int) {
|
||||
ar &boost::serialization::base_object<Message>(*this);
|
||||
ar &data;
|
||||
}
|
||||
};
|
||||
BOOST_CLASS_EXPORT(EchoMessage);
|
||||
|
||||
using Echo = RequestResponse<EchoMessage, EchoMessage>;
|
||||
|
||||
TEST(Rpc, Call) {
|
||||
System server_system({"127.0.0.1", 0});
|
||||
Server server(server_system, "main");
|
||||
@ -143,3 +161,18 @@ TEST(Rpc, ClientPool) {
|
||||
}
|
||||
EXPECT_LE(t2.Elapsed(), 200ms);
|
||||
}
|
||||
|
||||
TEST(Rpc, LargeMessage) {
|
||||
System server_system({"127.0.0.1", 0});
|
||||
Server server(server_system, "main");
|
||||
server.Register<Echo>([](const EchoMessage &request) {
|
||||
return std::make_unique<EchoMessage>(request.data);
|
||||
});
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
std::string testdata(100000, 'a');
|
||||
|
||||
Client client(server_system.endpoint(), "main");
|
||||
auto echo = client.Call<Echo>(testdata);
|
||||
EXPECT_EQ(echo->data, testdata);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user