Refactor HA RPC clients

Reviewers: msantl, ipaljak

Reviewed By: ipaljak

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1977
This commit is contained in:
Matej Ferencevic 2019-04-24 13:11:05 +02:00
parent e128a9f80f
commit 3ffed4bf6d
5 changed files with 96 additions and 145 deletions

View File

@ -18,8 +18,11 @@ Coordination::Coordination(
std::unordered_map<uint16_t, io::network::Endpoint> workers) std::unordered_map<uint16_t, io::network::Endpoint> workers)
: server_(workers[worker_id], server_workers_count), : server_(workers[worker_id], server_workers_count),
worker_id_(worker_id), worker_id_(worker_id),
workers_(workers), workers_(workers) {
thread_pool_(client_workers_count, "RPC client") {} for (const auto &worker : workers_) {
client_locks_[worker.first] = std::make_unique<std::mutex>();
}
}
Coordination::~Coordination() { Coordination::~Coordination() {
CHECK(!alive_) << "You must call Shutdown and AwaitShutdown on Coordination!"; CHECK(!alive_) << "You must call Shutdown and AwaitShutdown on Coordination!";
@ -71,24 +74,7 @@ std::vector<int> Coordination::GetWorkerIds() {
return worker_ids; return worker_ids;
} }
communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) { uint16_t Coordination::WorkerCount() { return workers_.size(); }
std::lock_guard<std::mutex> guard(lock_);
auto found = client_pools_.find(worker_id);
if (found != client_pools_.end()) return &found->second;
auto found_endpoint = workers_.find(worker_id);
CHECK(found_endpoint != workers_.end())
<< "No endpoint registered for worker id: " << worker_id;
auto &endpoint = found_endpoint->second;
return &client_pools_
.emplace(std::piecewise_construct,
std::forward_as_tuple(worker_id),
std::forward_as_tuple(endpoint))
.first->second;
}
uint16_t Coordination::WorkerCount() {
return workers_.size();
}
bool Coordination::Start() { bool Coordination::Start() {
if (!server_.Start()) return false; if (!server_.Start()) return false;

View File

@ -5,13 +5,15 @@
#include <atomic> #include <atomic>
#include <filesystem> #include <filesystem>
#include <functional> #include <functional>
#include <memory>
#include <mutex> #include <mutex>
#include <optional>
#include <thread> #include <thread>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "communication/rpc/client_pool.hpp" #include "communication/rpc/client.hpp"
#include "communication/rpc/server.hpp" #include "communication/rpc/server.hpp"
#include "io/network/endpoint.hpp" #include "io/network/endpoint.hpp"
#include "raft/exceptions.hpp" #include "raft/exceptions.hpp"
@ -57,36 +59,48 @@ class Coordination final {
/// Returns all workers ids. /// Returns all workers ids.
std::vector<int> GetWorkerIds(); std::vector<int> GetWorkerIds();
/// Returns a cached `ClientPool` for the given `worker_id`.
communication::rpc::ClientPool *GetClientPool(int worker_id);
uint16_t WorkerCount(); uint16_t WorkerCount();
/// Asynchronously executes the given function on the RPC client for the /// Executes a RPC on another worker in the cluster. If the RPC execution
/// given worker id. Returns an `std::future` of the given `execute` /// fails (because of underlying network issues) it returns a `std::nullopt`.
/// function's return type. template <class TRequestResponse, class... Args>
template <typename TResult> std::optional<typename TRequestResponse::Response> ExecuteOnOtherWorker(
auto ExecuteOnWorker( uint16_t worker_id, Args &&... args) {
int worker_id, CHECK(worker_id != worker_id_) << "Trying to execute RPC on self!";
const std::function<TResult(int worker_id,
communication::rpc::ClientPool &)> &execute) { communication::rpc::Client *client = nullptr;
auto client_pool = GetClientPool(worker_id); std::mutex *client_lock = nullptr;
return thread_pool_.Run(execute, worker_id, std::ref(*client_pool)); {
std::lock_guard<std::mutex> guard(lock_);
auto found = clients_.find(worker_id);
if (found != clients_.end()) {
client = &found->second;
} else {
auto found_endpoint = workers_.find(worker_id);
CHECK(found_endpoint != workers_.end())
<< "No endpoint registered for worker id: " << worker_id;
auto &endpoint = found_endpoint->second;
auto it = clients_.emplace(worker_id, endpoint);
client = &it.first->second;
} }
/// Asynchroniously executes the `execute` function on all worker rpc clients
/// except the one whose id is `skip_worker_id`. Returns a vector of futures auto lock_found = client_locks_.find(worker_id);
/// contaning the results of the `execute` function. CHECK(lock_found != client_locks_.end())
template <typename TResult> << "No client lock for worker id: " << worker_id;
auto ExecuteOnWorkers( client_lock = lock_found->second.get();
int skip_worker_id, }
const std::function<TResult(int worker_id,
communication::rpc::ClientPool &)> &execute) { try {
std::vector<std::future<TResult>> futures; std::lock_guard<std::mutex> guard(*client_lock);
for (auto &worker_id : GetWorkerIds()) { return client->Call<TRequestResponse>(std::forward<Args>(args)...);
if (worker_id == skip_worker_id) continue; } catch (...) {
futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute))); // Invalidate the client so that we reconnect next time.
std::lock_guard<std::mutex> guard(lock_);
CHECK(clients_.erase(worker_id) == 1)
<< "Couldn't remove client for worker id: " << worker_id;
return std::nullopt;
} }
return futures;
} }
template <class TRequestResponse> template <class TRequestResponse>
@ -131,8 +145,8 @@ class Coordination final {
mutable std::mutex lock_; mutable std::mutex lock_;
std::unordered_map<uint16_t, io::network::Endpoint> workers_; std::unordered_map<uint16_t, io::network::Endpoint> workers_;
std::unordered_map<int, communication::rpc::ClientPool> client_pools_; std::unordered_map<uint16_t, communication::rpc::Client> clients_;
utils::ThreadPool thread_pool_; std::unordered_map<uint16_t, std::unique_ptr<std::mutex>> client_locks_;
// Flags used for shutdown. // Flags used for shutdown.
std::atomic<bool> alive_{true}; std::atomic<bool> alive_{true};

View File

@ -576,7 +576,7 @@ void RaftServer::Transition(const Mode &new_mode) {
mode_ = Mode::CANDIDATE; mode_ = Mode::CANDIDATE;
if (HasMajortyVote()) { if (HasMajorityVote()) {
Transition(Mode::LEADER); Transition(Mode::LEADER);
state_changed_.notify_all(); state_changed_.notify_all();
return; return;
@ -699,29 +699,20 @@ void RaftServer::SendLogEntries(
if (next_index_[peer_id] <= log_size_ - 1) if (next_index_[peer_id] <= log_size_ - 1)
GetLogSuffix(next_index_[peer_id], request_entries); GetLogSuffix(next_index_[peer_id], request_entries);
bool unreachable_peer = false; // Copy all internal variables before releasing the lock.
auto peer_future = coordination_->ExecuteOnWorker<AppendEntriesRes>( auto server_id = server_id_;
peer_id, [&](int worker_id, auto &client) { auto commit_index = commit_index_;
try {
auto res = client.template Call<AppendEntriesRpc>(
server_id_, commit_index_, request_term, request_prev_log_index,
request_prev_log_term, request_entries);
return res;
} catch (...) {
// not being able to connect to peer means we need to retry.
// TODO(ipaljak): Consider backoff.
unreachable_peer = true;
return AppendEntriesRes(false, request_term);
}
});
VLOG(40) << "Entries size: " << request_entries.size(); VLOG(40) << "Entries size: " << request_entries.size();
lock->unlock(); // Release lock while waiting for response. // Execute the RPC.
auto reply = peer_future.get(); lock->unlock();
auto reply = coordination_->ExecuteOnOtherWorker<AppendEntriesRpc>(
peer_id, server_id, commit_index, request_term, request_prev_log_index,
request_prev_log_term, request_entries);
lock->lock(); lock->lock();
if (unreachable_peer) { if (!reply) {
next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval; next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval;
return; return;
} }
@ -730,7 +721,7 @@ void RaftServer::SendLogEntries(
return; return;
} }
if (OutOfSync(reply.term)) { if (OutOfSync(reply->term)) {
state_changed_.notify_all(); state_changed_.notify_all();
return; return;
} }
@ -738,13 +729,13 @@ void RaftServer::SendLogEntries(
DCHECK(mode_ == Mode::LEADER) DCHECK(mode_ == Mode::LEADER)
<< "Elected leader for term should never change."; << "Elected leader for term should never change.";
if (reply.term != current_term_) { if (reply->term != current_term_) {
VLOG(40) << "Server " << server_id_ VLOG(40) << "Server " << server_id_
<< ": Ignoring stale AppendEntriesRPC reply from " << peer_id; << ": Ignoring stale AppendEntriesRPC reply from " << peer_id;
return; return;
} }
if (!reply.success) { if (!reply->success) {
// Replication can fail for the first log entry if the peer that we're // Replication can fail for the first log entry if the peer that we're
// sending the entry is in the process of shutting down. // sending the entry is in the process of shutting down.
next_index_[peer_id] = std::max(next_index_[peer_id] - 1, 1UL); next_index_[peer_id] = std::max(next_index_[peer_id] - 1, 1UL);
@ -786,25 +777,17 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
VLOG(40) << "Snapshot size: " << snapshot_size << " bytes."; VLOG(40) << "Snapshot size: " << snapshot_size << " bytes.";
bool unreachable_peer = false; // Copy all internal variables before releasing the lock.
auto peer_future = coordination_->ExecuteOnWorker<InstallSnapshotRes>( auto server_id = server_id_;
peer_id, [&](int worker_id, auto &client) {
try {
auto res = client.template Call<InstallSnapshotRpc>(
server_id_, request_term, snapshot_metadata, std::move(snapshot),
snapshot_size);
return res;
} catch (...) {
unreachable_peer = true;
return InstallSnapshotRes(request_term);
}
});
// Execute the RPC.
lock->unlock(); lock->unlock();
auto reply = peer_future.get(); auto reply = coordination_->ExecuteOnOtherWorker<InstallSnapshotRpc>(
peer_id, server_id, request_term, snapshot_metadata, std::move(snapshot),
snapshot_size);
lock->lock(); lock->lock();
if (unreachable_peer) { if (!reply) {
next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval; next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval;
return; return;
} }
@ -813,12 +796,12 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
return; return;
} }
if (OutOfSync(reply.term)) { if (OutOfSync(reply->term)) {
state_changed_.notify_all(); state_changed_.notify_all();
return; return;
} }
if (reply.term != current_term_) { if (reply->term != current_term_) {
VLOG(40) << "Server " << server_id_ VLOG(40) << "Server " << server_id_
<< ": Ignoring stale InstallSnapshotRpc reply from " << peer_id; << ": Ignoring stale InstallSnapshotRpc reply from " << peer_id;
return; return;
@ -874,31 +857,26 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) {
// TODO(ipaljak): Consider backoff. // TODO(ipaljak): Consider backoff.
wait_until = TimePoint::max(); wait_until = TimePoint::max();
// Copy all internal variables before releasing the lock.
auto server_id = server_id_;
auto request_term = current_term_.load(); auto request_term = current_term_.load();
auto peer_future = coordination_->ExecuteOnWorker<RequestVoteRes>(
peer_id, [&](int worker_id, auto &client) {
auto last_entry_data = LastEntryData(); auto last_entry_data = LastEntryData();
try {
auto res = client.template Call<RequestVoteRpc>(
server_id_, request_term, last_entry_data.first,
last_entry_data.second);
return res;
} catch (...) {
// not being able to connect to peer defaults to a vote
// being denied from that peer. This is correct but not
// optimal.
//
// TODO(ipaljak): reconsider this decision :)
return RequestVoteRes(false, request_term);
}
});
vote_requested_[peer_id] = true; vote_requested_[peer_id] = true;
// Execute the RPC.
lock.unlock(); // Release lock while waiting for response lock.unlock(); // Release lock while waiting for response
auto reply = peer_future.get(); auto reply = coordination_->ExecuteOnOtherWorker<RequestVoteRpc>(
peer_id, server_id, request_term, last_entry_data.first,
last_entry_data.second);
lock.lock(); lock.lock();
// If the peer isn't reachable, it is the same as if he didn't grant
// us his vote.
if (!reply) {
reply = RequestVoteRes(false, request_term);
}
if (current_term_ != request_term || mode_ != Mode::CANDIDATE || if (current_term_ != request_term || mode_ != Mode::CANDIDATE ||
exiting_) { exiting_) {
VLOG(40) << "Server " << server_id_ VLOG(40) << "Server " << server_id_
@ -906,16 +884,16 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) {
break; break;
} }
if (OutOfSync(reply.term)) { if (OutOfSync(reply->term)) {
state_changed_.notify_all(); state_changed_.notify_all();
continue; continue;
} }
if (reply.vote_granted) { if (reply->vote_granted) {
VLOG(40) << "Server " << server_id_ << ": Got vote from " VLOG(40) << "Server " << server_id_ << ": Got vote from "
<< peer_id; << peer_id;
++granted_votes_; ++granted_votes_;
if (HasMajortyVote()) Transition(Mode::LEADER); if (HasMajorityVote()) Transition(Mode::LEADER);
} else { } else {
VLOG(40) << "Server " << server_id_ << ": Denied vote from " VLOG(40) << "Server " << server_id_ << ": Denied vote from "
<< peer_id; << peer_id;
@ -1049,7 +1027,7 @@ void RaftServer::SetNextElectionTimePoint() {
next_election_ = Clock::now() + wait_interval; next_election_ = Clock::now() + wait_interval;
} }
bool RaftServer::HasMajortyVote() { bool RaftServer::HasMajorityVote() {
if (2 * granted_votes_ > coordination_->WorkerCount()) { if (2 * granted_votes_ > coordination_->WorkerCount()) {
VLOG(40) << "Server " << server_id_ VLOG(40) << "Server " << server_id_
<< ": Obtained majority vote (Term: " << current_term_ << ")"; << ": Obtained majority vote (Term: " << current_term_ << ")";

View File

@ -326,7 +326,7 @@ class RaftServer final : public RaftInterface {
void SetNextElectionTimePoint(); void SetNextElectionTimePoint();
/// Checks if the current server obtained enough votes to become a leader. /// Checks if the current server obtained enough votes to become a leader.
bool HasMajortyVote(); bool HasMajorityVote();
/// Returns relevant metadata about the last entry in this server's Raft Log. /// Returns relevant metadata about the last entry in this server's Raft Log.
/// More precisely, returns a pair consisting of an index of the last entry /// More precisely, returns a pair consisting of an index of the last entry

View File

@ -14,8 +14,6 @@ using namespace std::literals::chrono_literals;
using Clock = std::chrono::system_clock; using Clock = std::chrono::system_clock;
using TimePoint = std::chrono::system_clock::time_point; using TimePoint = std::chrono::system_clock::time_point;
const std::chrono::duration<int64_t> kRpcTimeout = 1s;
StorageInfo::StorageInfo(database::GraphDb *db, Coordination *coordination, StorageInfo::StorageInfo(database::GraphDb *db, Coordination *coordination,
uint16_t server_id) uint16_t server_id)
: db_(db), coordination_(coordination), server_id_(server_id) { : db_(db), coordination_(coordination), server_id_(server_id) {
@ -56,44 +54,19 @@ StorageInfo::GetLocalStorageInfo() const {
std::map<std::string, std::vector<std::pair<std::string, std::string>>> std::map<std::string, std::vector<std::pair<std::string, std::string>>>
StorageInfo::GetStorageInfo() const { StorageInfo::GetStorageInfo() const {
std::map<std::string, std::vector<std::pair<std::string, std::string>>> info; std::map<std::string, std::vector<std::pair<std::string, std::string>>> info;
std::map<uint16_t, utils::Future<StorageInfoRes>> remote_storage_info_futures;
std::map<uint16_t, bool> received_reply;
auto peers = coordination_->GetWorkerIds(); auto peers = coordination_->GetWorkerIds();
for (auto id : peers) { for (auto id : peers) {
received_reply[id] = false;
if (id == server_id_) { if (id == server_id_) {
info.emplace(std::to_string(id), GetLocalStorageInfo()); info.emplace(std::to_string(id), GetLocalStorageInfo());
received_reply[id] = true;
} else { } else {
remote_storage_info_futures.emplace( auto reply = coordination_->ExecuteOnOtherWorker<StorageInfoRpc>(id);
id, coordination_->ExecuteOnWorker<StorageInfoRes>( if (reply) {
id, [&](int worker_id, auto &client) { info[std::to_string(id)] = std::move(reply->storage_info);
try { } else {
auto res = client.template Call<StorageInfoRpc>(); info[std::to_string(id)] = {};
return res;
} catch (...) {
return StorageInfoRes(id, {});
} }
}));
}
}
int16_t waiting_for = peers.size() - 1;
TimePoint start = Clock::now();
while (Clock::now() - start <= kRpcTimeout && waiting_for > 0) {
for (auto id : peers) {
if (received_reply[id]) continue;
auto &future = remote_storage_info_futures[id];
if (!future.IsReady()) continue;
auto reply = future.get();
info.emplace(std::to_string(reply.server_id),
std::move(reply.storage_info));
received_reply[id] = true;
waiting_for--;
} }
} }