Throw exceptions on RPC failure and Distributed error handling

Summary:
This diff changes the RPC layer to directly return `TResponse` to the user when
issuing a `Call<...>` RPC call. The call throws an exception on failure
(instead of the previous return `nullopt`).

All servers (network, RPC and distributed) are set to have explicit `Shutdown`
methods so that a controlled shutdown can always be performed. The object
destructors now have `CHECK`s to enforce that the `AwaitShutdown` methods were
called.

The distributed memgraph is changed that none of the binaries (master/workers)
crash when there is a communication failure. Instead, the whole cluster starts
a graceful shutdown when a persistent communication error is detected.
Transient errors are allowed during execution. The transaction that errored out
will be aborted on the whole cluster. The cluster state is managed using a new
Heartbeat RPC call.

Reviewers: buda, teon.banek, msantl

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1604
This commit is contained in:
Matej Ferencevic 2018-09-27 15:07:46 +02:00
parent 13529411db
commit 53c405c699
86 changed files with 1474 additions and 1012 deletions

View File

@ -32,7 +32,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
// Create a new SSL object that will be used for SSL communication.
ssl_ = SSL_new(context_->context());
if (ssl_ == nullptr) {
LOG(WARNING) << "Couldn't create client SSL object!";
LOG(ERROR) << "Couldn't create client SSL object!";
socket_.Close();
return false;
}
@ -43,7 +43,7 @@ bool Client::Connect(const io::network::Endpoint &endpoint) {
// handle that in our socket destructor).
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
if (bio_ == nullptr) {
LOG(WARNING) << "Couldn't create client BIO object!";
LOG(ERROR) << "Couldn't create client BIO object!";
socket_.Close();
return false;
}
@ -111,7 +111,7 @@ bool Client::Read(size_t len) {
continue;
} else {
// This is a fatal error.
LOG(WARNING) << "Received an unexpected SSL error: " << err;
LOG(ERROR) << "Received an unexpected SSL error: " << err;
return false;
}
} else if (got == 0) {

View File

@ -85,8 +85,13 @@ class Listener final {
}
~Listener() {
Shutdown();
AwaitShutdown();
bool worker_alive = false;
for (auto &thread : worker_threads_) {
if (thread.joinable()) worker_alive = true;
}
CHECK(!alive_ && !worker_alive && !timeout_thread_.joinable())
<< "You should call Shutdown and AwaitShutdown on "
"communication::Listener!";
}
Listener(const Listener &) = delete;

View File

@ -1,27 +1,17 @@
#include "communication/rpc/client.hpp"
#include <chrono>
#include <thread>
#include "gflags/gflags.h"
#include "communication/rpc/client.hpp"
DEFINE_HIDDEN_bool(rpc_random_latency, false,
"If a random wait should happen on each RPC call, to "
"simulate network latency.");
namespace communication::rpc {
Client::Client(const io::network::Endpoint &endpoint) : endpoint_(endpoint) {}
std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
::capnp::MessageBuilder *message) {
::capnp::FlatArrayMessageReader Client::Send(::capnp::MessageBuilder *message) {
std::lock_guard<std::mutex> guard(mutex_);
if (FLAGS_rpc_random_latency) {
auto microseconds = (int)(1000 * rand_(gen_));
std::this_thread::sleep_for(std::chrono::microseconds(microseconds));
}
// Check if the connection is broken (if we haven't used the client for a
// long time the server could have died).
if (client_ && client_->ErrorStatus()) {
@ -32,9 +22,9 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
if (!client_) {
client_.emplace(&context_);
if (!client_->Connect(endpoint_)) {
LOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
DLOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
client_ = std::experimental::nullopt;
return std::experimental::nullopt;
throw RpcFailedException(endpoint_);
}
}
@ -49,22 +39,22 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
MessageSize request_data_size = request_bytes.size();
if (!client_->Write(reinterpret_cast<uint8_t *>(&request_data_size),
sizeof(MessageSize), true)) {
LOG(ERROR) << "Couldn't send request size to " << client_->endpoint();
DLOG(ERROR) << "Couldn't send request size to " << client_->endpoint();
client_ = std::experimental::nullopt;
return std::experimental::nullopt;
throw RpcFailedException(endpoint_);
}
if (!client_->Write(request_bytes.begin(), request_bytes.size())) {
LOG(ERROR) << "Couldn't send request data to " << client_->endpoint();
DLOG(ERROR) << "Couldn't send request data to " << client_->endpoint();
client_ = std::experimental::nullopt;
return std::experimental::nullopt;
throw RpcFailedException(endpoint_);
}
// Receive response data size.
if (!client_->Read(sizeof(MessageSize))) {
LOG(ERROR) << "Couldn't get response from " << client_->endpoint();
DLOG(ERROR) << "Couldn't get response from " << client_->endpoint();
client_ = std::experimental::nullopt;
return std::experimental::nullopt;
throw RpcFailedException(endpoint_);
}
MessageSize response_data_size =
*reinterpret_cast<MessageSize *>(client_->GetData());
@ -72,9 +62,9 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
// Receive response data.
if (!client_->Read(response_data_size)) {
LOG(ERROR) << "Couldn't get response from " << client_->endpoint();
DLOG(ERROR) << "Couldn't get response from " << client_->endpoint();
client_ = std::experimental::nullopt;
return std::experimental::nullopt;
throw RpcFailedException(endpoint_);
}
// Read the response message.
@ -86,7 +76,7 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
reinterpret_cast<::capnp::word *>(data.end()));
::capnp::FlatArrayMessageReader response_message(data_words.asConst());
client_->ShiftData(response_data_size);
return std::experimental::make_optional(std::move(response_message));
return response_message;
}
void Client::Abort() {

View File

@ -3,13 +3,13 @@
#include <experimental/optional>
#include <memory>
#include <mutex>
#include <random>
#include <capnp/message.h>
#include <capnp/serialize.h>
#include <glog/logging.h>
#include "communication/client.hpp"
#include "communication/rpc/exceptions.hpp"
#include "communication/rpc/messages.capnp.h"
#include "communication/rpc/messages.hpp"
#include "io/network/endpoint.hpp"
@ -22,11 +22,17 @@ class Client {
public:
explicit Client(const io::network::Endpoint &endpoint);
/// Call function can initiate only one request at the time. Function blocks
/// until there is a response. If there was an error nullptr is returned.
/// Call a previously defined and registered RPC call. This function can
/// initiate only one request at a time. The call blocks until a response is
/// received.
///
/// @returns TRequestResponse::Response object that was specified to be
/// returned by the RPC call
/// @throws RpcFailedException if an error was occurred while executing the
/// RPC call (eg. connection failed, remote end
/// died, etc.)
template <class TRequestResponse, class... Args>
std::experimental::optional<typename TRequestResponse::Response> Call(
Args &&... args) {
typename TRequestResponse::Response Call(Args &&... args) {
return CallWithLoad<TRequestResponse>(
[](const auto &reader) {
typename TRequestResponse::Response response;
@ -38,7 +44,7 @@ class Client {
/// Same as `Call` but the first argument is a response loading function.
template <class TRequestResponse, class... Args>
std::experimental::optional<typename TRequestResponse::Response> CallWithLoad(
typename TRequestResponse::Response CallWithLoad(
std::function<typename TRequestResponse::Response(
const typename TRequestResponse::Response::Capnp::Reader &)>
load,
@ -56,18 +62,15 @@ class Client {
.template initAs<typename TRequestResponse::Request::Capnp>();
request.Save(&req_builder);
}
auto maybe_response = Send(&req_msg);
if (!maybe_response) {
return std::experimental::nullopt;
}
auto res_msg = maybe_response->getRoot<capnp::Message>();
auto response = Send(&req_msg);
auto res_msg = response.getRoot<capnp::Message>();
auto res_type = TRequestResponse::Response::TypeInfo;
if (res_msg.getTypeId() != res_type.id) {
// Since message_id was checked in private Call function, this means
// something is very wrong (probably on the server side).
LOG(ERROR) << "Message response was of unexpected type";
client_ = std::experimental::nullopt;
return std::experimental::nullopt;
throw RpcFailedException(endpoint_);
}
VLOG(12) << "[RpcClient] received " << res_type.name;
@ -75,15 +78,14 @@ class Client {
auto data_reader =
res_msg.getData()
.template getAs<typename TRequestResponse::Response::Capnp>();
return std::experimental::make_optional(load(data_reader));
return load(data_reader);
}
/// Call this function from another thread to abort a pending RPC call.
void Abort();
private:
std::experimental::optional<::capnp::FlatArrayMessageReader> Send(
::capnp::MessageBuilder *message);
::capnp::FlatArrayMessageReader Send(::capnp::MessageBuilder *message);
io::network::Endpoint endpoint_;
// TODO (mferencevic): currently the RPC client is hardcoded not to use SSL
@ -91,11 +93,6 @@ class Client {
std::experimental::optional<communication::Client> client_;
std::mutex mutex_;
// Random generator for simulated network latency (enable with a flag).
// Distribution parameters are rule-of-thumb chosen.
std::mt19937 gen_{std::random_device{}()};
std::lognormal_distribution<> rand_{0.0, 1.11};
};
} // namespace communication::rpc

View File

@ -18,8 +18,7 @@ class ClientPool {
: endpoint_(endpoint) {}
template <class TRequestResponse, class... Args>
std::experimental::optional<typename TRequestResponse::Response> Call(
Args &&... args) {
typename TRequestResponse::Response Call(Args &&... args) {
return WithUnusedClient([&](const auto &client) {
return client->template Call<TRequestResponse>(
std::forward<Args>(args)...);
@ -27,7 +26,7 @@ class ClientPool {
};
template <class TRequestResponse, class... Args>
std::experimental::optional<typename TRequestResponse::Response> CallWithLoad(
typename TRequestResponse::Response CallWithLoad(
std::function<typename TRequestResponse::Response(
const typename TRequestResponse::Response::Capnp::Reader &)>
load,

View File

@ -0,0 +1,25 @@
#include "io/network/endpoint.hpp"
#include "utils/exceptions.hpp"
namespace communication::rpc {
/// Exception that is thrown whenever a RPC call fails.
/// This exception inherits `std::exception` directly because
/// `utils::BasicException` is used for transient errors that should be reported
/// to the user and `utils::StacktraceException` is used for fatal errors.
/// This exception always requires explicit handling.
class RpcFailedException final : public utils::BasicException {
public:
RpcFailedException(const io::network::Endpoint &endpoint)
: utils::BasicException::BasicException(
"Couldn't communicate with the cluster! Please contact your "
"database administrator."),
endpoint_(endpoint) {}
/// Returns the endpoint associated with the error.
const io::network::Endpoint &endpoint() const { return endpoint_; }
private:
io::network::Endpoint endpoint_;
};
} // namespace communication::rpc

View File

@ -6,8 +6,11 @@ Server::Server(const io::network::Endpoint &endpoint,
size_t workers_count)
: server_(endpoint, this, &context_, -1, "RPC", workers_count) {}
void Server::StopProcessingCalls() {
void Server::Shutdown() {
server_.Shutdown();
}
void Server::AwaitShutdown() {
server_.AwaitShutdown();
}

View File

@ -25,7 +25,8 @@ class Server {
Server &operator=(const Server &) = delete;
Server &operator=(Server &&) = delete;
void StopProcessingCalls();
void Shutdown();
void AwaitShutdown();
const io::network::Endpoint &endpoint() const;

View File

@ -80,8 +80,9 @@ class Server final {
}
~Server() {
Shutdown();
AwaitShutdown();
CHECK(!alive_ && !thread_.joinable()) << "You should call Shutdown and "
"AwaitShutdown on "
"communication::Server!";
}
Server(const Server &) = delete;

View File

@ -60,9 +60,13 @@ DEFINE_VALIDATED_HIDDEN_int32(
"indicates the port on which to serve. If zero (default value), a port is "
"chosen at random. Sent to the master when registring worker node.",
FLAG_IN_RANGE(0, std::numeric_limits<uint16_t>::max()));
DEFINE_VALIDATED_HIDDEN_int32(rpc_num_workers,
DEFINE_VALIDATED_HIDDEN_int32(rpc_num_client_workers,
std::max(std::thread::hardware_concurrency(), 1U),
"Number of workers (RPC)",
"Number of client workers (RPC)",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_VALIDATED_HIDDEN_int32(rpc_num_server_workers,
std::max(std::thread::hardware_concurrency(), 1U),
"Number of server workers (RPC)",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_VALIDATED_int32(recovering_cluster_size, 0,
"Number of workers (including master) in the "
@ -94,7 +98,8 @@ database::Config::Config()
,
// Distributed flags.
dynamic_graph_partitioner_enabled{FLAGS_dynamic_graph_partitioner_enabled},
rpc_num_workers{FLAGS_rpc_num_workers},
rpc_num_client_workers{FLAGS_rpc_num_client_workers},
rpc_num_server_workers{FLAGS_rpc_num_server_workers},
worker_id{FLAGS_worker_id},
master_endpoint{FLAGS_master_host,
static_cast<uint16_t>(FLAGS_master_port)},

View File

@ -37,14 +37,11 @@ WorkerCounters::WorkerCounters(
: master_client_pool_(master_client_pool) {}
int64_t WorkerCounters::Get(const std::string &name) {
auto response = master_client_pool_->Call<CountersGetRpc>(name);
CHECK(response) << "CountersGetRpc failed";
return response->value;
return master_client_pool_->Call<CountersGetRpc>(name).value;
}
void WorkerCounters::Set(const std::string &name, int64_t value) {
auto response = master_client_pool_->Call<CountersSetRpc>(name, value);
CHECK(response) << "CountersSetRpc failed";
master_client_pool_->Call<CountersSetRpc>(name, value);
}
} // namespace database

View File

@ -15,6 +15,7 @@
#include "distributed/durability_rpc_master.hpp"
#include "distributed/durability_rpc_worker.hpp"
#include "distributed/dynamic_worker.hpp"
#include "distributed/index_rpc_messages.hpp"
#include "distributed/index_rpc_server.hpp"
#include "distributed/plan_consumer.hpp"
#include "distributed/plan_dispatcher.hpp"
@ -366,27 +367,27 @@ class DistributedAccessor : public GraphDbAccessor {
};
class MasterAccessor final : public DistributedAccessor {
distributed::IndexRpcClients *index_rpc_clients_{nullptr};
distributed::PullRpcClients *pull_clients_{nullptr};
distributed::Coordination *coordination_;
distributed::PullRpcClients *pull_clients_;
int worker_id_{0};
public:
MasterAccessor(Master *db, distributed::IndexRpcClients *index_rpc_clients,
MasterAccessor(Master *db, distributed::Coordination *coordination,
distributed::PullRpcClients *pull_clients_,
DistributedVertexAccessor *vertex_accessor,
DistributedEdgeAccessor *edge_accessor)
: DistributedAccessor(db, vertex_accessor, edge_accessor),
index_rpc_clients_(index_rpc_clients),
coordination_(coordination),
pull_clients_(pull_clients_),
worker_id_(db->WorkerId()) {}
MasterAccessor(Master *db, tx::TransactionId tx_id,
distributed::IndexRpcClients *index_rpc_clients,
distributed::Coordination *coordination,
distributed::PullRpcClients *pull_clients_,
DistributedVertexAccessor *vertex_accessor,
DistributedEdgeAccessor *edge_accessor)
: DistributedAccessor(db, tx_id, vertex_accessor, edge_accessor),
index_rpc_clients_(index_rpc_clients),
coordination_(coordination),
pull_clients_(pull_clients_),
worker_id_(db->WorkerId()) {}
@ -395,8 +396,17 @@ class MasterAccessor final : public DistributedAccessor {
index_rpc_completions;
// Notify all workers to create the index
index_rpc_completions.emplace(index_rpc_clients_->GetCreateIndexFutures(
key.label_, key.property_, worker_id_));
index_rpc_completions.emplace(coordination_->ExecuteOnWorkers<bool>(
worker_id_,
[&key](int worker_id, communication::rpc::ClientPool &client_pool) {
try {
client_pool.Call<distributed::CreateIndexRpc>(key.label_,
key.property_);
return true;
} catch (const communication::rpc::RpcFailedException &) {
return false;
}
}));
if (index_rpc_completions) {
// Wait first, check later - so that every thread finishes and none
@ -404,6 +414,8 @@ class MasterAccessor final : public DistributedAccessor {
// that we notify other workers to stop building indexes
for (auto &index_built : *index_rpc_completions) index_built.wait();
for (auto &index_built : *index_rpc_completions) {
// TODO: `get()` can throw an exception, should we delete the index when
// it throws?
if (!index_built.get()) {
db().storage().label_property_index().DeleteIndex(key);
throw IndexCreationOnWorkerException("Index exists on a worker");
@ -418,8 +430,17 @@ class MasterAccessor final : public DistributedAccessor {
// since they don't have to wait anymore
std::experimental::optional<std::vector<utils::Future<bool>>>
index_rpc_completions;
index_rpc_completions.emplace(index_rpc_clients_->GetPopulateIndexFutures(
key.label_, key.property_, transaction_id(), worker_id_));
index_rpc_completions.emplace(coordination_->ExecuteOnWorkers<bool>(
worker_id_, [this, &key](int worker_id,
communication::rpc::ClientPool &client_pool) {
try {
client_pool.Call<distributed::PopulateIndexRpc>(
key.label_, key.property_, transaction_id());
return true;
} catch (const communication::rpc::RpcFailedException &) {
return false;
}
}));
// Populate our own storage
GraphDbAccessor::PopulateIndexFromBuildIndex(key);
@ -432,6 +453,8 @@ class MasterAccessor final : public DistributedAccessor {
// that we notify other workers to stop building indexes
for (auto &index_built : *index_rpc_completions) index_built.wait();
for (auto &index_built : *index_rpc_completions) {
// TODO: `get()` can throw an exception, should we delete the index when
// it throws?
if (!index_built.get()) {
db().storage().label_property_index().DeleteIndex(key);
throw IndexCreationOnWorkerException("Index exists on a worker");
@ -440,11 +463,12 @@ class MasterAccessor final : public DistributedAccessor {
}
}
// TODO (mferencevic): Move this logic into the transaction engine.
void AdvanceCommand() override {
DistributedAccessor::AdvanceCommand();
auto tx_id = transaction_id();
auto futures = pull_clients_->NotifyAllTransactionCommandAdvanced(tx_id);
for (auto &future : futures) future.wait();
for (auto &future : futures) future.get();
}
};
@ -591,13 +615,14 @@ class Master {
// constructors of members.
database::Master *self_{nullptr};
communication::rpc::Server server_{
config_.master_endpoint, static_cast<size_t>(config_.rpc_num_workers)};
tx::EngineMaster tx_engine_{server_, rpc_worker_clients_, &wal_};
distributed::MasterCoordination coordination_{server_.endpoint()};
config_.master_endpoint,
static_cast<size_t>(config_.rpc_num_server_workers)};
tx::EngineMaster tx_engine_{&server_, &coordination_, &wal_};
distributed::MasterCoordination coordination_{server_.endpoint(),
config_.rpc_num_client_workers};
std::unique_ptr<StorageGcMaster> storage_gc_ =
std::make_unique<StorageGcMaster>(
*storage_, tx_engine_, config_.gc_cycle_sec, server_, coordination_);
distributed::RpcWorkerClients rpc_worker_clients_{coordination_};
TypemapPack<storage::MasterConcurrentIdMapper> typemap_pack_{server_};
database::MasterCounters counters_{&server_};
distributed::BfsSubcursorStorage subcursor_storage_{self_,
@ -605,25 +630,19 @@ class Master {
distributed::BfsRpcServer bfs_subcursor_server_{self_, &server_,
&subcursor_storage_};
distributed::BfsRpcClients bfs_subcursor_clients_{
self_, &subcursor_storage_, &rpc_worker_clients_, &data_manager_};
distributed::DurabilityRpcMaster durability_rpc_{rpc_worker_clients_};
self_, &subcursor_storage_, &coordination_, &data_manager_};
distributed::DurabilityRpcMaster durability_rpc_{&coordination_};
distributed::DataRpcServer data_server_{self_, &server_};
distributed::DataRpcClients data_clients_{rpc_worker_clients_};
distributed::PlanDispatcher plan_dispatcher_{rpc_worker_clients_};
distributed::PullRpcClients pull_clients_{&rpc_worker_clients_,
&data_manager_};
distributed::IndexRpcClients index_rpc_clients_{rpc_worker_clients_};
distributed::DataRpcClients data_clients_{&coordination_};
distributed::PlanDispatcher plan_dispatcher_{&coordination_};
distributed::PullRpcClients pull_clients_{&coordination_, &data_manager_};
distributed::UpdatesRpcServer updates_server_{self_, &server_};
distributed::UpdatesRpcClients updates_clients_{rpc_worker_clients_};
distributed::UpdatesRpcClients updates_clients_{&coordination_};
distributed::DataManager data_manager_{*self_, data_clients_};
distributed::ClusterDiscoveryMaster cluster_discovery_{
server_, coordination_, rpc_worker_clients_,
config_.durability_directory};
distributed::TokenSharingRpcClients token_sharing_clients_{
&rpc_worker_clients_};
&server_, &coordination_, config_.durability_directory};
distributed::TokenSharingRpcServer token_sharing_server_{
self_, config_.worker_id, &coordination_, &server_,
&token_sharing_clients_};
self_, config_.worker_id, &coordination_, &server_};
distributed::DynamicWorkerAddition dynamic_worker_addition_{self_, &server_};
};
@ -739,34 +758,17 @@ Master::Master(Config config)
}
}
Master::~Master() {
snapshot_creator_ = nullptr;
is_accepting_transactions_ = false;
impl_->tx_engine_.LocalForEachActiveTransaction(
[](auto &t) { t.set_should_abort(); });
// We are not a worker, so we can do a snapshot on exit if it's enabled. Doing
// this on the master forces workers to do the same through rpcs
if (impl_->config_.snapshot_on_exit) {
auto dba = Access();
MakeSnapshot(*dba);
}
// Transactional cache cleanup must be stopped before all of the objects
// that were registered for cleanup are destructed.
impl_->tx_engine_.StopTransactionalCacheCleanup();
}
Master::~Master() {}
std::unique_ptr<GraphDbAccessor> Master::Access() {
return std::make_unique<MasterAccessor>(
this, &impl_->index_rpc_clients_, &impl_->pull_clients_,
this, &impl_->coordination_, &impl_->pull_clients_,
&impl_->vertex_accessor_, &impl_->edge_accessor_);
}
std::unique_ptr<GraphDbAccessor> Master::Access(tx::TransactionId tx_id) {
return std::make_unique<MasterAccessor>(
this, tx_id, &impl_->index_rpc_clients_, &impl_->pull_clients_,
this, tx_id, &impl_->coordination_, &impl_->pull_clients_,
&impl_->vertex_accessor_, &impl_->edge_accessor_);
}
@ -823,6 +825,7 @@ bool Master::MakeSnapshot(GraphDbAccessor &accessor) {
void Master::ReinitializeStorage() {
// Release gc scheduler to stop it from touching storage
impl_->storage_gc_->Stop();
impl_->storage_gc_ = nullptr;
impl_->storage_ = std::make_unique<Storage>(
impl_->config_.worker_id, impl_->config_.properties_on_disk);
@ -839,6 +842,71 @@ io::network::Endpoint Master::GetEndpoint(int worker_id) {
return impl_->coordination_.GetEndpoint(worker_id);
}
bool Master::AwaitShutdown(std::function<void(void)> call_before_shutdown) {
bool ret =
impl_->coordination_.AwaitShutdown(
[this, &call_before_shutdown](bool is_cluster_alive) -> bool {
snapshot_creator_ = nullptr;
// Stop all running transactions. This will allow all shutdowns in
// the callback that depend on query execution to be aborted and
// cleaned up.
// TODO (mferencevic): When we have full cluster management
// (detection of failure and automatic failure recovery) this should
// this be done directly through the transaction engine (eg. using
// cluster degraded/operational hooks and callbacks).
is_accepting_transactions_ = false;
impl_->tx_engine_.LocalForEachActiveTransaction(
[](auto &t) { t.set_should_abort(); });
// Call the toplevel callback to stop everything that the caller
// wants us to stop.
call_before_shutdown();
// Now we stop everything that calls RPCs (garbage collection, etc.)
// Stop the storage garbage collector.
impl_->storage_gc_->Stop();
// Transactional cache cleanup must be stopped before all of the
// objects that were registered for cleanup are destructed.
impl_->tx_engine_.StopTransactionalCacheCleanup();
// We are not a worker, so we can do a snapshot on exit if it's
// enabled. Doing this on the master forces workers to do the same
// through RPCs. If the cluster is in a degraded state then don't
// attempt to do a snapshot because the snapshot can't be created on
// all workers. The cluster will have to recover from a previous
// snapshot and WALs.
if (impl_->config_.snapshot_on_exit) {
if (is_cluster_alive) {
auto dba = Access();
// Here we make the snapshot and return the snapshot creation
// success to the caller.
return MakeSnapshot(*dba);
} else {
LOG(WARNING)
<< "Because the cluster is in a degraded state we can't "
"create a snapshot. The cluster will be recovered from "
"previous snapshots and WALs.";
}
}
// The shutdown was completed successfully.
return true;
});
// We stop the RPC server to disable further requests.
// TODO (mferencevic): Move the RPC into coordination.
impl_->server_.Shutdown();
impl_->server_.AwaitShutdown();
// Return the shutdown success status.
return ret;
}
void Master::Shutdown() { return impl_->coordination_.Shutdown(); }
distributed::BfsRpcClients &Master::bfs_subcursor_clients() {
return impl_->bfs_subcursor_clients_;
}
@ -867,15 +935,10 @@ distributed::PlanDispatcher &Master::plan_dispatcher() {
return impl_->plan_dispatcher_;
}
distributed::IndexRpcClients &Master::index_rpc_clients() {
return impl_->index_rpc_clients_;
}
VertexAccessor InsertVertexIntoRemote(
GraphDbAccessor *dba, int worker_id,
const std::vector<storage::Label> &labels,
const std::unordered_map<storage::Property, PropertyValue>
&properties,
const std::unordered_map<storage::Property, PropertyValue> &properties,
std::experimental::optional<int64_t> cypher_id) {
// TODO: Replace this with virtual call or some other mechanism.
auto *distributed_db =
@ -930,44 +993,43 @@ class Worker {
// constructors of members.
database::Worker *self_{nullptr};
communication::rpc::Server server_{
config_.worker_endpoint, static_cast<size_t>(config_.rpc_num_workers)};
distributed::WorkerCoordination coordination_{server_,
config_.master_endpoint};
distributed::RpcWorkerClients rpc_worker_clients_{coordination_};
tx::EngineWorker tx_engine_{server_, rpc_worker_clients_.GetClientPool(0),
&wal_};
config_.worker_endpoint,
static_cast<size_t>(config_.rpc_num_server_workers)};
distributed::WorkerCoordination coordination_{
&server_, config_.master_endpoint, config_.worker_id,
config_.rpc_num_client_workers};
// TODO (mferencevic): Pass the coordination object directly wherever there is
// a `GetClientPool(xyz)` call.
tx::EngineWorker tx_engine_{&server_, coordination_.GetClientPool(0), &wal_};
std::unique_ptr<StorageGcWorker> storage_gc_ =
std::make_unique<StorageGcWorker>(
*storage_, tx_engine_, config_.gc_cycle_sec,
rpc_worker_clients_.GetClientPool(0), config_.worker_id);
*coordination_.GetClientPool(0), config_.worker_id);
TypemapPack<storage::WorkerConcurrentIdMapper> typemap_pack_{
rpc_worker_clients_.GetClientPool(0)};
database::WorkerCounters counters_{&rpc_worker_clients_.GetClientPool(0)};
*coordination_.GetClientPool(0)};
database::WorkerCounters counters_{coordination_.GetClientPool(0)};
distributed::BfsSubcursorStorage subcursor_storage_{self_,
&bfs_subcursor_clients_};
distributed::BfsRpcServer bfs_subcursor_server_{self_, &server_,
&subcursor_storage_};
distributed::BfsRpcClients bfs_subcursor_clients_{
self_, &subcursor_storage_, &rpc_worker_clients_, &data_manager_};
self_, &subcursor_storage_, &coordination_, &data_manager_};
distributed::DataRpcServer data_server_{self_, &server_};
distributed::DataRpcClients data_clients_{rpc_worker_clients_};
distributed::DataRpcClients data_clients_{&coordination_};
distributed::PlanConsumer plan_consumer_{server_};
distributed::ProduceRpcServer produce_server_{self_, &tx_engine_, server_,
plan_consumer_, &data_manager_};
distributed::IndexRpcServer index_rpc_server_{*self_, server_};
distributed::UpdatesRpcServer updates_server_{self_, &server_};
distributed::UpdatesRpcClients updates_clients_{rpc_worker_clients_};
distributed::UpdatesRpcClients updates_clients_{&coordination_};
distributed::DataManager data_manager_{*self_, data_clients_};
distributed::DurabilityRpcWorker durability_rpc_{self_, &server_};
distributed::ClusterDiscoveryWorker cluster_discovery_{
server_, coordination_, rpc_worker_clients_.GetClientPool(0)};
distributed::TokenSharingRpcClients token_sharing_clients_{
&rpc_worker_clients_};
server_, coordination_, *coordination_.GetClientPool(0)};
distributed::TokenSharingRpcServer token_sharing_server_{
self_, config_.worker_id, &coordination_, &server_,
&token_sharing_clients_};
self_, config_.worker_id, &coordination_, &server_};
distributed::DynamicWorkerRegistration dynamic_worker_registration_{
&rpc_worker_clients_.GetClientPool(0)};
coordination_.GetClientPool(0)};
};
} // namespace impl
@ -1064,14 +1126,7 @@ Worker::Worker(Config config)
}
}
Worker::~Worker() {
is_accepting_transactions_ = false;
impl_->tx_engine_.LocalForEachActiveTransaction(
[](auto &t) { t.set_should_abort(); });
// Transactional cache cleanup must be stopped before all of the objects
// that were registered for cleanup are destructed.
impl_->tx_engine_.StopTransactionalCacheCleanup();
}
Worker::~Worker() {}
std::unique_ptr<GraphDbAccessor> Worker::Access() {
return std::make_unique<WorkerAccessor>(this, &impl_->vertex_accessor_,
@ -1127,12 +1182,13 @@ bool Worker::MakeSnapshot(GraphDbAccessor &accessor) {
void Worker::ReinitializeStorage() {
// Release gc scheduler to stop it from touching storage
impl_->storage_gc_->Stop();
impl_->storage_gc_ = nullptr;
impl_->storage_ = std::make_unique<Storage>(
impl_->config_.worker_id, impl_->config_.properties_on_disk);
impl_->storage_gc_ = std::make_unique<StorageGcWorker>(
*impl_->storage_, impl_->tx_engine_, impl_->config_.gc_cycle_sec,
impl_->rpc_worker_clients_.GetClientPool(0), impl_->config_.worker_id);
*impl_->coordination_.GetClientPool(0), impl_->config_.worker_id);
}
void Worker::RecoverWalAndIndexes(durability::RecoveryData *recovery_data) {
@ -1150,10 +1206,43 @@ io::network::Endpoint Worker::GetEndpoint(int worker_id) {
return impl_->coordination_.GetEndpoint(worker_id);
}
void Worker::WaitForShutdown() {
return impl_->coordination_.WaitForShutdown();
bool Worker::AwaitShutdown(std::function<void(void)> call_before_shutdown) {
bool ret = impl_->coordination_.AwaitShutdown(
[this, &call_before_shutdown](bool is_cluster_alive) -> bool {
// Stop all running transactions. This will allow all shutdowns in the
// callback that depend on query execution to be aborted and cleaned up.
// TODO (mferencevic): See the note for this same code for the `Master`.
is_accepting_transactions_ = false;
impl_->tx_engine_.LocalForEachActiveTransaction(
[](auto &t) { t.set_should_abort(); });
// Call the toplevel callback to stop everything that the caller wants
// us to stop.
call_before_shutdown();
// Now we stop everything that calls RPCs (garbage collection, etc.)
// Stop the storage garbage collector.
impl_->storage_gc_->Stop();
// Transactional cache cleanup must be stopped before all of the objects
// that were registered for cleanup are destructed.
impl_->tx_engine_.StopTransactionalCacheCleanup();
// The worker shutdown always succeeds.
return true;
});
// Stop the RPC server
impl_->server_.Shutdown();
impl_->server_.AwaitShutdown();
// Return the shutdown success status.
return ret;
}
void Worker::Shutdown() { return impl_->coordination_.Shutdown(); }
distributed::BfsRpcClients &Worker::bfs_subcursor_clients() {
return impl_->bfs_subcursor_clients_;
}

View File

@ -66,6 +66,8 @@ class Master final : public DistributedGraphDb {
/** Gets the endpoint of the worker with the given id. */
// TODO make const once Coordination::GetEndpoint is const.
io::network::Endpoint GetEndpoint(int worker_id);
bool AwaitShutdown(std::function<void(void)> call_before_shutdown = [] {});
void Shutdown();
distributed::BfsRpcClients &bfs_subcursor_clients() override;
distributed::DataRpcClients &data_clients() override;
@ -111,7 +113,8 @@ class Worker final : public DistributedGraphDb {
/** Gets the endpoint of the worker with the given id. */
// TODO make const once Coordination::GetEndpoint is const.
io::network::Endpoint GetEndpoint(int worker_id);
void WaitForShutdown();
bool AwaitShutdown(std::function<void(void)> call_before_shutdown = [] {});
void Shutdown();
distributed::BfsRpcClients &bfs_subcursor_clients() override;
distributed::DataRpcClients &data_clients() override;

View File

@ -40,7 +40,8 @@ struct Config {
// Distributed master/worker flags.
bool dynamic_graph_partitioner_enabled{false};
int rpc_num_workers{0};
int rpc_num_client_workers{0};
int rpc_num_server_workers{0};
int worker_id{0};
io::network::Endpoint master_endpoint{"0.0.0.0", 0};
io::network::Endpoint worker_endpoint{"0.0.0.0", 0};

View File

@ -13,6 +13,7 @@
#include "storage/gid.hpp"
#include "storage/vertex.hpp"
#include "transactions/engine.hpp"
#include "utils/exceptions.hpp"
#include "utils/scheduler.hpp"
#include "utils/timer.hpp"
@ -46,13 +47,19 @@ class StorageGc {
vertices_(storage.vertices_),
edges_(storage.edges_) {
if (pause_sec > 0)
scheduler_.Run("Storage GC", std::chrono::seconds(pause_sec),
[this] { CollectGarbage(); });
scheduler_.Run(
"Storage GC", std::chrono::seconds(pause_sec), [this] {
try {
CollectGarbage();
} catch (const utils::BasicException &e) {
DLOG(WARNING)
<< "Couldn't perform storage garbage collection due to: "
<< e.what();
}
});
}
virtual ~StorageGc() {
scheduler_.Stop();
edges_.record_deleter_.FreeExpiredObjects(tx::Transaction::MaxId());
vertices_.record_deleter_.FreeExpiredObjects(tx::Transaction::MaxId());
edges_.version_list_deleter_.FreeExpiredObjects(tx::Transaction::MaxId());

View File

@ -31,6 +31,11 @@ class StorageGcMaster : public StorageGc {
// We have to stop scheduler before destroying this class because otherwise
// a task might try to utilize methods in this class which might cause pure
// virtual method called since they are not implemented for the base class.
CHECK(!scheduler_.IsRunning())
<< "You must call Stop on database::StorageGcMaster!";
}
void Stop() {
scheduler_.Stop();
rpc_server_.UnRegister<distributed::RanLocalGcRpc>();
}

View File

@ -19,15 +19,24 @@ class StorageGcWorker : public StorageGc {
// We have to stop scheduler before destroying this class because otherwise
// a task might try to utilize methods in this class which might cause pure
// virtual method called since they are not implemented for the base class.
scheduler_.Stop();
CHECK(!scheduler_.IsRunning())
<< "You must call Stop on database::StorageGcWorker!";
}
void Stop() { scheduler_.Stop(); }
void CollectCommitLogGarbage(tx::TransactionId oldest_active) final {
// We first need to delete transactions that we can delete to be sure that
// the locks are released as well. Otherwise some new transaction might
// try to acquire a lock which hasn't been released (if the transaction
// cache cleaner was not scheduled at this time), and take a look into the
// commit log which no longer contains that transaction id.
// TODO: when I (mferencevic) refactored the transaction engine code, I
// found out that the function `ClearTransactionalCache` of the
// `tx::EngineWorker` was called periodically in the transactional cache
// cleaner. That code was then moved and can now be found in the
// `tx::EngineDistributed` garbage collector. This may not be correct,
// @storage_team please investigate this.
dynamic_cast<tx::EngineWorker &>(tx_engine_)
.ClearTransactionalCache(oldest_active);
auto safe_to_delete = GetClogSafeTransaction(oldest_active);

View File

@ -8,24 +8,23 @@ namespace distributed {
BfsRpcClients::BfsRpcClients(database::DistributedGraphDb *db,
BfsSubcursorStorage *subcursor_storage,
RpcWorkerClients *clients,
Coordination *coordination,
DataManager *data_manager)
: db_(db),
subcursor_storage_(subcursor_storage),
clients_(clients),
coordination_(coordination),
data_manager_(data_manager) {}
std::unordered_map<int16_t, int64_t> BfsRpcClients::CreateBfsSubcursors(
tx::TransactionId tx_id, query::EdgeAtom::Direction direction,
const std::vector<storage::EdgeType> &edge_types,
query::GraphView graph_view) {
auto futures = clients_->ExecuteOnWorkers<std::pair<int16_t, int64_t>>(
auto futures = coordination_->ExecuteOnWorkers<std::pair<int16_t, int64_t>>(
db_->WorkerId(),
[tx_id, direction, &edge_types, graph_view](int worker_id, auto &client) {
auto res = client.template Call<CreateBfsSubcursorRpc>(
tx_id, direction, edge_types, graph_view);
CHECK(res) << "CreateBfsSubcursor RPC failed!";
return std::make_pair(worker_id, res->member);
return std::make_pair(worker_id, res.member);
});
std::unordered_map<int16_t, int64_t> subcursor_ids;
subcursor_ids.emplace(
@ -40,10 +39,9 @@ std::unordered_map<int16_t, int64_t> BfsRpcClients::CreateBfsSubcursors(
void BfsRpcClients::RegisterSubcursors(
const std::unordered_map<int16_t, int64_t> &subcursor_ids) {
auto futures = clients_->ExecuteOnWorkers<void>(
auto futures = coordination_->ExecuteOnWorkers<void>(
db_->WorkerId(), [&subcursor_ids](int worker_id, auto &client) {
auto res = client.template Call<RegisterSubcursorsRpc>(subcursor_ids);
CHECK(res) << "RegisterSubcursors RPC failed!";
client.template Call<RegisterSubcursorsRpc>(subcursor_ids);
});
subcursor_storage_->Get(subcursor_ids.at(db_->WorkerId()))
->RegisterSubcursors(subcursor_ids);
@ -55,11 +53,10 @@ void BfsRpcClients::RegisterSubcursors(
void BfsRpcClients::ResetSubcursors(
const std::unordered_map<int16_t, int64_t> &subcursor_ids) {
auto futures = clients_->ExecuteOnWorkers<void>(
auto futures = coordination_->ExecuteOnWorkers<void>(
db_->WorkerId(), [&subcursor_ids](int worker_id, auto &client) {
auto res = client.template Call<ResetSubcursorRpc>(
client.template Call<ResetSubcursorRpc>(
subcursor_ids.at(worker_id));
CHECK(res) << "ResetSubcursor RPC failed!";
});
subcursor_storage_->Get(subcursor_ids.at(db_->WorkerId()))->Reset();
// Wait and get all of the replies.
@ -70,11 +67,10 @@ void BfsRpcClients::ResetSubcursors(
void BfsRpcClients::RemoveBfsSubcursors(
const std::unordered_map<int16_t, int64_t> &subcursor_ids) {
auto futures = clients_->ExecuteOnWorkers<void>(
auto futures = coordination_->ExecuteOnWorkers<void>(
db_->WorkerId(), [&subcursor_ids](int worker_id, auto &client) {
auto res = client.template Call<RemoveBfsSubcursorRpc>(
client.template Call<RemoveBfsSubcursorRpc>(
subcursor_ids.at(worker_id));
CHECK(res) << "RemoveBfsSubcursor RPC failed!";
});
subcursor_storage_->Erase(subcursor_ids.at(db_->WorkerId()));
// Wait and get all of the replies.
@ -89,25 +85,23 @@ std::experimental::optional<VertexAccessor> BfsRpcClients::Pull(
return subcursor_storage_->Get(subcursor_id)->Pull();
}
auto res = clients_->GetClientPool(worker_id).CallWithLoad<SubcursorPullRpc>(
auto res = coordination_->GetClientPool(worker_id)->CallWithLoad<SubcursorPullRpc>(
[this, dba](const auto &reader) {
SubcursorPullRes res;
res.Load(reader, dba, this->data_manager_);
return res;
},
subcursor_id);
CHECK(res) << "SubcursorPull RPC failed!";
return res->vertex;
return res.vertex;
}
bool BfsRpcClients::ExpandLevel(
const std::unordered_map<int16_t, int64_t> &subcursor_ids) {
auto futures = clients_->ExecuteOnWorkers<bool>(
auto futures = coordination_->ExecuteOnWorkers<bool>(
db_->WorkerId(), [&subcursor_ids](int worker_id, auto &client) {
auto res =
client.template Call<ExpandLevelRpc>(subcursor_ids.at(worker_id));
CHECK(res) << "ExpandLevel RPC failed!";
return res->member;
return res.member;
});
bool expanded =
subcursor_storage_->Get(subcursor_ids.at(db_->WorkerId()))->ExpandLevel();
@ -128,9 +122,8 @@ void BfsRpcClients::SetSource(
subcursor_storage_->Get(subcursor_ids.at(db_->WorkerId()))
->SetSource(source_address);
} else {
auto res = clients_->GetClientPool(worker_id).Call<SetSourceRpc>(
coordination_->GetClientPool(worker_id)->Call<SetSourceRpc>(
subcursor_ids.at(worker_id), source_address);
CHECK(res) << "SetSourceRpc failed!";
}
}
@ -140,11 +133,10 @@ bool BfsRpcClients::ExpandToRemoteVertex(
CHECK(!vertex.is_local())
<< "ExpandToRemoteVertex should not be called with local vertex";
int worker_id = vertex.address().worker_id();
auto res = clients_->GetClientPool(worker_id).Call<ExpandToRemoteVertexRpc>(
auto res = coordination_->GetClientPool(worker_id)->Call<ExpandToRemoteVertexRpc>(
subcursor_ids.at(worker_id), edge.GlobalAddress(),
vertex.GlobalAddress());
CHECK(res) << "ExpandToRemoteVertex RPC failed!";
return res->member;
return res.member;
}
PathSegment BfsRpcClients::ReconstructPath(
@ -157,15 +149,14 @@ PathSegment BfsRpcClients::ReconstructPath(
}
auto res =
clients_->GetClientPool(worker_id).CallWithLoad<ReconstructPathRpc>(
coordination_->GetClientPool(worker_id)->CallWithLoad<ReconstructPathRpc>(
[this, dba](const auto &reader) {
ReconstructPathRes res;
res.Load(reader, dba, this->data_manager_);
return res;
},
subcursor_ids.at(worker_id), vertex);
CHECK(res) << "ReconstructPath RPC failed!";
return PathSegment{res->edges, res->next_vertex, res->next_edge};
return PathSegment{res.edges, res.next_vertex, res.next_edge};
}
PathSegment BfsRpcClients::ReconstructPath(
@ -177,24 +168,22 @@ PathSegment BfsRpcClients::ReconstructPath(
->ReconstructPath(edge);
}
auto res =
clients_->GetClientPool(worker_id).CallWithLoad<ReconstructPathRpc>(
coordination_->GetClientPool(worker_id)->CallWithLoad<ReconstructPathRpc>(
[this, dba](const auto &reader) {
ReconstructPathRes res;
res.Load(reader, dba, this->data_manager_);
return res;
},
subcursor_ids.at(worker_id), edge);
CHECK(res) << "ReconstructPath RPC failed!";
return PathSegment{res->edges, res->next_vertex, res->next_edge};
return PathSegment{res.edges, res.next_vertex, res.next_edge};
}
void BfsRpcClients::PrepareForExpand(
const std::unordered_map<int16_t, int64_t> &subcursor_ids, bool clear) {
auto futures = clients_->ExecuteOnWorkers<void>(
auto futures = coordination_->ExecuteOnWorkers<void>(
db_->WorkerId(), [clear, &subcursor_ids](int worker_id, auto &client) {
auto res = client.template Call<PrepareForExpandRpc>(
client.template Call<PrepareForExpandRpc>(
subcursor_ids.at(worker_id), clear);
CHECK(res) << "PrepareForExpand RPC failed!";
});
subcursor_storage_->Get(subcursor_ids.at(db_->WorkerId()))
->PrepareForExpand(clear);

View File

@ -2,7 +2,7 @@
#pragma once
#include "distributed/bfs_subcursor.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "distributed/coordination.hpp"
#include "transactions/transaction.hpp"
namespace database {
@ -23,7 +23,7 @@ class BfsRpcClients {
public:
BfsRpcClients(database::DistributedGraphDb *db,
BfsSubcursorStorage *subcursor_storage,
RpcWorkerClients *clients,
Coordination *coordination,
DataManager *data_manager);
std::unordered_map<int16_t, int64_t> CreateBfsSubcursors(
@ -64,10 +64,10 @@ class BfsRpcClients {
const std::unordered_map<int16_t, int64_t> &subcursor_ids, bool clear);
private:
database::DistributedGraphDb *db_{nullptr};
distributed::BfsSubcursorStorage *subcursor_storage_{nullptr};
distributed::RpcWorkerClients *clients_{nullptr};
distributed::DataManager *data_manager_{nullptr};
database::DistributedGraphDb *db_;
distributed::BfsSubcursorStorage *subcursor_storage_;
distributed::Coordination *coordination_;
distributed::DataManager *data_manager_;
};
} // namespace distributed

View File

@ -12,16 +12,14 @@ namespace distributed {
using Server = communication::rpc::Server;
ClusterDiscoveryMaster::ClusterDiscoveryMaster(
Server &server, MasterCoordination &coordination,
RpcWorkerClients &rpc_worker_clients,
Server *server, MasterCoordination *coordination,
const std::string &durability_directory)
: server_(server),
coordination_(coordination),
rpc_worker_clients_(rpc_worker_clients),
durability_directory_(durability_directory) {
server_.Register<RegisterWorkerRpc>([this](const auto &endpoint,
const auto &req_reader,
auto *res_builder) {
server_->Register<RegisterWorkerRpc>([this](const auto &endpoint,
const auto &req_reader,
auto *res_builder) {
bool registration_successful = false;
bool durability_error = false;
@ -56,34 +54,38 @@ ClusterDiscoveryMaster::ClusterDiscoveryMaster(
// Register the worker if the durability check succeeded.
if (!durability_error) {
registration_successful = this->coordination_.RegisterWorker(
req.desired_worker_id, worker_endpoint);
registration_successful =
coordination_->RegisterWorker(req.desired_worker_id, worker_endpoint);
}
// Notify the cluster of the new worker if the registration succeeded.
if (registration_successful) {
rpc_worker_clients_.ExecuteOnWorkers<void>(
0, [req, worker_endpoint](
int worker_id, communication::rpc::ClientPool &client_pool) {
auto result = client_pool.Call<ClusterDiscoveryRpc>(
req.desired_worker_id, worker_endpoint);
CHECK(result) << "ClusterDiscoveryRpc failed";
});
coordination_->ExecuteOnWorkers<
void>(0, [req, worker_endpoint](
int worker_id,
communication::rpc::ClientPool &client_pool) {
try {
client_pool.Call<ClusterDiscoveryRpc>(req.desired_worker_id,
worker_endpoint);
} catch (const communication::rpc::RpcFailedException &) {
LOG(FATAL)
<< "Couldn't notify the cluster of the changed configuration!";
}
});
}
RegisterWorkerRes res(registration_successful, durability_error,
this->coordination_.RecoveredSnapshotTx(),
this->coordination_.GetWorkers());
coordination_->RecoveredSnapshotTx(),
coordination_->GetWorkers());
res.Save(res_builder);
});
server_.Register<NotifyWorkerRecoveredRpc>(
[this](const auto &req_reader, auto *res_builder) {
NotifyWorkerRecoveredReq req;
req.Load(req_reader);
this->coordination_.WorkerRecoveredSnapshot(req.worker_id,
req.recovery_info);
});
server_->Register<NotifyWorkerRecoveredRpc>([this](const auto &req_reader,
auto *res_builder) {
NotifyWorkerRecoveredReq req;
req.Load(req_reader);
coordination_->WorkerRecoveredSnapshot(req.worker_id, req.recovery_info);
});
}
} // namespace distributed

View File

@ -2,7 +2,6 @@
#include "communication/rpc/server.hpp"
#include "distributed/coordination_master.hpp"
#include "distributed/rpc_worker_clients.hpp"
namespace distributed {
using Server = communication::rpc::Server;
@ -15,14 +14,12 @@ using Server = communication::rpc::Server;
*/
class ClusterDiscoveryMaster final {
public:
ClusterDiscoveryMaster(Server &server, MasterCoordination &coordination,
RpcWorkerClients &rpc_worker_clients,
ClusterDiscoveryMaster(Server *server, MasterCoordination *coordination,
const std::string &durability_directory);
private:
Server &server_;
MasterCoordination &coordination_;
RpcWorkerClients &rpc_worker_clients_;
Server *server_;
MasterCoordination *coordination_;
std::string durability_directory_;
};

View File

@ -30,21 +30,24 @@ void ClusterDiscoveryWorker::RegisterWorker(
std::experimental::filesystem::canonical(durability_directory);
// Register to the master.
auto result = client_pool_.Call<RegisterWorkerRpc>(
worker_id, server_.endpoint().port(), full_durability_directory);
CHECK(result) << "RegisterWorkerRpc failed";
CHECK(!result->durability_error)
<< "This worker was started on the same machine and with the same "
"durability directory as the master! Please change the durability "
"directory for this worker.";
CHECK(result->registration_successful)
<< "Unable to assign requested ID (" << worker_id << ") to worker!";
try {
auto result = client_pool_.Call<RegisterWorkerRpc>(
worker_id, server_.endpoint().port(), full_durability_directory);
CHECK(!result.durability_error)
<< "This worker was started on the same machine and with the same "
"durability directory as the master! Please change the durability "
"directory for this worker.";
CHECK(result.registration_successful)
<< "Unable to assign requested ID (" << worker_id << ") to worker!";
worker_id_ = worker_id;
for (auto &kv : result->workers) {
coordination_.RegisterWorker(kv.first, kv.second);
worker_id_ = worker_id;
for (auto &kv : result.workers) {
coordination_.RegisterWorker(kv.first, kv.second);
}
snapshot_to_recover_ = result.snapshot_to_recover;
} catch (const communication::rpc::RpcFailedException &e) {
LOG(FATAL) << "Couldn't register to the master!";
}
snapshot_to_recover_ = result->snapshot_to_recover;
}
void ClusterDiscoveryWorker::NotifyWorkerRecovered(
@ -53,9 +56,11 @@ void ClusterDiscoveryWorker::NotifyWorkerRecovered(
CHECK(worker_id_ >= 0)
<< "Workers id is not yet assigned, preform registration before "
"notifying that the recovery finished";
auto result =
client_pool_.Call<NotifyWorkerRecoveredRpc>(worker_id_, recovery_info);
CHECK(result) << "NotifyWorkerRecoveredRpc failed";
try {
client_pool_.Call<NotifyWorkerRecoveredRpc>(worker_id_, recovery_info);
} catch (const communication::rpc::RpcFailedException &e) {
LOG(FATAL) << "Couldn't notify the master that we finished recovering!";
}
}
} // namespace distributed

View File

@ -1,34 +1,75 @@
#include "glog/logging.h"
#include <thread>
#include "distributed/coordination.hpp"
namespace distributed {
using Endpoint = io::network::Endpoint;
Coordination::Coordination(const Endpoint &master_endpoint) {
Coordination::Coordination(const io::network::Endpoint &master_endpoint,
int worker_id, int client_workers_count)
: worker_id_(worker_id), thread_pool_(client_workers_count, "RPC client") {
// The master is always worker 0.
workers_.emplace(0, master_endpoint);
}
Endpoint Coordination::GetEndpoint(int worker_id) {
Coordination::~Coordination() {}
io::network::Endpoint Coordination::GetEndpoint(int worker_id) {
std::lock_guard<std::mutex> guard(lock_);
auto found = workers_.find(worker_id);
CHECK(found != workers_.end()) << "No endpoint registered for worker id: "
<< worker_id;
// TODO (mferencevic): Handle this error situation differently.
CHECK(found != workers_.end())
<< "No endpoint registered for worker id: " << worker_id;
return found->second;
}
std::vector<int> Coordination::GetWorkerIds() const {
std::vector<int> Coordination::GetWorkerIds() {
std::lock_guard<std::mutex> guard(lock_);
std::vector<int> worker_ids;
for (auto worker : workers_) worker_ids.push_back(worker.first);
return worker_ids;
}
void Coordination::AddWorker(int worker_id, Endpoint endpoint) {
workers_.emplace(worker_id, endpoint);
}
std::unordered_map<int, Endpoint> Coordination::GetWorkers() {
std::unordered_map<int, io::network::Endpoint> Coordination::GetWorkers() {
std::lock_guard<std::mutex> guard(lock_);
return workers_;
}
communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) {
std::lock_guard<std::mutex> guard(lock_);
auto found = client_pools_.find(worker_id);
if (found != client_pools_.end()) return &found->second;
auto found_endpoint = workers_.find(worker_id);
// TODO (mferencevic): Handle this error situation differently.
CHECK(found_endpoint != workers_.end())
<< "No endpoint registered for worker id: " << worker_id;
auto &endpoint = found_endpoint->second;
return &client_pools_
.emplace(std::piecewise_construct,
std::forward_as_tuple(worker_id),
std::forward_as_tuple(endpoint))
.first->second;
}
void Coordination::AddWorker(int worker_id,
const io::network::Endpoint &endpoint) {
std::lock_guard<std::mutex> guard(lock_);
workers_.insert({worker_id, endpoint});
}
std::string Coordination::GetWorkerName(const io::network::Endpoint &endpoint) {
std::lock_guard<std::mutex> guard(lock_);
for (const auto &worker : workers_) {
if (worker.second == endpoint) {
if (worker.first == 0) {
return fmt::format("master ({})", worker.second);
} else {
return fmt::format("worker {} ({})", worker.first, worker.second);
}
}
}
return fmt::format("unknown worker ({})", endpoint);
}
} // namespace distributed

View File

@ -1,36 +1,86 @@
#pragma once
#include <functional>
#include <mutex>
#include <thread>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "communication/rpc/client_pool.hpp"
#include "io/network/endpoint.hpp"
#include "utils/future.hpp"
#include "utils/thread.hpp"
namespace distributed {
/** Coordination base class. This class is not thread safe. */
/// Coordination base class. This class is thread safe.
class Coordination {
public:
explicit Coordination(const io::network::Endpoint &master_endpoint);
protected:
Coordination(const io::network::Endpoint &master_endpoint, int worker_id,
int client_workers_count = std::thread::hardware_concurrency());
~Coordination();
/** Gets the endpoint for the given worker ID from the master. */
public:
/// Gets the endpoint for the given worker ID from the master.
io::network::Endpoint GetEndpoint(int worker_id);
/** Returns all workers id, this includes master id(0). */
std::vector<int> GetWorkerIds() const;
/// Returns all workers id, this includes master (ID 0).
std::vector<int> GetWorkerIds();
/** Gets the mapping of worker id to worker endpoint including master (worker
* id = 0).
*/
/// Gets the mapping of worker id to worker endpoint including master (ID 0).
std::unordered_map<int, io::network::Endpoint> GetWorkers();
protected:
~Coordination() {}
/// Returns a cached `ClientPool` for the given `worker_id`.
communication::rpc::ClientPool *GetClientPool(int worker_id);
/** Adds a worker to coordination. */
void AddWorker(int worker_id, io::network::Endpoint endpoint);
/// Asynchroniously executes the given function on the rpc client for the
/// given worker id. Returns an `utils::Future` of the given `execute`
/// function's return type.
template <typename TResult>
auto ExecuteOnWorker(
int worker_id,
std::function<TResult(int worker_id, communication::rpc::ClientPool &)>
execute) {
// TODO (mferencevic): Change this lambda to accept a pointer to
// `ClientPool` instead of a reference!
auto client_pool = GetClientPool(worker_id);
return thread_pool_.Run(execute, worker_id, std::ref(*client_pool));
}
/// Asynchroniously executes the `execute` function on all worker rpc clients
/// except the one whose id is `skip_worker_id`. Returns a vector of futures
/// contaning the results of the `execute` function.
template <typename TResult>
auto ExecuteOnWorkers(
int skip_worker_id,
std::function<TResult(int worker_id, communication::rpc::ClientPool &)>
execute) {
std::vector<utils::Future<TResult>> futures;
// TODO (mferencevic): GetWorkerIds always copies the vector of workers,
// this may be an issue...
for (auto &worker_id : GetWorkerIds()) {
if (worker_id == skip_worker_id) continue;
futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute)));
}
return futures;
}
protected:
/// Adds a worker to the coordination. This function can be called multiple
/// times to replace an existing worker.
void AddWorker(int worker_id, const io::network::Endpoint &endpoint);
/// Gets a worker name for the given endpoint.
std::string GetWorkerName(const io::network::Endpoint &endpoint);
private:
std::unordered_map<int, io::network::Endpoint> workers_;
mutable std::mutex lock_;
int worker_id_;
std::unordered_map<int, communication::rpc::ClientPool> client_pools_;
utils::ThreadPool thread_pool_;
};
} // namespace distributed

View File

@ -1,3 +1,4 @@
#include <algorithm>
#include <chrono>
#include <thread>
@ -7,11 +8,27 @@
#include "distributed/coordination_master.hpp"
#include "distributed/coordination_rpc_messages.hpp"
#include "io/network/utils.hpp"
#include "utils/string.hpp"
namespace distributed {
MasterCoordination::MasterCoordination(const Endpoint &master_endpoint)
: Coordination(master_endpoint) {}
// Send a heartbeat request to the workers every `kHeartbeatIntervalSeconds`.
// This constant must be at least 10x smaller than `kHeartbeatMaxDelaySeconds`
// that is defined in the worker coordination.
const int kHeartbeatIntervalSeconds = 1;
MasterCoordination::MasterCoordination(const Endpoint &master_endpoint,
int client_workers_count)
: Coordination(master_endpoint, 0, client_workers_count) {
// TODO (mferencevic): Move this to an explicit `Start` method.
scheduler_.Run("Heartbeat", std::chrono::seconds(kHeartbeatIntervalSeconds),
[this] { IssueHeartbeats(); });
}
MasterCoordination::~MasterCoordination() {
CHECK(!alive_) << "You must call Shutdown and AwaitShutdown on "
"distributed::MasterCoordination!";
}
bool MasterCoordination::RegisterWorker(int desired_worker_id,
Endpoint endpoint) {
@ -19,13 +36,13 @@ bool MasterCoordination::RegisterWorker(int desired_worker_id,
// ensure the whole cluster is in a consistent state.
while (true) {
{
std::lock_guard<std::mutex> guard(lock_);
std::lock_guard<std::mutex> guard(master_lock_);
if (recovery_done_) break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(200));
}
std::lock_guard<std::mutex> guard(lock_);
std::lock_guard<std::mutex> guard(master_lock_);
auto workers = GetWorkers();
// Check if the desired worker id already exists.
if (workers.find(desired_worker_id) != workers.end()) {
@ -48,36 +65,10 @@ void MasterCoordination::WorkerRecoveredSnapshot(
<< "Worker already notified about finishing recovery";
}
Endpoint MasterCoordination::GetEndpoint(int worker_id) {
std::lock_guard<std::mutex> guard(lock_);
return Coordination::GetEndpoint(worker_id);
}
MasterCoordination::~MasterCoordination() {
using namespace std::chrono_literals;
std::lock_guard<std::mutex> guard(lock_);
auto workers = GetWorkers();
for (const auto &kv : workers) {
// Skip master (self).
if (kv.first == 0) continue;
communication::rpc::Client client(kv.second);
auto result = client.Call<StopWorkerRpc>();
CHECK(result) << "StopWorkerRpc failed for worker: " << kv.first;
}
// Make sure all workers have died.
for (const auto &kv : workers) {
// Skip master (self).
if (kv.first == 0) continue;
while (io::network::CanEstablishConnection(kv.second))
std::this_thread::sleep_for(0.5s);
}
}
void MasterCoordination::SetRecoveredSnapshot(
std::experimental::optional<std::pair<int64_t, tx::TransactionId>>
recovered_snapshot_tx) {
std::lock_guard<std::mutex> guard(lock_);
std::lock_guard<std::mutex> guard(master_lock_);
recovery_done_ = true;
recovered_snapshot_tx_ = recovered_snapshot_tx;
}
@ -88,7 +79,7 @@ int MasterCoordination::CountRecoveredWorkers() const {
std::experimental::optional<std::pair<int64_t, tx::TransactionId>>
MasterCoordination::RecoveredSnapshotTx() const {
std::lock_guard<std::mutex> guard(lock_);
std::lock_guard<std::mutex> guard(master_lock_);
CHECK(recovery_done_) << "Recovered snapshot requested before it's available";
return recovered_snapshot_tx_;
}
@ -102,7 +93,7 @@ std::vector<tx::TransactionId> MasterCoordination::CommonWalTransactions(
}
{
std::lock_guard<std::mutex> guard(lock_);
std::lock_guard<std::mutex> guard(master_lock_);
for (auto worker : recovered_workers_) {
// If there is no recovery info we can just return an empty vector since
// we can't restore any transaction
@ -125,4 +116,105 @@ std::vector<tx::TransactionId> MasterCoordination::CommonWalTransactions(
return tx_intersection;
}
bool MasterCoordination::AwaitShutdown(
std::function<bool(bool)> call_before_shutdown) {
// Wait for a shutdown notification.
while (alive_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
// Copy the current value of the cluster state.
bool is_cluster_alive = cluster_alive_;
// Call the before shutdown callback.
bool ret = call_before_shutdown(is_cluster_alive);
// Stop the heartbeat scheduler so we don't cause any errors during shutdown.
// Also, we manually issue one final heartbeat to all workers so that their
// counters are reset. This must be done immediately before issuing shutdown
// requests to the workers. The `IssueHeartbeats` will ignore any errors that
// occur now because we are in the process of shutting the cluster down.
scheduler_.Stop();
IssueHeartbeats();
// Shutdown all workers.
auto workers = GetWorkers();
std::vector<std::pair<int, io::network::Endpoint>> workers_sorted(
workers.begin(), workers.end());
std::sort(workers_sorted.begin(), workers_sorted.end(),
[](const std::pair<int, io::network::Endpoint> &a,
const std::pair<int, io::network::Endpoint> &b) {
return a.first < b.first;
});
LOG(INFO) << "Starting shutdown of all workers.";
for (const auto &worker : workers_sorted) {
// Skip master (self).
if (worker.first == 0) continue;
auto client_pool = GetClientPool(worker.first);
try {
client_pool->Call<StopWorkerRpc>();
} catch (const communication::rpc::RpcFailedException &e) {
LOG(WARNING) << "Couldn't shutdown " << GetWorkerName(e.endpoint());
}
}
// Make sure all workers have died.
while (true) {
std::vector<std::string> workers_alive;
std::this_thread::sleep_for(std::chrono::milliseconds(500));
for (const auto &worker : workers_sorted) {
// Skip master (self).
if (worker.first == 0) continue;
if (io::network::CanEstablishConnection(worker.second)) {
workers_alive.push_back(GetWorkerName(worker.second));
}
}
if (workers_alive.size() == 0) break;
LOG(INFO) << "Waiting for " << utils::Join(workers_alive, ", ")
<< " to finish shutting down...";
}
LOG(INFO) << "Shutdown of all workers is complete.";
// Return `true` if the cluster is alive and the `call_before_shutdown`
// succeeded.
return ret && is_cluster_alive;
}
void MasterCoordination::Shutdown() { alive_.store(false); }
bool MasterCoordination::IsClusterAlive() { return cluster_alive_; }
void MasterCoordination::IssueHeartbeats() {
std::lock_guard<std::mutex> guard(master_lock_);
auto workers = GetWorkers();
for (const auto &worker : workers) {
// Skip master (self).
if (worker.first == 0) continue;
auto client_pool = GetClientPool(worker.first);
try {
// TODO (mferencevic): Should we retry this call to ignore some transient
// communication errors?
client_pool->Call<HeartbeatRpc>();
} catch (const communication::rpc::RpcFailedException &e) {
// If we are not alive that means that we are in the process of a
// shutdown. We ignore any exceptions here to stop our Heartbeat from
// displaying warnings that the workers may have died (they should die,
// we are shutting them down). Note: The heartbeat scheduler must stay
// alive to ensure that the workers receive their heartbeat requests
// during shutdown (which may take a long time).
if (!alive_) continue;
LOG(WARNING) << "The " << GetWorkerName(e.endpoint())
<< " didn't respond to our heartbeat request. The cluster "
"is in a degraded state and we are starting a graceful "
"shutdown. Please check the logs on the worker for "
"more details.";
// Set the `cluster_alive_` flag to `false` to indicate that something
// in the cluster failed.
cluster_alive_.store(false);
// Shutdown the whole cluster.
Shutdown();
}
}
}
} // namespace distributed

View File

@ -1,6 +1,8 @@
#pragma once
#include <atomic>
#include <experimental/optional>
#include <functional>
#include <mutex>
#include <set>
#include <unordered_map>
@ -8,6 +10,7 @@
#include "distributed/coordination.hpp"
#include "durability/recovery.hpp"
#include "io/network/endpoint.hpp"
#include "utils/scheduler.hpp"
namespace distributed {
using Endpoint = io::network::Endpoint;
@ -16,9 +19,10 @@ using Endpoint = io::network::Endpoint;
* coordinated shutdown in a distributed memgraph. Master side. */
class MasterCoordination final : public Coordination {
public:
explicit MasterCoordination(const Endpoint &master_endpoint);
explicit MasterCoordination(
const Endpoint &master_endpoint,
int client_workers_count = std::thread::hardware_concurrency());
/** Shuts down all the workers and this master server. */
~MasterCoordination();
/** Registers a new worker with this master coordination.
@ -37,8 +41,6 @@ class MasterCoordination final : public Coordination {
int worker_id, const std::experimental::optional<durability::RecoveryInfo>
&recovery_info);
Endpoint GetEndpoint(int worker_id);
/// Sets the recovery info. nullopt indicates nothing was recovered.
void SetRecoveredSnapshot(
std::experimental::optional<std::pair<int64_t, tx::TransactionId>>
@ -52,19 +54,47 @@ class MasterCoordination final : public Coordination {
std::vector<tx::TransactionId> CommonWalTransactions(
const durability::RecoveryInfo &master_info) const;
private:
// Most master functions aren't thread-safe.
mutable std::mutex lock_;
/// Waits while the cluster is in a valid state or the `Shutdown` method is
/// called (suitable for use with signal handlers). Blocks the calling thread
/// until that has finished.
/// @param call_before_shutdown function that should be called before
/// shutdown, the function gets a bool argument indicating whether the cluster
/// is alive and should return a bool indicating whether the shutdown
/// succeeded without any issues
/// @returns `true` if the shutdown was completed without any issues, `false`
/// otherwise
bool AwaitShutdown(std::function<bool(bool)> call_before_shutdown =
[](bool is_cluster_alive) -> bool { return true; });
/// Durabilility recovery info.
/// Indicates if the recovery phase is done.
/// Hints that the coordination should start shutting down the whole cluster.
void Shutdown();
/// Returns `true` if the cluster is in a consistent state.
bool IsClusterAlive();
private:
/// Sends a heartbeat request to all workers.
void IssueHeartbeats();
// Most master functions aren't thread-safe.
mutable std::mutex master_lock_;
// Durabilility recovery info.
// Indicates if the recovery phase is done.
bool recovery_done_{false};
/// Set of workers that finished sucesfully recovering snapshot
// Set of workers that finished sucesfully recovering snapshot
std::map<int, std::experimental::optional<durability::RecoveryInfo>>
recovered_workers_;
/// If nullopt nothing was recovered.
// If nullopt nothing was recovered.
std::experimental::optional<std::pair<int64_t, tx::TransactionId>>
recovered_snapshot_tx_;
// Scheduler that is used to periodically ping all registered workers.
utils::Scheduler scheduler_;
// Flags used for shutdown.
std::atomic<bool> alive_{true};
std::atomic<bool> cluster_alive_{true};
};
} // namespace distributed

View File

@ -90,4 +90,8 @@ cpp<#
:capnp-type "Utils.Optional(Dur.RecoveryInfo)")))
(:response ()))
(lcp:define-rpc heartbeat
(:request ())
(:response ()))
(lcp:pop-namespace) ;; distributed

View File

@ -1,5 +1,4 @@
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <thread>
@ -10,37 +9,91 @@
namespace distributed {
using namespace std::literals::chrono_literals;
// Expect that a heartbeat should be received in this time interval. If it is
// not received we assume that the communication is broken and start a shutdown.
const int kHeartbeatMaxDelaySeconds = 10;
WorkerCoordination::WorkerCoordination(communication::rpc::Server &server,
const Endpoint &master_endpoint)
: Coordination(master_endpoint), server_(server) {}
// Check whether a heartbeat is received every `kHeartbeatCheckSeconds`. It
// should be larger than `kHeartbeatIntervalSeconds` defined in the master
// coordination because it makes no sense to check more often than the heartbeat
// is sent. Also, it must be smaller than `kHeartbeatMaxDelaySeconds` to
// function properly.
const int kHeartbeatCheckSeconds = 2;
using namespace std::chrono_literals;
WorkerCoordination::WorkerCoordination(communication::rpc::Server *server,
const Endpoint &master_endpoint,
int worker_id, int client_workers_count)
: Coordination(master_endpoint, worker_id, client_workers_count),
server_(server) {
server_->Register<StopWorkerRpc>(
[&](const auto &req_reader, auto *res_builder) {
LOG(INFO) << "The master initiated shutdown of this worker.";
Shutdown();
});
server_->Register<HeartbeatRpc>([&](const auto &req_reader,
auto *res_builder) {
std::lock_guard<std::mutex> guard(heartbeat_lock_);
last_heartbeat_time_ = std::chrono::steady_clock::now();
if (!scheduler_.IsRunning()) {
scheduler_.Run(
"Heartbeat", std::chrono::seconds(kHeartbeatCheckSeconds), [this] {
std::lock_guard<std::mutex> guard(heartbeat_lock_);
auto duration =
std::chrono::steady_clock::now() - last_heartbeat_time_;
if (duration > std::chrono::seconds(kHeartbeatMaxDelaySeconds)) {
LOG(WARNING) << "The master hasn't given us a heartbeat request "
"for at least "
<< kHeartbeatMaxDelaySeconds
<< " seconds! We are shutting down...";
// Set the `cluster_alive_` flag to `false` to indicate that
// something in the cluster failed.
cluster_alive_ = false;
// Shutdown the worker.
Shutdown();
}
});
}
});
}
WorkerCoordination::~WorkerCoordination() {
CHECK(!alive_) << "You must call Shutdown and AwaitShutdown on "
"distributed::WorkerCoordination!";
}
void WorkerCoordination::RegisterWorker(int worker_id, Endpoint endpoint) {
std::lock_guard<std::mutex> guard(lock_);
AddWorker(worker_id, endpoint);
}
void WorkerCoordination::WaitForShutdown() {
using namespace std::chrono_literals;
std::mutex mutex;
std::condition_variable cv;
bool shutdown = false;
bool WorkerCoordination::AwaitShutdown(
std::function<bool(bool)> call_before_shutdown) {
// Wait for a shutdown notification.
while (alive_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
server_.Register<StopWorkerRpc>([&](const auto &req_reader, auto *res_builder) {
std::unique_lock<std::mutex> lk(mutex);
shutdown = true;
lk.unlock();
cv.notify_one();
});
// The first thing we need to do is to stop our heartbeat scheduler because
// the master stopped their scheduler immediately before issuing the shutdown
// request to the worker. This will prevent our heartbeat from timing out on a
// regular shutdown.
scheduler_.Stop();
std::unique_lock<std::mutex> lk(mutex);
cv.wait(lk, [&shutdown] { return shutdown; });
// Copy the current value of the cluster state.
bool is_cluster_alive = cluster_alive_;
// Call the before shutdown callback.
bool ret = call_before_shutdown(is_cluster_alive);
// All other cleanup must be done here.
// Return `true` if the cluster is alive and the `call_before_shutdown`
// succeeded.
return ret && is_cluster_alive;
}
io::network::Endpoint WorkerCoordination::GetEndpoint(int worker_id) {
std::lock_guard<std::mutex> guard(lock_);
return Coordination::GetEndpoint(worker_id);
}
void WorkerCoordination::Shutdown() { alive_.store(false); }
} // namespace distributed

View File

@ -1,33 +1,59 @@
#pragma once
#include <atomic>
#include <mutex>
#include <unordered_map>
#include "communication/rpc/server.hpp"
#include "distributed/coordination.hpp"
#include "utils/scheduler.hpp"
namespace distributed {
/** Handles worker registration, getting of other workers' endpoints and
* coordinated shutdown in a distributed memgraph. Worker side. */
/// Handles worker registration, getting of other workers' endpoints and
/// coordinated shutdown in a distributed memgraph. Worker side.
class WorkerCoordination final : public Coordination {
using Endpoint = io::network::Endpoint;
public:
WorkerCoordination(communication::rpc::Server &server,
const Endpoint &master_endpoint);
WorkerCoordination(
communication::rpc::Server *server, const Endpoint &master_endpoint,
int worker_id,
int client_workers_count = std::thread::hardware_concurrency());
/** Registers the worker with the given endpoint. */
~WorkerCoordination();
/// Registers the worker with the given endpoint.
void RegisterWorker(int worker_id, Endpoint endpoint);
/** Starts listening for a remote shutdown command (issued by the master).
* Blocks the calling thread until that has finished. */
void WaitForShutdown();
/// Starts listening for a remote shutdown command (issued by the master) or
/// for the `Shutdown` method to be called (suitable for use with signal
/// handlers). Blocks the calling thread until that has finished.
/// @param call_before_shutdown function that should be called before
/// shutdown, the function gets a bool argument indicating whether the cluster
/// is alive and should return a bool indicating whether the shutdown
/// succeeded without any issues
/// @returns `true` if the shutdown was completed without any issues, `false`
/// otherwise
bool AwaitShutdown(std::function<bool(bool)> call_before_shutdown =
[](bool is_cluster_alive) -> bool { return true; });
Endpoint GetEndpoint(int worker_id);
/// Hints that the coordination should start shutting down the worker.
void Shutdown();
/// Returns `true` if the cluster is in a consistent state.
bool IsClusterAlive();
private:
communication::rpc::Server &server_;
mutable std::mutex lock_;
communication::rpc::Server *server_;
// Heartbeat variables
std::mutex heartbeat_lock_;
std::chrono::time_point<std::chrono::steady_clock> last_heartbeat_time_;
utils::Scheduler scheduler_;
// Flag used for shutdown.
std::atomic<bool> alive_{true};
std::atomic<bool> cluster_alive_{true};
};
} // namespace distributed

View File

@ -2,7 +2,6 @@
#include "distributed/data_rpc_clients.hpp"
#include "distributed/data_rpc_messages.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "storage/edge.hpp"
#include "storage/vertex.hpp"
@ -13,10 +12,9 @@ RemoteElementInfo<Edge> DataRpcClients::RemoteElement(int worker_id,
tx::TransactionId tx_id,
gid::Gid gid) {
auto response =
clients_.GetClientPool(worker_id).Call<EdgeRpc>(TxGidPair{tx_id, gid});
CHECK(response) << "EdgeRpc failed";
return RemoteElementInfo<Edge>(response->cypher_id,
std::move(response->edge_output));
coordination_->GetClientPool(worker_id)->Call<EdgeRpc>(TxGidPair{tx_id, gid});
return RemoteElementInfo<Edge>(response.cypher_id,
std::move(response.edge_output));
}
template <>
@ -24,19 +22,17 @@ RemoteElementInfo<Vertex> DataRpcClients::RemoteElement(int worker_id,
tx::TransactionId tx_id,
gid::Gid gid) {
auto response =
clients_.GetClientPool(worker_id).Call<VertexRpc>(TxGidPair{tx_id, gid});
CHECK(response) << "VertexRpc failed";
return RemoteElementInfo<Vertex>(response->cypher_id,
std::move(response->vertex_output));
coordination_->GetClientPool(worker_id)->Call<VertexRpc>(TxGidPair{tx_id, gid});
return RemoteElementInfo<Vertex>(response.cypher_id,
std::move(response.vertex_output));
}
std::unordered_map<int, int64_t> DataRpcClients::VertexCounts(
tx::TransactionId tx_id) {
auto future_results = clients_.ExecuteOnWorkers<std::pair<int, int64_t>>(
auto future_results = coordination_->ExecuteOnWorkers<std::pair<int, int64_t>>(
-1, [tx_id](int worker_id, communication::rpc::ClientPool &client_pool) {
auto response = client_pool.Call<VertexCountRpc>(tx_id);
CHECK(response) << "VertexCountRpc failed";
return std::make_pair(worker_id, response->member);
return std::make_pair(worker_id, response.member);
});
std::unordered_map<int, int64_t> results;

View File

@ -7,6 +7,7 @@
#include <unordered_map>
#include <utility>
#include "distributed/coordination.hpp"
#include "storage/gid.hpp"
#include "transactions/type.hpp"
@ -34,7 +35,8 @@ struct RemoteElementInfo {
/// Provides access to other worker's data.
class DataRpcClients {
public:
DataRpcClients(RpcWorkerClients &clients) : clients_(clients) {}
explicit DataRpcClients(Coordination *coordination)
: coordination_(coordination) {}
/// Returns a remote worker's record (vertex/edge) data for the given params.
/// That worker must own the vertex/edge for the given id, and that vertex
@ -49,7 +51,7 @@ class DataRpcClients {
std::unordered_map<int, int64_t> VertexCounts(tx::TransactionId tx_id);
private:
RpcWorkerClients &clients_;
Coordination *coordination_;
};
} // namespace distributed

View File

@ -7,11 +7,14 @@
namespace distributed {
utils::Future<bool> DurabilityRpcMaster::MakeSnapshot(tx::TransactionId tx) {
return utils::make_future(std::async(std::launch::async, [this, tx] {
auto futures = clients_.ExecuteOnWorkers<bool>(
auto futures = coordination_->ExecuteOnWorkers<bool>(
0, [tx](int worker_id, communication::rpc::ClientPool &client_pool) {
auto res = client_pool.Call<MakeSnapshotRpc>(tx);
if (!res) return false;
return res->member;
try {
auto res = client_pool.Call<MakeSnapshotRpc>(tx);
return res.member;
} catch (const communication::rpc::RpcFailedException &e) {
return false;
}
});
bool created = true;
@ -25,22 +28,25 @@ utils::Future<bool> DurabilityRpcMaster::MakeSnapshot(tx::TransactionId tx) {
utils::Future<bool> DurabilityRpcMaster::RecoverWalAndIndexes(
durability::RecoveryData *recovery_data) {
return utils::make_future(std::async(std::launch::async, [this,
recovery_data] {
auto futures = clients_.ExecuteOnWorkers<bool>(
0, [recovery_data](int worker_id,
communication::rpc::ClientPool &client_pool) {
auto res = client_pool.Call<RecoverWalAndIndexesRpc>(*recovery_data);
if (!res) return false;
return true;
});
return utils::make_future(
std::async(std::launch::async, [this, recovery_data] {
auto futures = coordination_->ExecuteOnWorkers<bool>(
0, [recovery_data](int worker_id,
communication::rpc::ClientPool &client_pool) {
try {
client_pool.Call<RecoverWalAndIndexesRpc>(*recovery_data);
return true;
} catch (const communication::rpc::RpcFailedException &e) {
return false;
}
});
bool recovered = true;
for (auto &future : futures) {
recovered &= future.get();
}
bool recovered = true;
for (auto &future : futures) {
recovered &= future.get();
}
return recovered;
}));
return recovered;
}));
}
} // namespace distributed

View File

@ -4,7 +4,7 @@
#include <mutex>
#include <utility>
#include "distributed/rpc_worker_clients.hpp"
#include "distributed/coordination.hpp"
#include "durability/recovery.hpp"
#include "storage/gid.hpp"
#include "transactions/type.hpp"
@ -14,7 +14,8 @@ namespace distributed {
/// Provides an ability to trigger snapshooting on other workers.
class DurabilityRpcMaster {
public:
explicit DurabilityRpcMaster(RpcWorkerClients &clients) : clients_(clients) {}
explicit DurabilityRpcMaster(Coordination *coordination)
: coordination_(coordination) {}
// Sends a snapshot request to workers and returns a future which becomes true
// if all workers sucesfully completed their snapshot creation, false
@ -26,7 +27,7 @@ class DurabilityRpcMaster {
durability::RecoveryData *recovery_data);
private:
RpcWorkerClients &clients_;
Coordination *coordination_;
};
} // namespace distributed

View File

@ -39,8 +39,7 @@ DynamicWorkerRegistration::DynamicWorkerRegistration(ClientPool *client_pool)
std::vector<std::pair<std::string, std::string>>
DynamicWorkerRegistration::GetIndicesToCreate() {
auto result = client_pool_->Call<DynamicWorkerRpc>();
CHECK(result) << "DynamicWorkerRpc failed";
return result->recover_indices;
return result.recover_indices;
}
} // namespace distributed

View File

@ -2,33 +2,30 @@
namespace distributed {
PlanDispatcher::PlanDispatcher(RpcWorkerClients &clients) : clients_(clients) {}
PlanDispatcher::PlanDispatcher(Coordination *coordination) : coordination_(coordination) {}
void PlanDispatcher::DispatchPlan(
int64_t plan_id, std::shared_ptr<query::plan::LogicalOperator> plan,
const query::SymbolTable &symbol_table) {
auto futures = clients_.ExecuteOnWorkers<void>(
auto futures = coordination_->ExecuteOnWorkers<void>(
0, [plan_id, plan, symbol_table](
int worker_id, communication::rpc::ClientPool &client_pool) {
auto result =
client_pool.Call<DispatchPlanRpc>(plan_id, plan, symbol_table);
CHECK(result) << "DispatchPlanRpc failed";
client_pool.Call<DispatchPlanRpc>(plan_id, plan, symbol_table);
});
for (auto &future : futures) {
future.wait();
future.get();
}
}
void PlanDispatcher::RemovePlan(int64_t plan_id) {
auto futures = clients_.ExecuteOnWorkers<void>(
auto futures = coordination_->ExecuteOnWorkers<void>(
0, [plan_id](int worker_id, communication::rpc::ClientPool &client_pool) {
auto result = client_pool.Call<RemovePlanRpc>(plan_id);
CHECK(result) << "Failed to remove plan from worker";
client_pool.Call<RemovePlanRpc>(plan_id);
});
for (auto &future : futures) {
future.wait();
future.get();
}
}

View File

@ -2,7 +2,6 @@
#include "distributed/coordination.hpp"
#include "distributed/plan_rpc_messages.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/plan/operator.hpp"
@ -13,7 +12,7 @@ namespace distributed {
*/
class PlanDispatcher {
public:
explicit PlanDispatcher(RpcWorkerClients &clients);
explicit PlanDispatcher(Coordination *coordination);
/** Dispatch a plan to all workers and wait for their acknowledgement. */
void DispatchPlan(int64_t plan_id,
@ -24,7 +23,7 @@ class PlanDispatcher {
void RemovePlan(int64_t plan_id);
private:
RpcWorkerClients &clients_;
Coordination *coordination_;
};
} // namespace distributed

View File

@ -12,7 +12,7 @@ utils::Future<PullData> PullRpcClients::Pull(
const query::EvaluationContext &evaluation_context,
const std::vector<query::Symbol> &symbols, bool accumulate,
int batch_size) {
return clients_->ExecuteOnWorker<PullData>(worker_id, [
return coordination_->ExecuteOnWorker<PullData>(worker_id, [
data_manager = data_manager_, dba, plan_id, command_id, evaluation_context,
symbols, accumulate, batch_size
](int worker_id, ClientPool &client_pool) {
@ -25,27 +25,25 @@ utils::Future<PullData> PullRpcClients::Pull(
load_pull_res, dba->transaction_id(), dba->transaction().snapshot(),
plan_id, command_id, evaluation_context, symbols, accumulate,
batch_size, storage::SendVersions::BOTH);
return PullData{result->data.pull_state, std::move(result->data.frames)};
return PullData{result.data.pull_state, std::move(result.data.frames)};
});
}
utils::Future<void> PullRpcClients::ResetCursor(database::GraphDbAccessor *dba,
int worker_id, int64_t plan_id,
tx::CommandId command_id) {
return clients_->ExecuteOnWorker<void>(
return coordination_->ExecuteOnWorker<void>(
worker_id, [dba, plan_id, command_id](int worker_id, auto &client) {
auto res = client.template Call<ResetCursorRpc>(dba->transaction_id(),
plan_id, command_id);
CHECK(res) << "ResetCursorRpc failed!";
client.template Call<ResetCursorRpc>(dba->transaction_id(), plan_id,
command_id);
});
}
std::vector<utils::Future<void>>
PullRpcClients::NotifyAllTransactionCommandAdvanced(tx::TransactionId tx_id) {
return clients_->ExecuteOnWorkers<void>(
return coordination_->ExecuteOnWorkers<void>(
0, [tx_id](int worker_id, auto &client) {
auto res = client.template Call<TransactionCommandAdvancedRpc>(tx_id);
CHECK(res) << "TransactionCommandAdvanceRpc failed";
client.template Call<TransactionCommandAdvancedRpc>(tx_id);
});
}

View File

@ -3,8 +3,8 @@
#include <vector>
#include "database/graph_db_accessor.hpp"
#include "distributed/coordination.hpp"
#include "distributed/pull_produce_rpc_messages.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "query/context.hpp"
#include "query/frontend/semantic/symbol.hpp"
#include "transactions/type.hpp"
@ -22,8 +22,8 @@ class PullRpcClients {
using ClientPool = communication::rpc::ClientPool;
public:
PullRpcClients(RpcWorkerClients *clients, DataManager *data_manager)
: clients_(clients), data_manager_(data_manager) {}
PullRpcClients(Coordination *coordination, DataManager *data_manager)
: coordination_(coordination), data_manager_(data_manager) {}
/// Calls a remote pull asynchroniously. IMPORTANT: take care not to call this
/// function for the same (tx_id, worker_id, plan_id, command_id) before the
@ -42,14 +42,14 @@ class PullRpcClients {
utils::Future<void> ResetCursor(database::GraphDbAccessor *dba, int worker_id,
int64_t plan_id, tx::CommandId command_id);
auto GetWorkerIds() { return clients_->GetWorkerIds(); }
auto GetWorkerIds() { return coordination_->GetWorkerIds(); }
std::vector<utils::Future<void>> NotifyAllTransactionCommandAdvanced(
tx::TransactionId tx_id);
private:
RpcWorkerClients *clients_{nullptr};
DataManager *data_manager_{nullptr};
Coordination *coordination_;
DataManager *data_manager_;
};
} // namespace distributed

View File

@ -1,134 +0,0 @@
#pragma once
#include <functional>
#include <type_traits>
#include <unordered_map>
#include "communication/rpc/client_pool.hpp"
#include "distributed/coordination.hpp"
#include "distributed/index_rpc_messages.hpp"
#include "distributed/token_sharing_rpc_messages.hpp"
#include "storage/types.hpp"
#include "transactions/transaction.hpp"
#include "utils/future.hpp"
#include "utils/thread.hpp"
namespace distributed {
/** A cache of RPC clients (of the given name/kind) per MG distributed worker.
* Thread safe. */
class RpcWorkerClients {
public:
explicit RpcWorkerClients(Coordination &coordination)
: coordination_(coordination),
thread_pool_(std::thread::hardware_concurrency()) {}
RpcWorkerClients(const RpcWorkerClients &) = delete;
RpcWorkerClients(RpcWorkerClients &&) = delete;
RpcWorkerClients &operator=(const RpcWorkerClients &) = delete;
RpcWorkerClients &operator=(RpcWorkerClients &&) = delete;
auto &GetClientPool(int worker_id) {
std::lock_guard<std::mutex> guard{lock_};
auto found = client_pools_.find(worker_id);
if (found != client_pools_.end()) return found->second;
return client_pools_
.emplace(std::piecewise_construct, std::forward_as_tuple(worker_id),
std::forward_as_tuple(coordination_.GetEndpoint(worker_id)))
.first->second;
}
auto GetWorkerIds() { return coordination_.GetWorkerIds(); }
/** Asynchroniously executes the given function on the rpc client for the
* given worker id. Returns an `utils::Future` of the given `execute`
* function's
* return type. */
template <typename TResult>
auto ExecuteOnWorker(
int worker_id,
std::function<TResult(int worker_id, communication::rpc::ClientPool &)>
execute) {
auto &client_pool = GetClientPool(worker_id);
return thread_pool_.Run(execute, worker_id, std::ref(client_pool));
}
/** Asynchroniously executes the `execute` function on all worker rpc clients
* except the one whose id is `skip_worker_id`. Returns a vectore of futures
* contaning the results of the `execute` function. */
template <typename TResult>
auto ExecuteOnWorkers(
int skip_worker_id,
std::function<TResult(int worker_id, communication::rpc::ClientPool &)>
execute) {
std::vector<utils::Future<TResult>> futures;
for (auto &worker_id : coordination_.GetWorkerIds()) {
if (worker_id == skip_worker_id) continue;
futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute)));
}
return futures;
}
private:
// TODO make Coordination const, it's member GetEndpoint must be const too.
Coordination &coordination_;
std::unordered_map<int, communication::rpc::ClientPool> client_pools_;
std::mutex lock_;
utils::ThreadPool thread_pool_;
};
/** Wrapper class around a RPC call to build indices.
*/
class IndexRpcClients {
public:
explicit IndexRpcClients(RpcWorkerClients &clients) : clients_(clients) {}
auto GetPopulateIndexFutures(const storage::Label &label,
const storage::Property &property,
tx::TransactionId transaction_id,
int worker_id) {
return clients_.ExecuteOnWorkers<bool>(
worker_id,
[label, property, transaction_id](
int worker_id, communication::rpc::ClientPool &client_pool) {
return static_cast<bool>(client_pool.Call<PopulateIndexRpc>(
label, property, transaction_id));
});
}
auto GetCreateIndexFutures(const storage::Label &label,
const storage::Property &property, int worker_id) {
return clients_.ExecuteOnWorkers<bool>(
worker_id,
[label, property](int worker_id,
communication::rpc::ClientPool &client_pool) {
return static_cast<bool>(
client_pool.Call<CreateIndexRpc>(label, property));
});
}
private:
RpcWorkerClients &clients_;
};
/** Wrapper class around a RPC call to share token between workers.
*/
class TokenSharingRpcClients {
public:
explicit TokenSharingRpcClients(RpcWorkerClients *clients)
: clients_(clients) {}
auto TransferToken(int worker_id) {
return clients_->ExecuteOnWorker<void>(
worker_id,
[](int worker_id, communication::rpc::ClientPool &client_pool) {
CHECK(client_pool.Call<TokenTransferRpc>())
<< "Unable to transfer token";
});
}
private:
RpcWorkerClients *clients_;
};
} // namespace distributed

View File

@ -2,7 +2,7 @@
#pragma once
#include "distributed/rpc_worker_clients.hpp"
#include "distributed/coordination.hpp"
#include "distributed/dgp/partitioner.hpp"
namespace communication::rpc {
@ -29,12 +29,10 @@ class TokenSharingRpcServer {
public:
TokenSharingRpcServer(database::DistributedGraphDb *db, int worker_id,
distributed::Coordination *coordination,
communication::rpc::Server *server,
distributed::TokenSharingRpcClients *clients)
communication::rpc::Server *server)
: worker_id_(worker_id),
coordination_(coordination),
server_(server),
clients_(clients),
dgp_(db) {
server_->Register<distributed::TokenTransferRpc>(
[this](const auto &req_reader, auto *res_builder) { token_ = true; });
@ -73,7 +71,17 @@ class TokenSharingRpcServer {
next_worker = workers[0];
}
clients_->TransferToken(next_worker);
// Try to transfer the token until successful.
while (true) {
try {
coordination_->GetClientPool(next_worker)->Call<TokenTransferRpc>();
break;
} catch (const communication::rpc::RpcFailedException &e) {
DLOG(WARNING) << "Unable to transfer token to worker "
<< next_worker;
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
}
}
});
}
@ -104,7 +112,6 @@ class TokenSharingRpcServer {
int worker_id_;
distributed::Coordination *coordination_;
communication::rpc::Server *server_;
distributed::TokenSharingRpcClients *clients_;
std::atomic<bool> started_{false};
std::atomic<bool> token_{false};

View File

@ -27,9 +27,7 @@ void RaiseIfRemoteError(UpdateResult result) {
UpdateResult UpdatesRpcClients::Update(int worker_id,
const database::StateDelta &delta) {
auto res = worker_clients_.GetClientPool(worker_id).Call<UpdateRpc>(delta);
CHECK(res) << "UpdateRpc failed on worker: " << worker_id;
return res->member;
return coordination_->GetClientPool(worker_id)->Call<UpdateRpc>(delta).member;
}
CreatedVertexInfo UpdatesRpcClients::CreateVertex(
@ -37,12 +35,11 @@ CreatedVertexInfo UpdatesRpcClients::CreateVertex(
const std::vector<storage::Label> &labels,
const std::unordered_map<storage::Property, PropertyValue> &properties,
std::experimental::optional<int64_t> cypher_id) {
auto res = worker_clients_.GetClientPool(worker_id).Call<CreateVertexRpc>(
auto res = coordination_->GetClientPool(worker_id)->Call<CreateVertexRpc>(
CreateVertexReqData{tx_id, labels, properties, cypher_id});
CHECK(res) << "CreateVertexRpc failed on worker: " << worker_id;
CHECK(res->member.result == UpdateResult::DONE)
CHECK(res.member.result == UpdateResult::DONE)
<< "Remote Vertex creation result not UpdateResult::DONE";
return CreatedVertexInfo(res->member.cypher_id, res->member.gid);
return CreatedVertexInfo(res.member.cypher_id, res.member.gid);
}
CreatedEdgeInfo UpdatesRpcClients::CreateEdge(
@ -52,13 +49,12 @@ CreatedEdgeInfo UpdatesRpcClients::CreateEdge(
CHECK(from.address().is_remote()) << "In CreateEdge `from` must be remote";
int from_worker = from.address().worker_id();
auto res =
worker_clients_.GetClientPool(from_worker)
.Call<CreateEdgeRpc>(CreateEdgeReqData{from.gid(), to.GlobalAddress(),
edge_type, tx_id, cypher_id});
CHECK(res) << "CreateEdge RPC failed on worker: " << from_worker;
RaiseIfRemoteError(res->member.result);
return CreatedEdgeInfo(res->member.cypher_id,
storage::EdgeAddress{res->member.gid, from_worker});
coordination_->GetClientPool(from_worker)
->Call<CreateEdgeRpc>(CreateEdgeReqData{
from.gid(), to.GlobalAddress(), edge_type, tx_id, cypher_id});
RaiseIfRemoteError(res.member.result);
return CreatedEdgeInfo(res.member.cypher_id,
storage::EdgeAddress{res.member.gid, from_worker});
}
void UpdatesRpcClients::AddInEdge(tx::TransactionId tx_id, VertexAccessor &from,
@ -70,47 +66,42 @@ void UpdatesRpcClients::AddInEdge(tx::TransactionId tx_id, VertexAccessor &from,
<< "AddInEdge should only be called when `to` is remote and "
"`from` is not on the same worker as `to`.";
auto worker_id = to.GlobalAddress().worker_id();
auto res = worker_clients_.GetClientPool(worker_id).Call<AddInEdgeRpc>(
auto res = coordination_->GetClientPool(worker_id)->Call<AddInEdgeRpc>(
AddInEdgeReqData{from.GlobalAddress(), edge_address, to.gid(), edge_type,
tx_id});
CHECK(res) << "AddInEdge RPC failed on worker: " << worker_id;
RaiseIfRemoteError(res->member);
RaiseIfRemoteError(res.member);
}
void UpdatesRpcClients::RemoveVertex(int worker_id, tx::TransactionId tx_id,
gid::Gid gid, bool check_empty) {
auto res = worker_clients_.GetClientPool(worker_id).Call<RemoveVertexRpc>(
auto res = coordination_->GetClientPool(worker_id)->Call<RemoveVertexRpc>(
RemoveVertexReqData{gid, tx_id, check_empty});
CHECK(res) << "RemoveVertex RPC failed on worker: " << worker_id;
RaiseIfRemoteError(res->member);
RaiseIfRemoteError(res.member);
}
void UpdatesRpcClients::RemoveEdge(tx::TransactionId tx_id, int worker_id,
gid::Gid edge_gid, gid::Gid vertex_from_id,
storage::VertexAddress vertex_to_addr) {
auto res = worker_clients_.GetClientPool(worker_id).Call<RemoveEdgeRpc>(
auto res = coordination_->GetClientPool(worker_id)->Call<RemoveEdgeRpc>(
RemoveEdgeData{tx_id, edge_gid, vertex_from_id, vertex_to_addr});
CHECK(res) << "RemoveEdge RPC failed on worker: " << worker_id;
RaiseIfRemoteError(res->member);
RaiseIfRemoteError(res.member);
}
void UpdatesRpcClients::RemoveInEdge(tx::TransactionId tx_id, int worker_id,
gid::Gid vertex_id,
storage::EdgeAddress edge_address) {
CHECK(edge_address.is_remote()) << "RemoveInEdge edge_address is local.";
auto res = worker_clients_.GetClientPool(worker_id).Call<RemoveInEdgeRpc>(
auto res = coordination_->GetClientPool(worker_id)->Call<RemoveInEdgeRpc>(
RemoveInEdgeData{tx_id, vertex_id, edge_address});
CHECK(res) << "RemoveInEdge RPC failed on worker: " << worker_id;
RaiseIfRemoteError(res->member);
RaiseIfRemoteError(res.member);
}
std::vector<utils::Future<UpdateResult>> UpdatesRpcClients::UpdateApplyAll(
int skip_worker_id, tx::TransactionId tx_id) {
return worker_clients_.ExecuteOnWorkers<UpdateResult>(
return coordination_->ExecuteOnWorkers<UpdateResult>(
skip_worker_id, [tx_id](int worker_id, auto &client) {
auto res = client.template Call<UpdateApplyRpc>(tx_id);
CHECK(res) << "UpdateApplyRpc failed";
return res->member;
return res.member;
});
}

View File

@ -4,7 +4,7 @@
#include <vector>
#include "database/state_delta.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "distributed/coordination.hpp"
#include "distributed/updates_rpc_messages.hpp"
#include "query/typed_value.hpp"
#include "storage/address_types.hpp"
@ -20,8 +20,8 @@ namespace distributed {
/// apply the accumulated deferred updates, or discard them.
class UpdatesRpcClients {
public:
explicit UpdatesRpcClients(RpcWorkerClients &clients)
: worker_clients_(clients) {}
explicit UpdatesRpcClients(Coordination *coordination)
: coordination_(coordination) {}
/// Sends an update delta to the given worker.
UpdateResult Update(int worker_id, const database::StateDelta &delta);
@ -76,7 +76,7 @@ class UpdatesRpcClients {
int skip_worker_id, tx::TransactionId tx_id);
private:
RpcWorkerClients &worker_clients_;
Coordination *coordination_;
};
} // namespace distributed

View File

@ -42,6 +42,10 @@ bool Endpoint::operator==(const Endpoint &other) const {
}
std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint) {
if (endpoint.family() == 6) {
return os << "[" << endpoint.address() << "]"
<< ":" << endpoint.port();
}
return os << endpoint.address() << ":" << endpoint.port();
}

View File

@ -164,20 +164,40 @@ void MasterMain() {
service_name, FLAGS_num_workers);
// Handler for regular termination signals
auto shutdown = [&server] {
// Server needs to be shutdown first and then the database. This prevents a
// race condition when a transaction is accepted during server shutdown.
server.Shutdown();
auto shutdown = [&db] {
// We call the shutdown method on the worker database so that we exit
// cleanly.
db.Shutdown();
};
InitSignalHandlers(shutdown);
server.AwaitShutdown();
// The return code of `AwaitShutdown` is ignored because we want the database
// to exit cleanly no matter what.
db.AwaitShutdown([&server] {
// Server needs to be shutdown first and then the database. This prevents a
// race condition when a transaction is accepted during server shutdown.
server.Shutdown();
server.AwaitShutdown();
});
}
void WorkerMain() {
google::SetUsageMessage("Memgraph distributed worker");
database::Worker db;
db.WaitForShutdown();
// Handler for regular termination signals
auto shutdown = [&db] {
// We call the shutdown method on the worker database so that we exit
// cleanly.
db.Shutdown();
};
InitSignalHandlers(shutdown);
// The return code of `AwaitShutdown` is ignored because we want the database
// to exit cleanly no matter what.
db.AwaitShutdown();
}
int main(int argc, char **argv) {

View File

@ -114,7 +114,7 @@ void KafkaStreamWriter(
try {
(*session_data.interpreter)(query, *dba, params_pv, false).PullAll(stream);
dba->Commit();
} catch (const query::QueryException &e) {
} catch (const utils::BasicException &e) {
LOG(WARNING) << "[Kafka] query execution failed with an exception: "
<< e.what();
dba->Abort();

View File

@ -27,7 +27,13 @@ class DistributedLogicalPlan final : public LogicalPlan {
~DistributedLogicalPlan() {
for (const auto &plan_pair : plan_.worker_plans) {
const auto &plan_id = plan_pair.first;
plan_dispatcher_->RemovePlan(plan_id);
try {
plan_dispatcher_->RemovePlan(plan_id);
} catch (const communication::rpc::RpcFailedException &) {
// We ignore RPC exceptions here because the other side can be possibly
// shutting down. TODO: If that is not the case then something is really
// wrong with the cluster!
}
}
}

View File

@ -3207,6 +3207,8 @@ class CartesianCursor : public Cursor {
restore_frame(self_.right_symbols(), right_op_frame_);
}
if (context.db_accessor_.should_abort()) throw HintedAbortError();
restore_frame(self_.left_symbols(), *left_op_frames_it_);
left_op_frames_it_++;
return true;

View File

@ -56,8 +56,11 @@ void StatsDispatchMain(const io::network::Endpoint &endpoint) {
size_t sent = 0, total = 0;
auto flush_batch = [&] {
if (client.Call<BatchStatsRpc>(batch_request)) {
try {
client.Call<BatchStatsRpc>(batch_request);
sent += batch_request.requests.size();
} catch (const communication::rpc::RpcFailedException &) {
DLOG(WARNING) << "BatchStatsRpc failed!";
}
total += batch_request.requests.size();
batch_request.requests.clear();

View File

@ -10,16 +10,12 @@ namespace storage {
template <> \
type WorkerConcurrentIdMapper<type>::RpcValueToId( \
const std::string &value) { \
auto response = master_client_pool_.Call<type##IdRpc>(value); \
CHECK(response) << (#type "IdRpc failed"); \
return response->member; \
return master_client_pool_.Call<type##IdRpc>(value).member; \
} \
\
template <> \
std::string WorkerConcurrentIdMapper<type>::RpcIdToValue(type id) { \
auto response = master_client_pool_.Call<Id##type##Rpc>(id); \
CHECK(response) << ("Id" #type "Rpc failed"); \
return response->member; \
return master_client_pool_.Call<Id##type##Rpc>(id).member; \
}
using namespace storage;

View File

@ -3,8 +3,9 @@
#pragma once
#include "communication/rpc/server.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "distributed/coordination.hpp"
#include "transactions/engine.hpp"
#include "utils/exceptions.hpp"
namespace tx {
@ -22,13 +23,17 @@ class EngineDistributed : public Engine {
void StartTransactionalCacheCleanup() {
cache_clearing_scheduler_.Run("TX cache GC", kCacheReleasePeriod, [this]() {
std::lock_guard<std::mutex> guard(lock_);
// TODO (mferencevic): this has to be aware that `GlobalGcSnapshot` can
// throw!
auto oldest_active = GlobalGcSnapshot().back();
// Call all registered functions for cleanup.
for (auto &f : functions_) f(oldest_active);
// Clean our cache.
ClearTransactionalCache(oldest_active);
try {
auto oldest_active = GlobalGcSnapshot().back();
// Call all registered functions for cleanup.
for (auto &f : functions_) f(oldest_active);
// Clean our cache.
ClearTransactionalCache(oldest_active);
} catch (const utils::BasicException &e) {
DLOG(WARNING)
<< "Couldn't perform transactional cache cleanup due to exception: "
<< e.what();
}
});
}

View File

@ -9,36 +9,36 @@
namespace tx {
EngineMaster::EngineMaster(communication::rpc::Server &server,
distributed::RpcWorkerClients &rpc_worker_clients,
EngineMaster::EngineMaster(communication::rpc::Server *server,
distributed::Coordination *coordination,
durability::WriteAheadLog *wal)
: engine_single_node_(wal),
rpc_server_(server),
rpc_worker_clients_(rpc_worker_clients) {
rpc_server_.Register<BeginRpc>(
server_(server),
coordination_(coordination) {
server_->Register<BeginRpc>(
[this](const auto &req_reader, auto *res_builder) {
auto tx = this->Begin();
BeginRes res(TxAndSnapshot{tx->id_, tx->snapshot()});
res.Save(res_builder);
});
rpc_server_.Register<AdvanceRpc>(
server_->Register<AdvanceRpc>(
[this](const auto &req_reader, auto *res_builder) {
AdvanceRes res(this->Advance(req_reader.getMember()));
res.Save(res_builder);
});
rpc_server_.Register<CommitRpc>(
server_->Register<CommitRpc>(
[this](const auto &req_reader, auto *res_builder) {
this->Commit(*this->RunningTransaction(req_reader.getMember()));
});
rpc_server_.Register<AbortRpc>(
server_->Register<AbortRpc>(
[this](const auto &req_reader, auto *res_builder) {
this->Abort(*this->RunningTransaction(req_reader.getMember()));
});
rpc_server_.Register<SnapshotRpc>(
server_->Register<SnapshotRpc>(
[this](const auto &req_reader, auto *res_builder) {
// It is guaranteed that the Worker will not be requesting this for a
// transaction that's done, and that there are no race conditions here.
@ -47,7 +47,7 @@ EngineMaster::EngineMaster(communication::rpc::Server &server,
res.Save(res_builder);
});
rpc_server_.Register<CommandRpc>(
server_->Register<CommandRpc>(
[this](const auto &req_reader, auto *res_builder) {
// It is guaranteed that the Worker will not be requesting this for a
// transaction that's done, and that there are no race conditions here.
@ -55,30 +55,30 @@ EngineMaster::EngineMaster(communication::rpc::Server &server,
res.Save(res_builder);
});
rpc_server_.Register<GcSnapshotRpc>(
server_->Register<GcSnapshotRpc>(
[this](const auto &req_reader, auto *res_builder) {
GcSnapshotRes res(this->GlobalGcSnapshot());
res.Save(res_builder);
});
rpc_server_.Register<ClogInfoRpc>(
server_->Register<ClogInfoRpc>(
[this](const auto &req_reader, auto *res_builder) {
ClogInfoRes res(this->Info(req_reader.getMember()));
res.Save(res_builder);
});
rpc_server_.Register<ActiveTransactionsRpc>(
server_->Register<ActiveTransactionsRpc>(
[this](const auto &req_reader, auto *res_builder) {
ActiveTransactionsRes res(this->GlobalActiveTransactions());
res.Save(res_builder);
});
rpc_server_.Register<EnsureNextIdGreaterRpc>(
server_->Register<EnsureNextIdGreaterRpc>(
[this](const auto &req_reader, auto *res_builder) {
this->EnsureNextIdGreater(req_reader.getMember());
});
rpc_server_.Register<GlobalLastRpc>(
server_->Register<GlobalLastRpc>(
[this](const auto &req_reader, auto *res_builder) {
GlobalLastRes res(this->GlobalLast());
res.Save(res_builder);
@ -97,18 +97,15 @@ CommandId EngineMaster::UpdateCommand(TransactionId id) {
void EngineMaster::Commit(const Transaction &t) {
auto tx_id = t.id_;
auto futures = rpc_worker_clients_.ExecuteOnWorkers<void>(
auto futures = coordination_->ExecuteOnWorkers<void>(
0, [tx_id](int worker_id, communication::rpc::ClientPool &client_pool) {
auto result = client_pool.Call<NotifyCommittedRpc>(tx_id);
CHECK(result)
<< "[NotifyCommittedRpc] failed to notify that transaction "
<< tx_id << " committed";
client_pool.Call<NotifyCommittedRpc>(tx_id);
});
// We need to wait for all workers to destroy pending futures to avoid
// using already destroyed (released) transaction objects.
for (auto &future : futures) {
future.wait();
future.get();
}
engine_single_node_.Commit(t);

View File

@ -3,7 +3,7 @@
#pragma once
#include "communication/rpc/server.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "distributed/coordination.hpp"
#include "transactions/distributed/engine_distributed.hpp"
#include "transactions/single_node/engine_single_node.hpp"
@ -14,12 +14,11 @@ namespace tx {
class EngineMaster final : public EngineDistributed {
public:
/// @param server - Required. Used for rpc::Server construction.
/// @param rpc_worker_clients - Required. Used for
/// OngoingProduceJoinerRpcClients construction.
/// @param coordination - Required. Used for communication with the workers.
/// @param wal - Optional. If present, the Engine will write tx
/// Begin/Commit/Abort atomically (while under lock).
EngineMaster(communication::rpc::Server &server,
distributed::RpcWorkerClients &rpc_worker_clients,
EngineMaster(communication::rpc::Server *server,
distributed::Coordination *coordination,
durability::WriteAheadLog *wal = nullptr);
EngineMaster(const EngineMaster &) = delete;
@ -47,7 +46,7 @@ class EngineMaster final : public EngineDistributed {
private:
EngineSingleNode engine_single_node_;
communication::rpc::Server &rpc_server_;
distributed::RpcWorkerClients &rpc_worker_clients_;
communication::rpc::Server *server_;
distributed::Coordination *coordination_;
};
} // namespace tx

View File

@ -8,8 +8,8 @@
namespace tx {
EngineWorker::EngineWorker(communication::rpc::Server &server,
communication::rpc::ClientPool &master_client_pool,
EngineWorker::EngineWorker(communication::rpc::Server *server,
communication::rpc::ClientPool *master_client_pool,
durability::WriteAheadLog *wal)
: server_(server), master_client_pool_(master_client_pool), wal_(wal) {
// Register our `NotifyCommittedRpc` server. This RPC should only write the
@ -27,7 +27,7 @@ EngineWorker::EngineWorker(communication::rpc::Server &server,
// RPC call could fail on other workers which will cause the transaction to be
// aborted. This mismatch in committed/aborted across workers is resolved by
// using the master as a single source of truth when doing recovery.
server_.Register<NotifyCommittedRpc>(
server_->Register<NotifyCommittedRpc>(
[this](const auto &req_reader, auto *res_builder) {
auto tid = req_reader.getMember();
if (wal_) {
@ -43,9 +43,8 @@ EngineWorker::~EngineWorker() {
}
Transaction *EngineWorker::Begin() {
auto res = master_client_pool_.Call<BeginRpc>();
CHECK(res) << "BeginRpc failed";
auto &data = res->member;
auto res = master_client_pool_->Call<BeginRpc>();
auto &data = res.member;
UpdateOldestActive(data.snapshot, data.tx_id);
Transaction *tx = CreateTransaction(data.tx_id, data.snapshot);
auto insertion = active_.access().insert(data.tx_id, tx);
@ -55,20 +54,18 @@ Transaction *EngineWorker::Begin() {
}
CommandId EngineWorker::Advance(TransactionId tx_id) {
auto res = master_client_pool_.Call<AdvanceRpc>(tx_id);
CHECK(res) << "AdvanceRpc failed";
auto res = master_client_pool_->Call<AdvanceRpc>(tx_id);
auto access = active_.access();
auto found = access.find(tx_id);
CHECK(found != access.end())
<< "Can't advance a transaction not in local cache";
SetCommand(found->second, res->member);
return res->member;
SetCommand(found->second, res.member);
return res.member;
}
CommandId EngineWorker::UpdateCommand(TransactionId tx_id) {
auto res = master_client_pool_.Call<CommandRpc>(tx_id);
CHECK(res) << "CommandRpc failed";
auto cmd_id = res->member;
auto res = master_client_pool_->Call<CommandRpc>(tx_id);
auto cmd_id = res.member;
// Assume there is no concurrent work being done on this worker in the given
// transaction. This assumption is sound because command advancing needs to be
@ -86,15 +83,13 @@ CommandId EngineWorker::UpdateCommand(TransactionId tx_id) {
}
void EngineWorker::Commit(const Transaction &t) {
auto res = master_client_pool_.Call<CommitRpc>(t.id_);
CHECK(res) << "CommitRpc failed";
master_client_pool_->Call<CommitRpc>(t.id_);
VLOG(11) << "[Tx] Commiting worker transaction " << t.id_;
ClearSingleTransaction(t.id_);
}
void EngineWorker::Abort(const Transaction &t) {
auto res = master_client_pool_.Call<AbortRpc>(t.id_);
CHECK(res) << "AbortRpc failed";
master_client_pool_->Call<AbortRpc>(t.id_);
VLOG(11) << "[Tx] Aborting worker transaction " << t.id_;
ClearSingleTransaction(t.id_);
}
@ -106,9 +101,8 @@ CommitLog::Info EngineWorker::Info(TransactionId tid) const {
if (!(info.is_aborted() || info.is_committed())) {
// @review: this version of Call is just used because Info has no
// default constructor.
auto res = master_client_pool_.Call<ClogInfoRpc>(tid);
CHECK(res) << "ClogInfoRpc failed";
info = res->member;
auto res = master_client_pool_->Call<ClogInfoRpc>(tid);
info = res.member;
if (!info.is_active()) {
if (info.is_committed()) clog_.set_committed(tid);
if (info.is_aborted()) clog_.set_aborted(tid);
@ -120,17 +114,15 @@ CommitLog::Info EngineWorker::Info(TransactionId tid) const {
}
Snapshot EngineWorker::GlobalGcSnapshot() {
auto res = master_client_pool_.Call<GcSnapshotRpc>();
CHECK(res) << "GcSnapshotRpc failed";
auto snapshot = std::move(res->member);
auto res = master_client_pool_->Call<GcSnapshotRpc>();
auto snapshot = std::move(res.member);
UpdateOldestActive(snapshot, local_last_.load());
return snapshot;
}
Snapshot EngineWorker::GlobalActiveTransactions() {
auto res = master_client_pool_.Call<ActiveTransactionsRpc>();
CHECK(res) << "ActiveTransactionsRpc failed";
auto snapshot = std::move(res->member);
auto res = master_client_pool_->Call<ActiveTransactionsRpc>();
auto snapshot = std::move(res.member);
UpdateOldestActive(snapshot, local_last_.load());
return snapshot;
}
@ -138,9 +130,7 @@ Snapshot EngineWorker::GlobalActiveTransactions() {
TransactionId EngineWorker::LocalLast() const { return local_last_; }
TransactionId EngineWorker::GlobalLast() const {
auto res = master_client_pool_.Call<GlobalLastRpc>();
CHECK(res) << "GlobalLastRpc failed";
return res->member;
return master_client_pool_->Call<GlobalLastRpc>().member;
}
void EngineWorker::LocalForEachActiveTransaction(
@ -155,9 +145,8 @@ Transaction *EngineWorker::RunningTransaction(TransactionId tx_id) {
auto found = accessor.find(tx_id);
if (found != accessor.end()) return found->second;
auto res = master_client_pool_.Call<SnapshotRpc>(tx_id);
CHECK(res) << "SnapshotRpc failed";
auto snapshot = std::move(res->member);
auto res = master_client_pool_->Call<SnapshotRpc>(tx_id);
auto snapshot = std::move(res.member);
UpdateOldestActive(snapshot, local_last_.load());
return RunningTransaction(tx_id, snapshot);
}
@ -208,7 +197,7 @@ void EngineWorker::UpdateOldestActive(const Snapshot &snapshot,
}
void EngineWorker::EnsureNextIdGreater(TransactionId tx_id) {
master_client_pool_.Call<EnsureNextIdGreaterRpc>(tx_id);
master_client_pool_->Call<EnsureNextIdGreaterRpc>(tx_id);
}
void EngineWorker::GarbageCollectCommitLog(TransactionId tx_id) {

View File

@ -18,8 +18,8 @@ namespace tx {
* begin/advance/end transactions on the master. */
class EngineWorker final : public EngineDistributed {
public:
EngineWorker(communication::rpc::Server &server,
communication::rpc::ClientPool &master_client_pool,
EngineWorker(communication::rpc::Server *server,
communication::rpc::ClientPool *master_client_pool,
durability::WriteAheadLog *wal = nullptr);
~EngineWorker();
@ -60,10 +60,10 @@ class EngineWorker final : public EngineDistributed {
mutable CommitLog clog_;
// Our local RPC server.
communication::rpc::Server &server_;
communication::rpc::Server *server_;
// Communication to the transactional master.
communication::rpc::ClientPool &master_client_pool_;
communication::rpc::ClientPool *master_client_pool_;
// Write ahead log.
durability::WriteAheadLog *wal_;

View File

@ -77,6 +77,11 @@ class Scheduler {
if (thread_.joinable()) thread_.join();
}
/**
* Returns whether the scheduler is running.
*/
bool IsRunning() { return is_working_; }
~Scheduler() { Stop(); }
private:

View File

@ -2,6 +2,7 @@
#include <sys/prctl.h>
#include <fmt/format.h>
#include <glog/logging.h>
namespace utils {
@ -12,19 +13,10 @@ void ThreadSetName(const std::string &name) {
<< "Couldn't set thread name: " << name << "!";
}
Thread::Thread(Thread &&other) {
DCHECK(thread_id == UNINITIALIZED) << "Thread was initialized before.";
thread_id = other.thread_id;
thread = std::move(other.thread);
}
void Thread::join() { return thread.join(); }
std::atomic<unsigned> Thread::thread_counter{1};
ThreadPool::ThreadPool(size_t threads) {
ThreadPool::ThreadPool(size_t threads, const std::string &name) {
for (size_t i = 0; i < threads; ++i)
workers_.emplace_back([this] {
workers_.emplace_back([this, name, i] {
ThreadSetName(fmt::format("{} {}", name, i + 1));
while (true) {
std::function<void()> task;
{

View File

@ -21,47 +21,12 @@ namespace utils {
/// Beware, the name length limit is 16 characters!
void ThreadSetName(const std::string &name);
class Thread {
static std::atomic<unsigned> thread_counter;
public:
static size_t count(std::memory_order order = std::memory_order_seq_cst) {
return thread_counter.load(order);
}
static constexpr unsigned UNINITIALIZED = -1;
static constexpr unsigned MAIN_THREAD = 0;
template <class F>
explicit Thread(F f) {
thread_id = thread_counter.fetch_add(1, std::memory_order_acq_rel);
thread = std::thread([this, f]() { start_thread(f); });
}
Thread() = default;
Thread(const Thread &) = delete;
Thread(Thread &&other);
void join();
private:
unsigned thread_id = UNINITIALIZED;
std::thread thread;
template <class F, class... Args>
void start_thread(F &&f) {
// this_thread::id = thread_id;
f();
}
};
/// A thread pool for asynchronous task execution. Supports tasks that produce
/// return values by returning `utils::Future` objects.
class ThreadPool final {
public:
/// Creates a thread pool with the given number of threads.
explicit ThreadPool(size_t threads);
ThreadPool(size_t threads, const std::string &name);
~ThreadPool();
ThreadPool(const ThreadPool &) = delete;

View File

@ -80,6 +80,10 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
// cleanup clients
for (int i = 0; i < Nc; ++i) clients[i].join();
// shutdown server
server.Shutdown();
server.AwaitShutdown();
}
int main(int argc, char **argv) {

View File

@ -33,6 +33,10 @@ TEST(Network, Server) {
// cleanup clients
for (int i = 0; i < N; ++i) clients[i].join();
// shutdown server
server.Shutdown();
server.AwaitShutdown();
}
int main(int argc, char **argv) {

View File

@ -41,6 +41,10 @@ TEST(Network, SessionLeak) {
for (int i = 0; i < N; ++i) clients[i].join();
std::this_thread::sleep_for(2s);
// shutdown server
server.Shutdown();
server.AwaitShutdown();
}
// run with "valgrind --leak-check=full ./network_session_leak" to check for

View File

@ -55,7 +55,8 @@ class MgCluster:
"--db-recover-on-startup",
"--query-vertex-count-to-expand-existing", "-1",
"--num-workers", str(WORKERS),
"--rpc-num-workers", str(WORKERS),
"--rpc-num-client-workers", str(WORKERS),
"--rpc-num-server-workers", str(WORKERS),
"--recovering-cluster-size", str(len(self._workers) + 1)
])
@ -74,7 +75,8 @@ class MgCluster:
"worker_" + str(i)),
"--db-recover-on-startup",
"--num-workers", str(WORKERS),
"--rpc-num-workers", str(WORKERS),
"--rpc-num-client-workers", str(WORKERS),
"--rpc-num-server-workers", str(WORKERS),
])
# sleep to allow the workers to startup

View File

@ -12,3 +12,6 @@ add_subdirectory(kafka)
# auth test binaries
add_subdirectory(auth)
# distributed test binaries
add_subdirectory(distributed)

View File

@ -42,3 +42,11 @@
- ../../../build_debug/memgraph # memgraph binary
- ../../../build_debug/tests/integration/auth/checker # checker binary
- ../../../build_debug/tests/integration/auth/tester # tester binary
- name: integration__distributed
cd: distributed
commands: TIMEOUT=480 ./runner.py
infiles:
- runner.py # runner script
- ../../../build_debug/memgraph # memgraph binary
- ../../../build_debug/tests/integration/distributed/tester # tester binary

View File

@ -0,0 +1,6 @@
set(target_name memgraph__integration__distributed)
set(tester_target_name ${target_name}__tester)
add_executable(${tester_target_name} tester.cpp)
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
target_link_libraries(${tester_target_name} mg-communication)

View File

@ -0,0 +1,152 @@
#!/usr/bin/python3
import argparse
import atexit
import json
import os
import subprocess
import tempfile
import time
import sys
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
workers = []
@atexit.register
def cleanup():
for worker in workers:
worker.kill()
worker.wait()
workers.clear()
def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
while subprocess.call(cmd) != 0:
time.sleep(0.01)
time.sleep(delay)
def generate_args(memgraph_binary, temporary_dir, worker_id):
args = [memgraph_binary]
if worker_id == 0:
args.append("--master")
else:
args.extend(["--worker", "--worker-id", str(worker_id)])
args.extend(["--master-host", "127.0.0.1", "--master-port", "10000"])
if worker_id != 0:
args.extend(["--worker-host", "127.0.0.1", "--worker-port",
str(10000 + worker_id)])
# All garbage collectors must be set to their lowest intervals to assure
# that they won't terminate the memgraph process when communication between
# the cluster fails.
args.extend(["--skiplist-gc-interval", "1", "--gc-cycle-sec", "1"])
# Each worker must have a unique durability directory.
args.extend(["--durability-directory",
os.path.join(temporary_dir, "worker" + str(worker_id))])
return args
def worker_id_to_name(worker_id):
if worker_id == 0:
return "master"
return "worker {}".format(worker_id)
def execute_test(memgraph_binary, tester_binary, cluster_size, disaster,
on_worker_id, execute_query):
args = {"cluster_size": cluster_size, "disaster": disaster,
"on_worker_id": on_worker_id, "execute_query": execute_query}
print("\033[1;36m~~ Executing test with arguments:",
json.dumps(args, sort_keys=True), "~~\033[0m")
# Get a temporary directory used for durability.
tempdir = tempfile.TemporaryDirectory()
# Start the cluster.
cleanup()
for worker_id in range(cluster_size):
workers.append(subprocess.Popen(
generate_args(memgraph_binary, tempdir.name, worker_id)))
time.sleep(0.2)
assert workers[worker_id].poll() is None, \
"The {} process died prematurely!".format(
worker_id_to_name(worker_id))
if worker_id == 0:
wait_for_server(10000)
# Wait for the cluster to startup.
wait_for_server(7687)
# Execute the query if required.
if execute_query:
time.sleep(1)
client = subprocess.Popen([tester_binary])
# Perform the disaster.
time.sleep(2)
if disaster == "terminate":
workers[on_worker_id].terminate()
else:
workers[on_worker_id].kill()
time.sleep(2)
# Array of exit codes.
codes = []
# Check what happened with query execution.
if execute_query:
try:
code = client.wait(timeout=30)
except subprocess.TimeoutExpired as e:
client.kill()
raise e
if code != 0:
print("The client process didn't exit cleanly!")
codes.append(code)
# Terminate the master and wait to see what happens with the cluster.
workers[0].terminate()
# Wait for all of the workers.
for worker_id in range(cluster_size):
code = workers[worker_id].wait(timeout=30)
if worker_id == on_worker_id and disaster == "kill":
if code == 0:
print("The", worker_id_to_name(worker_id),
"process should have died but it exited cleanly!")
codes.append(-1)
elif code != 0:
print("The", worker_id_to_name(worker_id),
"process didn't exit cleanly!")
codes.append(code)
assert not any(codes), "Something went wrong!"
if __name__ == "__main__":
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
if not os.path.exists(memgraph_binary):
memgraph_binary = os.path.join(PROJECT_DIR, "build_debug", "memgraph")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
"integration", "distributed", "tester")
if not os.path.exists(tester_binary):
tester_binary = os.path.join(PROJECT_DIR, "build_debug", "tests",
"integration", "distributed", "tester")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--tester", default=tester_binary)
args = parser.parse_args()
for cluster_size in [3, 5]:
for worker_id in [0, 1]:
for disaster in ["terminate", "kill"]:
for execute_query in [False, True]:
execute_test(args.memgraph, args.tester, cluster_size,
disaster, worker_id, execute_query)
print("\033[1;32m~~ The test finished successfully ~~\033[0m")
sys.exit(0)

View File

@ -0,0 +1,49 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/bolt/client.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/utils.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
/**
* This test creates a sample dataset in the database and then executes a query
* that has a long execution time so that we can see what happens if the cluster
* dies mid-execution.
*/
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
<< FLAGS_port;
}
client.Execute("UNWIND range(0, 10000) AS x CREATE ()", {});
try {
client.Execute("MATCH (a), (b), (c), (d), (e), (f) RETURN COUNT(*)", {});
LOG(FATAL)
<< "The long query shouldn't have finished successfully, but it did!";
} catch (const communication::bolt::ClientQueryException &e) {
LOG(WARNING) << e.what();
} catch (const communication::bolt::ClientFatalException &) {
LOG(WARNING) << "The server closed the connection to us!";
}
return 0;
}

View File

@ -73,5 +73,9 @@ int main(int argc, char **argv) {
message.size()) == message)
<< "Received message isn't equal to sent message!";
// Shutdown the server.
server.Shutdown();
server.AwaitShutdown();
return 0;
}

View File

@ -34,10 +34,10 @@ add_manual_test(card_fraud_generate_snapshot.cpp)
target_link_libraries(${test_prefix}card_fraud_generate_snapshot memgraph_lib kvstore_dummy_lib)
add_manual_test(card_fraud_local.cpp)
target_link_libraries(${test_prefix}card_fraud_local memgraph_lib kvstore_dummy_lib)
target_link_libraries(${test_prefix}card_fraud_local memgraph_lib kvstore_dummy_lib gtest)
add_manual_test(distributed_repl.cpp)
target_link_libraries(${test_prefix}distributed_repl memgraph_lib kvstore_dummy_lib)
target_link_libraries(${test_prefix}distributed_repl memgraph_lib kvstore_dummy_lib gtest)
add_manual_test(endinan.cpp)

View File

@ -13,7 +13,7 @@ DEFINE_int32(tx_per_thread, 1000, "Number of transactions each thread creates");
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
Cluster cluster(5);
Cluster cluster(5, "card_fraud_local");
cluster.Execute("CREATE INDEX ON :Card(id)");
cluster.Execute("CREATE INDEX ON :Transaction(id)");

View File

@ -1,19 +1,29 @@
#pragma once
#include <chrono>
#include <experimental/filesystem>
#include <vector>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "communication/result_stream_faker.hpp"
#include "database/distributed_graph_db.hpp"
#include "database/graph_db_accessor.hpp"
#include "glue/communication.hpp"
#include "query/distributed_interpreter.hpp"
#include "query/typed_value.hpp"
#include "utils/file.hpp"
DECLARE_string(durability_directory);
namespace fs = std::experimental::filesystem;
class WorkerInThread {
public:
explicit WorkerInThread(database::Config config) : worker_(config) {
thread_ = std::thread([this, config] { worker_.WaitForShutdown(); });
thread_ =
std::thread([this, config] { EXPECT_TRUE(worker_.AwaitShutdown()); });
}
~WorkerInThread() {
@ -29,10 +39,17 @@ class Cluster {
const std::string kLocal = "127.0.0.1";
public:
Cluster(int worker_count) {
database::Config masterconfig;
masterconfig.master_endpoint = {kLocal, 0};
master_ = std::make_unique<database::Master>(masterconfig);
Cluster(int worker_count, const std::string &test_name) {
tmp_dir_ = fs::temp_directory_path() / "MG_test_unit_distributed_common_" /
test_name;
EXPECT_TRUE(utils::EnsureDir(tmp_dir_));
database::Config master_config;
master_config.master_endpoint = {kLocal, 0};
master_config.durability_directory = GetDurabilityDirectory(0);
// Flag needs to be updated due to props on disk storage.
FLAGS_durability_directory = GetDurabilityDirectory(0);
master_ = std::make_unique<database::Master>(master_config);
interpreter_ =
std::make_unique<query::DistributedInterpreter>(master_.get());
std::this_thread::sleep_for(kInitTime);
@ -41,11 +58,14 @@ class Cluster {
database::Config config;
config.worker_id = worker_id;
config.master_endpoint = master_->endpoint();
config.durability_directory = GetDurabilityDirectory(worker_id);
config.worker_endpoint = {kLocal, 0};
return config;
};
for (int i = 0; i < worker_count; ++i) {
// Flag needs to be updated due to props on disk storage.
FLAGS_durability_directory = GetDurabilityDirectory(i + 1);
workers_.emplace_back(
std::make_unique<WorkerInThread>(worker_config(i + 1)));
std::this_thread::sleep_for(kInitTime);
@ -54,9 +74,9 @@ class Cluster {
void Stop() {
interpreter_ = nullptr;
auto t = std::thread([this]() { master_ = nullptr; });
master_->Shutdown();
EXPECT_TRUE(master_->AwaitShutdown());
workers_.clear();
if (t.joinable()) t.join();
}
~Cluster() {
@ -72,10 +92,17 @@ class Cluster {
return result.GetResults();
};
fs::path GetDurabilityDirectory(int worker_id) {
if (worker_id == 0) return tmp_dir_ / "master";
return tmp_dir_ / fmt::format("worker{}", worker_id);
}
private:
std::unique_ptr<database::Master> master_;
std::vector<std::unique_ptr<WorkerInThread>> workers_;
std::unique_ptr<query::DistributedInterpreter> interpreter_;
fs::path tmp_dir_{fs::temp_directory_path() /
"MG_test_manual_distributed_common"};
};
void CheckResults(

View File

@ -1,27 +1,34 @@
#include <chrono>
#include <experimental/filesystem>
#include <iostream>
#include <memory>
#include <thread>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "database/distributed_graph_db.hpp"
#include "query/distributed_interpreter.hpp"
#include "query/repl.hpp"
#include "utils/file.hpp"
#include "utils/flag_validation.hpp"
DEFINE_VALIDATED_int32(worker_count, 1,
"The number of worker nodes in cluster.",
FLAG_IN_RANGE(1, 1000));
DECLARE_int32(min_log_level);
DECLARE_string(durability_directory);
namespace fs = std::experimental::filesystem;
const std::string kLocal = "127.0.0.1";
class WorkerInThread {
public:
explicit WorkerInThread(database::Config config) : worker_(config) {
thread_ = std::thread([this, config] { worker_.WaitForShutdown(); });
thread_ =
std::thread([this, config] { EXPECT_TRUE(worker_.AwaitShutdown()); });
}
~WorkerInThread() {
@ -32,14 +39,26 @@ class WorkerInThread {
std::thread thread_;
};
fs::path GetDurabilityDirectory(const fs::path &path, int worker_id) {
if (worker_id == 0) return path / "master";
return path / fmt::format("worker{}", worker_id);
}
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
FLAGS_min_log_level = google::ERROR;
google::InitGoogleLogging(argv[0]);
fs::path tmp_dir =
fs::temp_directory_path() / "MG_test_manual_distributed_repl";
EXPECT_TRUE(utils::EnsureDir(tmp_dir));
// Start the master
database::Config master_config;
master_config.master_endpoint = {kLocal, 0};
master_config.durability_directory = GetDurabilityDirectory(tmp_dir, 0);
// Flag needs to be updated due to props on disk storage.
FLAGS_durability_directory = GetDurabilityDirectory(tmp_dir, 0);
auto master = std::make_unique<database::Master>(master_config);
// Allow the master to get initialized before making workers.
std::this_thread::sleep_for(std::chrono::milliseconds(250));
@ -50,6 +69,9 @@ int main(int argc, char *argv[]) {
config.worker_id = i + 1;
config.master_endpoint = master->endpoint();
config.worker_endpoint = {kLocal, 0};
config.durability_directory = GetDurabilityDirectory(tmp_dir, i + 1);
// Flag needs to be updated due to props on disk storage.
FLAGS_durability_directory = GetDurabilityDirectory(tmp_dir, i + 1);
workers.emplace_back(std::make_unique<WorkerInThread>(config));
}

View File

@ -70,5 +70,8 @@ int main(int argc, char **argv) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
server.Shutdown();
server.AwaitShutdown();
return 0;
}

View File

@ -20,7 +20,8 @@ COMMON_FLAGS = ["--durability-enabled=false",
"--snapshot-on-exit=false",
"--db-recover-on-startup=false"]
DISTRIBUTED_FLAGS = ["--num-workers", str(6),
"--rpc-num-workers", str(6)]
"--rpc-num-client-workers", str(6),
"--rpc-num-server-workers", str(6)]
MASTER_FLAGS = ["--master",
"--master-port", "10000"]
MEMGRAPH_PORT = 7687

View File

@ -191,9 +191,6 @@ target_link_libraries(${test_prefix}queue memgraph_lib kvstore_dummy_lib)
add_unit_test(record_edge_vertex_accessor.cpp)
target_link_libraries(${test_prefix}record_edge_vertex_accessor memgraph_lib kvstore_dummy_lib)
add_unit_test(rpc_worker_clients.cpp)
target_link_libraries(${test_prefix}rpc_worker_clients memgraph_lib kvstore_dummy_lib)
add_unit_test(serialization.cpp)
target_link_libraries(${test_prefix}serialization memgraph_lib kvstore_dummy_lib)

View File

@ -26,6 +26,8 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test {
worker_mapper_.emplace(master_client_pool_.value());
}
void TearDown() override {
master_server_.Shutdown();
master_server_.AwaitShutdown();
worker_mapper_ = std::experimental::nullopt;
master_mapper_ = std::experimental::nullopt;
master_client_pool_ = std::experimental::nullopt;

View File

@ -24,4 +24,7 @@ TEST(CountersDistributed, All) {
EXPECT_EQ(w2.Get("b"), 1);
w1.Set("b", 42);
EXPECT_EQ(w2.Get("b"), 42);
master_server.Shutdown();
master_server.AwaitShutdown();
}

View File

@ -8,4 +8,6 @@ TEST(DatabaseMaster, Instantiate) {
config.master_endpoint = io::network::Endpoint("127.0.0.1", 0);
config.worker_id = 0;
database::Master master(config);
master.Shutdown();
EXPECT_TRUE(master.AwaitShutdown());
}

View File

@ -20,7 +20,8 @@ namespace fs = std::experimental::filesystem;
class WorkerInThread {
public:
explicit WorkerInThread(database::Config config) : worker_(config) {
thread_ = std::thread([this, config] { worker_.WaitForShutdown(); });
thread_ =
std::thread([this, config] { EXPECT_TRUE(worker_.AwaitShutdown()); });
}
~WorkerInThread() {
@ -84,11 +85,11 @@ class DistributedGraphDbTest : public ::testing::Test {
}
void ShutDown() {
// Kill master first because it will expect a shutdown response from the
// workers.
auto t = std::thread([this]() { master_ = nullptr; });
// Shutdown the master. It will send a shutdown signal to the workers.
master_->Shutdown();
EXPECT_TRUE(master_->AwaitShutdown());
// Wait for all workers to finish shutting down.
workers_.clear();
if (t.joinable()) t.join();
}
fs::path GetDurabilityDirectory(int worker_id) {
@ -213,9 +214,9 @@ class Cluster {
Cluster &operator=(Cluster &&) = delete;
~Cluster() {
auto t = std::thread([this] { master_ = nullptr; });
master_->Shutdown();
EXPECT_TRUE(master_->AwaitShutdown());
workers_.clear();
if (t.joinable()) t.join();
if (fs::exists(tmp_dir_)) fs::remove_all(tmp_dir_);
}

View File

@ -14,7 +14,6 @@
#include "distributed/cluster_discovery_worker.hpp"
#include "distributed/coordination_master.hpp"
#include "distributed/coordination_worker.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "io/network/endpoint.hpp"
#include "utils/file.hpp"
@ -28,13 +27,16 @@ const std::string kLocal = "127.0.0.1";
class WorkerCoordinationInThread {
struct Worker {
Worker(Endpoint master_endpoint) : master_endpoint(master_endpoint) {}
Worker(Endpoint master_endpoint, int worker_id)
: master_endpoint(master_endpoint),
coord(&server, master_endpoint, worker_id),
worker_id(worker_id) {}
Endpoint master_endpoint;
Server server{{kLocal, 0}};
WorkerCoordination coord{server, master_endpoint};
WorkerCoordination coord;
ClientPool client_pool{master_endpoint};
ClusterDiscoveryWorker discovery{server, coord, client_pool};
std::atomic<int> worker_id_{0};
std::atomic<int> worker_id;
};
public:
@ -44,18 +46,23 @@ class WorkerCoordinationInThread {
std::atomic<bool> init_done{false};
worker_thread_ = std::thread(
[this, master_endpoint, durability_directory, desired_id, &init_done] {
worker.emplace(master_endpoint);
worker.emplace(master_endpoint, desired_id);
worker->discovery.RegisterWorker(desired_id, durability_directory);
worker->worker_id_ = desired_id;
init_done = true;
worker->coord.WaitForShutdown();
// We don't shutdown the worker coordination here because it will be
// shutdown by the master. We only wait for the shutdown to be
// finished.
EXPECT_TRUE(worker->coord.AwaitShutdown());
// Shutdown the RPC server.
worker->server.Shutdown();
worker->server.AwaitShutdown();
worker = std::experimental::nullopt;
});
while (!init_done) std::this_thread::sleep_for(10ms);
}
int worker_id() const { return worker->worker_id_; }
int worker_id() const { return worker->worker_id; }
auto endpoint() const { return worker->server.endpoint(); }
auto worker_endpoint(int worker_id) {
return worker->coord.GetEndpoint(worker_id);
@ -90,132 +97,146 @@ class Distributed : public ::testing::Test {
TEST_F(Distributed, Coordination) {
Server master_server({kLocal, 0});
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
{
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
RpcWorkerClients rpc_worker_clients(master_coord);
ClusterDiscoveryMaster master_discovery_(
master_server, master_coord, rpc_worker_clients, tmp_dir("master"));
for (int i = 1; i <= kWorkerCount; ++i)
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir(fmt::format("worker{}", i)), i));
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_server, &master_coord,
tmp_dir("master"));
// Expect that all workers have a different ID.
std::unordered_set<int> worker_ids;
for (const auto &w : workers) worker_ids.insert(w->worker_id());
ASSERT_EQ(worker_ids.size(), kWorkerCount);
for (int i = 1; i <= kWorkerCount; ++i)
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir(fmt::format("worker{}", i)), i));
// Check endpoints.
for (auto &w1 : workers) {
for (auto &w2 : workers) {
EXPECT_EQ(w1->worker_endpoint(w2->worker_id()), w2->endpoint());
}
// Expect that all workers have a different ID.
std::unordered_set<int> worker_ids;
for (const auto &w : workers) worker_ids.insert(w->worker_id());
ASSERT_EQ(worker_ids.size(), kWorkerCount);
// Check endpoints.
for (auto &w1 : workers) {
for (auto &w2 : workers) {
EXPECT_EQ(w1->worker_endpoint(w2->worker_id()), w2->endpoint());
}
}
// Coordinated shutdown.
master_coord.Shutdown();
EXPECT_TRUE(master_coord.AwaitShutdown());
for (auto &worker : workers) worker->join();
master_server.Shutdown();
master_server.AwaitShutdown();
}
TEST_F(Distributed, DesiredAndUniqueId) {
Server master_server({kLocal, 0});
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
{
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
RpcWorkerClients rpc_worker_clients(master_coord);
ClusterDiscoveryMaster master_discovery_(
master_server, master_coord, rpc_worker_clients, tmp_dir("master"));
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir("worker42"), 42));
EXPECT_EQ(workers[0]->worker_id(), 42);
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_server, &master_coord,
tmp_dir("master"));
EXPECT_DEATH(
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir("worker42"), 42)),
"");
}
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir("worker42"), 42));
EXPECT_EQ(workers[0]->worker_id(), 42);
EXPECT_DEATH(
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir("worker42"), 42)),
"");
// Coordinated shutdown.
master_coord.Shutdown();
EXPECT_TRUE(master_coord.AwaitShutdown());
for (auto &worker : workers) worker->join();
master_server.Shutdown();
master_server.AwaitShutdown();
}
TEST_F(Distributed, CoordinationWorkersId) {
Server master_server({kLocal, 0});
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
{
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
RpcWorkerClients rpc_worker_clients(master_coord);
ClusterDiscoveryMaster master_discovery_(
master_server, master_coord, rpc_worker_clients, tmp_dir("master"));
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir("worker42"), 42));
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir("worker43"), 43));
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_server, &master_coord,
tmp_dir("master"));
std::vector<int> ids;
ids.push_back(0);
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir("worker42"), 42));
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir("worker43"), 43));
for (auto &worker : workers) ids.push_back(worker->worker_id());
EXPECT_THAT(master_coord.GetWorkerIds(),
testing::UnorderedElementsAreArray(ids));
}
std::vector<int> ids;
ids.push_back(0);
for (auto &worker : workers) ids.push_back(worker->worker_id());
EXPECT_THAT(master_coord.GetWorkerIds(),
testing::UnorderedElementsAreArray(ids));
// Coordinated shutdown.
master_coord.Shutdown();
EXPECT_TRUE(master_coord.AwaitShutdown());
for (auto &worker : workers) worker->join();
master_server.Shutdown();
master_server.AwaitShutdown();
}
TEST_F(Distributed, ClusterDiscovery) {
Server master_server({kLocal, 0});
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
{
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
RpcWorkerClients rpc_worker_clients(master_coord);
ClusterDiscoveryMaster master_discovery_(
master_server, master_coord, rpc_worker_clients, tmp_dir("master"));
std::vector<int> ids;
int worker_count = 10;
ids.push_back(0);
for (int i = 1; i <= worker_count; ++i) {
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir(fmt::format("worker", i)), i));
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_server, &master_coord,
tmp_dir("master"));
std::vector<int> ids;
int worker_count = 10;
ids.push_back(i);
}
ids.push_back(0);
for (int i = 1; i <= worker_count; ++i) {
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir(fmt::format("worker", i)), i));
EXPECT_THAT(master_coord.GetWorkerIds(),
testing::UnorderedElementsAreArray(ids));
for (auto &worker : workers) {
EXPECT_THAT(worker->worker_ids(),
testing::UnorderedElementsAreArray(ids));
}
ids.push_back(i);
}
EXPECT_THAT(master_coord.GetWorkerIds(),
testing::UnorderedElementsAreArray(ids));
for (auto &worker : workers) {
EXPECT_THAT(worker->worker_ids(), testing::UnorderedElementsAreArray(ids));
}
// Coordinated shutdown.
master_coord.Shutdown();
EXPECT_TRUE(master_coord.AwaitShutdown());
for (auto &worker : workers) worker->join();
master_server.Shutdown();
master_server.AwaitShutdown();
}
TEST_F(Distributed, KeepsTrackOfRecovered) {
Server master_server({kLocal, 0});
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
{
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
RpcWorkerClients rpc_worker_clients(master_coord);
ClusterDiscoveryMaster master_discovery_(
master_server, master_coord, rpc_worker_clients, tmp_dir("master"));
int worker_count = 10;
for (int i = 1; i <= worker_count; ++i) {
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir(fmt::format("worker{}", i)), i));
workers.back()->NotifyWorkerRecovered();
EXPECT_THAT(master_coord.CountRecoveredWorkers(), i);
}
MasterCoordination master_coord(master_server.endpoint());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_server, &master_coord,
tmp_dir("master"));
int worker_count = 10;
for (int i = 1; i <= worker_count; ++i) {
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_server.endpoint(), tmp_dir(fmt::format("worker{}", i)), i));
workers.back()->NotifyWorkerRecovered();
EXPECT_THAT(master_coord.CountRecoveredWorkers(), i);
}
// Coordinated shutdown.
master_coord.Shutdown();
EXPECT_TRUE(master_coord.AwaitShutdown());
for (auto &worker : workers) worker->join();
master_server.Shutdown();
master_server.AwaitShutdown();
}
int main(int argc, char **argv) {

View File

@ -105,9 +105,9 @@ TEST_F(DistributedDynamicWorker, IndexExistsOnNewWorker) {
EXPECT_TRUE(dba->LabelPropertyIndexExists(label, property));
}
auto t = std::thread([&]() { master = nullptr; });
master->Shutdown();
EXPECT_TRUE(master->AwaitShutdown());
worker1 = nullptr;
if (t.joinable()) t.join();
}
TEST_F(DistributedDynamicWorker, IndexExistsOnNewWorkerAfterRecovery) {
@ -152,7 +152,9 @@ TEST_F(DistributedDynamicWorker, IndexExistsOnNewWorkerAfterRecovery) {
dba->BuildIndex(label, property);
EXPECT_TRUE(dba->LabelPropertyIndexExists(label, property));
}
master = nullptr;
master->Shutdown();
EXPECT_TRUE(master->AwaitShutdown());
}
{
@ -182,8 +184,8 @@ TEST_F(DistributedDynamicWorker, IndexExistsOnNewWorkerAfterRecovery) {
EXPECT_TRUE(dba->LabelPropertyIndexExists(label, property));
}
auto t = std::thread([&]() { master = nullptr; });
master->Shutdown();
EXPECT_TRUE(master->AwaitShutdown());
worker1 = nullptr;
if (t.joinable()) t.join();
}
}

View File

@ -813,6 +813,8 @@ TEST_F(Durability, WorkerIdRecovery) {
auto dba = recovered.Access();
EXPECT_NE(dba->VerticesCount(), 0);
EXPECT_NE(dba->EdgesCount(), 0);
recovered.Shutdown();
EXPECT_TRUE(recovered.AwaitShutdown());
}
// WorkerIds are not equal and recovery should fail
@ -826,7 +828,12 @@ TEST_F(Durability, WorkerIdRecovery) {
auto dba = recovered.Access();
EXPECT_EQ(dba->VerticesCount(), 0);
EXPECT_EQ(dba->EdgesCount(), 0);
recovered.Shutdown();
EXPECT_TRUE(recovered.AwaitShutdown());
}
db.Shutdown();
EXPECT_TRUE(db.AwaitShutdown());
}
TEST_F(Durability, SequentialRecovery) {

View File

@ -80,6 +80,10 @@ TEST(NetworkTimeouts, InactiveSession) {
// After this sleep the session should have timed out.
std::this_thread::sleep_for(3500ms);
ASSERT_FALSE(QueryServer(client, safe_query));
// Shutdown the server.
server.Shutdown();
server.AwaitShutdown();
}
TEST(NetworkTimeouts, ActiveSession) {
@ -108,6 +112,9 @@ TEST(NetworkTimeouts, ActiveSession) {
std::this_thread::sleep_for(3500ms);
ASSERT_FALSE(QueryServer(client, safe_query));
// Shutdown the server.
server.Shutdown();
server.AwaitShutdown();
}
int main(int argc, char **argv) {

View File

@ -97,8 +97,10 @@ TEST(Rpc, Call) {
Client client(server.endpoint());
auto sum = client.Call<Sum>(10, 20);
ASSERT_TRUE(sum);
EXPECT_EQ(sum->sum, 30);
EXPECT_EQ(sum.sum, 30);
server.Shutdown();
server.AwaitShutdown();
}
TEST(Rpc, Abort) {
@ -121,11 +123,14 @@ TEST(Rpc, Abort) {
});
utils::Timer timer;
auto sum = client.Call<Sum>(10, 20);
EXPECT_FALSE(sum);
EXPECT_THROW(client.Call<Sum>(10, 20),
communication::rpc::RpcFailedException);
EXPECT_LT(timer.Elapsed(), 200ms);
thread.join();
server.Shutdown();
server.AwaitShutdown();
}
TEST(Rpc, ClientPool) {
@ -145,8 +150,7 @@ TEST(Rpc, ClientPool) {
* client */
auto get_sum_client = [&client](int x, int y) {
auto sum = client.Call<Sum>(x, y);
ASSERT_TRUE(sum);
EXPECT_EQ(sum->sum, x + y);
EXPECT_EQ(sum.sum, x + y);
};
utils::Timer t1;
@ -167,8 +171,7 @@ TEST(Rpc, ClientPool) {
* parallel */
auto get_sum = [&pool](int x, int y) {
auto sum = pool.Call<Sum>(x, y);
ASSERT_TRUE(sum);
EXPECT_EQ(sum->sum, x + y);
EXPECT_EQ(sum.sum, x + y);
};
utils::Timer t2;
@ -179,6 +182,9 @@ TEST(Rpc, ClientPool) {
threads[i].join();
}
EXPECT_LE(t2.Elapsed(), 200ms);
server.Shutdown();
server.AwaitShutdown();
}
TEST(Rpc, LargeMessage) {
@ -194,6 +200,8 @@ TEST(Rpc, LargeMessage) {
Client client(server.endpoint());
auto echo = client.Call<Echo>(testdata);
ASSERT_TRUE(echo);
EXPECT_EQ(echo->data, testdata);
EXPECT_EQ(echo.data, testdata);
server.Shutdown();
server.AwaitShutdown();
}

View File

@ -1,158 +0,0 @@
#include <experimental/filesystem>
#include <mutex>
#include "capnp/serialize.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "communication/rpc/messages.hpp"
#include "communication/rpc/server.hpp"
#include "distributed/cluster_discovery_master.hpp"
#include "distributed/cluster_discovery_worker.hpp"
#include "distributed/coordination_master.hpp"
#include "distributed/coordination_worker.hpp"
#include "distributed/rpc_worker_clients.hpp"
#include "io/network/endpoint.hpp"
#include "utils/file.hpp"
namespace fs = std::experimental::filesystem;
using namespace std::literals::chrono_literals;
namespace distributed {
struct IncrementCounterReq {
using Capnp = ::capnp::AnyPointer;
static const communication::rpc::MessageType TypeInfo;
void Save(::capnp::AnyPointer::Builder *) const {}
void Load(const ::capnp::AnyPointer::Reader &) {}
};
const communication::rpc::MessageType IncrementCounterReq::TypeInfo{
0, "IncrementCounterReq"};
struct IncrementCounterRes {
using Capnp = ::capnp::AnyPointer;
static const communication::rpc::MessageType TypeInfo;
void Save(::capnp::AnyPointer::Builder *) const {}
void Load(const ::capnp::AnyPointer::Reader &) {}
};
const communication::rpc::MessageType IncrementCounterRes::TypeInfo{
1, "IncrementCounterRes"};
using IncrementCounterRpc =
communication::rpc::RequestResponse<IncrementCounterReq,
IncrementCounterRes>;
}; // namespace distributed
class RpcWorkerClientsTest : public ::testing::Test {
protected:
const io::network::Endpoint kLocalHost{"127.0.0.1", 0};
const int kWorkerCount = 2;
void SetUp() override {
ASSERT_TRUE(utils::EnsureDir(tmp_dir_));
master_coord_->SetRecoveredSnapshot(std::experimental::nullopt);
for (int i = 1; i <= kWorkerCount; ++i) {
workers_server_.emplace_back(
std::make_unique<communication::rpc::Server>(kLocalHost));
workers_coord_.emplace_back(
std::make_unique<distributed::WorkerCoordination>(
*workers_server_.back(), master_server_.endpoint()));
cluster_discovery_.emplace_back(
std::make_unique<distributed::ClusterDiscoveryWorker>(
*workers_server_.back(), *workers_coord_.back(),
rpc_workers_.GetClientPool(0)));
cluster_discovery_.back()->RegisterWorker(
i, tmp_dir(fmt::format("worker{}", i)));
workers_server_.back()->Register<distributed::IncrementCounterRpc>(
[this, i](const auto &req_reader, auto *res_builder) {
std::unique_lock<std::mutex> lock(mutex_);
workers_cnt_[i]++;
});
}
}
void TearDown() override {
std::vector<std::thread> wait_on_shutdown;
for (int i = 0; i < workers_coord_.size(); ++i) {
wait_on_shutdown.emplace_back([i, this]() {
workers_coord_[i]->WaitForShutdown();
workers_server_[i] = nullptr;
});
}
std::this_thread::sleep_for(300ms);
// Starts server shutdown and notifies the workers
master_coord_ = std::experimental::nullopt;
for (auto &worker : wait_on_shutdown) worker.join();
// Cleanup temporary directory
if (fs::exists(tmp_dir_)) fs::remove_all(tmp_dir_);
}
const fs::path tmp_dir(const fs::path &path) const { return tmp_dir_ / path; }
fs::path tmp_dir_{fs::temp_directory_path() /
"MG_test_unit_rpc_worker_clients"};
std::vector<std::unique_ptr<communication::rpc::Server>> workers_server_;
std::vector<std::unique_ptr<distributed::WorkerCoordination>> workers_coord_;
std::vector<std::unique_ptr<distributed::ClusterDiscoveryWorker>>
cluster_discovery_;
std::mutex mutex_;
std::unordered_map<int, int> workers_cnt_;
communication::rpc::Server master_server_{kLocalHost};
std::experimental::optional<distributed::MasterCoordination> master_coord_{
master_server_.endpoint()};
distributed::RpcWorkerClients rpc_workers_{*master_coord_};
distributed::ClusterDiscoveryMaster cluster_disocvery_{
master_server_, *master_coord_, rpc_workers_, tmp_dir("master")};
};
TEST_F(RpcWorkerClientsTest, GetWorkerIds) {
EXPECT_THAT(rpc_workers_.GetWorkerIds(), testing::UnorderedElementsAreArray(
master_coord_->GetWorkerIds()));
}
TEST_F(RpcWorkerClientsTest, GetClientPool) {
auto &pool1 = rpc_workers_.GetClientPool(1);
auto &pool2 = rpc_workers_.GetClientPool(2);
EXPECT_NE(&pool1, &pool2);
EXPECT_EQ(&pool1, &rpc_workers_.GetClientPool(1));
}
TEST_F(RpcWorkerClientsTest, ExecuteOnWorker) {
auto execute = [](int worker_id, auto &client) -> void {
ASSERT_TRUE(client.template Call<distributed::IncrementCounterRpc>());
};
rpc_workers_.ExecuteOnWorker<void>(1, execute).get();
EXPECT_EQ(workers_cnt_[0], 0);
EXPECT_EQ(workers_cnt_[1], 1);
EXPECT_EQ(workers_cnt_[2], 0);
}
TEST_F(RpcWorkerClientsTest, ExecuteOnWorkers) {
auto execute = [](int worker_id, auto &client) -> void {
ASSERT_TRUE(client.template Call<distributed::IncrementCounterRpc>());
};
// Skip master
for (auto &future : rpc_workers_.ExecuteOnWorkers<void>(0, execute))
future.get();
EXPECT_EQ(workers_cnt_[0], 0);
EXPECT_EQ(workers_cnt_[1], 1);
EXPECT_EQ(workers_cnt_[2], 1);
}

View File

@ -20,15 +20,26 @@ class WorkerEngineTest : public testing::Test {
protected:
const std::string local{"127.0.0.1"};
void TearDown() override {
// First we shutdown the master.
master_coordination_.Shutdown();
EXPECT_TRUE(master_coordination_.AwaitShutdown());
// Shutdown the RPC servers.
master_server_.Shutdown();
master_server_.AwaitShutdown();
worker_server_.Shutdown();
worker_server_.AwaitShutdown();
}
Server master_server_{{local, 0}};
Server worker_server_{{local, 0}};
MasterCoordination master_coordination_{master_server_.endpoint()};
RpcWorkerClients rpc_worker_clients_{master_coordination_};
EngineMaster master_{&master_server_, &master_coordination_};
EngineMaster master_{master_server_, rpc_worker_clients_};
ClientPool master_client_pool{master_server_.endpoint()};
EngineWorker worker_{worker_server_, master_client_pool};
EngineWorker worker_{&worker_server_, &master_client_pool};
};
TEST_F(WorkerEngineTest, BeginOnWorker) {
@ -66,8 +77,7 @@ TEST_F(WorkerEngineTest, RunningTransaction) {
worker_.LocalForEachActiveTransaction([&count](Transaction &t) {
++count;
if (t.id_ == 1) {
EXPECT_EQ(t.snapshot(),
tx::Snapshot(std::vector<tx::TransactionId>{}));
EXPECT_EQ(t.snapshot(), tx::Snapshot(std::vector<tx::TransactionId>{}));
} else {
EXPECT_EQ(t.snapshot(), tx::Snapshot({1}));
}

View File

@ -10,7 +10,7 @@
#include "utils/timer.hpp"
TEST(ThreadPool, RunMany) {
utils::ThreadPool tp(10);
utils::ThreadPool tp(10, "Test");
const int kResults = 10000;
std::vector<utils::Future<int>> results;
for (int i = 0; i < kResults; ++i) {
@ -26,7 +26,7 @@ TEST(ThreadPool, EnsureParallel) {
using namespace std::chrono_literals;
const int kSize = 10;
utils::ThreadPool tp(kSize);
utils::ThreadPool tp(kSize, "Test");
std::vector<utils::Future<void>> results;
utils::Timer t;