diff --git a/src/raft/coordination.cpp b/src/raft/coordination.cpp index e6a23e2ac..2192fd93e 100644 --- a/src/raft/coordination.cpp +++ b/src/raft/coordination.cpp @@ -18,8 +18,11 @@ Coordination::Coordination( std::unordered_map workers) : server_(workers[worker_id], server_workers_count), worker_id_(worker_id), - workers_(workers), - thread_pool_(client_workers_count, "RPC client") {} + workers_(workers) { + for (const auto &worker : workers_) { + client_locks_[worker.first] = std::make_unique(); + } +} Coordination::~Coordination() { CHECK(!alive_) << "You must call Shutdown and AwaitShutdown on Coordination!"; @@ -71,24 +74,7 @@ std::vector Coordination::GetWorkerIds() { return worker_ids; } -communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) { - std::lock_guard guard(lock_); - auto found = client_pools_.find(worker_id); - if (found != client_pools_.end()) return &found->second; - auto found_endpoint = workers_.find(worker_id); - CHECK(found_endpoint != workers_.end()) - << "No endpoint registered for worker id: " << worker_id; - auto &endpoint = found_endpoint->second; - return &client_pools_ - .emplace(std::piecewise_construct, - std::forward_as_tuple(worker_id), - std::forward_as_tuple(endpoint)) - .first->second; -} - -uint16_t Coordination::WorkerCount() { - return workers_.size(); -} +uint16_t Coordination::WorkerCount() { return workers_.size(); } bool Coordination::Start() { if (!server_.Start()) return false; diff --git a/src/raft/coordination.hpp b/src/raft/coordination.hpp index 948e37ec0..85d9cdda1 100644 --- a/src/raft/coordination.hpp +++ b/src/raft/coordination.hpp @@ -5,13 +5,15 @@ #include #include #include +#include #include +#include #include #include #include #include -#include "communication/rpc/client_pool.hpp" +#include "communication/rpc/client.hpp" #include "communication/rpc/server.hpp" #include "io/network/endpoint.hpp" #include "raft/exceptions.hpp" @@ -57,36 +59,48 @@ class Coordination final { /// Returns all workers ids. std::vector GetWorkerIds(); - /// Returns a cached `ClientPool` for the given `worker_id`. - communication::rpc::ClientPool *GetClientPool(int worker_id); - uint16_t WorkerCount(); - /// Asynchronously executes the given function on the RPC client for the - /// given worker id. Returns an `std::future` of the given `execute` - /// function's return type. - template - auto ExecuteOnWorker( - int worker_id, - const std::function &execute) { - auto client_pool = GetClientPool(worker_id); - return thread_pool_.Run(execute, worker_id, std::ref(*client_pool)); - } - /// Asynchroniously executes the `execute` function on all worker rpc clients - /// except the one whose id is `skip_worker_id`. Returns a vector of futures - /// contaning the results of the `execute` function. - template - auto ExecuteOnWorkers( - int skip_worker_id, - const std::function &execute) { - std::vector> futures; - for (auto &worker_id : GetWorkerIds()) { - if (worker_id == skip_worker_id) continue; - futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute))); + /// Executes a RPC on another worker in the cluster. If the RPC execution + /// fails (because of underlying network issues) it returns a `std::nullopt`. + template + std::optional ExecuteOnOtherWorker( + uint16_t worker_id, Args &&... args) { + CHECK(worker_id != worker_id_) << "Trying to execute RPC on self!"; + + communication::rpc::Client *client = nullptr; + std::mutex *client_lock = nullptr; + { + std::lock_guard guard(lock_); + + auto found = clients_.find(worker_id); + if (found != clients_.end()) { + client = &found->second; + } else { + auto found_endpoint = workers_.find(worker_id); + CHECK(found_endpoint != workers_.end()) + << "No endpoint registered for worker id: " << worker_id; + auto &endpoint = found_endpoint->second; + auto it = clients_.emplace(worker_id, endpoint); + client = &it.first->second; + } + + auto lock_found = client_locks_.find(worker_id); + CHECK(lock_found != client_locks_.end()) + << "No client lock for worker id: " << worker_id; + client_lock = lock_found->second.get(); + } + + try { + std::lock_guard guard(*client_lock); + return client->Call(std::forward(args)...); + } catch (...) { + // Invalidate the client so that we reconnect next time. + std::lock_guard guard(lock_); + CHECK(clients_.erase(worker_id) == 1) + << "Couldn't remove client for worker id: " << worker_id; + return std::nullopt; } - return futures; } template @@ -131,8 +145,8 @@ class Coordination final { mutable std::mutex lock_; std::unordered_map workers_; - std::unordered_map client_pools_; - utils::ThreadPool thread_pool_; + std::unordered_map clients_; + std::unordered_map> client_locks_; // Flags used for shutdown. std::atomic alive_{true}; diff --git a/src/raft/raft_server.cpp b/src/raft/raft_server.cpp index 78c992c42..c980958c6 100644 --- a/src/raft/raft_server.cpp +++ b/src/raft/raft_server.cpp @@ -576,7 +576,7 @@ void RaftServer::Transition(const Mode &new_mode) { mode_ = Mode::CANDIDATE; - if (HasMajortyVote()) { + if (HasMajorityVote()) { Transition(Mode::LEADER); state_changed_.notify_all(); return; @@ -699,29 +699,20 @@ void RaftServer::SendLogEntries( if (next_index_[peer_id] <= log_size_ - 1) GetLogSuffix(next_index_[peer_id], request_entries); - bool unreachable_peer = false; - auto peer_future = coordination_->ExecuteOnWorker( - peer_id, [&](int worker_id, auto &client) { - try { - auto res = client.template Call( - server_id_, commit_index_, request_term, request_prev_log_index, - request_prev_log_term, request_entries); - return res; - } catch (...) { - // not being able to connect to peer means we need to retry. - // TODO(ipaljak): Consider backoff. - unreachable_peer = true; - return AppendEntriesRes(false, request_term); - } - }); + // Copy all internal variables before releasing the lock. + auto server_id = server_id_; + auto commit_index = commit_index_; VLOG(40) << "Entries size: " << request_entries.size(); - lock->unlock(); // Release lock while waiting for response. - auto reply = peer_future.get(); + // Execute the RPC. + lock->unlock(); + auto reply = coordination_->ExecuteOnOtherWorker( + peer_id, server_id, commit_index, request_term, request_prev_log_index, + request_prev_log_term, request_entries); lock->lock(); - if (unreachable_peer) { + if (!reply) { next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval; return; } @@ -730,7 +721,7 @@ void RaftServer::SendLogEntries( return; } - if (OutOfSync(reply.term)) { + if (OutOfSync(reply->term)) { state_changed_.notify_all(); return; } @@ -738,13 +729,13 @@ void RaftServer::SendLogEntries( DCHECK(mode_ == Mode::LEADER) << "Elected leader for term should never change."; - if (reply.term != current_term_) { + if (reply->term != current_term_) { VLOG(40) << "Server " << server_id_ << ": Ignoring stale AppendEntriesRPC reply from " << peer_id; return; } - if (!reply.success) { + if (!reply->success) { // Replication can fail for the first log entry if the peer that we're // sending the entry is in the process of shutting down. next_index_[peer_id] = std::max(next_index_[peer_id] - 1, 1UL); @@ -786,25 +777,17 @@ void RaftServer::SendSnapshot(uint16_t peer_id, VLOG(40) << "Snapshot size: " << snapshot_size << " bytes."; - bool unreachable_peer = false; - auto peer_future = coordination_->ExecuteOnWorker( - peer_id, [&](int worker_id, auto &client) { - try { - auto res = client.template Call( - server_id_, request_term, snapshot_metadata, std::move(snapshot), - snapshot_size); - return res; - } catch (...) { - unreachable_peer = true; - return InstallSnapshotRes(request_term); - } - }); + // Copy all internal variables before releasing the lock. + auto server_id = server_id_; + // Execute the RPC. lock->unlock(); - auto reply = peer_future.get(); + auto reply = coordination_->ExecuteOnOtherWorker( + peer_id, server_id, request_term, snapshot_metadata, std::move(snapshot), + snapshot_size); lock->lock(); - if (unreachable_peer) { + if (!reply) { next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval; return; } @@ -813,12 +796,12 @@ void RaftServer::SendSnapshot(uint16_t peer_id, return; } - if (OutOfSync(reply.term)) { + if (OutOfSync(reply->term)) { state_changed_.notify_all(); return; } - if (reply.term != current_term_) { + if (reply->term != current_term_) { VLOG(40) << "Server " << server_id_ << ": Ignoring stale InstallSnapshotRpc reply from " << peer_id; return; @@ -874,31 +857,26 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) { // TODO(ipaljak): Consider backoff. wait_until = TimePoint::max(); + // Copy all internal variables before releasing the lock. + auto server_id = server_id_; auto request_term = current_term_.load(); - auto peer_future = coordination_->ExecuteOnWorker( - peer_id, [&](int worker_id, auto &client) { - auto last_entry_data = LastEntryData(); - try { - auto res = client.template Call( - server_id_, request_term, last_entry_data.first, - last_entry_data.second); - return res; - } catch (...) { - // not being able to connect to peer defaults to a vote - // being denied from that peer. This is correct but not - // optimal. - // - // TODO(ipaljak): reconsider this decision :) - return RequestVoteRes(false, request_term); - } - }); + auto last_entry_data = LastEntryData(); vote_requested_[peer_id] = true; + // Execute the RPC. lock.unlock(); // Release lock while waiting for response - auto reply = peer_future.get(); + auto reply = coordination_->ExecuteOnOtherWorker( + peer_id, server_id, request_term, last_entry_data.first, + last_entry_data.second); lock.lock(); + // If the peer isn't reachable, it is the same as if he didn't grant + // us his vote. + if (!reply) { + reply = RequestVoteRes(false, request_term); + } + if (current_term_ != request_term || mode_ != Mode::CANDIDATE || exiting_) { VLOG(40) << "Server " << server_id_ @@ -906,16 +884,16 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) { break; } - if (OutOfSync(reply.term)) { + if (OutOfSync(reply->term)) { state_changed_.notify_all(); continue; } - if (reply.vote_granted) { + if (reply->vote_granted) { VLOG(40) << "Server " << server_id_ << ": Got vote from " << peer_id; ++granted_votes_; - if (HasMajortyVote()) Transition(Mode::LEADER); + if (HasMajorityVote()) Transition(Mode::LEADER); } else { VLOG(40) << "Server " << server_id_ << ": Denied vote from " << peer_id; @@ -1049,7 +1027,7 @@ void RaftServer::SetNextElectionTimePoint() { next_election_ = Clock::now() + wait_interval; } -bool RaftServer::HasMajortyVote() { +bool RaftServer::HasMajorityVote() { if (2 * granted_votes_ > coordination_->WorkerCount()) { VLOG(40) << "Server " << server_id_ << ": Obtained majority vote (Term: " << current_term_ << ")"; diff --git a/src/raft/raft_server.hpp b/src/raft/raft_server.hpp index c68dd297f..d301cb622 100644 --- a/src/raft/raft_server.hpp +++ b/src/raft/raft_server.hpp @@ -326,7 +326,7 @@ class RaftServer final : public RaftInterface { void SetNextElectionTimePoint(); /// Checks if the current server obtained enough votes to become a leader. - bool HasMajortyVote(); + bool HasMajorityVote(); /// Returns relevant metadata about the last entry in this server's Raft Log. /// More precisely, returns a pair consisting of an index of the last entry diff --git a/src/raft/storage_info.cpp b/src/raft/storage_info.cpp index 63469c757..9643f8c61 100644 --- a/src/raft/storage_info.cpp +++ b/src/raft/storage_info.cpp @@ -14,8 +14,6 @@ using namespace std::literals::chrono_literals; using Clock = std::chrono::system_clock; using TimePoint = std::chrono::system_clock::time_point; -const std::chrono::duration kRpcTimeout = 1s; - StorageInfo::StorageInfo(database::GraphDb *db, Coordination *coordination, uint16_t server_id) : db_(db), coordination_(coordination), server_id_(server_id) { @@ -56,44 +54,19 @@ StorageInfo::GetLocalStorageInfo() const { std::map>> StorageInfo::GetStorageInfo() const { std::map>> info; - std::map> remote_storage_info_futures; - std::map received_reply; auto peers = coordination_->GetWorkerIds(); for (auto id : peers) { - received_reply[id] = false; if (id == server_id_) { info.emplace(std::to_string(id), GetLocalStorageInfo()); - received_reply[id] = true; } else { - remote_storage_info_futures.emplace( - id, coordination_->ExecuteOnWorker( - id, [&](int worker_id, auto &client) { - try { - auto res = client.template Call(); - return res; - } catch (...) { - return StorageInfoRes(id, {}); - } - })); - } - } - - int16_t waiting_for = peers.size() - 1; - - TimePoint start = Clock::now(); - while (Clock::now() - start <= kRpcTimeout && waiting_for > 0) { - for (auto id : peers) { - if (received_reply[id]) continue; - auto &future = remote_storage_info_futures[id]; - if (!future.IsReady()) continue; - - auto reply = future.get(); - info.emplace(std::to_string(reply.server_id), - std::move(reply.storage_info)); - received_reply[id] = true; - waiting_for--; + auto reply = coordination_->ExecuteOnOtherWorker(id); + if (reply) { + info[std::to_string(id)] = std::move(reply->storage_info); + } else { + info[std::to_string(id)] = {}; + } } }