Migrate RPC to SLK

Summary:
Migrate all RPCs
Simplify Raft InstallSnapshot RPC
Add missing Load and Save for `char`

Reviewers: teon.banek, msantl

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2001
This commit is contained in:
Matej Ferencevic 2019-05-06 13:35:22 +02:00
parent 5833e6cc0f
commit d678e45c10
39 changed files with 543 additions and 553 deletions

View File

@ -83,10 +83,11 @@ void Client::Close() {
socket_.Close();
}
bool Client::Read(size_t len) {
bool Client::Read(size_t len, bool exactly_len) {
if (len == 0) return false;
size_t received = 0;
buffer_.write_end()->Resize(buffer_.read_end()->size() + len);
while (received < len) {
do {
auto buff = buffer_.write_end()->Allocate();
if (ssl_) {
// We clear errors here to prevent errors piling up in the internal
@ -140,7 +141,7 @@ bool Client::Read(size_t len) {
buffer_.write_end()->Written(got);
received += got;
}
}
} while (received < len && exactly_len);
return true;
}

View File

@ -58,11 +58,12 @@ class Client final {
void Close();
/**
* This function is used to receive `len` bytes from the socket and stores it
* in an internal buffer. It returns `true` if the read succeeded and `false`
* if it didn't.
* This function is used to receive exactly `len` bytes from the socket and
* stores it in an internal buffer. If `exactly_len` is set to `false` then
* less than `len` bytes can be received. It returns `true` if the read
* succeeded and `false` if it didn't.
*/
bool Read(size_t len);
bool Read(size_t len, bool exactly_len = true);
/**
* This function returns a pointer to the read data that is currently stored

View File

@ -1,84 +1,9 @@
#include "communication/rpc/client.hpp"
#include <chrono>
#include <thread>
#include "gflags/gflags.h"
namespace communication::rpc {
Client::Client(const io::network::Endpoint &endpoint) : endpoint_(endpoint) {}
::capnp::FlatArrayMessageReader Client::Send(::capnp::MessageBuilder *message) {
std::lock_guard<std::mutex> guard(mutex_);
// Check if the connection is broken (if we haven't used the client for a
// long time the server could have died).
if (client_ && client_->ErrorStatus()) {
client_ = std::nullopt;
}
// Connect to the remote server.
if (!client_) {
client_.emplace(&context_);
if (!client_->Connect(endpoint_)) {
DLOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
client_ = std::nullopt;
throw RpcFailedException(endpoint_);
}
}
// Serialize and send request.
auto request_words = ::capnp::messageToFlatArray(*message);
auto request_bytes = request_words.asBytes();
CHECK(request_bytes.size() <= std::numeric_limits<MessageSize>::max())
<< fmt::format(
"Trying to send message of size {}, max message size is {}",
request_bytes.size(), std::numeric_limits<MessageSize>::max());
MessageSize request_data_size = request_bytes.size();
if (!client_->Write(reinterpret_cast<uint8_t *>(&request_data_size),
sizeof(MessageSize), true)) {
DLOG(ERROR) << "Couldn't send request size to " << client_->endpoint();
client_ = std::nullopt;
throw RpcFailedException(endpoint_);
}
if (!client_->Write(request_bytes.begin(), request_bytes.size())) {
DLOG(ERROR) << "Couldn't send request data to " << client_->endpoint();
client_ = std::nullopt;
throw RpcFailedException(endpoint_);
}
// Receive response data size.
if (!client_->Read(sizeof(MessageSize))) {
DLOG(ERROR) << "Couldn't get response from " << client_->endpoint();
client_ = std::nullopt;
throw RpcFailedException(endpoint_);
}
MessageSize response_data_size =
*reinterpret_cast<MessageSize *>(client_->GetData());
client_->ShiftData(sizeof(MessageSize));
// Receive response data.
if (!client_->Read(response_data_size)) {
DLOG(ERROR) << "Couldn't get response from " << client_->endpoint();
client_ = std::nullopt;
throw RpcFailedException(endpoint_);
}
// Read the response message.
auto data = ::kj::arrayPtr(client_->GetData(), response_data_size);
// Our data is word aligned and padded to 64bit because we use regular
// (non-packed) serialization of Cap'n Proto. So we can use reinterpret_cast.
auto data_words =
::kj::arrayPtr(reinterpret_cast<::capnp::word *>(data.begin()),
reinterpret_cast<::capnp::word *>(data.end()));
::capnp::FlatArrayMessageReader response_message(data_words.asConst());
client_->ShiftData(response_data_size);
return response_message;
}
void Client::Abort() {
if (!client_) return;
// We need to call Shutdown on the client to abort any pending read or

View File

@ -10,10 +10,11 @@
#include "communication/client.hpp"
#include "communication/rpc/exceptions.hpp"
#include "communication/rpc/messages.capnp.h"
#include "communication/rpc/messages.hpp"
#include "io/network/endpoint.hpp"
#include "utils/demangle.hpp"
#include "slk/serialization.hpp"
#include "slk/streams.hpp"
#include "utils/on_scope_exit.hpp"
namespace communication::rpc {
@ -34,9 +35,9 @@ class Client {
template <class TRequestResponse, class... Args>
typename TRequestResponse::Response Call(Args &&... args) {
return CallWithLoad<TRequestResponse>(
[](const auto &reader) {
[](auto *reader) {
typename TRequestResponse::Response response;
Load(&response, reader);
TRequestResponse::Response::Load(&response, reader);
return response;
},
std::forward<Args>(args)...);
@ -45,29 +46,68 @@ class Client {
/// Same as `Call` but the first argument is a response loading function.
template <class TRequestResponse, class... Args>
typename TRequestResponse::Response CallWithLoad(
std::function<typename TRequestResponse::Response(
const typename TRequestResponse::Response::Capnp::Reader &)>
load,
std::function<typename TRequestResponse::Response(slk::Reader *)> load,
Args &&... args) {
typename TRequestResponse::Request request(std::forward<Args>(args)...);
auto req_type = TRequestResponse::Request::kType;
VLOG(12) << "[RpcClient] sent " << req_type.name;
::capnp::MallocMessageBuilder req_msg;
{
auto builder = req_msg.initRoot<capnp::Message>();
builder.setTypeId(req_type.id);
auto data_builder = builder.initData();
auto req_builder =
data_builder
.template initAs<typename TRequestResponse::Request::Capnp>();
Save(request, &req_builder);
}
auto response = Send(&req_msg);
auto res_msg = response.getRoot<capnp::Message>();
auto res_type = TRequestResponse::Response::kType;
if (res_msg.getTypeId() != res_type.id) {
// Since message_id was checked in private Call function, this means
// something is very wrong (probably on the server side).
VLOG(12) << "[RpcClient] sent " << req_type.name;
std::lock_guard<std::mutex> guard(mutex_);
// Check if the connection is broken (if we haven't used the client for a
// long time the server could have died).
if (client_ && client_->ErrorStatus()) {
client_ = std::nullopt;
}
// Connect to the remote server.
if (!client_) {
client_.emplace(&context_);
if (!client_->Connect(endpoint_)) {
DLOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
client_ = std::nullopt;
throw RpcFailedException(endpoint_);
}
}
// Build and send the request.
slk::Builder req_builder(
[&](const uint8_t *data, size_t size, bool have_more) {
client_->Write(data, size, have_more);
});
slk::Save(req_type.id, &req_builder);
TRequestResponse::Request::Save(request, &req_builder);
req_builder.Finalize();
// Receive response.
uint64_t response_data_size = 0;
while (true) {
auto ret =
slk::CheckStreamComplete(client_->GetData(), client_->GetDataSize());
if (ret.status == slk::StreamStatus::INVALID) {
throw RpcFailedException(endpoint_);
} else if (ret.status == slk::StreamStatus::PARTIAL) {
if (!client_->Read(ret.stream_size - client_->GetDataSize(),
/* exactly_len = */ false)) {
throw RpcFailedException(endpoint_);
}
} else {
response_data_size = ret.stream_size;
break;
}
}
// Load the response.
slk::Reader res_reader(client_->GetData(), response_data_size);
utils::OnScopeExit res_cleanup(
[&, response_data_size] { client_->ShiftData(response_data_size); });
uint64_t res_id = 0;
slk::Load(&res_id, &res_reader);
// Check response ID.
if (res_id != res_type.id) {
LOG(ERROR) << "Message response was of unexpected type";
client_ = std::nullopt;
throw RpcFailedException(endpoint_);
@ -75,18 +115,13 @@ class Client {
VLOG(12) << "[RpcClient] received " << res_type.name;
auto data_reader =
res_msg.getData()
.template getAs<typename TRequestResponse::Response::Capnp>();
return load(data_reader);
return load(&res_reader);
}
/// Call this function from another thread to abort a pending RPC call.
void Abort();
private:
::capnp::FlatArrayMessageReader Send(::capnp::MessageBuilder *message);
io::network::Endpoint endpoint_;
// TODO (mferencevic): currently the RPC client is hardcoded not to use SSL
communication::ClientContext context_;

View File

@ -23,19 +23,15 @@ class ClientPool {
return client->template Call<TRequestResponse>(
std::forward<Args>(args)...);
});
};
}
template <class TRequestResponse, class... Args>
typename TRequestResponse::Response CallWithLoad(
std::function<typename TRequestResponse::Response(
const typename TRequestResponse::Response::Capnp::Reader &)>
load,
Args &&... args) {
typename TRequestResponse::Response CallWithLoad(Args &&... args) {
return WithUnusedClient([&](const auto &client) {
return client->template CallWithLoad<TRequestResponse>(
load, std::forward<Args>(args)...);
std::forward<Args>(args)...);
});
};
}
private:
template <class TFun>

