From 3ffed4bf6df260a76469a5b622309182a1f9f75b Mon Sep 17 00:00:00 2001
From: Matej Ferencevic <matej.ferencevic@memgraph.io>
Date: Wed, 24 Apr 2019 13:11:05 +0200
Subject: [PATCH] Refactor HA RPC clients

Reviewers: msantl, ipaljak

Reviewed By: ipaljak

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1977
---
 src/raft/coordination.cpp |  26 +++-------
 src/raft/coordination.hpp |  74 ++++++++++++++++------------
 src/raft/raft_server.cpp  | 100 +++++++++++++++-----------------------
 src/raft/raft_server.hpp  |   2 +-
 src/raft/storage_info.cpp |  39 +++------------
 5 files changed, 96 insertions(+), 145 deletions(-)

diff --git a/src/raft/coordination.cpp b/src/raft/coordination.cpp
index e6a23e2ac..2192fd93e 100644
--- a/src/raft/coordination.cpp
+++ b/src/raft/coordination.cpp
@@ -18,8 +18,11 @@ Coordination::Coordination(
     std::unordered_map<uint16_t, io::network::Endpoint> workers)
     : server_(workers[worker_id], server_workers_count),
       worker_id_(worker_id),
-      workers_(workers),
-      thread_pool_(client_workers_count, "RPC client") {}
+      workers_(workers) {
+  for (const auto &worker : workers_) {
+    client_locks_[worker.first] = std::make_unique<std::mutex>();
+  }
+}
 
 Coordination::~Coordination() {
   CHECK(!alive_) << "You must call Shutdown and AwaitShutdown on Coordination!";
@@ -71,24 +74,7 @@ std::vector<int> Coordination::GetWorkerIds() {
   return worker_ids;
 }
 
-communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) {
-  std::lock_guard<std::mutex> guard(lock_);
-  auto found = client_pools_.find(worker_id);
-  if (found != client_pools_.end()) return &found->second;
-  auto found_endpoint = workers_.find(worker_id);
-  CHECK(found_endpoint != workers_.end())
-      << "No endpoint registered for worker id: " << worker_id;
-  auto &endpoint = found_endpoint->second;
-  return &client_pools_
-              .emplace(std::piecewise_construct,
-                       std::forward_as_tuple(worker_id),
-                       std::forward_as_tuple(endpoint))
-              .first->second;
-}
-
-uint16_t Coordination::WorkerCount() {
-  return workers_.size();
-}
+uint16_t Coordination::WorkerCount() { return workers_.size(); }
 
 bool Coordination::Start() {
   if (!server_.Start()) return false;
diff --git a/src/raft/coordination.hpp b/src/raft/coordination.hpp
index 948e37ec0..85d9cdda1 100644
--- a/src/raft/coordination.hpp
+++ b/src/raft/coordination.hpp
@@ -5,13 +5,15 @@
 #include <atomic>
 #include <filesystem>
 #include <functional>
+#include <memory>
 #include <mutex>
+#include <optional>
 #include <thread>
 #include <type_traits>
 #include <unordered_map>
 #include <vector>
 
-#include "communication/rpc/client_pool.hpp"
+#include "communication/rpc/client.hpp"
 #include "communication/rpc/server.hpp"
 #include "io/network/endpoint.hpp"
 #include "raft/exceptions.hpp"
@@ -57,36 +59,48 @@ class Coordination final {
   /// Returns all workers ids.
   std::vector<int> GetWorkerIds();
 
-  /// Returns a cached `ClientPool` for the given `worker_id`.
-  communication::rpc::ClientPool *GetClientPool(int worker_id);
-
   uint16_t WorkerCount();
 
-  /// Asynchronously executes the given function on the RPC client for the
-  /// given worker id. Returns an `std::future` of the given `execute`
-  /// function's return type.
-  template <typename TResult>
-  auto ExecuteOnWorker(
-      int worker_id,
-      const std::function<TResult(int worker_id,
-                                  communication::rpc::ClientPool &)> &execute) {
-    auto client_pool = GetClientPool(worker_id);
-    return thread_pool_.Run(execute, worker_id, std::ref(*client_pool));
-  }
-  /// Asynchroniously executes the `execute` function on all worker rpc clients
-  /// except the one whose id is `skip_worker_id`. Returns a vector of futures
-  /// contaning the results of the `execute` function.
-  template <typename TResult>
-  auto ExecuteOnWorkers(
-      int skip_worker_id,
-      const std::function<TResult(int worker_id,
-                                  communication::rpc::ClientPool &)> &execute) {
-    std::vector<std::future<TResult>> futures;
-    for (auto &worker_id : GetWorkerIds()) {
-      if (worker_id == skip_worker_id) continue;
-      futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute)));
+  /// Executes a RPC on another worker in the cluster. If the RPC execution
+  /// fails (because of underlying network issues) it returns a `std::nullopt`.
+  template <class TRequestResponse, class... Args>
+  std::optional<typename TRequestResponse::Response> ExecuteOnOtherWorker(
+      uint16_t worker_id, Args &&... args) {
+    CHECK(worker_id != worker_id_) << "Trying to execute RPC on self!";
+
+    communication::rpc::Client *client = nullptr;
+    std::mutex *client_lock = nullptr;
+    {
+      std::lock_guard<std::mutex> guard(lock_);
+
+      auto found = clients_.find(worker_id);
+      if (found != clients_.end()) {
+        client = &found->second;
+      } else {
+        auto found_endpoint = workers_.find(worker_id);
+        CHECK(found_endpoint != workers_.end())
+            << "No endpoint registered for worker id: " << worker_id;
+        auto &endpoint = found_endpoint->second;
+        auto it = clients_.emplace(worker_id, endpoint);
+        client = &it.first->second;
+      }
+
+      auto lock_found = client_locks_.find(worker_id);
+      CHECK(lock_found != client_locks_.end())
+          << "No client lock for worker id: " << worker_id;
+      client_lock = lock_found->second.get();
+    }
+
+    try {
+      std::lock_guard<std::mutex> guard(*client_lock);
+      return client->Call<TRequestResponse>(std::forward<Args>(args)...);
+    } catch (...) {
+      // Invalidate the client so that we reconnect next time.
+      std::lock_guard<std::mutex> guard(lock_);
+      CHECK(clients_.erase(worker_id) == 1)
+          << "Couldn't remove client for worker id: " << worker_id;
+      return std::nullopt;
     }
-    return futures;
   }
 
   template <class TRequestResponse>
@@ -131,8 +145,8 @@ class Coordination final {
   mutable std::mutex lock_;
   std::unordered_map<uint16_t, io::network::Endpoint> workers_;
 
-  std::unordered_map<int, communication::rpc::ClientPool> client_pools_;
-  utils::ThreadPool thread_pool_;
+  std::unordered_map<uint16_t, communication::rpc::Client> clients_;
+  std::unordered_map<uint16_t, std::unique_ptr<std::mutex>> client_locks_;
 
   // Flags used for shutdown.
   std::atomic<bool> alive_{true};
diff --git a/src/raft/raft_server.cpp b/src/raft/raft_server.cpp
index 78c992c42..c980958c6 100644
--- a/src/raft/raft_server.cpp
+++ b/src/raft/raft_server.cpp
@@ -576,7 +576,7 @@ void RaftServer::Transition(const Mode &new_mode) {
 
       mode_ = Mode::CANDIDATE;
 
-      if (HasMajortyVote()) {
+      if (HasMajorityVote()) {
         Transition(Mode::LEADER);
         state_changed_.notify_all();
         return;
@@ -699,29 +699,20 @@ void RaftServer::SendLogEntries(
   if (next_index_[peer_id] <= log_size_ - 1)
     GetLogSuffix(next_index_[peer_id], request_entries);
 
-  bool unreachable_peer = false;
-  auto peer_future = coordination_->ExecuteOnWorker<AppendEntriesRes>(
-      peer_id, [&](int worker_id, auto &client) {
-        try {
-          auto res = client.template Call<AppendEntriesRpc>(
-              server_id_, commit_index_, request_term, request_prev_log_index,
-              request_prev_log_term, request_entries);
-          return res;
-        } catch (...) {
-          // not being able to connect to peer means we need to retry.
-          // TODO(ipaljak): Consider backoff.
-          unreachable_peer = true;
-          return AppendEntriesRes(false, request_term);
-        }
-      });
+  // Copy all internal variables before releasing the lock.
+  auto server_id = server_id_;
+  auto commit_index = commit_index_;
 
   VLOG(40) << "Entries size: " << request_entries.size();
 
-  lock->unlock();  // Release lock while waiting for response.
-  auto reply = peer_future.get();
+  // Execute the RPC.
+  lock->unlock();
+  auto reply = coordination_->ExecuteOnOtherWorker<AppendEntriesRpc>(
+      peer_id, server_id, commit_index, request_term, request_prev_log_index,
+      request_prev_log_term, request_entries);
   lock->lock();
 
-  if (unreachable_peer) {
+  if (!reply) {
     next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval;
     return;
   }
@@ -730,7 +721,7 @@ void RaftServer::SendLogEntries(
     return;
   }
 
-  if (OutOfSync(reply.term)) {
+  if (OutOfSync(reply->term)) {
     state_changed_.notify_all();
     return;
   }
@@ -738,13 +729,13 @@ void RaftServer::SendLogEntries(
   DCHECK(mode_ == Mode::LEADER)
       << "Elected leader for term should never change.";
 
-  if (reply.term != current_term_) {
+  if (reply->term != current_term_) {
     VLOG(40) << "Server " << server_id_
              << ": Ignoring stale AppendEntriesRPC reply from " << peer_id;
     return;
   }
 
-  if (!reply.success) {
+  if (!reply->success) {
     // Replication can fail for the first log entry if the peer that we're
     // sending the entry is in the process of shutting down.
     next_index_[peer_id] = std::max(next_index_[peer_id] - 1, 1UL);
@@ -786,25 +777,17 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
 
   VLOG(40) << "Snapshot size: " << snapshot_size << " bytes.";
 
-  bool unreachable_peer = false;
-  auto peer_future = coordination_->ExecuteOnWorker<InstallSnapshotRes>(
-      peer_id, [&](int worker_id, auto &client) {
-        try {
-          auto res = client.template Call<InstallSnapshotRpc>(
-              server_id_, request_term, snapshot_metadata, std::move(snapshot),
-              snapshot_size);
-          return res;
-        } catch (...) {
-          unreachable_peer = true;
-          return InstallSnapshotRes(request_term);
-        }
-      });
+  // Copy all internal variables before releasing the lock.
+  auto server_id = server_id_;
 
+  // Execute the RPC.
   lock->unlock();
-  auto reply = peer_future.get();
+  auto reply = coordination_->ExecuteOnOtherWorker<InstallSnapshotRpc>(
+      peer_id, server_id, request_term, snapshot_metadata, std::move(snapshot),
+      snapshot_size);
   lock->lock();
 
-  if (unreachable_peer) {
+  if (!reply) {
     next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval;
     return;
   }
@@ -813,12 +796,12 @@ void RaftServer::SendSnapshot(uint16_t peer_id,
     return;
   }
 
-  if (OutOfSync(reply.term)) {
+  if (OutOfSync(reply->term)) {
     state_changed_.notify_all();
     return;
   }
 
-  if (reply.term != current_term_) {
+  if (reply->term != current_term_) {
     VLOG(40) << "Server " << server_id_
              << ": Ignoring stale InstallSnapshotRpc reply from " << peer_id;
     return;
@@ -874,31 +857,26 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) {
           // TODO(ipaljak): Consider backoff.
           wait_until = TimePoint::max();
 
+          // Copy all internal variables before releasing the lock.
+          auto server_id = server_id_;
           auto request_term = current_term_.load();
-          auto peer_future = coordination_->ExecuteOnWorker<RequestVoteRes>(
-              peer_id, [&](int worker_id, auto &client) {
-                auto last_entry_data = LastEntryData();
-                try {
-                  auto res = client.template Call<RequestVoteRpc>(
-                      server_id_, request_term, last_entry_data.first,
-                      last_entry_data.second);
-                  return res;
-                } catch (...) {
-                  // not being able to connect to peer defaults to a vote
-                  // being denied from that peer. This is correct but not
-                  // optimal.
-                  //
-                  // TODO(ipaljak): reconsider this decision :)
-                  return RequestVoteRes(false, request_term);
-                }
-              });
+          auto last_entry_data = LastEntryData();
 
           vote_requested_[peer_id] = true;
 
+          // Execute the RPC.
           lock.unlock();  // Release lock while waiting for response
-          auto reply = peer_future.get();
+          auto reply = coordination_->ExecuteOnOtherWorker<RequestVoteRpc>(
+              peer_id, server_id, request_term, last_entry_data.first,
+              last_entry_data.second);
           lock.lock();
 
+          // If the peer isn't reachable, it is the same as if he didn't grant
+          // us his vote.
+          if (!reply) {
+            reply = RequestVoteRes(false, request_term);
+          }
+
           if (current_term_ != request_term || mode_ != Mode::CANDIDATE ||
               exiting_) {
             VLOG(40) << "Server " << server_id_
@@ -906,16 +884,16 @@ void RaftServer::PeerThreadMain(uint16_t peer_id) {
             break;
           }
 
-          if (OutOfSync(reply.term)) {
+          if (OutOfSync(reply->term)) {
             state_changed_.notify_all();
             continue;
           }
 
-          if (reply.vote_granted) {
+          if (reply->vote_granted) {
             VLOG(40) << "Server " << server_id_ << ": Got vote from "
                      << peer_id;
             ++granted_votes_;
-            if (HasMajortyVote()) Transition(Mode::LEADER);
+            if (HasMajorityVote()) Transition(Mode::LEADER);
           } else {
             VLOG(40) << "Server " << server_id_ << ": Denied vote from "
                      << peer_id;
@@ -1049,7 +1027,7 @@ void RaftServer::SetNextElectionTimePoint() {
   next_election_ = Clock::now() + wait_interval;
 }
 
-bool RaftServer::HasMajortyVote() {
+bool RaftServer::HasMajorityVote() {
   if (2 * granted_votes_ > coordination_->WorkerCount()) {
     VLOG(40) << "Server " << server_id_
              << ": Obtained majority vote (Term: " << current_term_ << ")";
diff --git a/src/raft/raft_server.hpp b/src/raft/raft_server.hpp
index c68dd297f..d301cb622 100644
--- a/src/raft/raft_server.hpp
+++ b/src/raft/raft_server.hpp
@@ -326,7 +326,7 @@ class RaftServer final : public RaftInterface {
   void SetNextElectionTimePoint();
 
   /// Checks if the current server obtained enough votes to become a leader.
-  bool HasMajortyVote();
+  bool HasMajorityVote();
 
   /// Returns relevant metadata about the last entry in this server's Raft Log.
   /// More precisely, returns a pair consisting of an index of the last entry
diff --git a/src/raft/storage_info.cpp b/src/raft/storage_info.cpp
index 63469c757..9643f8c61 100644
--- a/src/raft/storage_info.cpp
+++ b/src/raft/storage_info.cpp
@@ -14,8 +14,6 @@ using namespace std::literals::chrono_literals;
 using Clock = std::chrono::system_clock;
 using TimePoint = std::chrono::system_clock::time_point;
 
-const std::chrono::duration<int64_t> kRpcTimeout = 1s;
-
 StorageInfo::StorageInfo(database::GraphDb *db, Coordination *coordination,
                          uint16_t server_id)
     : db_(db), coordination_(coordination), server_id_(server_id) {
@@ -56,44 +54,19 @@ StorageInfo::GetLocalStorageInfo() const {
 std::map<std::string, std::vector<std::pair<std::string, std::string>>>
 StorageInfo::GetStorageInfo() const {
   std::map<std::string, std::vector<std::pair<std::string, std::string>>> info;
-  std::map<uint16_t, utils::Future<StorageInfoRes>> remote_storage_info_futures;
-  std::map<uint16_t, bool> received_reply;
 
   auto peers = coordination_->GetWorkerIds();
 
   for (auto id : peers) {
-    received_reply[id] = false;
     if (id == server_id_) {
       info.emplace(std::to_string(id), GetLocalStorageInfo());
-      received_reply[id] = true;
     } else {
-      remote_storage_info_futures.emplace(
-          id, coordination_->ExecuteOnWorker<StorageInfoRes>(
-                  id, [&](int worker_id, auto &client) {
-                    try {
-                      auto res = client.template Call<StorageInfoRpc>();
-                      return res;
-                    } catch (...) {
-                      return StorageInfoRes(id, {});
-                    }
-                  }));
-    }
-  }
-
-  int16_t waiting_for = peers.size() - 1;
-
-  TimePoint start = Clock::now();
-  while (Clock::now() - start <= kRpcTimeout && waiting_for > 0) {
-    for (auto id : peers) {
-      if (received_reply[id]) continue;
-      auto &future = remote_storage_info_futures[id];
-      if (!future.IsReady()) continue;
-
-      auto reply = future.get();
-      info.emplace(std::to_string(reply.server_id),
-                   std::move(reply.storage_info));
-      received_reply[id] = true;
-      waiting_for--;
+      auto reply = coordination_->ExecuteOnOtherWorker<StorageInfoRpc>(id);
+      if (reply) {
+        info[std::to_string(id)] = std::move(reply->storage_info);
+      } else {
+        info[std::to_string(id)] = {};
+      }
     }
   }