Refactor HA RPC clients

Reviewers: msantl, ipaljak

Reviewed By: ipaljak

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1977
This commit is contained in:
Matej Ferencevic 2019-04-24 13:11:05 +02:00
parent e128a9f80f
commit 3ffed4bf6d
5 changed files with 96 additions and 145 deletions

View File

@ -18,8 +18,11 @@ Coordination::Coordination(
std::unordered_map<uint16_t, io::network::Endpoint> workers)
: server_(workers[worker_id], server_workers_count),
worker_id_(worker_id),
workers_(workers),
thread_pool_(client_workers_count, "RPC client") {}
workers_(workers) {
for (const auto &worker : workers_) {
client_locks_[worker.first] = std::make_unique<std::mutex>();
}
}
Coordination::~Coordination() {
CHECK(!alive_) << "You must call Shutdown and AwaitShutdown on Coordination!";
@ -71,24 +74,7 @@ std::vector<int> Coordination::GetWorkerIds() {
return worker_ids;
}
communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) {
std::lock_guard<std::mutex> guard(lock_);
auto found = client_pools_.find(worker_id);
if (found != client_pools_.end()) return &found->second;
auto found_endpoint = workers_.find(worker_id);
CHECK(found_endpoint != workers_.end())
<< "No endpoint registered for worker id: " << worker_id;
auto &endpoint = found_endpoint->second;
return &client_pools_
.emplace(std::piecewise_construct,
std::forward_as_tuple(worker_id),
std::forward_as_tuple(endpoint))
.first->second;
}
uint16_t Coordination::WorkerCount() {
return workers_.size();
}
uint16_t Coordination::WorkerCount() { return workers_.size(); }
bool Coordination::Start() {
if (!server_.Start()) return false;

View File

@ -5,13 +5,15 @@
#include <atomic>
#include <filesystem>
#include <functional>
#include <memory>
#include <mutex>
#include <optional>
#include <thread>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "communication/rpc/client_pool.hpp"
#include "communication/rpc/client.hpp"
#include "communication/rpc/server.hpp"
#include "io/network/endpoint.hpp"
#include "raft/exceptions.hpp"
@ -57,36 +59,48 @@ class Coordination final {
/// Returns all workers ids.
std::vector<int> GetWorkerIds();
/// Returns a cached `ClientPool` for the given `worker_id`.
communication::rpc::ClientPool *GetClientPool(int worker_id);
uint16_t WorkerCount();
/// Asynchronously executes the given function on the RPC client for the
/// given worker id. Returns an `std::future` of the given `execute`
/// function's return type.
template <typename TResult>
auto ExecuteOnWorker(
int worker_id,
const std::function<TResult(int worker_id,
communication::rpc::ClientPool &)> &execute) {
auto client_pool = GetClientPool(worker_id);
return thread_pool_.Run(execute, worker_id, std::ref(*client_pool));
/// Executes a RPC on another worker in the cluster. If the RPC execution
/// fails (because of underlying network issues) it returns a `std::nullopt`.
template <class TRequestResponse, class... Args>
std::optional<typename TRequestResponse::Response> ExecuteOnOtherWorker(
uint16_t worker_id, Args &&... args) {
CHECK(worker_id != worker_id_) << "Trying to execute RPC on self!";
communication::rpc::Client *client = nullptr;
std::mutex *client_lock = nullptr;
{
std::lock_guard<std::mutex> guard(lock_);
auto found = clients_.find(worker_id);
if (found != clients_.end()) {
client = &found->second;
} else {
auto found_endpoint = workers_.find(worker_id);
CHECK(found_endpoint != workers_.end())
<< "No endpoint registered for worker id: " << worker_id;
auto &endpoint = found_endpoint->second;
auto it = clients_.emplace(worker_id, endpoint);
client = &it.first->second;
}
/// Asynchroniously executes the `execute` function on all worker rpc clients
/// except the one whose id is `skip_worker_id`. Returns a vector of futures
/// contaning the results of the `execute` function.
template <typename TResult>
auto ExecuteOnWorkers(
int skip_worker_id,
const std::function<TResult(int worker_id,
communication::rpc::ClientPool &)> &execute) {
std::vector<std::future<TResult>> futures;
for (auto &worker_id : GetWorkerIds()) {
if (worker_id == skip_worker_id) continue;
futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute)));
auto lock_found = client_locks_.find(worker_id);
CHECK(lock_found != client_locks_.end())
<< "No client lock for worker id: " << worker_id;
client_lock = lock_found->second.get();
}
try {
std::lock_guard<std::mutex> guard(*client_lock);
return client->Call<TRequestResponse>(std::forward<Args>(args)...);
} catch (...) {
// Invalidate the client so that we reconnect next time.
std::lock_guard<std::mutex> guard(lock_);
CHECK(clients_.erase(worker_id) == 1)
<< "Couldn't remove client for worker id: " << worker_id;
return std::nullopt;
}
return futures;
}
template <class TRequestResponse>
@ -131,8 +145,8 @@ class Coordination final {
mutable std::mutex lock_;
std::unordered_map<uint16_t, io::network::Endpoint> workers_;
std::unordered_map<int, communication::rpc::ClientPool> client_pools_;
utils::ThreadPool thread_pool_;
std::unordered_map<uint16_t, communication::rpc::Client> clients_;
std::unordered_map<uint16_t, std::unique_ptr<std::mutex>> client_locks_;
// Flags used for shutdown.
std::atomic<bool> alive_{true};

View File

@ -576,7 +576,7 @@ void RaftServer::Transition(const Mode &new_mode) {
mode_ = Mode::CANDIDATE;
if (HasMajortyVote()) {
if (HasMajorityVote()) {
Transition(Mode::LEADER);
state_changed_.notify_all();
return;
@ -699,29 +699,20 @@ void RaftServer::SendLogEntries(
if (next_index_[peer_id] <= log_size_ - 1)
GetLogSuffix(next_index_[peer_id], request_entries);
bool unreachable_peer = false;
auto peer_future = coordination_->ExecuteOnWorker<AppendEntriesRes>(
peer_id, [&](int worker_id, auto &client) {
try {
auto res = client.template Call<AppendEntriesRpc>(
server_id_, commit_index_, request_term, request_prev_log_index,
request_prev_log_term, request_entries);
return res;
} catch (...) {
// not being able to connect to peer means we need to retry.
// TODO(ipaljak): Consider backoff.
unreachable_peer = true;
return AppendEntriesRes(false, request_term);
}
});
// Copy all internal variables before releasing the lock.
auto server_id = server_id_;
auto commit_index = commit_index_;
VLOG(40) << "Entries size: " << request_entries.size();
lock->unlock(); // Release lock while waiting for response.
auto reply = peer_future.get();
// Execute the RPC.
lock->unlock();
auto reply = coordination_->ExecuteOnOtherWorker<AppendEntriesRpc>(
peer_id, server_id, commit_index, request_term, request_prev_log_index,
request_prev_log_term, request_entries);
lock->lock();
if (unreachable_peer) {
if (!reply) {
next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval;
return;
}
@ -730,7 +721,7 @@ void RaftServer::SendLogEntries(
return;
}
if (OutOfSync(reply.term)) {
if (OutOfSync(reply->term)) {
state_changed_.notify_all();
return;
}
@ -738,13 +729,13 @@ void RaftServer::SendLogEntries(
DCHECK(mode_ == Mode::LEADER)
<< "Elected leader for term should never change.";
if (reply.term != current_term_) {
if (reply->term != current_term_) {
VLOG(40) << "Server " << server_id_
<< ": Ignoring stale AppendEntriesRPC reply from " << peer_id;
return;
}
if (!reply.success) {
if (!reply->success) {
// Replication can fail for the first log entry if the peer that we're
// sending the entry is in the process of shutting down.
next_index_[peer_id] = std::max(next_index_[peer_id] - 1, 1UL);
@ -786,25 +777,17 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
VLOG(40) << "Snapshot size: " << snapshot_size << " bytes.";
bool unreachable_peer = false;
auto peer_future = coordination_->ExecuteOnWorker<InstallSnapshotRes>(
peer_id, [&](int worker_id, auto &client) {
try {
auto res = client.template Call<InstallSnapshotRpc>(
server_id_, request_term, snapshot_metadata, std::move(snapshot),
snapshot_size);
return res;
} catch (...) {
unreachable_peer = true;
return InstallSnapshotRes(request_term);
}
});
// Copy all internal variables before releasing the lock.
auto server_id = server_id_;
// Execute the RPC.
lock->unlock();
auto reply = peer_future.get();
auto reply = coordination_->ExecuteOnOtherWorker<InstallSnapshotRpc>(
peer_id, server_id, request_term, snapshot_metadata, std::move(snapshot),
snapshot_size);
lock->lock();
if (unreachable_peer) {
if (!reply) {
next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval;
return;
}
@ -813,12 +796,12 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
return;
}
if (OutOfSync(reply.term)) {
if (OutOfSync(reply->term)) {
state_changed_.notify_all();
return;
}
if (reply.term != current_term_) {
if (reply->term != current_term_) {
VLOG(40) << "Server " << server_id_
<< ": Ignoring stale InstallSnapshotRpc reply from " << peer_id;
return;
@ -874,31 +857,26 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) {
// TODO(ipaljak): Consider backoff.
wait_until = TimePoint::max();
// Copy all internal variables before releasing the lock.
auto server_id = server_id_;
auto request_term = current_term_.load();
auto peer_future = coordination_->ExecuteOnWorker<RequestVoteRes>(
peer_id, [&](int worker_id, auto &client) {
auto last_entry_data = LastEntryData();
try {
auto res = client.template Call<RequestVoteRpc>(
server_id_, request_term, last_entry_data.first,
last_entry_data.second);
return res;
} catch (...) {
// not being able to connect to peer defaults to a vote
// being denied from that peer. This is correct but not
// optimal.
//
// TODO(ipaljak): reconsider this decision :)
return RequestVoteRes(false, request_term);
}
});
vote_requested_[peer_id] = true;
// Execute the RPC.
lock.unlock(); // Release lock while waiting for response
auto reply = peer_future.get();
auto reply = coordination_->ExecuteOnOtherWorker<RequestVoteRpc>(
peer_id, server_id, request_term, last_entry_data.first,
last_entry_data.second);
lock.lock();
// If the peer isn't reachable, it is the same as if he didn't grant
// us his vote.
if (!reply) {
reply = RequestVoteRes(false, request_term);
}
if (current_term_ != request_term || mode_ != Mode::CANDIDATE ||
exiting_) {
VLOG(40) << "Server " << server_id_
@ -906,16 +884,16 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) {
break;
}
if (OutOfSync(reply.term)) {
if (OutOfSync(reply->term)) {
state_changed_.notify_all();
continue;
}
if (reply.vote_granted) {
if (reply->vote_granted) {
VLOG(40) << "Server " << server_id_ << ": Got vote from "
<< peer_id;
++granted_votes_;
if (HasMajortyVote()) Transition(Mode::LEADER);
if (HasMajorityVote()) Transition(Mode::LEADER);
} else {
VLOG(40) << "Server " << server_id_ << ": Denied vote from "
<< peer_id;
@ -1049,7 +1027,7 @@ void RaftServer::SetNextElectionTimePoint() {
next_election_ = Clock::now() + wait_interval;
}
bool RaftServer::HasMajortyVote() {
bool RaftServer::HasMajorityVote() {
if (2 * granted_votes_ > coordination_->WorkerCount()) {
VLOG(40) << "Server " << server_id_
<< ": Obtained majority vote (Term: " << current_term_ << ")";

View File

@ -326,7 +326,7 @@ class RaftServer final : public RaftInterface {
void SetNextElectionTimePoint();
/// Checks if the current server obtained enough votes to become a leader.
bool HasMajortyVote();
bool HasMajorityVote();
/// Returns relevant metadata about the last entry in this server's Raft Log.
/// More precisely, returns a pair consisting of an index of the last entry

View File

@ -14,8 +14,6 @@ using namespace std::literals::chrono_literals;
using Clock = std::chrono::system_clock;
using TimePoint = std::chrono::system_clock::time_point;
const std::chrono::duration<int64_t> kRpcTimeout = 1s;
StorageInfo::StorageInfo(database::GraphDb *db, Coordination *coordination,
uint16_t server_id)
: db_(db), coordination_(coordination), server_id_(server_id) {
@ -56,44 +54,19 @@ StorageInfo::GetLocalStorageInfo() const {
std::map<std::string, std::vector<std::pair<std::string, std::string>>>
StorageInfo::GetStorageInfo() const {
std::map<std::string, std::vector<std::pair<std::string, std::string>>> info;
std::map<uint16_t, utils::Future<StorageInfoRes>> remote_storage_info_futures;
std::map<uint16_t, bool> received_reply;
auto peers = coordination_->GetWorkerIds();
for (auto id : peers) {
received_reply[id] = false;
if (id == server_id_) {
info.emplace(std::to_string(id), GetLocalStorageInfo());
received_reply[id] = true;
} else {
remote_storage_info_futures.emplace(
id, coordination_->ExecuteOnWorker<StorageInfoRes>(
id, [&](int worker_id, auto &client) {
try {
auto res = client.template Call<StorageInfoRpc>();
return res;
} catch (...) {
return StorageInfoRes(id, {});
auto reply = coordination_->ExecuteOnOtherWorker<StorageInfoRpc>(id);
if (reply) {
info[std::to_string(id)] = std::move(reply->storage_info);
} else {
info[std::to_string(id)] = {};
}
}));
}
}
int16_t waiting_for = peers.size() - 1;
TimePoint start = Clock::now();
while (Clock::now() - start <= kRpcTimeout && waiting_for > 0) {
for (auto id : peers) {
if (received_reply[id]) continue;
auto &future = remote_storage_info_futures[id];
if (!future.IsReady()) continue;
auto reply = future.get();
info.emplace(std::to_string(reply.server_id),
std::move(reply.storage_info));
received_reply[id] = true;
waiting_for--;
}
}