From d678e45c102b2a8853af623efef26ff53db2df9d Mon Sep 17 00:00:00 2001 From: Matej Ferencevic Date: Mon, 6 May 2019 13:35:22 +0200 Subject: [PATCH] Migrate RPC to SLK Summary: Migrate all RPCs Simplify Raft InstallSnapshot RPC Add missing Load and Save for `char` Reviewers: teon.banek, msantl Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D2001 --- src/communication/client.cpp | 7 +- src/communication/client.hpp | 9 +- src/communication/rpc/client.cpp | 75 ---------- src/communication/rpc/client.hpp | 93 ++++++++---- src/communication/rpc/client_pool.hpp | 12 +- src/communication/rpc/messages.hpp | 8 +- src/communication/rpc/protocol.cpp | 85 +++++------ src/communication/rpc/server.hpp | 61 ++------ .../distributed/distributed_counters.cpp | 17 ++- src/distributed/bfs_rpc_clients.cpp | 12 +- src/distributed/bfs_rpc_messages.lcp | 5 +- src/distributed/bfs_rpc_server.hpp | 57 ++++---- src/distributed/cluster_discovery_master.cpp | 12 +- src/distributed/cluster_discovery_worker.cpp | 6 +- src/distributed/coordination.hpp | 11 +- src/distributed/coordination_worker.cpp | 4 +- src/distributed/data_rpc_messages.lcp | 18 +-- src/distributed/data_rpc_server.cpp | 34 +++-- src/distributed/durability_rpc_worker.cpp | 18 ++- src/distributed/dynamic_worker.cpp | 6 +- src/distributed/index_rpc_server.cpp | 12 +- src/distributed/plan_consumer.cpp | 14 +- src/distributed/produce_rpc_server.cpp | 18 +-- src/distributed/pull_rpc_clients.cpp | 4 +- src/distributed/token_sharing_rpc_server.hpp | 2 +- src/distributed/updates_rpc_server.cpp | 50 +++---- src/raft/coordination.hpp | 11 +- src/raft/raft_rpc_messages.lcp | 36 +---- src/raft/raft_server.cpp | 64 +++++---- src/raft/storage_info.cpp | 6 +- src/slk/serialization.hpp | 2 + .../concurrent_id_mapper_master.cpp | 12 +- src/storage/distributed/rpc/serialization.cpp | 1 + src/storage/distributed/storage_gc_master.hpp | 6 +- .../distributed/engine_master.cpp | 81 +++++++---- .../distributed/engine_worker.cpp | 9 +- tests/benchmark/rpc.cpp | 21 ++- tests/unit/rpc.cpp | 132 ++++++++---------- tests/unit/rpc_messages.hpp | 65 +++++++++ 39 files changed, 543 insertions(+), 553 deletions(-) create mode 100644 tests/unit/rpc_messages.hpp diff --git a/src/communication/client.cpp b/src/communication/client.cpp index 9f5c50530..6e900b9ba 100644 --- a/src/communication/client.cpp +++ b/src/communication/client.cpp @@ -83,10 +83,11 @@ void Client::Close() { socket_.Close(); } -bool Client::Read(size_t len) { +bool Client::Read(size_t len, bool exactly_len) { + if (len == 0) return false; size_t received = 0; buffer_.write_end()->Resize(buffer_.read_end()->size() + len); - while (received < len) { + do { auto buff = buffer_.write_end()->Allocate(); if (ssl_) { // We clear errors here to prevent errors piling up in the internal @@ -140,7 +141,7 @@ bool Client::Read(size_t len) { buffer_.write_end()->Written(got); received += got; } - } + } while (received < len && exactly_len); return true; } diff --git a/src/communication/client.hpp b/src/communication/client.hpp index b7412cb78..9562b36d9 100644 --- a/src/communication/client.hpp +++ b/src/communication/client.hpp @@ -58,11 +58,12 @@ class Client final { void Close(); /** - * This function is used to receive `len` bytes from the socket and stores it - * in an internal buffer. It returns `true` if the read succeeded and `false` - * if it didn't. + * This function is used to receive exactly `len` bytes from the socket and + * stores it in an internal buffer. If `exactly_len` is set to `false` then + * less than `len` bytes can be received. It returns `true` if the read + * succeeded and `false` if it didn't. */ - bool Read(size_t len); + bool Read(size_t len, bool exactly_len = true); /** * This function returns a pointer to the read data that is currently stored diff --git a/src/communication/rpc/client.cpp b/src/communication/rpc/client.cpp index d6ccb0447..cc0005951 100644 --- a/src/communication/rpc/client.cpp +++ b/src/communication/rpc/client.cpp @@ -1,84 +1,9 @@ #include "communication/rpc/client.hpp" -#include -#include - -#include "gflags/gflags.h" - namespace communication::rpc { Client::Client(const io::network::Endpoint &endpoint) : endpoint_(endpoint) {} -::capnp::FlatArrayMessageReader Client::Send(::capnp::MessageBuilder *message) { - std::lock_guard guard(mutex_); - - // Check if the connection is broken (if we haven't used the client for a - // long time the server could have died). - if (client_ && client_->ErrorStatus()) { - client_ = std::nullopt; - } - - // Connect to the remote server. - if (!client_) { - client_.emplace(&context_); - if (!client_->Connect(endpoint_)) { - DLOG(ERROR) << "Couldn't connect to remote address " << endpoint_; - client_ = std::nullopt; - throw RpcFailedException(endpoint_); - } - } - - // Serialize and send request. - auto request_words = ::capnp::messageToFlatArray(*message); - auto request_bytes = request_words.asBytes(); - CHECK(request_bytes.size() <= std::numeric_limits::max()) - << fmt::format( - "Trying to send message of size {}, max message size is {}", - request_bytes.size(), std::numeric_limits::max()); - - MessageSize request_data_size = request_bytes.size(); - if (!client_->Write(reinterpret_cast(&request_data_size), - sizeof(MessageSize), true)) { - DLOG(ERROR) << "Couldn't send request size to " << client_->endpoint(); - client_ = std::nullopt; - throw RpcFailedException(endpoint_); - } - - if (!client_->Write(request_bytes.begin(), request_bytes.size())) { - DLOG(ERROR) << "Couldn't send request data to " << client_->endpoint(); - client_ = std::nullopt; - throw RpcFailedException(endpoint_); - } - - // Receive response data size. - if (!client_->Read(sizeof(MessageSize))) { - DLOG(ERROR) << "Couldn't get response from " << client_->endpoint(); - client_ = std::nullopt; - throw RpcFailedException(endpoint_); - } - MessageSize response_data_size = - *reinterpret_cast(client_->GetData()); - client_->ShiftData(sizeof(MessageSize)); - - // Receive response data. - if (!client_->Read(response_data_size)) { - DLOG(ERROR) << "Couldn't get response from " << client_->endpoint(); - client_ = std::nullopt; - throw RpcFailedException(endpoint_); - } - - // Read the response message. - auto data = ::kj::arrayPtr(client_->GetData(), response_data_size); - // Our data is word aligned and padded to 64bit because we use regular - // (non-packed) serialization of Cap'n Proto. So we can use reinterpret_cast. - auto data_words = - ::kj::arrayPtr(reinterpret_cast<::capnp::word *>(data.begin()), - reinterpret_cast<::capnp::word *>(data.end())); - ::capnp::FlatArrayMessageReader response_message(data_words.asConst()); - client_->ShiftData(response_data_size); - return response_message; -} - void Client::Abort() { if (!client_) return; // We need to call Shutdown on the client to abort any pending read or diff --git a/src/communication/rpc/client.hpp b/src/communication/rpc/client.hpp index aa17c8fa2..a735a0f09 100644 --- a/src/communication/rpc/client.hpp +++ b/src/communication/rpc/client.hpp @@ -10,10 +10,11 @@ #include "communication/client.hpp" #include "communication/rpc/exceptions.hpp" -#include "communication/rpc/messages.capnp.h" #include "communication/rpc/messages.hpp" #include "io/network/endpoint.hpp" -#include "utils/demangle.hpp" +#include "slk/serialization.hpp" +#include "slk/streams.hpp" +#include "utils/on_scope_exit.hpp" namespace communication::rpc { @@ -34,9 +35,9 @@ class Client { template typename TRequestResponse::Response Call(Args &&... args) { return CallWithLoad( - [](const auto &reader) { + [](auto *reader) { typename TRequestResponse::Response response; - Load(&response, reader); + TRequestResponse::Response::Load(&response, reader); return response; }, std::forward(args)...); @@ -45,29 +46,68 @@ class Client { /// Same as `Call` but the first argument is a response loading function. template typename TRequestResponse::Response CallWithLoad( - std::function - load, + std::function load, Args &&... args) { typename TRequestResponse::Request request(std::forward(args)...); auto req_type = TRequestResponse::Request::kType; - VLOG(12) << "[RpcClient] sent " << req_type.name; - ::capnp::MallocMessageBuilder req_msg; - { - auto builder = req_msg.initRoot(); - builder.setTypeId(req_type.id); - auto data_builder = builder.initData(); - auto req_builder = - data_builder - .template initAs(); - Save(request, &req_builder); - } - auto response = Send(&req_msg); - auto res_msg = response.getRoot(); auto res_type = TRequestResponse::Response::kType; - if (res_msg.getTypeId() != res_type.id) { - // Since message_id was checked in private Call function, this means - // something is very wrong (probably on the server side). + VLOG(12) << "[RpcClient] sent " << req_type.name; + + std::lock_guard guard(mutex_); + + // Check if the connection is broken (if we haven't used the client for a + // long time the server could have died). + if (client_ && client_->ErrorStatus()) { + client_ = std::nullopt; + } + + // Connect to the remote server. + if (!client_) { + client_.emplace(&context_); + if (!client_->Connect(endpoint_)) { + DLOG(ERROR) << "Couldn't connect to remote address " << endpoint_; + client_ = std::nullopt; + throw RpcFailedException(endpoint_); + } + } + + // Build and send the request. + slk::Builder req_builder( + [&](const uint8_t *data, size_t size, bool have_more) { + client_->Write(data, size, have_more); + }); + slk::Save(req_type.id, &req_builder); + TRequestResponse::Request::Save(request, &req_builder); + req_builder.Finalize(); + + // Receive response. + uint64_t response_data_size = 0; + while (true) { + auto ret = + slk::CheckStreamComplete(client_->GetData(), client_->GetDataSize()); + if (ret.status == slk::StreamStatus::INVALID) { + throw RpcFailedException(endpoint_); + } else if (ret.status == slk::StreamStatus::PARTIAL) { + if (!client_->Read(ret.stream_size - client_->GetDataSize(), + /* exactly_len = */ false)) { + throw RpcFailedException(endpoint_); + } + } else { + response_data_size = ret.stream_size; + break; + } + } + + // Load the response. + slk::Reader res_reader(client_->GetData(), response_data_size); + utils::OnScopeExit res_cleanup( + [&, response_data_size] { client_->ShiftData(response_data_size); }); + + uint64_t res_id = 0; + slk::Load(&res_id, &res_reader); + + // Check response ID. + if (res_id != res_type.id) { LOG(ERROR) << "Message response was of unexpected type"; client_ = std::nullopt; throw RpcFailedException(endpoint_); @@ -75,18 +115,13 @@ class Client { VLOG(12) << "[RpcClient] received " << res_type.name; - auto data_reader = - res_msg.getData() - .template getAs(); - return load(data_reader); + return load(&res_reader); } /// Call this function from another thread to abort a pending RPC call. void Abort(); private: - ::capnp::FlatArrayMessageReader Send(::capnp::MessageBuilder *message); - io::network::Endpoint endpoint_; // TODO (mferencevic): currently the RPC client is hardcoded not to use SSL communication::ClientContext context_; diff --git a/src/communication/rpc/client_pool.hpp b/src/communication/rpc/client_pool.hpp index 478fc0cee..64ee9eb6f 100644 --- a/src/communication/rpc/client_pool.hpp +++ b/src/communication/rpc/client_pool.hpp @@ -23,19 +23,15 @@ class ClientPool { return client->template Call( std::forward(args)...); }); - }; + } template - typename TRequestResponse::Response CallWithLoad( - std::function - load, - Args &&... args) { + typename TRequestResponse::Response CallWithLoad(Args &&... args) { return WithUnusedClient([&](const auto &client) { return client->template CallWithLoad( - load, std::forward(args)...); + std::forward(args)...); }); - }; + } private: template diff --git a/src/communication/rpc/messages.hpp b/src/communication/rpc/messages.hpp index a006b73dd..327101a81 100644 --- a/src/communication/rpc/messages.hpp +++ b/src/communication/rpc/messages.hpp @@ -14,11 +14,9 @@ using MessageSize = uint32_t; /// `TRequest` and `TResponse` are required to be classes which have a static /// member `kType` of `utils::TypeInfo` type. This is used for proper /// registration and deserialization of RPC types. Additionally, both `TRequest` -/// and `TResponse` are required to define a nested `Capnp` type, which -/// corresponds to the Cap'n Proto schema type, as well as defined the following -/// serialization functions: -/// * void Save(const TRequest|TResponse &, Capnp::Builder *, ...) -/// * void Load(const Capnp::Reader &, ...) +/// and `TResponse` are required to define the following serialization functions: +/// * static void Save(const TRequest|TResponse &, slk::Builder *, ...) +/// * static void Load(TRequest|TResponse *, slk::Reader *, ...) template struct RequestResponse { using Request = TRequest; diff --git a/src/communication/rpc/protocol.cpp b/src/communication/rpc/protocol.cpp index 223997429..7e5cadb19 100644 --- a/src/communication/rpc/protocol.cpp +++ b/src/communication/rpc/protocol.cpp @@ -1,14 +1,10 @@ -#include - -#include "capnp/message.h" -#include "capnp/serialize.h" -#include "fmt/format.h" - -#include "communication/rpc/messages.capnp.h" -#include "communication/rpc/messages.hpp" #include "communication/rpc/protocol.hpp" + +#include "communication/rpc/messages.hpp" #include "communication/rpc/server.hpp" -#include "utils/demangle.hpp" +#include "slk/serialization.hpp" +#include "slk/streams.hpp" +#include "utils/on_scope_exit.hpp" namespace communication::rpc { @@ -21,39 +17,40 @@ Session::Session(Server *server, const io::network::Endpoint &endpoint, output_stream_(output_stream) {} void Session::Execute() { - if (input_stream_->size() < sizeof(MessageSize)) return; - MessageSize request_len = - *reinterpret_cast(input_stream_->data()); - uint64_t request_size = sizeof(MessageSize) + request_len; - input_stream_->Resize(request_size); - if (input_stream_->size() < request_size) return; + auto ret = + slk::CheckStreamComplete(input_stream_->data(), input_stream_->size()); + if (ret.status == slk::StreamStatus::INVALID) { + throw SessionException("Received an invalid SLK stream!"); + } else if (ret.status == slk::StreamStatus::PARTIAL) { + input_stream_->Resize(ret.stream_size); + return; + } - // Read the request message. - auto data = - ::kj::arrayPtr(input_stream_->data() + sizeof(request_len), request_len); - // Our data is word aligned and padded to 64bit because we use regular - // (non-packed) serialization of Cap'n Proto. So we can use reinterpret_cast. - auto data_words = - ::kj::arrayPtr(reinterpret_cast<::capnp::word *>(data.begin()), - reinterpret_cast<::capnp::word *>(data.end())); - ::capnp::FlatArrayMessageReader request_message(data_words.asConst()); - auto request = request_message.getRoot(); - input_stream_->Shift(sizeof(MessageSize) + request_len); + // Remove the data from the stream on scope exit. + utils::OnScopeExit shift_data( + [&, ret] { input_stream_->Shift(ret.stream_size); }); - ::capnp::MallocMessageBuilder response_message; - // callback fills the message data - auto response_builder = response_message.initRoot(); + // Prepare SLK reader and builder. + slk::Reader req_reader(input_stream_->data(), input_stream_->size()); + slk::Builder res_builder( + [&](const uint8_t *data, size_t size, bool have_more) { + output_stream_->Write(data, size, have_more); + }); + + // Load the request ID. + uint64_t req_id = 0; + slk::Load(&req_id, &req_reader); // Access to `callbacks_` and `extended_callbacks_` is done here without // acquiring the `mutex_` because we don't allow RPC registration after the // server was started so those two maps will never be updated when we `find` // over them. - auto it = server_->callbacks_.find(request.getTypeId()); + auto it = server_->callbacks_.find(req_id); auto extended_it = server_->extended_callbacks_.end(); if (it == server_->callbacks_.end()) { // We couldn't find a regular callback to call, try to find an extended // callback to call. - extended_it = server_->extended_callbacks_.find(request.getTypeId()); + extended_it = server_->extended_callbacks_.find(req_id); if (extended_it == server_->extended_callbacks_.end()) { // Throw exception to close the socket and cleanup the session. @@ -61,29 +58,17 @@ void Session::Execute() { "Session trying to execute an unregistered RPC call!"); } VLOG(12) << "[RpcServer] received " << extended_it->second.req_type.name; - extended_it->second.callback(endpoint_, request, &response_builder); + slk::Save(extended_it->second.res_type.id, &res_builder); + extended_it->second.callback(endpoint_, &req_reader, &res_builder); } else { VLOG(12) << "[RpcServer] received " << it->second.req_type.name; - it->second.callback(request, &response_builder); + slk::Save(it->second.res_type.id, &res_builder); + it->second.callback(&req_reader, &res_builder); } - // Serialize and send response - auto response_words = ::capnp::messageToFlatArray(response_message); - auto response_bytes = response_words.asBytes(); - if (response_bytes.size() > std::numeric_limits::max()) { - throw SessionException(fmt::format( - "Trying to send response of size {}, max response size is {}", - response_bytes.size(), std::numeric_limits::max())); - } - - MessageSize input_stream_size = response_bytes.size(); - if (!output_stream_->Write(reinterpret_cast(&input_stream_size), - sizeof(MessageSize), true)) { - throw SessionException("Couldn't send response size!"); - } - if (!output_stream_->Write(response_bytes.begin(), response_bytes.size())) { - throw SessionException("Couldn't send response data!"); - } + // Finalize the SLK streams. + req_reader.Finalize(); + res_builder.Finalize(); VLOG(12) << "[RpcServer] sent " << (it != server_->callbacks_.end() diff --git a/src/communication/rpc/server.hpp b/src/communication/rpc/server.hpp index baa5b7fed..e56ea6319 100644 --- a/src/communication/rpc/server.hpp +++ b/src/communication/rpc/server.hpp @@ -4,14 +4,11 @@ #include #include -#include "capnp/any.h" - -#include "communication/rpc/messages.capnp.h" #include "communication/rpc/messages.hpp" #include "communication/rpc/protocol.hpp" #include "communication/server.hpp" #include "io/network/endpoint.hpp" -#include "utils/demangle.hpp" +#include "slk/streams.hpp" namespace communication::rpc { @@ -31,26 +28,14 @@ class Server { const io::network::Endpoint &endpoint() const; template - void Register(std::function< - void(const typename TRequestResponse::Request::Capnp::Reader &, - typename TRequestResponse::Response::Capnp::Builder *)> - callback) { + void Register(std::function callback) { std::lock_guard guard(lock_); - CHECK(!server_.IsRunning()) << "You can't register RPCs when the server is running!"; + CHECK(!server_.IsRunning()) + << "You can't register RPCs when the server is running!"; RpcCallback rpc; rpc.req_type = TRequestResponse::Request::kType; rpc.res_type = TRequestResponse::Response::kType; - rpc.callback = [callback = callback](const auto &reader, auto *builder) { - auto req_data = - reader.getData() - .template getAs(); - builder->setTypeId(TRequestResponse::Response::kType.id); - auto data_builder = builder->initData(); - auto res_builder = - data_builder - .template initAs(); - callback(req_data, &res_builder); - }; + rpc.callback = callback; if (extended_callbacks_.find(TRequestResponse::Request::kType.id) != extended_callbacks_.end()) { @@ -64,33 +49,16 @@ class Server { } template - void Register(std::function< - void(const io::network::Endpoint &, - const typename TRequestResponse::Request::Capnp::Reader &, - typename TRequestResponse::Response::Capnp::Builder *)> + void Register(std::function callback) { std::lock_guard guard(lock_); - CHECK(!server_.IsRunning()) << "You can't register RPCs when the server is running!"; + CHECK(!server_.IsRunning()) + << "You can't register RPCs when the server is running!"; RpcExtendedCallback rpc; rpc.req_type = TRequestResponse::Request::kType; rpc.res_type = TRequestResponse::Response::kType; - rpc.callback = [callback = callback](const io::network::Endpoint &endpoint, - const auto &reader, auto *builder) { - auto req_data = - reader.getData() - .template getAs(); - builder->setTypeId(TRequestResponse::Response::kType.id); - auto data_builder = builder->initData(); - auto res_builder = - data_builder - .template initAs(); - callback(endpoint, req_data, &res_builder); - }; - - if (callbacks_.find(TRequestResponse::Request::kType.id) != - callbacks_.end()) { - LOG(FATAL) << "Callback for that message type already registered!"; - } + rpc.callback = callback; auto got = extended_callbacks_.insert({TRequestResponse::Request::kType.id, rpc}); @@ -104,17 +72,14 @@ class Server { struct RpcCallback { utils::TypeInfo req_type; - std::function - callback; + std::function callback; utils::TypeInfo res_type; }; struct RpcExtendedCallback { utils::TypeInfo req_type; - std::function + std::function callback; utils::TypeInfo res_type; }; diff --git a/src/database/distributed/distributed_counters.cpp b/src/database/distributed/distributed_counters.cpp index 70eb5df11..6bb8049be 100644 --- a/src/database/distributed/distributed_counters.cpp +++ b/src/database/distributed/distributed_counters.cpp @@ -6,14 +6,19 @@ namespace database { MasterCounters::MasterCounters(distributed::Coordination *coordination) { coordination->Register( - [this](const auto &req_reader, auto *res_builder) { - CountersGetRes res(Get(req_reader.getName())); - Save(res, res_builder); + [this](auto *req_reader, auto *res_builder) { + CountersGetReq req; + slk::Load(&req, req_reader); + CountersGetRes res(Get(req.name)); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { - Set(req_reader.getName(), req_reader.getValue()); - return std::make_unique(); + [this](auto *req_reader, auto *res_builder) { + CountersSetReq req; + slk::Load(&req, req_reader); + Set(req.name, req.value); + CountersSetRes res; + slk::Save(res, res_builder); }); } diff --git a/src/distributed/bfs_rpc_clients.cpp b/src/distributed/bfs_rpc_clients.cpp index 593e15d2f..275a62391 100644 --- a/src/distributed/bfs_rpc_clients.cpp +++ b/src/distributed/bfs_rpc_clients.cpp @@ -76,9 +76,9 @@ std::optional BfsRpcClients::Pull( auto res = coordination_->GetClientPool(worker_id)->CallWithLoad( - [this, dba](const auto &reader) { + [this, dba](auto *res_reader) { SubcursorPullRes res; - Load(&res, reader, dba, this->data_manager_); + slk::Load(&res, res_reader, dba, this->data_manager_); return res; }, subcursor_id); @@ -149,9 +149,9 @@ PathSegment BfsRpcClients::ReconstructPath( auto res = coordination_->GetClientPool(worker_id)->CallWithLoad( - [this, dba](const auto &reader) { + [this, dba](auto *res_reader) { ReconstructPathRes res; - Load(&res, reader, dba, this->data_manager_); + slk::Load(&res, res_reader, dba, this->data_manager_); return res; }, subcursor_ids.at(worker_id), vertex); @@ -168,9 +168,9 @@ PathSegment BfsRpcClients::ReconstructPath( } auto res = coordination_->GetClientPool(worker_id)->CallWithLoad( - [this, dba](const auto &reader) { + [this, dba](auto *res_reader) { ReconstructPathRes res; - Load(&res, reader, dba, this->data_manager_); + slk::Load(&res, res_reader, dba, this->data_manager_); return res; }, subcursor_ids.at(worker_id), edge); diff --git a/src/distributed/bfs_rpc_messages.lcp b/src/distributed/bfs_rpc_messages.lcp index b36ab3de9..4fc321e3e 100644 --- a/src/distributed/bfs_rpc_messages.lcp +++ b/src/distributed/bfs_rpc_messages.lcp @@ -257,11 +257,12 @@ cpp<# cpp<#) :slk-load (lambda (member) #>cpp + auto *subcursor = subcursor_storage->Get(self->subcursor_id); size_t size; slk::Load(&size, reader); self->${member}.resize(size); for (size_t i = 0; i < size; ++i) { - slk::Load(&self->${member}[i], reader, dba, data_manager); + slk::Load(&self->${member}[i], reader, subcursor->db_accessor(), data_manager); } cpp<#) :capnp-type "List(Query.TypedValue)" @@ -281,7 +282,7 @@ cpp<# return value; }")) (worker-id :int16_t :dont-save t)) - (:serialize (:slk :load-args '((dba "database::GraphDbAccessor *") + (:serialize (:slk :load-args '((subcursor_storage "distributed::BfsSubcursorStorage *") (data-manager "distributed::DataManager *"))) (:capnp :load-args '((dba "database::GraphDbAccessor *") diff --git a/src/distributed/bfs_rpc_server.hpp b/src/distributed/bfs_rpc_server.hpp index 97ef284c0..ad007a405 100644 --- a/src/distributed/bfs_rpc_server.hpp +++ b/src/distributed/bfs_rpc_server.hpp @@ -21,11 +21,11 @@ class BfsRpcServer { distributed::Coordination *coordination, BfsSubcursorStorage *subcursor_storage) : db_(db), subcursor_storage_(subcursor_storage) { - coordination->Register([this](const auto &req_reader, + coordination->Register([this](auto *req_reader, auto *res_builder) { CreateBfsSubcursorReq req; auto ast_storage = std::make_unique(); - Load(&req, req_reader, ast_storage.get()); + slk::Load(&req, req_reader, ast_storage.get()); database::GraphDbAccessor *dba; { std::lock_guard guard(lock_); @@ -46,41 +46,41 @@ class BfsRpcServer { dba, req.direction, req.edge_types, std::move(req.symbol_table), std::move(ast_storage), req.filter_lambda, evaluation_context); CreateBfsSubcursorRes res(id); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { RegisterSubcursorsReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); subcursor_storage_->Get(req.subcursor_ids.at(db_->WorkerId())) ->RegisterSubcursors(req.subcursor_ids); RegisterSubcursorsRes res; - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { ResetSubcursorReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); subcursor_storage_->Get(req.subcursor_id)->Reset(); ResetSubcursorRes res; - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { SetSourceReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); subcursor_storage_->Get(req.subcursor_id)->SetSource(req.source); SetSourceRes res; - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { ExpandLevelReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto subcursor = subcursor_storage_->Get(req.member); ExpandResult result; try { @@ -90,32 +90,32 @@ class BfsRpcServer { result = ExpandResult::LAMBDA_ERROR; } ExpandLevelRes res(result); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { SubcursorPullReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto vertex = subcursor_storage_->Get(req.member)->Pull(); SubcursorPullRes res(vertex); - Save(res, res_builder, db_->WorkerId()); + slk::Save(res, res_builder, db_->WorkerId()); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { ExpandToRemoteVertexReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); ExpandToRemoteVertexRes res( subcursor_storage_->Get(req.subcursor_id) ->ExpandToLocalVertex(req.edge, req.vertex)); - Save(res, res_builder); + slk::Save(res, res_builder); }); - coordination->Register([this](const auto &req_reader, + coordination->Register([this](auto *req_reader, auto *res_builder) { ReconstructPathReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto subcursor = subcursor_storage_->Get(req.subcursor_id); PathSegment result; if (req.vertex) { @@ -127,18 +127,17 @@ class BfsRpcServer { } ReconstructPathRes res(result.edges, result.next_vertex, result.next_edge); - Save(res, res_builder, db_->WorkerId()); + slk::Save(res, res_builder, db_->WorkerId()); }); - coordination->Register([this](const auto &req_reader, + coordination->Register([this](auto *req_reader, auto *res_builder) { PrepareForExpandReq req; - auto subcursor_id = req_reader.getSubcursorId(); - auto *subcursor = subcursor_storage_->Get(subcursor_id); - Load(&req, req_reader, subcursor->db_accessor(), &db_->data_manager()); + slk::Load(&req, req_reader, subcursor_storage_, &db_->data_manager()); + auto *subcursor = subcursor_storage_->Get(req.subcursor_id); subcursor->PrepareForExpand(req.clear, std::move(req.frame)); PrepareForExpandRes res; - Save(res, res_builder); + slk::Save(res, res_builder); }); } diff --git a/src/distributed/cluster_discovery_master.cpp b/src/distributed/cluster_discovery_master.cpp index 802bbc029..fa7adb41a 100644 --- a/src/distributed/cluster_discovery_master.cpp +++ b/src/distributed/cluster_discovery_master.cpp @@ -13,13 +13,13 @@ ClusterDiscoveryMaster::ClusterDiscoveryMaster( MasterCoordination *coordination, const std::string &durability_directory) : coordination_(coordination), durability_directory_(durability_directory) { coordination_->Register([this](const auto &endpoint, - const auto &req_reader, + auto *req_reader, auto *res_builder) { bool registration_successful = false; bool durability_error = false; RegisterWorkerReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); // Compose the worker's endpoint from its connecting address and its // advertised port. @@ -70,15 +70,17 @@ ClusterDiscoveryMaster::ClusterDiscoveryMaster( RegisterWorkerRes res(registration_successful, durability_error, coordination_->RecoveredSnapshotTx(), coordination_->GetWorkers()); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { NotifyWorkerRecoveredReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); coordination_->WorkerRecoveredSnapshot(req.worker_id, req.recovery_info); + NotifyWorkerRecoveredRes res; + slk::Save(res, res_builder); }); } diff --git a/src/distributed/cluster_discovery_worker.cpp b/src/distributed/cluster_discovery_worker.cpp index 03f4c2da3..c886113e9 100644 --- a/src/distributed/cluster_discovery_worker.cpp +++ b/src/distributed/cluster_discovery_worker.cpp @@ -12,10 +12,12 @@ ClusterDiscoveryWorker::ClusterDiscoveryWorker(WorkerCoordination *coordination) : coordination_(coordination), client_pool_(coordination->GetClientPool(0)) { coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { ClusterDiscoveryReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); coordination_->RegisterWorker(req.worker_id, req.endpoint); + ClusterDiscoveryRes res; + slk::Save(res, res_builder); }); } diff --git a/src/distributed/coordination.hpp b/src/distributed/coordination.hpp index 17729d8e7..47f7eaa9f 100644 --- a/src/distributed/coordination.hpp +++ b/src/distributed/coordination.hpp @@ -73,18 +73,13 @@ class Coordination { } template - void Register(std::function< - void(const typename TRequestResponse::Request::Capnp::Reader &, - typename TRequestResponse::Response::Capnp::Builder *)> - callback) { + void Register(std::function callback) { server_.Register(callback); } template - void Register(std::function< - void(const io::network::Endpoint &, - const typename TRequestResponse::Request::Capnp::Reader &, - typename TRequestResponse::Response::Capnp::Builder *)> + void Register(std::function callback) { server_.Register(callback); } diff --git a/src/distributed/coordination_worker.cpp b/src/distributed/coordination_worker.cpp index 8089795cc..1d89e6db0 100644 --- a/src/distributed/coordination_worker.cpp +++ b/src/distributed/coordination_worker.cpp @@ -29,12 +29,12 @@ WorkerCoordination::WorkerCoordination( : Coordination(worker_endpoint, worker_id, master_endpoint, server_workers_count, client_workers_count) { server_.Register( - [&](const auto &req_reader, auto *res_builder) { + [&](auto *req_reader, auto *res_builder) { LOG(INFO) << "The master initiated shutdown of this worker."; Shutdown(); }); - server_.Register([&](const auto &req_reader, + server_.Register([&](auto *req_reader, auto *res_builder) { std::lock_guard guard(heartbeat_lock_); last_heartbeat_time_ = std::chrono::steady_clock::now(); diff --git a/src/distributed/data_rpc_messages.lcp b/src/distributed/data_rpc_messages.lcp index 88272c281..c218afcd9 100644 --- a/src/distributed/data_rpc_messages.lcp +++ b/src/distributed/data_rpc_messages.lcp @@ -153,10 +153,11 @@ cpp<# :slk-load (lambda (member) #>cpp - // slk::Load will read a bool which was explicity - // saved in :slk::save and based on that read record - // data - slk::Load(&self->edge_old_output, reader); + bool has_ptr; + slk::Load(&has_ptr, reader); + if (has_ptr) { + slk::Load(&self->edge_old_output, reader); + } cpp<#)) (edge-new-input "const Edge *" :capnp-type "Storage.Edge" @@ -191,10 +192,11 @@ cpp<# :slk-load (lambda (member) #>cpp - // slk::Load will read a bool which was explicity - // saved in :slk::save and based on that read record - // data - slk::Load(&self->edge_new_output, reader); + bool has_ptr; + slk::Load(&has_ptr, reader); + if (has_ptr) { + slk::Load(&self->edge_old_output, reader); + } cpp<#)) (worker-id :int64_t :dont-save t) (edge-old-output "std::unique_ptr" :initarg nil :dont-save t) diff --git a/src/distributed/data_rpc_server.cpp b/src/distributed/data_rpc_server.cpp index cc3a4e777..2567b7ef5 100644 --- a/src/distributed/data_rpc_server.cpp +++ b/src/distributed/data_rpc_server.cpp @@ -13,46 +13,50 @@ DataRpcServer::DataRpcServer(database::GraphDb *db, distributed::Coordination *coordination) : db_(db) { coordination->Register( - [this](const auto &req_reader, auto *res_builder) { - auto dba = db_->Access(req_reader.getMember().getTxId()); - auto vertex = dba->FindVertexRaw(req_reader.getMember().getGid()); + [this](auto *req_reader, auto *res_builder) { + VertexReq req; + slk::Load(&req, req_reader); + auto dba = db_->Access(req.member.tx_id); + auto vertex = dba->FindVertexRaw(req.member.gid); auto *old = vertex.GetOld(); auto *newr = vertex.GetNew() ? vertex.GetNew()->CloneData() : nullptr; db_->updates_server().ApplyDeltasToRecord( - dba->transaction().id_, req_reader.getMember().getGid(), - req_reader.getMember().getFromWorkerId(), &old, &newr); + dba->transaction().id_, req.member.gid, + req.member.from_worker_id, &old, &newr); VertexRes response(vertex.CypherId(), old, newr, db_->WorkerId()); - Save(response, res_builder); + slk::Save(response, res_builder); delete newr; }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { - auto dba = db_->Access(req_reader.getMember().getTxId()); - auto edge = dba->FindEdgeRaw(req_reader.getMember().getGid()); + [this](auto *req_reader, auto *res_builder) { + EdgeReq req; + slk::Load(&req, req_reader); + auto dba = db_->Access(req.member.tx_id); + auto edge = dba->FindEdgeRaw(req.member.gid); auto *old = edge.GetOld(); auto *newr = edge.GetNew() ? edge.GetNew()->CloneData() : nullptr; db_->updates_server().ApplyDeltasToRecord( - dba->transaction().id_, req_reader.getMember().getGid(), - req_reader.getMember().getFromWorkerId(), &old, &newr); + dba->transaction().id_, req.member.gid, + req.member.from_worker_id, &old, &newr); EdgeRes response(edge.CypherId(), old, newr, db_->WorkerId()); - Save(response, res_builder); + slk::Save(response, res_builder); delete newr; }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { VertexCountReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto dba = db_->Access(req.member); int64_t size = 0; for (auto vertex : dba->Vertices(false)) ++size; VertexCountRes res(size); - Save(res, res_builder); + slk::Save(res, res_builder); }); } diff --git a/src/distributed/durability_rpc_worker.cpp b/src/distributed/durability_rpc_worker.cpp index d1c4b6db6..af0a6d5cf 100644 --- a/src/distributed/durability_rpc_worker.cpp +++ b/src/distributed/durability_rpc_worker.cpp @@ -10,17 +10,21 @@ DurabilityRpcWorker::DurabilityRpcWorker( database::Worker *db, distributed::Coordination *coordination) : db_(db) { coordination->Register( - [this](const auto &req_reader, auto *res_builder) { - auto dba = db_->Access(req_reader.getMember()); + [this](auto *req_reader, auto *res_builder) { + MakeSnapshotReq req; + slk::Load(&req, req_reader); + auto dba = db_->Access(req.member); MakeSnapshotRes res(db_->MakeSnapshot(*dba)); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { - durability::RecoveryData recovery_data; - durability::Load(&recovery_data, req_reader.getMember()); - this->db_->RecoverWalAndIndexes(&recovery_data); + [this](auto *req_reader, auto *res_builder) { + RecoverWalAndIndexesReq req; + slk::Load(&req, req_reader); + this->db_->RecoverWalAndIndexes(&req.member); + RecoverWalAndIndexesRes res; + slk::Save(res, res_builder); }); } diff --git a/src/distributed/dynamic_worker.cpp b/src/distributed/dynamic_worker.cpp index 31219fa50..da00f77fb 100644 --- a/src/distributed/dynamic_worker.cpp +++ b/src/distributed/dynamic_worker.cpp @@ -9,11 +9,11 @@ DynamicWorkerAddition::DynamicWorkerAddition(database::GraphDb *db, distributed::Coordination *coordination) : db_(db), coordination_(coordination) { coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { DynamicWorkerReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); DynamicWorkerRes res(this->GetIndicesToCreate()); - Save(res, res_builder); + slk::Save(res, res_builder); }); } diff --git a/src/distributed/index_rpc_server.cpp b/src/distributed/index_rpc_server.cpp index b3c162e10..7eb064302 100644 --- a/src/distributed/index_rpc_server.cpp +++ b/src/distributed/index_rpc_server.cpp @@ -10,21 +10,25 @@ IndexRpcServer::IndexRpcServer(database::GraphDb *db, distributed::Coordination *coordination) : db_(db) { coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { CreateIndexReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); database::LabelPropertyIndex::Key key{req.label, req.property}; db_->storage().label_property_index_.CreateIndex(key); + CreateIndexRes res; + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { PopulateIndexReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); database::LabelPropertyIndex::Key key{req.label, req.property}; auto dba = db_->Access(req.tx_id); dba->PopulateIndex(key); dba->EnableIndex(key); + PopulateIndexRes res; + slk::Save(res, res_builder); }); } diff --git a/src/distributed/plan_consumer.cpp b/src/distributed/plan_consumer.cpp index cf48b7371..7c1e46d2e 100644 --- a/src/distributed/plan_consumer.cpp +++ b/src/distributed/plan_consumer.cpp @@ -4,19 +4,23 @@ namespace distributed { PlanConsumer::PlanConsumer(distributed::Coordination *coordination) { coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { DispatchPlanReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); plan_cache_.access().insert( req.plan_id, std::make_unique(req.plan, req.symbol_table, std::move(req.storage))); DispatchPlanRes res; - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { - plan_cache_.access().remove(req_reader.getMember()); + [this](auto *req_reader, auto *res_builder) { + RemovePlanReq req; + slk::Load(&req, req_reader); + plan_cache_.access().remove(req.member); + RemovePlanRes res; + slk::Save(res, res_builder); }); } diff --git a/src/distributed/produce_rpc_server.cpp b/src/distributed/produce_rpc_server.cpp index bf658803d..2a74fa3fb 100644 --- a/src/distributed/produce_rpc_server.cpp +++ b/src/distributed/produce_rpc_server.cpp @@ -111,32 +111,32 @@ ProduceRpcServer::ProduceRpcServer(database::Worker *db, plan_consumer_(plan_consumer), tx_engine_(tx_engine) { coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { PullReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); PullRes res(Pull(req)); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { ResetCursorReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); Reset(req); ResetCursorRes res; - Save(res, res_builder); + slk::Save(res, res_builder); }); CHECK(data_manager); coordination->Register( - [this, data_manager](const auto &req_reader, auto *res_builder) { + [this, data_manager](auto *req_reader, auto *res_builder) { TransactionCommandAdvancedReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); tx_engine_->UpdateCommand(req.member); data_manager->ClearCacheForSingleTransaction(req.member); TransactionCommandAdvancedRes res; - Save(res, res_builder); + slk::Save(res, res_builder); }); } diff --git a/src/distributed/pull_rpc_clients.cpp b/src/distributed/pull_rpc_clients.cpp index fba7837f8..57f819034 100644 --- a/src/distributed/pull_rpc_clients.cpp +++ b/src/distributed/pull_rpc_clients.cpp @@ -17,9 +17,9 @@ utils::Future PullRpcClients::Pull( worker_id, [data_manager = data_manager_, dba, plan_id, command_id, evaluation_context, symbols, accumulate, batch_size](int worker_id, ClientPool &client_pool) { - auto load_pull_res = [data_manager, dba](const auto &res_reader) { + auto load_pull_res = [data_manager, dba](auto *res_reader) { PullRes res; - Load(&res, res_reader, dba, data_manager); + slk::Load(&res, res_reader, dba, data_manager); return res; }; auto result = client_pool.CallWithLoad( diff --git a/src/distributed/token_sharing_rpc_server.hpp b/src/distributed/token_sharing_rpc_server.hpp index e74f3d6e8..5b3468aa7 100644 --- a/src/distributed/token_sharing_rpc_server.hpp +++ b/src/distributed/token_sharing_rpc_server.hpp @@ -27,7 +27,7 @@ class TokenSharingRpcServer { distributed::Coordination *coordination) : worker_id_(worker_id), coordination_(coordination), dgp_(db) { coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { token_ = true; }); + [this](auto *req_reader, auto *res_builder) { token_ = true; }); // TODO (buda): It's not trivial to move this part in the Start method // because worker then doesn't run the step. Will resolve that with // a different implementation of the token assignment. diff --git a/src/distributed/updates_rpc_server.cpp b/src/distributed/updates_rpc_server.cpp index 8a7c795c7..eef99bd0c 100644 --- a/src/distributed/updates_rpc_server.cpp +++ b/src/distributed/updates_rpc_server.cpp @@ -250,10 +250,10 @@ void UpdatesRpcServer::TransactionUpdates::ApplyDeltasToRecord( UpdatesRpcServer::UpdatesRpcServer(database::GraphDb *db, distributed::Coordination *coordination) : db_(db) { - coordination->Register([this](const auto &req_reader, + coordination->Register([this](auto *req_reader, auto *res_builder) { UpdateReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); using DeltaType = database::StateDelta::Type; auto &delta = req.member; switch (delta.type) { @@ -264,13 +264,13 @@ UpdatesRpcServer::UpdatesRpcServer(database::GraphDb *db, case database::StateDelta::Type::REMOVE_IN_EDGE: { UpdateRes res(GetUpdates(vertex_updates_, delta.transaction_id) .Emplace(delta, req.worker_id)); - Save(res, res_builder); + slk::Save(res, res_builder); return; } case DeltaType::SET_PROPERTY_EDGE: { UpdateRes res(GetUpdates(edge_updates_, delta.transaction_id) .Emplace(delta, req.worker_id)); - Save(res, res_builder); + slk::Save(res, res_builder); return; } default: @@ -280,29 +280,29 @@ UpdatesRpcServer::UpdatesRpcServer(database::GraphDb *db, }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { UpdateApplyReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); UpdateApplyRes res(Apply(req.member)); - Save(res, res_builder); + slk::Save(res, res_builder); }); - coordination->Register([this](const auto &req_reader, + coordination->Register([this](auto *req_reader, auto *res_builder) { CreateVertexReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto result = GetUpdates(vertex_updates_, req.member.tx_id) .CreateVertex(req.member.labels, req.member.properties, req.member.cypher_id); CreateVertexRes res( CreateResult{UpdateResult::DONE, result.cypher_id, result.gid}); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { CreateEdgeReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto data = req.member; auto creation_result = CreateEdge(data); @@ -319,53 +319,53 @@ UpdatesRpcServer::UpdatesRpcServer(database::GraphDb *db, } CreateEdgeRes res(creation_result); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { AddInEdgeReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto to_delta = database::StateDelta::AddInEdge( req.member.tx_id, req.member.to, req.member.from, req.member.edge_address, req.member.edge_type); auto result = GetUpdates(vertex_updates_, req.member.tx_id) .Emplace(to_delta, req.member.worker_id); AddInEdgeRes res(result); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { RemoveVertexReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto to_delta = database::StateDelta::RemoveVertex( req.member.tx_id, req.member.gid, req.member.check_empty); auto result = GetUpdates(vertex_updates_, req.member.tx_id) .Emplace(to_delta, req.member.worker_id); RemoveVertexRes res(result); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { RemoveEdgeReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); RemoveEdgeRes res(RemoveEdge(req.member)); - Save(res, res_builder); + slk::Save(res, res_builder); }); - coordination->Register([this](const auto &req_reader, + coordination->Register([this](auto *req_reader, auto *res_builder) { RemoveInEdgeReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); auto data = req.member; RemoveInEdgeRes res( GetUpdates(vertex_updates_, data.tx_id) .Emplace(database::StateDelta::RemoveInEdge(data.tx_id, data.vertex, data.edge_address), data.worker_id)); - Save(res, res_builder); + slk::Save(res, res_builder); }); } diff --git a/src/raft/coordination.hpp b/src/raft/coordination.hpp index 3e211f204..9495076ed 100644 --- a/src/raft/coordination.hpp +++ b/src/raft/coordination.hpp @@ -94,19 +94,14 @@ class Coordination final { /// Registers a RPC call on this node. template - void Register(std::function< - void(const typename TRequestResponse::Request::Capnp::Reader &, - typename TRequestResponse::Response::Capnp::Builder *)> - callback) { + void Register(std::function callback) { server_.Register(callback); } /// Registers an extended RPC call on this node. template - void Register(std::function< - void(const io::network::Endpoint &, - const typename TRequestResponse::Request::Capnp::Reader &, - typename TRequestResponse::Response::Capnp::Builder *)> + void Register(std::function callback) { server_.Register(callback); } diff --git a/src/raft/raft_rpc_messages.lcp b/src/raft/raft_rpc_messages.lcp index 7851d4012..bc0fd52b8 100644 --- a/src/raft/raft_rpc_messages.lcp +++ b/src/raft/raft_rpc_messages.lcp @@ -52,41 +52,7 @@ cpp<# ((leader-id :uint16_t) (term :uint64_t) (snapshot-metadata "raft::SnapshotMetadata" :capnp-type "Snap.SnapshotMetadata") - (data "std::unique_ptr" - :initarg :move - :capnp-type "Data" - :capnp-init nil - :capnp-save (lambda (builder member capnp-name) - #>cpp - auto data_builder = ${builder}->initData(self.size); - memcpy(data_builder.begin(), ${member}.get(), self.size); - cpp<#) - :slk-save (lambda (member) - #>cpp - slk::Save(self.size, builder); - for (uint32_t i = 0; i < self.size; ++i) { - slk::Save(self.data[i], builder); - } - cpp<#) - :capnp-load (lambda (reader member capnp-name) - (declare (ignore capnp-name)) - #>cpp - auto data_reader = ${reader}.getData(); - self->size = data_reader.size(); - ${member}.reset(new char[self->size]); - memcpy(${member}.get(), data_reader.begin(), self->size); - cpp<#) - :slk-load (lambda (member) - #>cpp - slk::Load(&self->size, reader); - self->data.reset(new char[self->size]); - for (uint32_t i = 0; i < self->size; ++i) { - uint8_t curr; - slk::Load(&curr, reader); - self->data[i] = curr; - } - cpp<#)) - (size :uint32_t :dont-save t))) + (data "std::string"))) (:response ((term :uint64_t)))) diff --git a/src/raft/raft_server.cpp b/src/raft/raft_server.cpp index c2b40a674..942508798 100644 --- a/src/raft/raft_server.cpp +++ b/src/raft/raft_server.cpp @@ -83,17 +83,17 @@ void RaftServer::Start() { // RPC registration coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { std::lock_guard guard(lock_); RequestVoteReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); // [Raft paper 5.1] // "If a server recieves a request with a stale term, // it rejects the request" if (exiting_ || req.term < current_term_) { RequestVoteRes res(false, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); return; } @@ -118,21 +118,21 @@ void RaftServer::Start() { last_entry_data.first, last_entry_data.second); RequestVoteRes res(grant_vote, current_term_); if (grant_vote) SetNextElectionTimePoint(); - Save(res, res_builder); + slk::Save(res, res_builder); }); - coordination_->Register([this](const auto &req_reader, + coordination_->Register([this](auto *req_reader, auto *res_builder) { std::lock_guard guard(lock_); AppendEntriesReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); // [Raft paper 5.1] // "If a server receives a request with a stale term, it rejects the // request" if (exiting_ || req.term < current_term_) { AppendEntriesRes res(false, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); return; } @@ -172,7 +172,7 @@ void RaftServer::Start() { snapshot_metadata->last_included_index == req.prev_log_index) { if (req.prev_log_term != snapshot_metadata->last_included_term) { AppendEntriesRes res(false, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); return; } } else if (snapshot_metadata && @@ -180,13 +180,13 @@ void RaftServer::Start() { LOG(ERROR) << "Received entries that are already commited and have been " "compacted"; AppendEntriesRes res(false, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); return; } else { if (log_size_ <= req.prev_log_index || GetLogEntry(req.prev_log_index).term != req.prev_log_term) { AppendEntriesRes res(false, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); return; } } @@ -204,24 +204,24 @@ void RaftServer::Start() { // Respond positively to a heartbeat. if (req.entries.empty()) { AppendEntriesRes res(true, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); if (mode_ != Mode::FOLLOWER) Transition(Mode::FOLLOWER); return; } AppendEntriesRes res(true, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { std::lock_guard guard(lock_); HeartbeatReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); if (exiting_ || req.term < current_term_) { HeartbeatRes res(false, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); return; } @@ -234,21 +234,21 @@ void RaftServer::Start() { election_change_.notify_all(); HeartbeatRes res(true, current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { // Acquire snapshot lock. std::lock_guard snapshot_guard(snapshot_lock_); std::lock_guard guard(lock_); InstallSnapshotReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); if (exiting_ || req.term < current_term_) { InstallSnapshotRes res(current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); return; } @@ -256,7 +256,7 @@ void RaftServer::Start() { if (req.snapshot_metadata.last_included_index == last_applied_ && req.snapshot_metadata.last_included_term == current_term_) { InstallSnapshotRes res(current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); return; } @@ -288,7 +288,7 @@ void RaftServer::Start() { VLOG(40) << "[InstallSnapshotRpc] Saving received snapshot."; std::ofstream output_stream; output_stream.open(snapshot_path, std::ios::out | std::ios::binary); - output_stream.write(req.data.get(), req.size); + output_stream.write(req.data.data(), req.data.size()); output_stream.flush(); output_stream.close(); } @@ -317,7 +317,7 @@ void RaftServer::Start() { SetLogSize(req.snapshot_metadata.last_included_index + 1); InstallSnapshotRes res(current_term_); - Save(res, res_builder); + slk::Save(res, res_builder); }); // start threads @@ -797,8 +797,12 @@ void RaftServer::SendSnapshot(uint16_t peer_id, const SnapshotMetadata &snapshot_metadata, std::unique_lock *lock) { uint64_t request_term = current_term_; - uint32_t snapshot_size = 0; - std::unique_ptr snapshot; + std::string snapshot_data; + + // TODO: The snapshot is currently sent all at once. Because the snapshot file + // can be extremely large (>100GB, it contains the whole database) it must be + // sent out in chunks! Reimplement this logic so that it sends out the + // snapshot in chunks. { const auto snapshot_path = durability::MakeSnapshotPath( @@ -807,19 +811,19 @@ void RaftServer::SendSnapshot(uint16_t peer_id, std::ifstream input_stream; input_stream.open(snapshot_path, std::ios::in | std::ios::binary); input_stream.seekg(0, std::ios::end); - snapshot_size = input_stream.tellg(); + uint64_t snapshot_size = input_stream.tellg(); - snapshot.reset(new char[snapshot_size]); + snapshot_data = std::string(snapshot_size, '\0'); input_stream.seekg(0, std::ios::beg); - input_stream.read(snapshot.get(), snapshot_size); + input_stream.read(snapshot_data.data(), snapshot_size); input_stream.close(); } VLOG(40) << "Server " << server_id_ << ": Sending Snapshot RPC to server " << peer_id << " (Term: " << current_term_ << ")"; - VLOG(40) << "Snapshot size: " << snapshot_size << " bytes."; + VLOG(40) << "Snapshot size: " << snapshot_data.size() << " bytes."; // Copy all internal variables before releasing the lock. auto server_id = server_id_; @@ -827,8 +831,8 @@ void RaftServer::SendSnapshot(uint16_t peer_id, // Execute the RPC. lock->unlock(); auto reply = coordination_->ExecuteOnOtherNode( - peer_id, server_id, request_term, snapshot_metadata, std::move(snapshot), - snapshot_size); + peer_id, server_id, request_term, snapshot_metadata, + std::move(snapshot_data)); lock->lock(); if (!reply) { diff --git a/src/raft/storage_info.cpp b/src/raft/storage_info.cpp index 000584458..31122b955 100644 --- a/src/raft/storage_info.cpp +++ b/src/raft/storage_info.cpp @@ -25,12 +25,12 @@ StorageInfo::~StorageInfo() {} void StorageInfo::Start() { coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { StorageInfoReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); StorageInfoRes res(this->server_id_, this->GetLocalStorageInfo()); - Save(res, res_builder); + slk::Save(res, res_builder); }); } diff --git a/src/slk/serialization.hpp b/src/slk/serialization.hpp index eefd1a8ba..aae3b2916 100644 --- a/src/slk/serialization.hpp +++ b/src/slk/serialization.hpp @@ -97,6 +97,7 @@ void Load( } MAKE_PRIMITIVE_SAVE(bool) +MAKE_PRIMITIVE_SAVE(char) MAKE_PRIMITIVE_SAVE(int8_t) MAKE_PRIMITIVE_SAVE(uint8_t) MAKE_PRIMITIVE_SAVE(int16_t) @@ -116,6 +117,7 @@ MAKE_PRIMITIVE_SAVE(double) } MAKE_PRIMITIVE_LOAD(bool) +MAKE_PRIMITIVE_LOAD(char) MAKE_PRIMITIVE_LOAD(int8_t) MAKE_PRIMITIVE_LOAD(uint8_t) MAKE_PRIMITIVE_LOAD(int16_t) diff --git a/src/storage/distributed/concurrent_id_mapper_master.cpp b/src/storage/distributed/concurrent_id_mapper_master.cpp index 791a74f4f..45be8eb03 100644 --- a/src/storage/distributed/concurrent_id_mapper_master.cpp +++ b/src/storage/distributed/concurrent_id_mapper_master.cpp @@ -16,18 +16,18 @@ void RegisterRpc(MasterConcurrentIdMapper *mapper, void RegisterRpc(MasterConcurrentIdMapper * mapper, \ distributed::Coordination * coordination) { \ coordination->Register( \ - [mapper](const auto &req_reader, auto *res_builder) { \ + [mapper](auto *req_reader, auto *res_builder) { \ type##IdReq req; \ - Load(&req, req_reader); \ + slk::Load(&req, req_reader); \ type##IdRes res(mapper->value_to_id(req.member)); \ - Save(res, res_builder); \ + slk::Save(res, res_builder); \ }); \ coordination->Register( \ - [mapper](const auto &req_reader, auto *res_builder) { \ + [mapper](auto *req_reader, auto *res_builder) { \ Id##type##Req req; \ - Load(&req, req_reader); \ + slk::Load(&req, req_reader); \ Id##type##Res res(mapper->id_to_value(req.member)); \ - Save(res, res_builder); \ + slk::Save(res, res_builder); \ }); \ } diff --git a/src/storage/distributed/rpc/serialization.cpp b/src/storage/distributed/rpc/serialization.cpp index 5eab2364d..8e3be03f0 100644 --- a/src/storage/distributed/rpc/serialization.cpp +++ b/src/storage/distributed/rpc/serialization.cpp @@ -357,6 +357,7 @@ template void SaveRecordAccessor(const TRecordAccessor &accessor, slk::Builder *builder, storage::SendVersions versions, int16_t worker_id) { bool reconstructed = false; + auto guard = storage::GetDataLock(accessor); if (!accessor.GetOld() && !accessor.GetNew()) { reconstructed = true; bool result = accessor.Reconstruct(); diff --git a/src/storage/distributed/storage_gc_master.hpp b/src/storage/distributed/storage_gc_master.hpp index 526ce2a14..d9f01fed5 100644 --- a/src/storage/distributed/storage_gc_master.hpp +++ b/src/storage/distributed/storage_gc_master.hpp @@ -26,11 +26,13 @@ class StorageGcMaster final : public StorageGcDistributed { : StorageGcDistributed(storage, tx_engine, pause_sec), coordination_(coordination) { coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { distributed::RanLocalGcReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); std::unique_lock lock(worker_safe_transaction_mutex_); worker_safe_transaction_[req.worker_id] = req.local_oldest_active; + distributed::RanLocalGcRes res; + slk::Save(res, res_builder); }); } diff --git a/src/transactions/distributed/engine_master.cpp b/src/transactions/distributed/engine_master.cpp index 6c040220c..eff4177e3 100644 --- a/src/transactions/distributed/engine_master.cpp +++ b/src/transactions/distributed/engine_master.cpp @@ -13,72 +13,99 @@ EngineMaster::EngineMaster(distributed::Coordination *coordination, durability::WriteAheadLog *wal) : engine_single_node_(wal), coordination_(coordination) { coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { + BeginReq req; + slk::Load(&req, req_reader); auto tx = this->Begin(); BeginRes res(TxAndSnapshot{tx->id_, tx->snapshot()}); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { - AdvanceRes res(this->Advance(req_reader.getMember())); - Save(res, res_builder); + [this](auto *req_reader, auto *res_builder) { + AdvanceReq req; + slk::Load(&req, req_reader); + AdvanceRes res(this->Advance(req.member)); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { - this->Commit(*this->RunningTransaction(req_reader.getMember())); + [this](auto *req_reader, auto *res_builder) { + CommitReq req; + slk::Load(&req, req_reader); + this->Commit(*this->RunningTransaction(req.member)); + CommitRes res; + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { - this->Abort(*this->RunningTransaction(req_reader.getMember())); + [this](auto *req_reader, auto *res_builder) { + AbortReq req; + slk::Load(&req, req_reader); + this->Abort(*this->RunningTransaction(req.member)); + AbortRes res; + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { // It is guaranteed that the Worker will not be requesting this for a // transaction that's done, and that there are no race conditions here. - SnapshotRes res( - this->RunningTransaction(req_reader.getMember())->snapshot()); - Save(res, res_builder); + SnapshotReq req; + slk::Load(&req, req_reader); + SnapshotRes res(this->RunningTransaction(req.member)->snapshot()); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { // It is guaranteed that the Worker will not be requesting this for a // transaction that's done, and that there are no race conditions here. - CommandRes res(this->RunningTransaction(req_reader.getMember())->cid()); - Save(res, res_builder); + CommandReq req; + slk::Load(&req, req_reader); + CommandRes res(this->RunningTransaction(req.member)->cid()); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { + GcSnapshotReq req; + slk::Load(&req, req_reader); GcSnapshotRes res(this->GlobalGcSnapshot()); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { - ClogInfoRes res(this->Info(req_reader.getMember())); - Save(res, res_builder); + [this](auto *req_reader, auto *res_builder) { + ClogInfoReq req; + slk::Load(&req, req_reader); + ClogInfoRes res(this->Info(req.member)); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { + ActiveTransactionsReq req; + slk::Load(&req, req_reader); ActiveTransactionsRes res(this->GlobalActiveTransactions()); - Save(res, res_builder); + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { - this->EnsureNextIdGreater(req_reader.getMember()); + [this](auto *req_reader, auto *res_builder) { + EnsureNextIdGreaterReq req; + slk::Load(&req, req_reader); + this->EnsureNextIdGreater(req.member); + EnsureNextIdGreaterRes res; + slk::Save(res, res_builder); }); coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { + [this](auto *req_reader, auto *res_builder) { + GlobalLastReq req; + slk::Load(&req, req_reader); GlobalLastRes res(this->GlobalLast()); - Save(res, res_builder); + slk::Save(res, res_builder); }); } diff --git a/src/transactions/distributed/engine_worker.cpp b/src/transactions/distributed/engine_worker.cpp index 762d7dd47..a8c0ec184 100644 --- a/src/transactions/distributed/engine_worker.cpp +++ b/src/transactions/distributed/engine_worker.cpp @@ -29,11 +29,14 @@ EngineWorker::EngineWorker(distributed::Coordination *coordination, // aborted. This mismatch in committed/aborted across workers is resolved by // using the master as a single source of truth when doing recovery. coordination_->Register( - [this](const auto &req_reader, auto *res_builder) { - auto tid = req_reader.getMember(); + [this](auto *req_reader, auto *res_builder) { + NotifyCommittedReq req; + slk::Load(&req, req_reader); if (wal_) { - wal_->Emplace(database::StateDelta::TxCommit(tid)); + wal_->Emplace(database::StateDelta::TxCommit(req.member)); } + NotifyCommittedRes res; + slk::Save(res, res_builder); }); } diff --git a/tests/benchmark/rpc.cpp b/tests/benchmark/rpc.cpp index 881b405b6..8380de767 100644 --- a/tests/benchmark/rpc.cpp +++ b/tests/benchmark/rpc.cpp @@ -9,26 +9,33 @@ #include "communication/rpc/client_pool.hpp" #include "communication/rpc/messages.hpp" #include "communication/rpc/server.hpp" +#include "slk/serialization.hpp" #include "utils/timer.hpp" struct EchoMessage { - using Capnp = ::capnp::AnyPointer; static const utils::TypeInfo kType; EchoMessage() {} // Needed for serialization. EchoMessage(const std::string &data) : data(data) {} + static void Load(EchoMessage *obj, slk::Reader *reader); + static void Save(const EchoMessage &obj, slk::Builder *builder); + std::string data; }; -void Save(const EchoMessage &echo, ::capnp::AnyPointer::Builder *builder) { - auto list_builder = builder->initAs<::capnp::List<::capnp::Text>>(1); - list_builder.set(0, echo.data); +namespace slk { +void Save(const EchoMessage &echo, Builder *builder) { + Save(echo.data, builder); } +void Load(EchoMessage *echo, Reader *reader) { Load(&echo->data, reader); } +} // namespace slk -void Load(EchoMessage *echo, const ::capnp::AnyPointer::Reader &reader) { - auto list_reader = reader.getAs<::capnp::List<::capnp::Text>>(); - echo->data = list_reader[0]; +void EchoMessage::Load(EchoMessage *obj, slk::Reader *reader) { + slk::Load(obj, reader); +} +void EchoMessage::Save(const EchoMessage &obj, slk::Builder *builder) { + slk::Save(obj, builder); } const utils::TypeInfo EchoMessage::kType{2, "EchoMessage"}; diff --git a/tests/unit/rpc.cpp b/tests/unit/rpc.cpp index 6247f67bd..7ada41543 100644 --- a/tests/unit/rpc.cpp +++ b/tests/unit/rpc.cpp @@ -10,88 +10,57 @@ #include "communication/rpc/server.hpp" #include "utils/timer.hpp" +#include "rpc_messages.hpp" + using namespace communication::rpc; using namespace std::literals::chrono_literals; -struct SumReq { - using Capnp = ::capnp::AnyPointer; - static const utils::TypeInfo kType; - - SumReq() {} // Needed for serialization. - SumReq(int x, int y) : x(x), y(y) {} - int x; - int y; -}; - -void Save(const SumReq &sum, ::capnp::AnyPointer::Builder *builder) { - auto list_builder = builder->initAs<::capnp::List>(2); - list_builder.set(0, sum.x); - list_builder.set(1, sum.y); +namespace slk { +void Save(const SumReq &sum, Builder *builder) { + Save(sum.x, builder); + Save(sum.y, builder); } -void Load(SumReq *sum, const ::capnp::AnyPointer::Reader &reader) { - auto list_reader = reader.getAs<::capnp::List>(); - sum->x = list_reader[0]; - sum->y = list_reader[1]; +void Load(SumReq *sum, Reader *reader) { + Load(&sum->x, reader); + Load(&sum->y, reader); } -const utils::TypeInfo SumReq::kType{0, "SumReq"}; +void Save(const SumRes &res, Builder *builder) { Save(res.sum, builder); } -struct SumRes { - using Capnp = ::capnp::AnyPointer; - static const utils::TypeInfo kType; +void Load(SumRes *res, Reader *reader) { Load(&res->sum, reader); } - SumRes() {} // Needed for serialization. - SumRes(int sum) : sum(sum) {} - - int sum; -}; - -void Save(const SumRes &res, ::capnp::AnyPointer::Builder *builder) { - auto list_builder = builder->initAs<::capnp::List>(1); - list_builder.set(0, res.sum); +void Save(const EchoMessage &echo, Builder *builder) { + Save(echo.data, builder); } -void Load(SumRes *res, const ::capnp::AnyPointer::Reader &reader) { - auto list_reader = reader.getAs<::capnp::List>(); - res->sum = list_reader[0]; +void Load(EchoMessage *echo, Reader *reader) { Load(&echo->data, reader); } +} // namespace slk + +void SumReq::Load(SumReq *obj, slk::Reader *reader) { slk::Load(obj, reader); } +void SumReq::Save(const SumReq &obj, slk::Builder *builder) { + slk::Save(obj, builder); } -const utils::TypeInfo SumRes::kType{1, "SumRes"}; - -using Sum = RequestResponse; - -struct EchoMessage { - using Capnp = ::capnp::AnyPointer; - static const utils::TypeInfo kType; - - EchoMessage() {} // Needed for serialization. - EchoMessage(const std::string &data) : data(data) {} - - std::string data; -}; - -void Save(const EchoMessage &echo, ::capnp::AnyPointer::Builder *builder) { - auto list_builder = builder->initAs<::capnp::List<::capnp::Text>>(1); - list_builder.set(0, echo.data); +void SumRes::Load(SumRes *obj, slk::Reader *reader) { slk::Load(obj, reader); } +void SumRes::Save(const SumRes &obj, slk::Builder *builder) { + slk::Save(obj, builder); } -void Load(EchoMessage *echo, const ::capnp::AnyPointer::Reader &reader) { - auto list_reader = reader.getAs<::capnp::List<::capnp::Text>>(); - echo->data = list_reader[0]; +void EchoMessage::Load(EchoMessage *obj, slk::Reader *reader) { + slk::Load(obj, reader); +} +void EchoMessage::Save(const EchoMessage &obj, slk::Builder *builder) { + slk::Save(obj, builder); } - -const utils::TypeInfo EchoMessage::kType{2, "EchoMessage"}; - -using Echo = RequestResponse; TEST(Rpc, Call) { Server server({"127.0.0.1", 0}); - server.Register([](const auto &req_reader, auto *res_builder) { + server.Register([](auto *req_reader, auto *res_builder) { SumReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); SumRes res(req.x + req.y); - Save(res, res_builder); + slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); @@ -106,12 +75,12 @@ TEST(Rpc, Call) { TEST(Rpc, Abort) { Server server({"127.0.0.1", 0}); - server.Register([](const auto &req_reader, auto *res_builder) { + server.Register([](auto *req_reader, auto *res_builder) { SumReq req; - Load(&req, req_reader); + slk::Load(&req, req_reader); std::this_thread::sleep_for(500ms); SumRes res(req.x + req.y); - Save(res, res_builder); + slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); @@ -149,8 +118,8 @@ TEST(Rpc, ClientPool) { Client client(server.endpoint()); - /* these calls should take more than 400ms because we're using a regular - * client */ + // These calls should take more than 400ms because we're using a regular + // client auto get_sum_client = [&client](int x, int y) { auto sum = client.Call(x, y); EXPECT_EQ(sum.sum, x + y); @@ -170,8 +139,8 @@ TEST(Rpc, ClientPool) { ClientPool pool(server.endpoint()); - /* these calls shouldn't take much more that 100ms because they execute in - * parallel */ + // These calls shouldn't take much more that 100ms because they execute in + // parallel auto get_sum = [&pool](int x, int y) { auto sum = pool.Call(x, y); EXPECT_EQ(sum.sum, x + y); @@ -192,10 +161,10 @@ TEST(Rpc, ClientPool) { TEST(Rpc, LargeMessage) { Server server({"127.0.0.1", 0}); - server.Register([](const auto &req_reader, auto *res_builder) { + server.Register([](auto *req_reader, auto *res_builder) { EchoMessage res; - Load(&res, req_reader); - Save(res, res_builder); + slk::Load(&res, req_reader); + slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); @@ -209,3 +178,24 @@ TEST(Rpc, LargeMessage) { server.Shutdown(); server.AwaitShutdown(); } + +TEST(Rpc, JumboMessage) { + Server server({"127.0.0.1", 0}); + server.Register([](auto *req_reader, auto *res_builder) { + EchoMessage res; + slk::Load(&res, req_reader); + slk::Save(res, res_builder); + }); + ASSERT_TRUE(server.Start()); + std::this_thread::sleep_for(100ms); + + // NOLINTNEXTLINE (bugprone-string-constructor) + std::string testdata(10000000, 'a'); + + Client client(server.endpoint()); + auto echo = client.Call(testdata); + EXPECT_EQ(echo.data, testdata); + + server.Shutdown(); + server.AwaitShutdown(); +} diff --git a/tests/unit/rpc_messages.hpp b/tests/unit/rpc_messages.hpp new file mode 100644 index 000000000..d031ce501 --- /dev/null +++ b/tests/unit/rpc_messages.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include "communication/rpc/messages.hpp" +#include "slk/serialization.hpp" +#include "utils/typeinfo.hpp" + +struct SumReq { + static const utils::TypeInfo kType; + + SumReq() {} // Needed for serialization. + SumReq(int x, int y) : x(x), y(y) {} + + static void Load(SumReq *obj, slk::Reader *reader); + static void Save(const SumReq &obj, slk::Builder *builder); + + int x; + int y; +}; + +const utils::TypeInfo SumReq::kType{0, "SumReq"}; + +struct SumRes { + static const utils::TypeInfo kType; + + SumRes() {} // Needed for serialization. + SumRes(int sum) : sum(sum) {} + + static void Load(SumRes *obj, slk::Reader *reader); + static void Save(const SumRes &obj, slk::Builder *builder); + + int sum; +}; + +const utils::TypeInfo SumRes::kType{1, "SumRes"}; + +namespace slk { +void Save(const SumReq &sum, Builder *builder); +void Load(SumReq *sum, Reader *reader); + +void Save(const SumRes &res, Builder *builder); +void Load(SumRes *res, Reader *reader); +} // namespace slk + +using Sum = communication::rpc::RequestResponse; + +struct EchoMessage { + static const utils::TypeInfo kType; + + EchoMessage() {} // Needed for serialization. + EchoMessage(const std::string &data) : data(data) {} + + static void Load(EchoMessage *obj, slk::Reader *reader); + static void Save(const EchoMessage &obj, slk::Builder *builder); + + std::string data; +}; + +const utils::TypeInfo EchoMessage::kType{2, "EchoMessage"}; + +namespace slk { +void Save(const EchoMessage &echo, Builder *builder); +void Load(EchoMessage *echo, Reader *reader); +} // namespace slk + +using Echo = communication::rpc::RequestResponse;