View File

@ -14,11 +14,9 @@ using MessageSize = uint32_t;
/// `TRequest` and `TResponse` are required to be classes which have a static
/// member `kType` of `utils::TypeInfo` type. This is used for proper
/// registration and deserialization of RPC types. Additionally, both `TRequest`
/// and `TResponse` are required to define a nested `Capnp` type, which
/// corresponds to the Cap'n Proto schema type, as well as defined the following
/// serialization functions:
/// * void Save(const TRequest|TResponse &, Capnp::Builder *, ...)
/// * void Load(const Capnp::Reader &, ...)
/// and `TResponse` are required to define the following serialization functions:
/// * static void Save(const TRequest|TResponse &, slk::Builder *, ...)
/// * static void Load(TRequest|TResponse *, slk::Reader *, ...)
template <typename TRequest, typename TResponse>
struct RequestResponse {
using Request = TRequest;

View File

@ -1,14 +1,10 @@
#include <sstream>
#include "capnp/message.h"
#include "capnp/serialize.h"
#include "fmt/format.h"
#include "communication/rpc/messages.capnp.h"
#include "communication/rpc/messages.hpp"
#include "communication/rpc/protocol.hpp"
#include "communication/rpc/messages.hpp"
#include "communication/rpc/server.hpp"
#include "utils/demangle.hpp"
#include "slk/serialization.hpp"
#include "slk/streams.hpp"
#include "utils/on_scope_exit.hpp"
namespace communication::rpc {
@ -21,39 +17,40 @@ Session::Session(Server *server, const io::network::Endpoint &endpoint,
output_stream_(output_stream) {}
void Session::Execute() {
if (input_stream_->size() < sizeof(MessageSize)) return;
MessageSize request_len =
*reinterpret_cast<MessageSize *>(input_stream_->data());
uint64_t request_size = sizeof(MessageSize) + request_len;
input_stream_->Resize(request_size);
if (input_stream_->size() < request_size) return;
auto ret =
slk::CheckStreamComplete(input_stream_->data(), input_stream_->size());
if (ret.status == slk::StreamStatus::INVALID) {
throw SessionException("Received an invalid SLK stream!");
} else if (ret.status == slk::StreamStatus::PARTIAL) {
input_stream_->Resize(ret.stream_size);
return;
}
// Read the request message.
auto data =
::kj::arrayPtr(input_stream_->data() + sizeof(request_len), request_len);
// Our data is word aligned and padded to 64bit because we use regular
// (non-packed) serialization of Cap'n Proto. So we can use reinterpret_cast.
auto data_words =
::kj::arrayPtr(reinterpret_cast<::capnp::word *>(data.begin()),
reinterpret_cast<::capnp::word *>(data.end()));
::capnp::FlatArrayMessageReader request_message(data_words.asConst());
auto request = request_message.getRoot<capnp::Message>();
input_stream_->Shift(sizeof(MessageSize) + request_len);
// Remove the data from the stream on scope exit.
utils::OnScopeExit shift_data(
[&, ret] { input_stream_->Shift(ret.stream_size); });
::capnp::MallocMessageBuilder response_message;
// callback fills the message data
auto response_builder = response_message.initRoot<capnp::Message>();
// Prepare SLK reader and builder.
slk::Reader req_reader(input_stream_->data(), input_stream_->size());
slk::Builder res_builder(
[&](const uint8_t *data, size_t size, bool have_more) {
output_stream_->Write(data, size, have_more);
});
// Load the request ID.
uint64_t req_id = 0;
slk::Load(&req_id, &req_reader);
// Access to `callbacks_` and `extended_callbacks_` is done here without
// acquiring the `mutex_` because we don't allow RPC registration after the
// server was started so those two maps will never be updated when we `find`
// over them.
auto it = server_->callbacks_.find(request.getTypeId());
auto it = server_->callbacks_.find(req_id);
auto extended_it = server_->extended_callbacks_.end();
if (it == server_->callbacks_.end()) {
// We couldn't find a regular callback to call, try to find an extended
// callback to call.
extended_it = server_->extended_callbacks_.find(request.getTypeId());
extended_it = server_->extended_callbacks_.find(req_id);
if (extended_it == server_->extended_callbacks_.end()) {
// Throw exception to close the socket and cleanup the session.
@ -61,29 +58,17 @@ void Session::Execute() {
"Session trying to execute an unregistered RPC call!");
}
VLOG(12) << "[RpcServer] received " << extended_it->second.req_type.name;
extended_it->second.callback(endpoint_, request, &response_builder);
slk::Save(extended_it->second.res_type.id, &res_builder);
extended_it->second.callback(endpoint_, &req_reader, &res_builder);
} else {
VLOG(12) << "[RpcServer] received " << it->second.req_type.name;
it->second.callback(request, &response_builder);
slk::Save(it->second.res_type.id, &res_builder);
it->second.callback(&req_reader, &res_builder);
}
// Serialize and send response
auto response_words = ::capnp::messageToFlatArray(response_message);
auto response_bytes = response_words.asBytes();
if (response_bytes.size() > std::numeric_limits<MessageSize>::max()) {
throw SessionException(fmt::format(
"Trying to send response of size {}, max response size is {}",
response_bytes.size(), std::numeric_limits<MessageSize>::max()));
}
MessageSize input_stream_size = response_bytes.size();
if (!output_stream_->Write(reinterpret_cast<uint8_t *>(&input_stream_size),
sizeof(MessageSize), true)) {
throw SessionException("Couldn't send response size!");
}
if (!output_stream_->Write(response_bytes.begin(), response_bytes.size())) {
throw SessionException("Couldn't send response data!");
}
// Finalize the SLK streams.
req_reader.Finalize();
res_builder.Finalize();
VLOG(12) << "[RpcServer] sent "
<< (it != server_->callbacks_.end()

View File

@ -4,14 +4,11 @@
#include <mutex>
#include <vector>
#include "capnp/any.h"
#include "communication/rpc/messages.capnp.h"
#include "communication/rpc/messages.hpp"
#include "communication/rpc/protocol.hpp"
#include "communication/server.hpp"
#include "io/network/endpoint.hpp"
#include "utils/demangle.hpp"
#include "slk/streams.hpp"
namespace communication::rpc {
@ -31,26 +28,14 @@ class Server {
const io::network::Endpoint &endpoint() const;
template <class TRequestResponse>
void Register(std::function<
void(const typename TRequestResponse::Request::Capnp::Reader &,
typename TRequestResponse::Response::Capnp::Builder *)>
callback) {
void Register(std::function<void(slk::Reader *, slk::Builder *)> callback) {
std::lock_guard<std::mutex> guard(lock_);
CHECK(!server_.IsRunning()) << "You can't register RPCs when the server is running!";
CHECK(!server_.IsRunning())
<< "You can't register RPCs when the server is running!";
RpcCallback rpc;
rpc.req_type = TRequestResponse::Request::kType;
rpc.res_type = TRequestResponse::Response::kType;
rpc.callback = [callback = callback](const auto &reader, auto *builder) {
auto req_data =
reader.getData()
.template getAs<typename TRequestResponse::Request::Capnp>();
builder->setTypeId(TRequestResponse::Response::kType.id);
auto data_builder = builder->initData();
auto res_builder =
data_builder
.template initAs<typename TRequestResponse::Response::Capnp>();
callback(req_data, &res_builder);
};
rpc.callback = callback;
if (extended_callbacks_.find(TRequestResponse::Request::kType.id) !=
extended_callbacks_.end()) {
@ -64,33 +49,16 @@ class Server {
}
template <class TRequestResponse>
void Register(std::function<
void(const io::network::Endpoint &,
const typename TRequestResponse::Request::Capnp::Reader &,
typename TRequestResponse::Response::Capnp::Builder *)>
void Register(std::function<void(const io::network::Endpoint &, slk::Reader *,
slk::Builder *)>
callback) {
std::lock_guard<std::mutex> guard(lock_);
CHECK(!server_.IsRunning()) << "You can't register RPCs when the server is running!";
CHECK(!server_.IsRunning())
<< "You can't register RPCs when the server is running!";
RpcExtendedCallback rpc;
rpc.req_type = TRequestResponse::Request::kType;
rpc.res_type = TRequestResponse::Response::kType;
rpc.callback = [callback = callback](const io::network::Endpoint &endpoint,
const auto &reader, auto *builder) {
auto req_data =
reader.getData()
.template getAs<typename TRequestResponse::Request::Capnp>();
builder->setTypeId(TRequestResponse::Response::kType.id);
auto data_builder = builder->initData();
auto res_builder =
data_builder
.template initAs<typename TRequestResponse::Response::Capnp>();
callback(endpoint, req_data, &res_builder);
};
if (callbacks_.find(TRequestResponse::Request::kType.id) !=
callbacks_.end()) {
LOG(FATAL) << "Callback for that message type already registered!";
}
rpc.callback = callback;
auto got =
extended_callbacks_.insert({TRequestResponse::Request::kType.id, rpc});
@ -104,17 +72,14 @@ class Server {
struct RpcCallback {
utils::TypeInfo req_type;
std::function<void(const capnp::Message::Reader &,
capnp::Message::Builder *)>
callback;
std::function<void(slk::Reader *, slk::Builder *)> callback;
utils::TypeInfo res_type;
};
struct RpcExtendedCallback {
utils::TypeInfo req_type;
std::function<void(const io::network::Endpoint &,
const capnp::Message::Reader &,
capnp::Message::Builder *)>
std::function<void(const io::network::Endpoint &, slk::Reader *,
slk::Builder *)>
callback;
utils::TypeInfo res_type;
};

View File

@ -6,14 +6,19 @@ namespace database {
MasterCounters::MasterCounters(distributed::Coordination *coordination) {
coordination->Register<CountersGetRpc>(
[this](const auto &req_reader, auto *res_builder) {
CountersGetRes res(Get(req_reader.getName()));
Save(res, res_builder);
[this](auto *req_reader, auto *res_builder) {
CountersGetReq req;
slk::Load(&req, req_reader);
CountersGetRes res(Get(req.name));
slk::Save(res, res_builder);
});
coordination->Register<CountersSetRpc>(
[this](const auto &req_reader, auto *res_builder) {
Set(req_reader.getName(), req_reader.getValue());
return std::make_unique<CountersSetRes>();
[this](auto *req_reader, auto *res_builder) {
CountersSetReq req;
slk::Load(&req, req_reader);
Set(req.name, req.value);
CountersSetRes res;
slk::Save(res, res_builder);
});
}

View File

@ -76,9 +76,9 @@ std::optional<VertexAccessor> BfsRpcClients::Pull(
auto res =
coordination_->GetClientPool(worker_id)->CallWithLoad<SubcursorPullRpc>(
[this, dba](const auto &reader) {
[this, dba](auto *res_reader) {
SubcursorPullRes res;
Load(&res, reader, dba, this->data_manager_);
slk::Load(&res, res_reader, dba, this->data_manager_);
return res;
},
subcursor_id);
@ -149,9 +149,9 @@ PathSegment BfsRpcClients::ReconstructPath(
auto res =
coordination_->GetClientPool(worker_id)->CallWithLoad<ReconstructPathRpc>(
[this, dba](const auto &reader) {
[this, dba](auto *res_reader) {
ReconstructPathRes res;
Load(&res, reader, dba, this->data_manager_);
slk::Load(&res, res_reader, dba, this->data_manager_);
return res;
},
subcursor_ids.at(worker_id), vertex);
@ -168,9 +168,9 @@ PathSegment BfsRpcClients::ReconstructPath(
}
auto res =
coordination_->GetClientPool(worker_id)->CallWithLoad<ReconstructPathRpc>(
[this, dba](const auto &reader) {
[this, dba](auto *res_reader) {
ReconstructPathRes res;
Load(&res, reader, dba, this->data_manager_);
slk::Load(&res, res_reader, dba, this->data_manager_);
return res;
},
subcursor_ids.at(worker_id), edge);

View File

@ -257,11 +257,12 @@ cpp<#
cpp<#)
:slk-load (lambda (member)
#>cpp
auto *subcursor = subcursor_storage->Get(self->subcursor_id);
size_t size;
slk::Load(&size, reader);
self->${member}.resize(size);
for (size_t i = 0; i < size; ++i) {
slk::Load(&self->${member}[i], reader, dba, data_manager);
slk::Load(&self->${member}[i], reader, subcursor->db_accessor(), data_manager);
}
cpp<#)
:capnp-type "List(Query.TypedValue)"
@ -281,7 +282,7 @@ cpp<#
return value;
}"))
(worker-id :int16_t :dont-save t))
(:serialize (:slk :load-args '((dba "database::GraphDbAccessor *")
(:serialize (:slk :load-args '((subcursor_storage "distributed::BfsSubcursorStorage *")
(data-manager "distributed::DataManager *")))
(:capnp
:load-args '((dba "database::GraphDbAccessor *")

View File

@ -21,11 +21,11 @@ class BfsRpcServer {
distributed::Coordination *coordination,
BfsSubcursorStorage *subcursor_storage)
: db_(db), subcursor_storage_(subcursor_storage) {
coordination->Register<CreateBfsSubcursorRpc>([this](const auto &req_reader,
coordination->Register<CreateBfsSubcursorRpc>([this](auto *req_reader,
auto *res_builder) {
CreateBfsSubcursorReq req;
auto ast_storage = std::make_unique<query::AstStorage>();
Load(&req, req_reader, ast_storage.get());
slk::Load(&req, req_reader, ast_storage.get());
database::GraphDbAccessor *dba;
{
std::lock_guard<std::mutex> guard(lock_);
@ -46,41 +46,41 @@ class BfsRpcServer {
dba, req.direction, req.edge_types, std::move(req.symbol_table),
std::move(ast_storage), req.filter_lambda, evaluation_context);
CreateBfsSubcursorRes res(id);
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<RegisterSubcursorsRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
RegisterSubcursorsReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
subcursor_storage_->Get(req.subcursor_ids.at(db_->WorkerId()))
->RegisterSubcursors(req.subcursor_ids);
RegisterSubcursorsRes res;
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<ResetSubcursorRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
ResetSubcursorReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
subcursor_storage_->Get(req.subcursor_id)->Reset();
ResetSubcursorRes res;
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<SetSourceRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
SetSourceReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
subcursor_storage_->Get(req.subcursor_id)->SetSource(req.source);
SetSourceRes res;
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<ExpandLevelRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
ExpandLevelReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto subcursor = subcursor_storage_->Get(req.member);
ExpandResult result;
try {
@ -90,32 +90,32 @@ class BfsRpcServer {
result = ExpandResult::LAMBDA_ERROR;
}
ExpandLevelRes res(result);
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<SubcursorPullRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
SubcursorPullReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto vertex = subcursor_storage_->Get(req.member)->Pull();
SubcursorPullRes res(vertex);
Save(res, res_builder, db_->WorkerId());
slk::Save(res, res_builder, db_->WorkerId());
});
coordination->Register<ExpandToRemoteVertexRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
ExpandToRemoteVertexReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
ExpandToRemoteVertexRes res(
subcursor_storage_->Get(req.subcursor_id)
->ExpandToLocalVertex(req.edge, req.vertex));
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<ReconstructPathRpc>([this](const auto &req_reader,
coordination->Register<ReconstructPathRpc>([this](auto *req_reader,
auto *res_builder) {
ReconstructPathReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto subcursor = subcursor_storage_->Get(req.subcursor_id);
PathSegment result;
if (req.vertex) {
@ -127,18 +127,17 @@ class BfsRpcServer {
}
ReconstructPathRes res(result.edges, result.next_vertex,
result.next_edge);
Save(res, res_builder, db_->WorkerId());
slk::Save(res, res_builder, db_->WorkerId());
});
coordination->Register<PrepareForExpandRpc>([this](const auto &req_reader,
coordination->Register<PrepareForExpandRpc>([this](auto *req_reader,
auto *res_builder) {
PrepareForExpandReq req;
auto subcursor_id = req_reader.getSubcursorId();
auto *subcursor = subcursor_storage_->Get(subcursor_id);
Load(&req, req_reader, subcursor->db_accessor(), &db_->data_manager());
slk::Load(&req, req_reader, subcursor_storage_, &db_->data_manager());
auto *subcursor = subcursor_storage_->Get(req.subcursor_id);
subcursor->PrepareForExpand(req.clear, std::move(req.frame));
PrepareForExpandRes res;
Save(res, res_builder);
slk::Save(res, res_builder);
});
}

View File

@ -13,13 +13,13 @@ ClusterDiscoveryMaster::ClusterDiscoveryMaster(
MasterCoordination *coordination, const std::string &durability_directory)
: coordination_(coordination), durability_directory_(durability_directory) {
coordination_->Register<RegisterWorkerRpc>([this](const auto &endpoint,
const auto &req_reader,
auto *req_reader,
auto *res_builder) {
bool registration_successful = false;
bool durability_error = false;
RegisterWorkerReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
// Compose the worker's endpoint from its connecting address and its
// advertised port.
@ -70,15 +70,17 @@ ClusterDiscoveryMaster::ClusterDiscoveryMaster(
RegisterWorkerRes res(registration_successful, durability_error,
coordination_->RecoveredSnapshotTx(),
coordination_->GetWorkers());
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination_->Register<NotifyWorkerRecoveredRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
NotifyWorkerRecoveredReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
coordination_->WorkerRecoveredSnapshot(req.worker_id,
req.recovery_info);
NotifyWorkerRecoveredRes res;
slk::Save(res, res_builder);
});
}

View File

@ -12,10 +12,12 @@ ClusterDiscoveryWorker::ClusterDiscoveryWorker(WorkerCoordination *coordination)
: coordination_(coordination),
client_pool_(coordination->GetClientPool(0)) {
coordination->Register<ClusterDiscoveryRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
ClusterDiscoveryReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
coordination_->RegisterWorker(req.worker_id, req.endpoint);
ClusterDiscoveryRes res;
slk::Save(res, res_builder);
});
}

View File

@ -73,18 +73,13 @@ class Coordination {
}
template <class TRequestResponse>
void Register(std::function<
void(const typename TRequestResponse::Request::Capnp::Reader &,
typename TRequestResponse::Response::Capnp::Builder *)>
callback) {
void Register(std::function<void(slk::Reader *, slk::Builder *)> callback) {
server_.Register<TRequestResponse>(callback);
}
template <class TRequestResponse>
void Register(std::function<
void(const io::network::Endpoint &,
const typename TRequestResponse::Request::Capnp::Reader &,
typename TRequestResponse::Response::Capnp::Builder *)>
void Register(std::function<void(const io::network::Endpoint &, slk::Reader *,
slk::Builder *)>
callback) {
server_.Register<TRequestResponse>(callback);
}

View File

@ -29,12 +29,12 @@ WorkerCoordination::WorkerCoordination(
: Coordination(worker_endpoint, worker_id, master_endpoint,
server_workers_count, client_workers_count) {
server_.Register<StopWorkerRpc>(
[&](const auto &req_reader, auto *res_builder) {
[&](auto *req_reader, auto *res_builder) {
LOG(INFO) << "The master initiated shutdown of this worker.";
Shutdown();
});
server_.Register<HeartbeatRpc>([&](const auto &req_reader,
server_.Register<HeartbeatRpc>([&](auto *req_reader,
auto *res_builder) {
std::lock_guard<std::mutex> guard(heartbeat_lock_);
last_heartbeat_time_ = std::chrono::steady_clock::now();

View File

@ -153,10 +153,11 @@ cpp<#
:slk-load
(lambda (member)
#>cpp
// slk::Load will read a bool which was explicity
// saved in :slk::save and based on that read record
// data
bool has_ptr;
slk::Load(&has_ptr, reader);
if (has_ptr) {
slk::Load(&self->edge_old_output, reader);
}
cpp<#))
(edge-new-input "const Edge *"
:capnp-type "Storage.Edge"
@ -191,10 +192,11 @@ cpp<#
:slk-load
(lambda (member)
#>cpp
// slk::Load will read a bool which was explicity
// saved in :slk::save and based on that read record
// data
slk::Load(&self->edge_new_output, reader);
bool has_ptr;
slk::Load(&has_ptr, reader);
if (has_ptr) {
slk::Load(&self->edge_old_output, reader);
}
cpp<#))
(worker-id :int64_t :dont-save t)
(edge-old-output "std::unique_ptr<Edge>" :initarg nil :dont-save t)

View File

@ -13,46 +13,50 @@ DataRpcServer::DataRpcServer(database::GraphDb *db,
distributed::Coordination *coordination)
: db_(db) {
coordination->Register<VertexRpc>(
[this](const auto &req_reader, auto *res_builder) {
auto dba = db_->Access(req_reader.getMember().getTxId());
auto vertex = dba->FindVertexRaw(req_reader.getMember().getGid());
[this](auto *req_reader, auto *res_builder) {
VertexReq req;
slk::Load(&req, req_reader);
auto dba = db_->Access(req.member.tx_id);
auto vertex = dba->FindVertexRaw(req.member.gid);
auto *old = vertex.GetOld();
auto *newr = vertex.GetNew() ? vertex.GetNew()->CloneData() : nullptr;
db_->updates_server().ApplyDeltasToRecord(
dba->transaction().id_, req_reader.getMember().getGid(),
req_reader.getMember().getFromWorkerId(), &old, &newr);
dba->transaction().id_, req.member.gid,
req.member.from_worker_id, &old, &newr);
VertexRes response(vertex.CypherId(), old, newr, db_->WorkerId());
Save(response, res_builder);
slk::Save(response, res_builder);
delete newr;
});
coordination->Register<EdgeRpc>(
[this](const auto &req_reader, auto *res_builder) {
auto dba = db_->Access(req_reader.getMember().getTxId());
auto edge = dba->FindEdgeRaw(req_reader.getMember().getGid());
[this](auto *req_reader, auto *res_builder) {
EdgeReq req;
slk::Load(&req, req_reader);
auto dba = db_->Access(req.member.tx_id);
auto edge = dba->FindEdgeRaw(req.member.gid);
auto *old = edge.GetOld();
auto *newr = edge.GetNew() ? edge.GetNew()->CloneData() : nullptr;
db_->updates_server().ApplyDeltasToRecord(
dba->transaction().id_, req_reader.getMember().getGid(),
req_reader.getMember().getFromWorkerId(), &old, &newr);
dba->transaction().id_, req.member.gid,
req.member.from_worker_id, &old, &newr);
EdgeRes response(edge.CypherId(), old, newr, db_->WorkerId());
Save(response, res_builder);
slk::Save(response, res_builder);
delete newr;
});
coordination->Register<VertexCountRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
VertexCountReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto dba = db_->Access(req.member);
int64_t size = 0;
for (auto vertex : dba->Vertices(false)) ++size;
VertexCountRes res(size);
Save(res, res_builder);
slk::Save(res, res_builder);
});
}

View File

@ -10,17 +10,21 @@ DurabilityRpcWorker::DurabilityRpcWorker(
database::Worker *db, distributed::Coordination *coordination)
: db_(db) {
coordination->Register<MakeSnapshotRpc>(
[this](const auto &req_reader, auto *res_builder) {
auto dba = db_->Access(req_reader.getMember());
[this](auto *req_reader, auto *res_builder) {
MakeSnapshotReq req;
slk::Load(&req, req_reader);
auto dba = db_->Access(req.member);
MakeSnapshotRes res(db_->MakeSnapshot(*dba));
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<RecoverWalAndIndexesRpc>(
[this](const auto &req_reader, auto *res_builder) {
durability::RecoveryData recovery_data;
durability::Load(&recovery_data, req_reader.getMember());
this->db_->RecoverWalAndIndexes(&recovery_data);
[this](auto *req_reader, auto *res_builder) {
RecoverWalAndIndexesReq req;
slk::Load(&req, req_reader);
this->db_->RecoverWalAndIndexes(&req.member);
RecoverWalAndIndexesRes res;
slk::Save(res, res_builder);
});
}

View File

@ -9,11 +9,11 @@ DynamicWorkerAddition::DynamicWorkerAddition(database::GraphDb *db,
distributed::Coordination *coordination)
: db_(db), coordination_(coordination) {
coordination_->Register<DynamicWorkerRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
DynamicWorkerReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
DynamicWorkerRes res(this->GetIndicesToCreate());
Save(res, res_builder);
slk::Save(res, res_builder);
});
}

View File

@ -10,21 +10,25 @@ IndexRpcServer::IndexRpcServer(database::GraphDb *db,
distributed::Coordination *coordination)
: db_(db) {
coordination->Register<CreateIndexRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
CreateIndexReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
database::LabelPropertyIndex::Key key{req.label, req.property};
db_->storage().label_property_index_.CreateIndex(key);
CreateIndexRes res;
slk::Save(res, res_builder);
});
coordination->Register<PopulateIndexRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
PopulateIndexReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
database::LabelPropertyIndex::Key key{req.label, req.property};
auto dba = db_->Access(req.tx_id);
dba->PopulateIndex(key);
dba->EnableIndex(key);
PopulateIndexRes res;
slk::Save(res, res_builder);
});
}

View File

@ -4,19 +4,23 @@ namespace distributed {
PlanConsumer::PlanConsumer(distributed::Coordination *coordination) {
coordination->Register<DispatchPlanRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
DispatchPlanReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
plan_cache_.access().insert(
req.plan_id, std::make_unique<PlanPack>(req.plan, req.symbol_table,
std::move(req.storage)));
DispatchPlanRes res;
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<RemovePlanRpc>(
[this](const auto &req_reader, auto *res_builder) {
plan_cache_.access().remove(req_reader.getMember());
[this](auto *req_reader, auto *res_builder) {
RemovePlanReq req;
slk::Load(&req, req_reader);
plan_cache_.access().remove(req.member);
RemovePlanRes res;
slk::Save(res, res_builder);
});
}

View File

@ -111,32 +111,32 @@ ProduceRpcServer::ProduceRpcServer(database::Worker *db,
plan_consumer_(plan_consumer),
tx_engine_(tx_engine) {
coordination->Register<PullRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
PullReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
PullRes res(Pull(req));
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<ResetCursorRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
ResetCursorReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
Reset(req);
ResetCursorRes res;
Save(res, res_builder);
slk::Save(res, res_builder);
});
CHECK(data_manager);
coordination->Register<TransactionCommandAdvancedRpc>(
[this, data_manager](const auto &req_reader, auto *res_builder) {
[this, data_manager](auto *req_reader, auto *res_builder) {
TransactionCommandAdvancedReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
tx_engine_->UpdateCommand(req.member);
data_manager->ClearCacheForSingleTransaction(req.member);
TransactionCommandAdvancedRes res;
Save(res, res_builder);
slk::Save(res, res_builder);
});
}

View File

@ -17,9 +17,9 @@ utils::Future<PullData> PullRpcClients::Pull(
worker_id, [data_manager = data_manager_, dba, plan_id, command_id,
evaluation_context, symbols, accumulate,
batch_size](int worker_id, ClientPool &client_pool) {
auto load_pull_res = [data_manager, dba](const auto &res_reader) {
auto load_pull_res = [data_manager, dba](auto *res_reader) {
PullRes res;
Load(&res, res_reader, dba, data_manager);
slk::Load(&res, res_reader, dba, data_manager);
return res;
};
auto result = client_pool.CallWithLoad<PullRpc>(

View File

@ -27,7 +27,7 @@ class TokenSharingRpcServer {
distributed::Coordination *coordination)
: worker_id_(worker_id), coordination_(coordination), dgp_(db) {
coordination_->Register<distributed::TokenTransferRpc>(
[this](const auto &req_reader, auto *res_builder) { token_ = true; });
[this](auto *req_reader, auto *res_builder) { token_ = true; });
// TODO (buda): It's not trivial to move this part in the Start method
// because worker then doesn't run the step. Will resolve that with
// a different implementation of the token assignment.

View File

@ -250,10 +250,10 @@ void UpdatesRpcServer::TransactionUpdates<TRecordAccessor>::ApplyDeltasToRecord(
UpdatesRpcServer::UpdatesRpcServer(database::GraphDb *db,
distributed::Coordination *coordination)
: db_(db) {
coordination->Register<UpdateRpc>([this](const auto &req_reader,
coordination->Register<UpdateRpc>([this](auto *req_reader,
auto *res_builder) {
UpdateReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
using DeltaType = database::StateDelta::Type;
auto &delta = req.member;
switch (delta.type) {
@ -264,13 +264,13 @@ UpdatesRpcServer::UpdatesRpcServer(database::GraphDb *db,
case database::StateDelta::Type::REMOVE_IN_EDGE: {
UpdateRes res(GetUpdates(vertex_updates_, delta.transaction_id)
.Emplace(delta, req.worker_id));
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
case DeltaType::SET_PROPERTY_EDGE: {
UpdateRes res(GetUpdates(edge_updates_, delta.transaction_id)
.Emplace(delta, req.worker_id));
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
default:
@ -280,29 +280,29 @@ UpdatesRpcServer::UpdatesRpcServer(database::GraphDb *db,
});
coordination->Register<UpdateApplyRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
UpdateApplyReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
UpdateApplyRes res(Apply(req.member));
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<CreateVertexRpc>([this](const auto &req_reader,
coordination->Register<CreateVertexRpc>([this](auto *req_reader,
auto *res_builder) {
CreateVertexReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto result = GetUpdates(vertex_updates_, req.member.tx_id)
.CreateVertex(req.member.labels, req.member.properties,
req.member.cypher_id);
CreateVertexRes res(
CreateResult{UpdateResult::DONE, result.cypher_id, result.gid});
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<CreateEdgeRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
CreateEdgeReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto data = req.member;
auto creation_result = CreateEdge(data);
@ -319,53 +319,53 @@ UpdatesRpcServer::UpdatesRpcServer(database::GraphDb *db,
}
CreateEdgeRes res(creation_result);
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<AddInEdgeRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
AddInEdgeReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto to_delta = database::StateDelta::AddInEdge(
req.member.tx_id, req.member.to, req.member.from,
req.member.edge_address, req.member.edge_type);
auto result = GetUpdates(vertex_updates_, req.member.tx_id)
.Emplace(to_delta, req.member.worker_id);
AddInEdgeRes res(result);
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<RemoveVertexRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
RemoveVertexReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto to_delta = database::StateDelta::RemoveVertex(
req.member.tx_id, req.member.gid, req.member.check_empty);
auto result = GetUpdates(vertex_updates_, req.member.tx_id)
.Emplace(to_delta, req.member.worker_id);
RemoveVertexRes res(result);
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<RemoveEdgeRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
RemoveEdgeReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
RemoveEdgeRes res(RemoveEdge(req.member));
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination->Register<RemoveInEdgeRpc>([this](const auto &req_reader,
coordination->Register<RemoveInEdgeRpc>([this](auto *req_reader,
auto *res_builder) {
RemoveInEdgeReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
auto data = req.member;
RemoveInEdgeRes res(
GetUpdates(vertex_updates_, data.tx_id)
.Emplace(database::StateDelta::RemoveInEdge(data.tx_id, data.vertex,
data.edge_address),
data.worker_id));
Save(res, res_builder);
slk::Save(res, res_builder);
});
}

View File

@ -94,19 +94,14 @@ class Coordination final {
/// Registers a RPC call on this node.
template <class TRequestResponse>
void Register(std::function<
void(const typename TRequestResponse::Request::Capnp::Reader &,
typename TRequestResponse::Response::Capnp::Builder *)>
callback) {
void Register(std::function<void(slk::Reader *, slk::Builder *)> callback) {
server_.Register<TRequestResponse>(callback);
}
/// Registers an extended RPC call on this node.
template <class TRequestResponse>
void Register(std::function<
void(const io::network::Endpoint &,
const typename TRequestResponse::Request::Capnp::Reader &,
typename TRequestResponse::Response::Capnp::Builder *)>
void Register(std::function<void(const io::network::Endpoint &, slk::Reader *,
slk::Builder *)>
callback) {
server_.Register<TRequestResponse>(callback);
}

View File

@ -52,41 +52,7 @@ cpp<#
((leader-id :uint16_t)
(term :uint64_t)
(snapshot-metadata "raft::SnapshotMetadata" :capnp-type "Snap.SnapshotMetadata")
(data "std::unique_ptr<char[]>"
:initarg :move
:capnp-type "Data"
:capnp-init nil
:capnp-save (lambda (builder member capnp-name)
#>cpp
auto data_builder = ${builder}->initData(self.size);
memcpy(data_builder.begin(), ${member}.get(), self.size);
cpp<#)
:slk-save (lambda (member)
#>cpp
slk::Save(self.size, builder);
for (uint32_t i = 0; i < self.size; ++i) {
slk::Save(self.data[i], builder);
}
cpp<#)
:capnp-load (lambda (reader member capnp-name)
(declare (ignore capnp-name))
#>cpp
auto data_reader = ${reader}.getData();
self->size = data_reader.size();
${member}.reset(new char[self->size]);
memcpy(${member}.get(), data_reader.begin(), self->size);
cpp<#)
:slk-load (lambda (member)
#>cpp
slk::Load(&self->size, reader);
self->data.reset(new char[self->size]);
for (uint32_t i = 0; i < self->size; ++i) {
uint8_t curr;
slk::Load(&curr, reader);
self->data[i] = curr;
}
cpp<#))
(size :uint32_t :dont-save t)))
(data "std::string")))
(:response
((term :uint64_t))))

View File

@ -83,17 +83,17 @@ void RaftServer::Start() {
// RPC registration
coordination_->Register<RequestVoteRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
std::lock_guard<std::mutex> guard(lock_);
RequestVoteReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
// [Raft paper 5.1]
// "If a server recieves a request with a stale term,
// it rejects the request"
if (exiting_ || req.term < current_term_) {
RequestVoteRes res(false, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
@ -118,21 +118,21 @@ void RaftServer::Start() {
last_entry_data.first, last_entry_data.second);
RequestVoteRes res(grant_vote, current_term_);
if (grant_vote) SetNextElectionTimePoint();
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination_->Register<AppendEntriesRpc>([this](const auto &req_reader,
coordination_->Register<AppendEntriesRpc>([this](auto *req_reader,
auto *res_builder) {
std::lock_guard<std::mutex> guard(lock_);
AppendEntriesReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
// [Raft paper 5.1]
// "If a server receives a request with a stale term, it rejects the
// request"
if (exiting_ || req.term < current_term_) {
AppendEntriesRes res(false, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
@ -172,7 +172,7 @@ void RaftServer::Start() {
snapshot_metadata->last_included_index == req.prev_log_index) {
if (req.prev_log_term != snapshot_metadata->last_included_term) {
AppendEntriesRes res(false, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
} else if (snapshot_metadata &&
@ -180,13 +180,13 @@ void RaftServer::Start() {
LOG(ERROR) << "Received entries that are already commited and have been "
"compacted";
AppendEntriesRes res(false, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
return;
} else {
if (log_size_ <= req.prev_log_index ||
GetLogEntry(req.prev_log_index).term != req.prev_log_term) {
AppendEntriesRes res(false, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
}
@ -204,24 +204,24 @@ void RaftServer::Start() {
// Respond positively to a heartbeat.
if (req.entries.empty()) {
AppendEntriesRes res(true, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
if (mode_ != Mode::FOLLOWER) Transition(Mode::FOLLOWER);
return;
}
AppendEntriesRes res(true, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination_->Register<HeartbeatRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
std::lock_guard<std::mutex> guard(lock_);
HeartbeatReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
if (exiting_ || req.term < current_term_) {
HeartbeatRes res(false, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
@ -234,21 +234,21 @@ void RaftServer::Start() {
election_change_.notify_all();
HeartbeatRes res(true, current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination_->Register<InstallSnapshotRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
// Acquire snapshot lock.
std::lock_guard<std::mutex> snapshot_guard(snapshot_lock_);
std::lock_guard<std::mutex> guard(lock_);
InstallSnapshotReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
if (exiting_ || req.term < current_term_) {
InstallSnapshotRes res(current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
@ -256,7 +256,7 @@ void RaftServer::Start() {
if (req.snapshot_metadata.last_included_index == last_applied_ &&
req.snapshot_metadata.last_included_term == current_term_) {
InstallSnapshotRes res(current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
return;
}
@ -288,7 +288,7 @@ void RaftServer::Start() {
VLOG(40) << "[InstallSnapshotRpc] Saving received snapshot.";
std::ofstream output_stream;
output_stream.open(snapshot_path, std::ios::out | std::ios::binary);
output_stream.write(req.data.get(), req.size);
output_stream.write(req.data.data(), req.data.size());
output_stream.flush();
output_stream.close();
}
@ -317,7 +317,7 @@ void RaftServer::Start() {
SetLogSize(req.snapshot_metadata.last_included_index + 1);
InstallSnapshotRes res(current_term_);
Save(res, res_builder);
slk::Save(res, res_builder);
});
// start threads
@ -797,8 +797,12 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
const SnapshotMetadata &snapshot_metadata,
std::unique_lock<std::mutex> *lock) {
uint64_t request_term = current_term_;
uint32_t snapshot_size = 0;
std::unique_ptr<char[]> snapshot;
std::string snapshot_data;
// TODO: The snapshot is currently sent all at once. Because the snapshot file
// can be extremely large (>100GB, it contains the whole database) it must be
// sent out in chunks! Reimplement this logic so that it sends out the
// snapshot in chunks.
{
const auto snapshot_path = durability::MakeSnapshotPath(
@ -807,19 +811,19 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
std::ifstream input_stream;
input_stream.open(snapshot_path, std::ios::in | std::ios::binary);
input_stream.seekg(0, std::ios::end);
snapshot_size = input_stream.tellg();
uint64_t snapshot_size = input_stream.tellg();
snapshot.reset(new char[snapshot_size]);
snapshot_data = std::string(snapshot_size, '\0');
input_stream.seekg(0, std::ios::beg);
input_stream.read(snapshot.get(), snapshot_size);
input_stream.read(snapshot_data.data(), snapshot_size);
input_stream.close();
}
VLOG(40) << "Server " << server_id_
<< ": Sending Snapshot RPC to server " << peer_id
<< " (Term: " << current_term_ << ")";
VLOG(40) << "Snapshot size: " << snapshot_size << " bytes.";
VLOG(40) << "Snapshot size: " << snapshot_data.size() << " bytes.";
// Copy all internal variables before releasing the lock.
auto server_id = server_id_;
@ -827,8 +831,8 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
// Execute the RPC.
lock->unlock();
auto reply = coordination_->ExecuteOnOtherNode<InstallSnapshotRpc>(
peer_id, server_id, request_term, snapshot_metadata, std::move(snapshot),
snapshot_size);
peer_id, server_id, request_term, snapshot_metadata,
std::move(snapshot_data));
lock->lock();
if (!reply) {

View File

@ -25,12 +25,12 @@ StorageInfo::~StorageInfo() {}
void StorageInfo::Start() {
coordination_->Register<StorageInfoRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
StorageInfoReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
StorageInfoRes res(this->server_id_, this->GetLocalStorageInfo());
Save(res, res_builder);
slk::Save(res, res_builder);
});
}

View File

@ -97,6 +97,7 @@ void Load(
}
MAKE_PRIMITIVE_SAVE(bool)
MAKE_PRIMITIVE_SAVE(char)
MAKE_PRIMITIVE_SAVE(int8_t)
MAKE_PRIMITIVE_SAVE(uint8_t)
MAKE_PRIMITIVE_SAVE(int16_t)
@ -116,6 +117,7 @@ MAKE_PRIMITIVE_SAVE(double)
}
MAKE_PRIMITIVE_LOAD(bool)
MAKE_PRIMITIVE_LOAD(char)
MAKE_PRIMITIVE_LOAD(int8_t)
MAKE_PRIMITIVE_LOAD(uint8_t)
MAKE_PRIMITIVE_LOAD(int16_t)

View File

@ -16,18 +16,18 @@ void RegisterRpc(MasterConcurrentIdMapper<TId> *mapper,
void RegisterRpc<type>(MasterConcurrentIdMapper<type> * mapper, \
distributed::Coordination * coordination) { \
coordination->Register<type##IdRpc>( \
[mapper](const auto &req_reader, auto *res_builder) { \
[mapper](auto *req_reader, auto *res_builder) { \
type##IdReq req; \
Load(&req, req_reader); \
slk::Load(&req, req_reader); \
type##IdRes res(mapper->value_to_id(req.member)); \
Save(res, res_builder); \
slk::Save(res, res_builder); \
}); \
coordination->Register<Id##type##Rpc>( \
[mapper](const auto &req_reader, auto *res_builder) { \
[mapper](auto *req_reader, auto *res_builder) { \
Id##type##Req req; \
Load(&req, req_reader); \
slk::Load(&req, req_reader); \
Id##type##Res res(mapper->id_to_value(req.member)); \
Save(res, res_builder); \
slk::Save(res, res_builder); \
}); \
}

View File

@ -357,6 +357,7 @@ template <class TRecordAccessor>
void SaveRecordAccessor(const TRecordAccessor &accessor, slk::Builder *builder,
storage::SendVersions versions, int16_t worker_id) {
bool reconstructed = false;
auto guard = storage::GetDataLock(accessor);
if (!accessor.GetOld() && !accessor.GetNew()) {
reconstructed = true;
bool result = accessor.Reconstruct();

View File

@ -26,11 +26,13 @@ class StorageGcMaster final : public StorageGcDistributed {
: StorageGcDistributed(storage, tx_engine, pause_sec),
coordination_(coordination) {
coordination_->Register<distributed::RanLocalGcRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
distributed::RanLocalGcReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
std::unique_lock<std::mutex> lock(worker_safe_transaction_mutex_);
worker_safe_transaction_[req.worker_id] = req.local_oldest_active;
distributed::RanLocalGcRes res;
slk::Save(res, res_builder);
});
}

View File

@ -13,72 +13,99 @@ EngineMaster::EngineMaster(distributed::Coordination *coordination,
durability::WriteAheadLog *wal)
: engine_single_node_(wal), coordination_(coordination) {
coordination_->Register<BeginRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
BeginReq req;
slk::Load(&req, req_reader);
auto tx = this->Begin();
BeginRes res(TxAndSnapshot{tx->id_, tx->snapshot()});
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination_->Register<AdvanceRpc>(
[this](const auto &req_reader, auto *res_builder) {
AdvanceRes res(this->Advance(req_reader.getMember()));
Save(res, res_builder);
[this](auto *req_reader, auto *res_builder) {
AdvanceReq req;
slk::Load(&req, req_reader);
AdvanceRes res(this->Advance(req.member));
slk::Save(res, res_builder);
});
coordination_->Register<CommitRpc>(
[this](const auto &req_reader, auto *res_builder) {
this->Commit(*this->RunningTransaction(req_reader.getMember()));
[this](auto *req_reader, auto *res_builder) {
CommitReq req;
slk::Load(&req, req_reader);
this->Commit(*this->RunningTransaction(req.member));
CommitRes res;
slk::Save(res, res_builder);
});
coordination_->Register<AbortRpc>(
[this](const auto &req_reader, auto *res_builder) {
this->Abort(*this->RunningTransaction(req_reader.getMember()));
[this](auto *req_reader, auto *res_builder) {
AbortReq req;
slk::Load(&req, req_reader);
this->Abort(*this->RunningTransaction(req.member));
AbortRes res;
slk::Save(res, res_builder);
});
coordination_->Register<SnapshotRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
// It is guaranteed that the Worker will not be requesting this for a
// transaction that's done, and that there are no race conditions here.
SnapshotRes res(
this->RunningTransaction(req_reader.getMember())->snapshot());
Save(res, res_builder);
SnapshotReq req;
slk::Load(&req, req_reader);
SnapshotRes res(this->RunningTransaction(req.member)->snapshot());
slk::Save(res, res_builder);
});
coordination_->Register<CommandRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
// It is guaranteed that the Worker will not be requesting this for a
// transaction that's done, and that there are no race conditions here.
CommandRes res(this->RunningTransaction(req_reader.getMember())->cid());
Save(res, res_builder);
CommandReq req;
slk::Load(&req, req_reader);
CommandRes res(this->RunningTransaction(req.member)->cid());
slk::Save(res, res_builder);
});
coordination_->Register<GcSnapshotRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
GcSnapshotReq req;
slk::Load(&req, req_reader);
GcSnapshotRes res(this->GlobalGcSnapshot());
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination_->Register<ClogInfoRpc>(
[this](const auto &req_reader, auto *res_builder) {
ClogInfoRes res(this->Info(req_reader.getMember()));
Save(res, res_builder);
[this](auto *req_reader, auto *res_builder) {
ClogInfoReq req;
slk::Load(&req, req_reader);
ClogInfoRes res(this->Info(req.member));
slk::Save(res, res_builder);
});
coordination_->Register<ActiveTransactionsRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
ActiveTransactionsReq req;
slk::Load(&req, req_reader);
ActiveTransactionsRes res(this->GlobalActiveTransactions());
Save(res, res_builder);
slk::Save(res, res_builder);
});
coordination_->Register<EnsureNextIdGreaterRpc>(
[this](const auto &req_reader, auto *res_builder) {
this->EnsureNextIdGreater(req_reader.getMember());
[this](auto *req_reader, auto *res_builder) {
EnsureNextIdGreaterReq req;
slk::Load(&req, req_reader);
this->EnsureNextIdGreater(req.member);
EnsureNextIdGreaterRes res;
slk::Save(res, res_builder);
});
coordination_->Register<GlobalLastRpc>(
[this](const auto &req_reader, auto *res_builder) {
[this](auto *req_reader, auto *res_builder) {
GlobalLastReq req;
slk::Load(&req, req_reader);
GlobalLastRes res(this->GlobalLast());
Save(res, res_builder);
slk::Save(res, res_builder);
});
}

View File

@ -29,11 +29,14 @@ EngineWorker::EngineWorker(distributed::Coordination *coordination,
// aborted. This mismatch in committed/aborted across workers is resolved by
// using the master as a single source of truth when doing recovery.
coordination_->Register<NotifyCommittedRpc>(
[this](const auto &req_reader, auto *res_builder) {
auto tid = req_reader.getMember();
[this](auto *req_reader, auto *res_builder) {
NotifyCommittedReq req;
slk::Load(&req, req_reader);
if (wal_) {
wal_->Emplace(database::StateDelta::TxCommit(tid));
wal_->Emplace(database::StateDelta::TxCommit(req.member));
}
NotifyCommittedRes res;
slk::Save(res, res_builder);
});
}

View File

@ -9,26 +9,33 @@
#include "communication/rpc/client_pool.hpp"
#include "communication/rpc/messages.hpp"
#include "communication/rpc/server.hpp"
#include "slk/serialization.hpp"
#include "utils/timer.hpp"
struct EchoMessage {
using Capnp = ::capnp::AnyPointer;
static const utils::TypeInfo kType;
EchoMessage() {} // Needed for serialization.
EchoMessage(const std::string &data) : data(data) {}
static void Load(EchoMessage *obj, slk::Reader *reader);
static void Save(const EchoMessage &obj, slk::Builder *builder);
std::string data;
};
void Save(const EchoMessage &echo, ::capnp::AnyPointer::Builder *builder) {
auto list_builder = builder->initAs<::capnp::List<::capnp::Text>>(1);
list_builder.set(0, echo.data);
namespace slk {
void Save(const EchoMessage &echo, Builder *builder) {
Save(echo.data, builder);
}
void Load(EchoMessage *echo, Reader *reader) { Load(&echo->data, reader); }
} // namespace slk
void Load(EchoMessage *echo, const ::capnp::AnyPointer::Reader &reader) {
auto list_reader = reader.getAs<::capnp::List<::capnp::Text>>();
echo->data = list_reader[0];
void EchoMessage::Load(EchoMessage *obj, slk::Reader *reader) {
slk::Load(obj, reader);
}
void EchoMessage::Save(const EchoMessage &obj, slk::Builder *builder) {
slk::Save(obj, builder);
}
const utils::TypeInfo EchoMessage::kType{2, "EchoMessage"};

View File

@ -10,88 +10,57 @@
#include "communication/rpc/server.hpp"
#include "utils/timer.hpp"
#include "rpc_messages.hpp"
using namespace communication::rpc;
using namespace std::literals::chrono_literals;
struct SumReq {
using Capnp = ::capnp::AnyPointer;
static const utils::TypeInfo kType;
SumReq() {} // Needed for serialization.
SumReq(int x, int y) : x(x), y(y) {}
int x;
int y;
};
void Save(const SumReq &sum, ::capnp::AnyPointer::Builder *builder) {
auto list_builder = builder->initAs<::capnp::List<int>>(2);
list_builder.set(0, sum.x);
list_builder.set(1, sum.y);
namespace slk {
void Save(const SumReq &sum, Builder *builder) {
Save(sum.x, builder);
Save(sum.y, builder);
}
void Load(SumReq *sum, const ::capnp::AnyPointer::Reader &reader) {
auto list_reader = reader.getAs<::capnp::List<int>>();
sum->x = list_reader[0];
sum->y = list_reader[1];
void Load(SumReq *sum, Reader *reader) {
Load(&sum->x, reader);
Load(&sum->y, reader);
}
const utils::TypeInfo SumReq::kType{0, "SumReq"};
void Save(const SumRes &res, Builder *builder) { Save(res.sum, builder); }
struct SumRes {
using Capnp = ::capnp::AnyPointer;
static const utils::TypeInfo kType;
void Load(SumRes *res, Reader *reader) { Load(&res->sum, reader); }
SumRes() {} // Needed for serialization.
SumRes(int sum) : sum(sum) {}
int sum;
};
void Save(const SumRes &res, ::capnp::AnyPointer::Builder *builder) {
auto list_builder = builder->initAs<::capnp::List<int>>(1);
list_builder.set(0, res.sum);
void Save(const EchoMessage &echo, Builder *builder) {
Save(echo.data, builder);
}
void Load(SumRes *res, const ::capnp::AnyPointer::Reader &reader) {
auto list_reader = reader.getAs<::capnp::List<int>>();
res->sum = list_reader[0];
void Load(EchoMessage *echo, Reader *reader) { Load(&echo->data, reader); }
} // namespace slk
void SumReq::Load(SumReq *obj, slk::Reader *reader) { slk::Load(obj, reader); }
void SumReq::Save(const SumReq &obj, slk::Builder *builder) {
slk::Save(obj, builder);
}
const utils::TypeInfo SumRes::kType{1, "SumRes"};
using Sum = RequestResponse<SumReq, SumRes>;
struct EchoMessage {
using Capnp = ::capnp::AnyPointer;
static const utils::TypeInfo kType;
EchoMessage() {} // Needed for serialization.
EchoMessage(const std::string &data) : data(data) {}
std::string data;
};
void Save(const EchoMessage &echo, ::capnp::AnyPointer::Builder *builder) {
auto list_builder = builder->initAs<::capnp::List<::capnp::Text>>(1);
list_builder.set(0, echo.data);
void SumRes::Load(SumRes *obj, slk::Reader *reader) { slk::Load(obj, reader); }
void SumRes::Save(const SumRes &obj, slk::Builder *builder) {
slk::Save(obj, builder);
}
void Load(EchoMessage *echo, const ::capnp::AnyPointer::Reader &reader) {
auto list_reader = reader.getAs<::capnp::List<::capnp::Text>>();
echo->data = list_reader[0];
void EchoMessage::Load(EchoMessage *obj, slk::Reader *reader) {
slk::Load(obj, reader);
}
void EchoMessage::Save(const EchoMessage &obj, slk::Builder *builder) {
slk::Save(obj, builder);
}
const utils::TypeInfo EchoMessage::kType{2, "EchoMessage"};
using Echo = RequestResponse<EchoMessage, EchoMessage>;
TEST(Rpc, Call) {
Server server({"127.0.0.1", 0});
server.Register<Sum>([](const auto &req_reader, auto *res_builder) {
server.Register<Sum>([](auto *req_reader, auto *res_builder) {
SumReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
SumRes res(req.x + req.y);
Save(res, res_builder);
slk::Save(res, res_builder);
});
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
@ -106,12 +75,12 @@ TEST(Rpc, Call) {
TEST(Rpc, Abort) {
Server server({"127.0.0.1", 0});
server.Register<Sum>([](const auto &req_reader, auto *res_builder) {
server.Register<Sum>([](auto *req_reader, auto *res_builder) {
SumReq req;
Load(&req, req_reader);
slk::Load(&req, req_reader);
std::this_thread::sleep_for(500ms);
SumRes res(req.x + req.y);
Save(res, res_builder);
slk::Save(res, res_builder);
});
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
@ -149,8 +118,8 @@ TEST(Rpc, ClientPool) {
Client client(server.endpoint());
/* these calls should take more than 400ms because we're using a regular
* client */
// These calls should take more than 400ms because we're using a regular
// client
auto get_sum_client = [&client](int x, int y) {
auto sum = client.Call<Sum>(x, y);
EXPECT_EQ(sum.sum, x + y);
@ -170,8 +139,8 @@ TEST(Rpc, ClientPool) {
ClientPool pool(server.endpoint());
/* these calls shouldn't take much more that 100ms because they execute in
* parallel */
// These calls shouldn't take much more that 100ms because they execute in
// parallel
auto get_sum = [&pool](int x, int y) {
auto sum = pool.Call<Sum>(x, y);
EXPECT_EQ(sum.sum, x + y);
@ -192,10 +161,10 @@ TEST(Rpc, ClientPool) {
TEST(Rpc, LargeMessage) {
Server server({"127.0.0.1", 0});
server.Register<Echo>([](const auto &req_reader, auto *res_builder) {
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
EchoMessage res;
Load(&res, req_reader);
Save(res, res_builder);
slk::Load(&res, req_reader);
slk::Save(res, res_builder);
});
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
@ -209,3 +178,24 @@ TEST(Rpc, LargeMessage) {
server.Shutdown();
server.AwaitShutdown();
}
TEST(Rpc, JumboMessage) {
Server server({"127.0.0.1", 0});
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
EchoMessage res;
slk::Load(&res, req_reader);
slk::Save(res, res_builder);
});
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
// NOLINTNEXTLINE (bugprone-string-constructor)
std::string testdata(10000000, 'a');
Client client(server.endpoint());
auto echo = client.Call<Echo>(testdata);
EXPECT_EQ(echo.data, testdata);
server.Shutdown();
server.AwaitShutdown();
}

View File

@ -0,0 +1,65 @@
#pragma once
#include "communication/rpc/messages.hpp"
#include "slk/serialization.hpp"
#include "utils/typeinfo.hpp"
struct SumReq {
static const utils::TypeInfo kType;
SumReq() {} // Needed for serialization.
SumReq(int x, int y) : x(x), y(y) {}
static void Load(SumReq *obj, slk::Reader *reader);
static void Save(const SumReq &obj, slk::Builder *builder);
int x;
int y;
};
const utils::TypeInfo SumReq::kType{0, "SumReq"};
struct SumRes {
static const utils::TypeInfo kType;
SumRes() {} // Needed for serialization.
SumRes(int sum) : sum(sum) {}
static void Load(SumRes *obj, slk::Reader *reader);
static void Save(const SumRes &obj, slk::Builder *builder);
int sum;
};
const utils::TypeInfo SumRes::kType{1, "SumRes"};
namespace slk {
void Save(const SumReq &sum, Builder *builder);
void Load(SumReq *sum, Reader *reader);
void Save(const SumRes &res, Builder *builder);
void Load(SumRes *res, Reader *reader);
} // namespace slk
using Sum = communication::rpc::RequestResponse<SumReq, SumRes>;
struct EchoMessage {
static const utils::TypeInfo kType;
EchoMessage() {} // Needed for serialization.
EchoMessage(const std::string &data) : data(data) {}
static void Load(EchoMessage *obj, slk::Reader *reader);
static void Save(const EchoMessage &obj, slk::Builder *builder);
std::string data;
};
const utils::TypeInfo EchoMessage::kType{2, "EchoMessage"};
namespace slk {
void Save(const EchoMessage &echo, Builder *builder);
void Load(EchoMessage *echo, Reader *reader);
} // namespace slk
using Echo = communication::rpc::RequestResponse<EchoMessage, EchoMessage>;