diff --git a/src/communication/bolt/client.hpp b/src/communication/bolt/client.hpp index c2cfe47ea..a03cc90f4 100644 --- a/src/communication/bolt/client.hpp +++ b/src/communication/bolt/client.hpp @@ -19,7 +19,17 @@ namespace communication::bolt { class ClientQueryException : public utils::BasicException { public: using utils::BasicException::BasicException; + ClientQueryException() : utils::BasicException("Couldn't execute query!") {} + + template + ClientQueryException(const std::string &code, Args &&... args) + : utils::BasicException(std::forward(args)...), code_(code) {} + + const std::string &code() const { return code_; } + + private: + std::string code_; }; /// This exception is thrown whenever a fatal error occurs during query @@ -154,7 +164,13 @@ class Client final { auto &tmp = fields.ValueMap(); auto it = tmp.find("message"); if (it != tmp.end()) { - throw ClientQueryException(it->second.ValueString()); + auto it_code = tmp.find("code"); + if (it_code != tmp.end()) { + throw ClientQueryException(it_code->second.ValueString(), + it->second.ValueString()); + } else { + throw ClientQueryException("", it->second.ValueString()); + } } throw ClientQueryException(); } else if (signature != Signature::Success) { @@ -192,7 +208,13 @@ class Client final { auto &tmp = data.ValueMap(); auto it = tmp.find("message"); if (it != tmp.end()) { - throw ClientQueryException(it->second.ValueString()); + auto it_code = tmp.find("code"); + if (it_code != tmp.end()) { + throw ClientQueryException(it_code->second.ValueString(), + it->second.ValueString()); + } else { + throw ClientQueryException("", it->second.ValueString()); + } } throw ClientQueryException(); } else { diff --git a/src/communication/bolt/ha_client.hpp b/src/communication/bolt/ha_client.hpp new file mode 100644 index 000000000..d5f94e1f3 --- /dev/null +++ b/src/communication/bolt/ha_client.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include +#include + +#include + +#include "communication/bolt/client.hpp" + +namespace communication::bolt { + +/// HA Bolt client. +/// It has methods used to execute queries against a cluster of servers. It +/// supports both SSL and plaintext connections. +class HAClient final { + public: + HAClient(const std::vector &endpoints, + communication::ClientContext *context, const std::string &username, + const std::string &password, uint64_t num_retries, + const std::chrono::milliseconds &retry_delay, + const std::string &client_name = "memgraph-bolt") + : endpoints_(endpoints), + context_(context), + username_(username), + password_(password), + num_retries_(num_retries), + retry_delay_(retry_delay), + client_name_(client_name) { + if (endpoints.size() < 3) { + throw ClientFatalException( + "You should specify at least three server endpoints to connect to!"); + } + // Create all clients. + for (size_t i = 0; i < endpoints.size(); ++i) { + clients_.push_back(std::make_unique(context_)); + } + } + + HAClient(const HAClient &) = delete; + HAClient(HAClient &&) = delete; + HAClient &operator=(const HAClient &) = delete; + HAClient &operator=(HAClient &&) = delete; + + /// Function used to execute queries against the leader server. + /// @throws ClientQueryException when there is some transient error while + /// executing the query (eg. mistyped query, + /// etc.) + /// @throws ClientFatalException when we couldn't communicate with the leader + /// server even after `num_retries` tries + QueryData Execute(const std::string &query, + const std::map ¶meters) { + for (int i = 0; i < num_retries_; ++i) { + // Try to find a leader. + if (!leader_) { + for (int j = 0; j < num_retries_; ++j) { + if (!(i == 0 && j == 0)) { + std::this_thread::sleep_for( + std::chrono::milliseconds(retry_delay_)); + } + try { + FindLeader(); + break; + } catch (const ClientFatalException &e) { + continue; + } + } + if (!leader_) { + throw ClientFatalException("Couldn't find leader after {} tries!", + num_retries_); + } + } + // Try to execute the query. + try { + return leader_->Execute(query, parameters); + } catch (const utils::BasicException &e) { + // Check if this is a cluster failure or a Raft failure. + auto qe = dynamic_cast(&e); + if (dynamic_cast(&e) || + (qe && qe->code() == "Memgraph.DatabaseError.Raft.Error")) { + // We need to look for a new leader. + leader_ = nullptr; + continue; + } + // If it isn't just forward the exception to the client. + throw; + } + } + throw ClientFatalException("Couldn't execute query after {} tries!", + num_retries_); + } + + private: + void FindLeader() { + // Reconnect clients that aren't available + bool connected = false; + for (size_t i = 0; i < clients_.size(); ++i) { + const auto &ep = endpoints_[i]; + const auto &client = clients_[i]; + try { + client->Execute("SHOW RAFT INFO", {}); + connected = true; + continue; + } catch (const ClientQueryException &e) { + continue; + } catch (const ClientFatalException &e) { + client->Close(); + try { + client->Connect(ep, username_, password_, client_name_); + connected = true; + } catch (const utils::BasicException &) { + // Suppress any exceptions. + } + } + } + if (!connected) { + throw ClientFatalException("Couldn't connect to any server!"); + } + + // Determine which server is the leader + leader_ = nullptr; + int64_t leader_id = -1; + for (const auto &client : clients_) { + try { + auto ret = client->Execute("SHOW RAFT INFO", {}); + int64_t term_id = -1; + bool is_leader = false; + for (const auto &rec : ret.records) { + if (rec.size() != 2) continue; + if (!rec[0].IsString()) continue; + const auto &key = rec[0].ValueString(); + if (key == "term_id") { + if (!rec[1].IsInt()) continue; + term_id = rec[1].ValueInt(); + } else if (key == "is_leader") { + if (!rec[1].IsBool()) continue; + is_leader = rec[1].ValueBool(); + } else { + continue; + } + } + if (is_leader && term_id > leader_id) { + leader_id = term_id; + leader_ = client.get(); + } + } catch (const utils::BasicException &) { + continue; + } + } + if (!leader_) { + throw ClientFatalException("Couldn't find leader server!"); + } + } + + std::vector endpoints_; + communication::ClientContext *context_; + std::string username_; + std::string password_; + uint64_t num_retries_; + std::chrono::milliseconds retry_delay_; + std::string client_name_; + + Client *leader_ = nullptr; + std::vector> clients_; +}; +} // namespace communication::bolt diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 92b1b1890..3445beec7 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -18,6 +18,9 @@ #include "query/plan/planner.hpp" #include "query/plan/profile.hpp" #include "query/plan/vertex_count_cache.hpp" +#ifdef MG_SINGLE_NODE_HA +#include "raft/exceptions.hpp" +#endif #include "utils/exceptions.hpp" #include "utils/flag_validation.hpp" #include "utils/string.hpp" @@ -719,8 +722,7 @@ Interpreter::Results Interpreter::operator()( if (!db_accessor.raft()->IsLeader() && (!(info_query = utils::Downcast(parsed_query.query)) || info_query->info_type_ != InfoQuery::InfoType::RAFT)) { - throw QueryException( - "Memgraph High Availability: Can't execute queries if not leader."); + throw raft::CantExecuteQueries(); } } #endif diff --git a/src/raft/exceptions.hpp b/src/raft/exceptions.hpp index bf038bd08..add00f93b 100644 --- a/src/raft/exceptions.hpp +++ b/src/raft/exceptions.hpp @@ -78,4 +78,15 @@ class ReplicationTimeoutException : public RaftException { : RaftException("Raft Log replication is taking too long. ") {} }; +/// This exception is thrown when a client tries to execute a query on a server +/// that isn't a leader. +class CantExecuteQueries : public RaftException { + public: + using RaftException::RaftException; + CantExecuteQueries() + : RaftException( + "Memgraph High Availability: Can't execute queries if not " + "leader.") {} +}; + } // namespace raft