diff --git a/src/database/single_node_ha/graph_db.cpp b/src/database/single_node_ha/graph_db.cpp index 705634637..15600781f 100644 --- a/src/database/single_node_ha/graph_db.cpp +++ b/src/database/single_node_ha/graph_db.cpp @@ -19,6 +19,7 @@ GraphDb::GraphDb(Config config) : config_(config) {} void GraphDb::Start() { utils::CheckDir(config_.durability_directory); + raft_server_.Start(); CHECK(coordination_.Start()) << "Couldn't start coordination!"; // Start transaction killer. @@ -62,7 +63,10 @@ bool GraphDb::AwaitShutdown(std::function call_before_shutdown) { return ret; } -void GraphDb::Shutdown() { coordination_.Shutdown(); } +void GraphDb::Shutdown() { + raft_server_.Shutdown(); + coordination_.Shutdown(); +} std::unique_ptr GraphDb::Access() { // NOTE: We are doing a heap allocation to allow polymorphism. If this poses diff --git a/src/raft/raft_server.cpp b/src/raft/raft_server.cpp index 70e640e4e..9db016209 100644 --- a/src/raft/raft_server.cpp +++ b/src/raft/raft_server.cpp @@ -27,7 +27,9 @@ RaftServer::RaftServer(uint16_t server_id, const std::string &durability_dir, mode_(Mode::FOLLOWER), server_id_(server_id), disk_storage_(fs::path(durability_dir) / kRaftDir), - reset_callback_(reset_callback) { + reset_callback_(reset_callback) {} + +void RaftServer::Start() { // Persistent storage initialization/recovery. if (Log().empty()) { UpdateTerm(0); @@ -36,14 +38,14 @@ RaftServer::RaftServer(uint16_t server_id, const std::string &durability_dir, } // Peer state - int cluster_size = coordination_->WorkerCount(); + int cluster_size = coordination_->WorkerCount() + 1; next_index_.resize(cluster_size); match_index_.resize(cluster_size); next_heartbeat_.resize(cluster_size); backoff_until_.resize(cluster_size); // RPC registration - coordination->Register( + coordination_->Register( [this](const auto &req_reader, auto *res_builder) { std::lock_guard guard(lock_); RequestVoteReq req; @@ -64,8 +66,7 @@ RaftServer::RaftServer(uint16_t server_id, const std::string &durability_dir, // set currentTerm = T and convert to follower. if (req.term > current_term) { UpdateTerm(req.term); - if (mode_ != Mode::FOLLOWER) - Transition(Mode::FOLLOWER); + if (mode_ != Mode::FOLLOWER) Transition(Mode::FOLLOWER); } // [Raft paper 5.2, 5.4] @@ -105,8 +106,7 @@ RaftServer::RaftServer(uint16_t server_id, const std::string &durability_dir, // set currentTerm = T and convert to follower. if (req.term > current_term) { UpdateTerm(req.term); - if (mode_ != Mode::FOLLOWER) - Transition(Mode::FOLLOWER); + if (mode_ != Mode::FOLLOWER) Transition(Mode::FOLLOWER); } // respond positively to a heartbeat. @@ -162,11 +162,14 @@ RaftServer::RaftServer(uint16_t server_id, const std::string &durability_dir, } } -RaftServer::~RaftServer() { - exiting_ = true; +void RaftServer::Shutdown() { + { + std::lock_guard guard(lock_); + exiting_ = true; - state_changed_.notify_all(); - election_change_.notify_all(); + state_changed_.notify_all(); + election_change_.notify_all(); + } for (auto &peer_thread : peer_threads_) { if (peer_thread.joinable()) peer_thread.join(); @@ -449,6 +452,8 @@ void RaftServer::PeerThreadMain(int peer_id) { }); next_heartbeat_[peer_id] = Clock::now() + config_.heartbeat_interval; + state_changed_.notify_all(); + continue; } wait_until = next_heartbeat_[peer_id]; break; diff --git a/src/raft/raft_server.hpp b/src/raft/raft_server.hpp index 92cd74289..e3292e4ef 100644 --- a/src/raft/raft_server.hpp +++ b/src/raft/raft_server.hpp @@ -18,7 +18,7 @@ namespace raft { -using Clock = std::chrono::system_clock; +using Clock = std::chrono::system_clock; using TimePoint = std::chrono::system_clock::time_point; enum class Mode { FOLLOWER, CANDIDATE, LEADER }; @@ -55,7 +55,11 @@ class RaftServer final : public RaftInterface { const Config &config, raft::Coordination *coordination, std::function reset_callback); - ~RaftServer(); + /// Starts the RPC servers and starts mechanisms inside Raft protocol. + void Start(); + + /// Stops all threads responsible for the Raft protocol. + void Shutdown(); /// Retrieves the current term from persistent storage. /// @@ -117,7 +121,7 @@ class RaftServer final : public RaftInterface { bool IsStateDeltaTransactionEnd(const database::StateDelta &delta); }; - mutable std::mutex lock_; ///< Guards all internal state. + mutable std::mutex lock_; ///< Guards all internal state. ////////////////////////////////////////////////////////////////////////////// // volatile state on all servers @@ -138,13 +142,13 @@ class RaftServer final : public RaftInterface { /// log is ready for replication it will be discarded anyway. LogEntryBuffer log_entry_buffer_{this}; - std::vector peer_threads_; ///< One thread per peer which - ///< handles outgoing RPCs. + std::vector peer_threads_; ///< One thread per peer which + ///< handles outgoing RPCs. - std::condition_variable state_changed_; ///< Notifies all peer threads on - ///< relevant state change. + std::condition_variable state_changed_; ///< Notifies all peer threads on + ///< relevant state change. - bool exiting_ = false; ///< True on server shutdown. + bool exiting_ = false; ///< True on server shutdown. ////////////////////////////////////////////////////////////////////////////// // volatile state on followers and candidates @@ -153,8 +157,8 @@ class RaftServer final : public RaftInterface { std::thread election_thread_; ///< Timer thread for triggering elections. TimePoint next_election_; ///< Next election `TimePoint`. - std::condition_variable election_change_; ///> Used to notify election_thread - ///> on next_election_ change. + std::condition_variable election_change_; ///> Used to notify election_thread + ///> on next_election_ change. std::mt19937_64 rng_ = std::mt19937_64(std::random_device{}());