diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c30e99fcf..766b1017e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,6 +4,7 @@ set(memgraph_src_files communication/bolt/v1/decoder/decoded_value.cpp communication/bolt/v1/session.cpp + communication/rpc/buffer.cpp communication/rpc/client.cpp communication/rpc/protocol.cpp communication/rpc/server.cpp diff --git a/src/communication/rpc/buffer.cpp b/src/communication/rpc/buffer.cpp new file mode 100644 index 000000000..e77de0012 --- /dev/null +++ b/src/communication/rpc/buffer.cpp @@ -0,0 +1,39 @@ +#include "glog/logging.h" + +#include "communication/rpc/buffer.hpp" + +namespace communication::rpc { + +Buffer::Buffer() : data_(kBufferInitialSize, 0) {} + +io::network::StreamBuffer Buffer::Allocate() { + return {data_.data() + have_, data_.size() - have_}; +} + +void Buffer::Written(size_t len) { + have_ += len; + DCHECK(have_ <= data_.size()) << "Written more than storage has space!"; +} + +void Buffer::Shift(size_t len) { + DCHECK(len <= have_) << "Tried to shift more data than the buffer has!"; + if (len == have_) { + have_ = 0; + } else { + data_.erase(data_.begin(), data_.begin() + len); + have_ -= len; + } +} + +void Buffer::Resize(size_t len) { + if (len < data_.size()) return; + data_.resize(len, 0); +} + +void Buffer::Clear() { have_ = 0; } + +uint8_t *Buffer::data() { return data_.data(); } + +size_t Buffer::size() { return have_; } + +} // namespace communication::rpc diff --git a/src/communication/rpc/buffer.hpp b/src/communication/rpc/buffer.hpp new file mode 100644 index 000000000..2dc03a17a --- /dev/null +++ b/src/communication/rpc/buffer.hpp @@ -0,0 +1,86 @@ +#pragma once + +#include + +#include "io/network/stream_buffer.hpp" + +namespace communication::rpc { + +// Initial capacity of the internal buffer. +const size_t kBufferInitialSize = 65536; + +/** + * @brief Buffer + * + * Has methods for writing and reading raw data. + * + * Allocating, writing and written stores data in the buffer. The stored + * data can then be read using the pointer returned with the data function. + * This implementation stores data in a variable sized array (a vector). + * The internal array can only grow in size. + */ +class Buffer { + public: + Buffer(); + + /** + * Allocates a new StreamBuffer from the internal buffer. + * This function returns a pointer to the first currently free memory + * location in the internal buffer. Also, it returns the size of the + * available memory. + */ + io::network::StreamBuffer Allocate(); + + /** + * This method is used to notify the buffer that the data has been written. + * To write data to this buffer you should do this: + * Call Allocate(), then write to the returned data pointer. + * IMPORTANT: Don't write more data then the returned size, you will cause + * a memory overflow. Then call Written(size) with the length of data that + * you have written into the buffer. + * + * @param len the size of data that has been written into the buffer + */ + void Written(size_t len); + + /** + * This method shifts the available data for len. It is used when you read + * some data from the buffer and you want to remove it from the buffer. + * + * @param len the length of data that has to be removed from the start of + * the buffer + */ + void Shift(size_t len); + + /** + * This method resizes the internal data buffer. + * It is used to notify the buffer of the incoming message size. + * If the requested size is larger than the buffer size then the buffer is + * resized, if the requested size is smaller than the buffer size then + * nothing is done. + * + * @param len the desired size of the buffer + */ + void Resize(size_t len); + + /** + * This method clears the buffer. + */ + void Clear(); + + /** + * This function returns a pointer to the internal buffer. It is used for + * reading data from the buffer. + */ + uint8_t *data(); + + /** + * This function returns the size of available data for reading. + */ + size_t size(); + + private: + std::vector data_; + size_t have_{0}; +}; +} // namespace communication::rpc diff --git a/src/communication/rpc/client.cpp b/src/communication/rpc/client.cpp index f77410216..f96fb5f41 100644 --- a/src/communication/rpc/client.cpp +++ b/src/communication/rpc/client.cpp @@ -27,7 +27,7 @@ std::unique_ptr Client::Call(std::unique_ptr request) { // Connect to the remote server. if (!socket_) { socket_.emplace(); - received_bytes_ = 0; + buffer_.Clear(); if (!socket_->Connect(endpoint_)) { LOG(ERROR) << "Couldn't connect to remote address: " << endpoint_; socket_ = std::experimental::nullopt; @@ -70,12 +70,12 @@ std::unique_ptr Client::Call(std::unique_ptr request) { } const std::string &request_buffer = request_stream.str(); - MessageSize request_data_size = request_buffer.size(); - int64_t request_size = sizeof(uint32_t) + request_data_size; - CHECK(request_size <= kMaxMessageSize) << fmt::format( - "Trying to send message of size {}, max message size is {}", request_size, - kMaxMessageSize); + CHECK(request_buffer.size() <= std::numeric_limits::max()) + << fmt::format( + "Trying to send message of size {}, max message size is {}", + request_buffer.size(), std::numeric_limits::max()); + MessageSize request_data_size = request_buffer.size(); if (!socket_->Write(reinterpret_cast(&request_data_size), sizeof(MessageSize), true)) { LOG(ERROR) << "Couldn't send request size!"; @@ -91,21 +91,22 @@ std::unique_ptr Client::Call(std::unique_ptr request) { // Receive response. while (true) { - auto received = socket_->Read(buffer_.data() + received_bytes_, - buffer_.size() - received_bytes_); + auto buff = buffer_.Allocate(); + auto received = socket_->Read(buff.data, buff.len); if (received <= 0) { socket_ = std::experimental::nullopt; return nullptr; } - received_bytes_ += received; + buffer_.Written(received); - if (received_bytes_ < sizeof(uint32_t) + sizeof(MessageSize)) continue; + if (buffer_.size() < sizeof(uint32_t) + sizeof(MessageSize)) continue; uint32_t response_id = *reinterpret_cast(buffer_.data()); MessageSize response_data_size = *reinterpret_cast(buffer_.data() + sizeof(uint32_t)); size_t response_size = sizeof(uint32_t) + sizeof(MessageSize) + response_data_size; - if (received_bytes_ < response_size) continue; + buffer_.Resize(response_size); + if (buffer_.size() < response_size) continue; std::unique_ptr response; { @@ -119,9 +120,7 @@ std::unique_ptr Client::Call(std::unique_ptr request) { response_archive >> response; } - std::copy(buffer_.begin() + response_size, - buffer_.begin() + received_bytes_, buffer_.begin()); - received_bytes_ -= response_size; + buffer_.Shift(response_size); if (response_id != request_id) { // This can happen if some stale response arrives after we issued a new diff --git a/src/communication/rpc/client.hpp b/src/communication/rpc/client.hpp index 598b5745b..105c4c0be 100644 --- a/src/communication/rpc/client.hpp +++ b/src/communication/rpc/client.hpp @@ -6,6 +6,7 @@ #include +#include "communication/rpc/buffer.hpp" #include "communication/rpc/messages.hpp" #include "io/network/endpoint.hpp" #include "io/network/socket.hpp" @@ -53,8 +54,7 @@ class Client { std::experimental::optional socket_; uint32_t next_message_id_{0}; - std::array buffer_; - size_t received_bytes_{0}; + Buffer buffer_; std::mutex mutex_; }; diff --git a/src/communication/rpc/messages.hpp b/src/communication/rpc/messages.hpp index b9b74e1e4..0693ee8a3 100644 --- a/src/communication/rpc/messages.hpp +++ b/src/communication/rpc/messages.hpp @@ -9,9 +9,7 @@ namespace communication::rpc { -// This buffer should be larger than the largest serialized message. -const uint64_t kMaxMessageSize = 262144; -using MessageSize = uint16_t; +using MessageSize = uint32_t; /** * Base class for messages. diff --git a/src/communication/rpc/protocol.cpp b/src/communication/rpc/protocol.cpp index 783c7151c..a652eb6b4 100644 --- a/src/communication/rpc/protocol.cpp +++ b/src/communication/rpc/protocol.cpp @@ -22,6 +22,7 @@ void Session::Execute() { if (!handshake_done_) { if (buffer_.size() < sizeof(MessageSize)) return; MessageSize service_len = *reinterpret_cast(buffer_.data()); + buffer_.Resize(sizeof(MessageSize) + service_len); if (buffer_.size() < sizeof(MessageSize) + service_len) return; service_name_ = std::string( reinterpret_cast(buffer_.data() + sizeof(MessageSize)), @@ -34,8 +35,9 @@ void Session::Execute() { uint32_t message_id = *reinterpret_cast(buffer_.data()); MessageSize message_len = *reinterpret_cast(buffer_.data() + sizeof(uint32_t)); - if (buffer_.size() < sizeof(uint32_t) + sizeof(MessageSize) + message_len) - return; + uint64_t request_size = sizeof(uint32_t) + sizeof(MessageSize) + message_len; + buffer_.Resize(request_size); + if (buffer_.size() < request_size) return; // TODO (mferencevic): check for exceptions std::unique_ptr message; @@ -78,10 +80,10 @@ void SendMessage(Socket &socket, uint32_t message_id, } const std::string &buffer = stream.str(); - uint64_t message_size = sizeof(MessageSize) + buffer.size(); - CHECK(message_size <= kMaxMessageSize) << fmt::format( - "Trying to send message of size {}, max message size is {}", message_size, - kMaxMessageSize); + CHECK(buffer.size() <= std::numeric_limits::max()) + << fmt::format( + "Trying to send message of size {}, max message size is {}", + buffer.size(), std::numeric_limits::max()); if (!socket.Write(reinterpret_cast(&message_id), sizeof(uint32_t), true)) { diff --git a/src/communication/rpc/protocol.hpp b/src/communication/rpc/protocol.hpp index c6f831b45..d90b4bc4b 100644 --- a/src/communication/rpc/protocol.hpp +++ b/src/communication/rpc/protocol.hpp @@ -4,7 +4,7 @@ #include #include -#include "communication/bolt/v1/decoder/buffer.hpp" +#include "communication/rpc/buffer.hpp" #include "communication/rpc/messages.hpp" #include "io/network/endpoint.hpp" #include "io/network/socket.hpp" @@ -26,7 +26,6 @@ namespace communication::rpc { using Endpoint = io::network::Endpoint; using Socket = io::network::Socket; using StreamBuffer = io::network::StreamBuffer; -using Buffer = bolt::Buffer; // Forward declaration of class System class System; diff --git a/tests/unit/rpc.cpp b/tests/unit/rpc.cpp index f55b73a5c..be2031785 100644 --- a/tests/unit/rpc.cpp +++ b/tests/unit/rpc.cpp @@ -54,6 +54,24 @@ struct SumRes : public Message { BOOST_CLASS_EXPORT(SumRes); using Sum = RequestResponse; +struct EchoMessage : public Message { + EchoMessage(const std::string &data) : data(data) {} + std::string data; + + private: + friend class boost::serialization::access; + EchoMessage() {} // Needed for serialization. + + template + void serialize(TArchive &ar, unsigned int) { + ar &boost::serialization::base_object(*this); + ar &data; + } +}; +BOOST_CLASS_EXPORT(EchoMessage); + +using Echo = RequestResponse; + TEST(Rpc, Call) { System server_system({"127.0.0.1", 0}); Server server(server_system, "main"); @@ -143,3 +161,18 @@ TEST(Rpc, ClientPool) { } EXPECT_LE(t2.Elapsed(), 200ms); } + +TEST(Rpc, LargeMessage) { + System server_system({"127.0.0.1", 0}); + Server server(server_system, "main"); + server.Register([](const EchoMessage &request) { + return std::make_unique(request.data); + }); + std::this_thread::sleep_for(100ms); + + std::string testdata(100000, 'a'); + + Client client(server_system.endpoint(), "main"); + auto echo = client.Call(testdata); + EXPECT_EQ(echo->data, testdata); +}