Refactor HA RPC clients
Reviewers: msantl, ipaljak Reviewed By: ipaljak Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1977
This commit is contained in:
parent
e128a9f80f
commit
3ffed4bf6d
@ -18,8 +18,11 @@ Coordination::Coordination(
|
||||
std::unordered_map<uint16_t, io::network::Endpoint> workers)
|
||||
: server_(workers[worker_id], server_workers_count),
|
||||
worker_id_(worker_id),
|
||||
workers_(workers),
|
||||
thread_pool_(client_workers_count, "RPC client") {}
|
||||
workers_(workers) {
|
||||
for (const auto &worker : workers_) {
|
||||
client_locks_[worker.first] = std::make_unique<std::mutex>();
|
||||
}
|
||||
}
|
||||
|
||||
Coordination::~Coordination() {
|
||||
CHECK(!alive_) << "You must call Shutdown and AwaitShutdown on Coordination!";
|
||||
@ -71,24 +74,7 @@ std::vector<int> Coordination::GetWorkerIds() {
|
||||
return worker_ids;
|
||||
}
|
||||
|
||||
communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
auto found = client_pools_.find(worker_id);
|
||||
if (found != client_pools_.end()) return &found->second;
|
||||
auto found_endpoint = workers_.find(worker_id);
|
||||
CHECK(found_endpoint != workers_.end())
|
||||
<< "No endpoint registered for worker id: " << worker_id;
|
||||
auto &endpoint = found_endpoint->second;
|
||||
return &client_pools_
|
||||
.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(worker_id),
|
||||
std::forward_as_tuple(endpoint))
|
||||
.first->second;
|
||||
}
|
||||
|
||||
uint16_t Coordination::WorkerCount() {
|
||||
return workers_.size();
|
||||
}
|
||||
uint16_t Coordination::WorkerCount() { return workers_.size(); }
|
||||
|
||||
bool Coordination::Start() {
|
||||
if (!server_.Start()) return false;
|
||||
|
@ -5,13 +5,15 @@
|
||||
#include <atomic>
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <thread>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "communication/rpc/client_pool.hpp"
|
||||
#include "communication/rpc/client.hpp"
|
||||
#include "communication/rpc/server.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "raft/exceptions.hpp"
|
||||
@ -57,36 +59,48 @@ class Coordination final {
|
||||
/// Returns all workers ids.
|
||||
std::vector<int> GetWorkerIds();
|
||||
|
||||
/// Returns a cached `ClientPool` for the given `worker_id`.
|
||||
communication::rpc::ClientPool *GetClientPool(int worker_id);
|
||||
|
||||
uint16_t WorkerCount();
|
||||
|
||||
/// Asynchronously executes the given function on the RPC client for the
|
||||
/// given worker id. Returns an `std::future` of the given `execute`
|
||||
/// function's return type.
|
||||
template <typename TResult>
|
||||
auto ExecuteOnWorker(
|
||||
int worker_id,
|
||||
const std::function<TResult(int worker_id,
|
||||
communication::rpc::ClientPool &)> &execute) {
|
||||
auto client_pool = GetClientPool(worker_id);
|
||||
return thread_pool_.Run(execute, worker_id, std::ref(*client_pool));
|
||||
}
|
||||
/// Asynchroniously executes the `execute` function on all worker rpc clients
|
||||
/// except the one whose id is `skip_worker_id`. Returns a vector of futures
|
||||
/// contaning the results of the `execute` function.
|
||||
template <typename TResult>
|
||||
auto ExecuteOnWorkers(
|
||||
int skip_worker_id,
|
||||
const std::function<TResult(int worker_id,
|
||||
communication::rpc::ClientPool &)> &execute) {
|
||||
std::vector<std::future<TResult>> futures;
|
||||
for (auto &worker_id : GetWorkerIds()) {
|
||||
if (worker_id == skip_worker_id) continue;
|
||||
futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute)));
|
||||
/// Executes a RPC on another worker in the cluster. If the RPC execution
|
||||
/// fails (because of underlying network issues) it returns a `std::nullopt`.
|
||||
template <class TRequestResponse, class... Args>
|
||||
std::optional<typename TRequestResponse::Response> ExecuteOnOtherWorker(
|
||||
uint16_t worker_id, Args &&... args) {
|
||||
CHECK(worker_id != worker_id_) << "Trying to execute RPC on self!";
|
||||
|
||||
communication::rpc::Client *client = nullptr;
|
||||
std::mutex *client_lock = nullptr;
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
|
||||
auto found = clients_.find(worker_id);
|
||||
if (found != clients_.end()) {
|
||||
client = &found->second;
|
||||
} else {
|
||||
auto found_endpoint = workers_.find(worker_id);
|
||||
CHECK(found_endpoint != workers_.end())
|
||||
<< "No endpoint registered for worker id: " << worker_id;
|
||||
auto &endpoint = found_endpoint->second;
|
||||
auto it = clients_.emplace(worker_id, endpoint);
|
||||
client = &it.first->second;
|
||||
}
|
||||
|
||||
auto lock_found = client_locks_.find(worker_id);
|
||||
CHECK(lock_found != client_locks_.end())
|
||||
<< "No client lock for worker id: " << worker_id;
|
||||
client_lock = lock_found->second.get();
|
||||
}
|
||||
|
||||
try {
|
||||
std::lock_guard<std::mutex> guard(*client_lock);
|
||||
return client->Call<TRequestResponse>(std::forward<Args>(args)...);
|
||||
} catch (...) {
|
||||
// Invalidate the client so that we reconnect next time.
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
CHECK(clients_.erase(worker_id) == 1)
|
||||
<< "Couldn't remove client for worker id: " << worker_id;
|
||||
return std::nullopt;
|
||||
}
|
||||
return futures;
|
||||
}
|
||||
|
||||
template <class TRequestResponse>
|
||||
@ -131,8 +145,8 @@ class Coordination final {
|
||||
mutable std::mutex lock_;
|
||||
std::unordered_map<uint16_t, io::network::Endpoint> workers_;
|
||||
|
||||
std::unordered_map<int, communication::rpc::ClientPool> client_pools_;
|
||||
utils::ThreadPool thread_pool_;
|
||||
std::unordered_map<uint16_t, communication::rpc::Client> clients_;
|
||||
std::unordered_map<uint16_t, std::unique_ptr<std::mutex>> client_locks_;
|
||||
|
||||
// Flags used for shutdown.
|
||||
std::atomic<bool> alive_{true};
|
||||
|
@ -576,7 +576,7 @@ void RaftServer::Transition(const Mode &new_mode) {
|
||||
|
||||
mode_ = Mode::CANDIDATE;
|
||||
|
||||
if (HasMajortyVote()) {
|
||||
if (HasMajorityVote()) {
|
||||
Transition(Mode::LEADER);
|
||||
state_changed_.notify_all();
|
||||
return;
|
||||
@ -699,29 +699,20 @@ void RaftServer::SendLogEntries(
|
||||
if (next_index_[peer_id] <= log_size_ - 1)
|
||||
GetLogSuffix(next_index_[peer_id], request_entries);
|
||||
|
||||
bool unreachable_peer = false;
|
||||
auto peer_future = coordination_->ExecuteOnWorker<AppendEntriesRes>(
|
||||
peer_id, [&](int worker_id, auto &client) {
|
||||
try {
|
||||
auto res = client.template Call<AppendEntriesRpc>(
|
||||
server_id_, commit_index_, request_term, request_prev_log_index,
|
||||
request_prev_log_term, request_entries);
|
||||
return res;
|
||||
} catch (...) {
|
||||
// not being able to connect to peer means we need to retry.
|
||||
// TODO(ipaljak): Consider backoff.
|
||||
unreachable_peer = true;
|
||||
return AppendEntriesRes(false, request_term);
|
||||
}
|
||||
});
|
||||
// Copy all internal variables before releasing the lock.
|
||||
auto server_id = server_id_;
|
||||
auto commit_index = commit_index_;
|
||||
|
||||
VLOG(40) << "Entries size: " << request_entries.size();
|
||||
|
||||
lock->unlock(); // Release lock while waiting for response.
|
||||
auto reply = peer_future.get();
|
||||
// Execute the RPC.
|
||||
lock->unlock();
|
||||
auto reply = coordination_->ExecuteOnOtherWorker<AppendEntriesRpc>(
|
||||
peer_id, server_id, commit_index, request_term, request_prev_log_index,
|
||||
request_prev_log_term, request_entries);
|
||||
lock->lock();
|
||||
|
||||
if (unreachable_peer) {
|
||||
if (!reply) {
|
||||
next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval;
|
||||
return;
|
||||
}
|
||||
@ -730,7 +721,7 @@ void RaftServer::SendLogEntries(
|
||||
return;
|
||||
}
|
||||
|
||||
if (OutOfSync(reply.term)) {
|
||||
if (OutOfSync(reply->term)) {
|
||||
state_changed_.notify_all();
|
||||
return;
|
||||
}
|
||||
@ -738,13 +729,13 @@ void RaftServer::SendLogEntries(
|
||||
DCHECK(mode_ == Mode::LEADER)
|
||||
<< "Elected leader for term should never change.";
|
||||
|
||||
if (reply.term != current_term_) {
|
||||
if (reply->term != current_term_) {
|
||||
VLOG(40) << "Server " << server_id_
|
||||
<< ": Ignoring stale AppendEntriesRPC reply from " << peer_id;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!reply.success) {
|
||||
if (!reply->success) {
|
||||
// Replication can fail for the first log entry if the peer that we're
|
||||
// sending the entry is in the process of shutting down.
|
||||
next_index_[peer_id] = std::max(next_index_[peer_id] - 1, 1UL);
|
||||
@ -786,25 +777,17 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
|
||||
|
||||
VLOG(40) << "Snapshot size: " << snapshot_size << " bytes.";
|
||||
|
||||
bool unreachable_peer = false;
|
||||
auto peer_future = coordination_->ExecuteOnWorker<InstallSnapshotRes>(
|
||||
peer_id, [&](int worker_id, auto &client) {
|
||||
try {
|
||||
auto res = client.template Call<InstallSnapshotRpc>(
|
||||
server_id_, request_term, snapshot_metadata, std::move(snapshot),
|
||||
snapshot_size);
|
||||
return res;
|
||||
} catch (...) {
|
||||
unreachable_peer = true;
|
||||
return InstallSnapshotRes(request_term);
|
||||
}
|
||||
});
|
||||
// Copy all internal variables before releasing the lock.
|
||||
auto server_id = server_id_;
|
||||
|
||||
// Execute the RPC.
|
||||
lock->unlock();
|
||||
auto reply = peer_future.get();
|
||||
auto reply = coordination_->ExecuteOnOtherWorker<InstallSnapshotRpc>(
|
||||
peer_id, server_id, request_term, snapshot_metadata, std::move(snapshot),
|
||||
snapshot_size);
|
||||
lock->lock();
|
||||
|
||||
if (unreachable_peer) {
|
||||
if (!reply) {
|
||||
next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval;
|
||||
return;
|
||||
}
|
||||
@ -813,12 +796,12 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
|
||||
return;
|
||||
}
|
||||
|
||||
if (OutOfSync(reply.term)) {
|
||||
if (OutOfSync(reply->term)) {
|
||||
state_changed_.notify_all();
|
||||
return;
|
||||
}
|
||||
|
||||
if (reply.term != current_term_) {
|
||||
if (reply->term != current_term_) {
|
||||
VLOG(40) << "Server " << server_id_
|
||||
<< ": Ignoring stale InstallSnapshotRpc reply from " << peer_id;
|
||||
return;
|
||||
@ -874,31 +857,26 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) {
|
||||
// TODO(ipaljak): Consider backoff.
|
||||
wait_until = TimePoint::max();
|
||||
|
||||
// Copy all internal variables before releasing the lock.
|
||||
auto server_id = server_id_;
|
||||
auto request_term = current_term_.load();
|
||||
auto peer_future = coordination_->ExecuteOnWorker<RequestVoteRes>(
|
||||
peer_id, [&](int worker_id, auto &client) {
|
||||
auto last_entry_data = LastEntryData();
|
||||
try {
|
||||
auto res = client.template Call<RequestVoteRpc>(
|
||||
server_id_, request_term, last_entry_data.first,
|
||||
last_entry_data.second);
|
||||
return res;
|
||||
} catch (...) {
|
||||
// not being able to connect to peer defaults to a vote
|
||||
// being denied from that peer. This is correct but not
|
||||
// optimal.
|
||||
//
|
||||
// TODO(ipaljak): reconsider this decision :)
|
||||
return RequestVoteRes(false, request_term);
|
||||
}
|
||||
});
|
||||
auto last_entry_data = LastEntryData();
|
||||
|
||||
vote_requested_[peer_id] = true;
|
||||
|
||||
// Execute the RPC.
|
||||
lock.unlock(); // Release lock while waiting for response
|
||||
auto reply = peer_future.get();
|
||||
auto reply = coordination_->ExecuteOnOtherWorker<RequestVoteRpc>(
|
||||
peer_id, server_id, request_term, last_entry_data.first,
|
||||
last_entry_data.second);
|
||||
lock.lock();
|
||||
|
||||
// If the peer isn't reachable, it is the same as if he didn't grant
|
||||
// us his vote.
|
||||
if (!reply) {
|
||||
reply = RequestVoteRes(false, request_term);
|
||||
}
|
||||
|
||||
if (current_term_ != request_term || mode_ != Mode::CANDIDATE ||
|
||||
exiting_) {
|
||||
VLOG(40) << "Server " << server_id_
|
||||
@ -906,16 +884,16 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (OutOfSync(reply.term)) {
|
||||
if (OutOfSync(reply->term)) {
|
||||
state_changed_.notify_all();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (reply.vote_granted) {
|
||||
if (reply->vote_granted) {
|
||||
VLOG(40) << "Server " << server_id_ << ": Got vote from "
|
||||
<< peer_id;
|
||||
++granted_votes_;
|
||||
if (HasMajortyVote()) Transition(Mode::LEADER);
|
||||
if (HasMajorityVote()) Transition(Mode::LEADER);
|
||||
} else {
|
||||
VLOG(40) << "Server " << server_id_ << ": Denied vote from "
|
||||
<< peer_id;
|
||||
@ -1049,7 +1027,7 @@ void RaftServer::SetNextElectionTimePoint() {
|
||||
next_election_ = Clock::now() + wait_interval;
|
||||
}
|
||||
|
||||
bool RaftServer::HasMajortyVote() {
|
||||
bool RaftServer::HasMajorityVote() {
|
||||
if (2 * granted_votes_ > coordination_->WorkerCount()) {
|
||||
VLOG(40) << "Server " << server_id_
|
||||
<< ": Obtained majority vote (Term: " << current_term_ << ")";
|
||||
|
@ -326,7 +326,7 @@ class RaftServer final : public RaftInterface {
|
||||
void SetNextElectionTimePoint();
|
||||
|
||||
/// Checks if the current server obtained enough votes to become a leader.
|
||||
bool HasMajortyVote();
|
||||
bool HasMajorityVote();
|
||||
|
||||
/// Returns relevant metadata about the last entry in this server's Raft Log.
|
||||
/// More precisely, returns a pair consisting of an index of the last entry
|
||||
|
@ -14,8 +14,6 @@ using namespace std::literals::chrono_literals;
|
||||
using Clock = std::chrono::system_clock;
|
||||
using TimePoint = std::chrono::system_clock::time_point;
|
||||
|
||||
const std::chrono::duration<int64_t> kRpcTimeout = 1s;
|
||||
|
||||
StorageInfo::StorageInfo(database::GraphDb *db, Coordination *coordination,
|
||||
uint16_t server_id)
|
||||
: db_(db), coordination_(coordination), server_id_(server_id) {
|
||||
@ -56,44 +54,19 @@ StorageInfo::GetLocalStorageInfo() const {
|
||||
std::map<std::string, std::vector<std::pair<std::string, std::string>>>
|
||||
StorageInfo::GetStorageInfo() const {
|
||||
std::map<std::string, std::vector<std::pair<std::string, std::string>>> info;
|
||||
std::map<uint16_t, utils::Future<StorageInfoRes>> remote_storage_info_futures;
|
||||
std::map<uint16_t, bool> received_reply;
|
||||
|
||||
auto peers = coordination_->GetWorkerIds();
|
||||
|
||||
for (auto id : peers) {
|
||||
received_reply[id] = false;
|
||||
if (id == server_id_) {
|
||||
info.emplace(std::to_string(id), GetLocalStorageInfo());
|
||||
received_reply[id] = true;
|
||||
} else {
|
||||
remote_storage_info_futures.emplace(
|
||||
id, coordination_->ExecuteOnWorker<StorageInfoRes>(
|
||||
id, [&](int worker_id, auto &client) {
|
||||
try {
|
||||
auto res = client.template Call<StorageInfoRpc>();
|
||||
return res;
|
||||
} catch (...) {
|
||||
return StorageInfoRes(id, {});
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
int16_t waiting_for = peers.size() - 1;
|
||||
|
||||
TimePoint start = Clock::now();
|
||||
while (Clock::now() - start <= kRpcTimeout && waiting_for > 0) {
|
||||
for (auto id : peers) {
|
||||
if (received_reply[id]) continue;
|
||||
auto &future = remote_storage_info_futures[id];
|
||||
if (!future.IsReady()) continue;
|
||||
|
||||
auto reply = future.get();
|
||||
info.emplace(std::to_string(reply.server_id),
|
||||
std::move(reply.storage_info));
|
||||
received_reply[id] = true;
|
||||
waiting_for--;
|
||||
auto reply = coordination_->ExecuteOnOtherWorker<StorageInfoRpc>(id);
|
||||
if (reply) {
|
||||
info[std::to_string(id)] = std::move(reply->storage_info);
|
||||
} else {
|
||||
info[std::to_string(id)] = {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user