diff --git a/src/coordination/coordinator_client.cpp b/src/coordination/coordinator_client.cpp index bc7f42eaa..8530faff3 100644 --- a/src/coordination/coordinator_client.cpp +++ b/src/coordination/coordinator_client.cpp @@ -135,8 +135,7 @@ auto CoordinatorClient::SendSwapMainUUIDRpc(utils::UUID const &uuid) const -> bo auto CoordinatorClient::SendUnregisterReplicaRpc(std::string_view instance_name) const -> bool { try { - auto stream{rpc_client_.Stream( - std::string(instance_name))}; // TODO: (andi) Try to change to stream string_view and do just one copy later + auto stream{rpc_client_.Stream(instance_name)}; if (!stream.AwaitResponse().success) { spdlog::error("Failed to receive successful RPC response for unregistering replica!"); return false; diff --git a/src/coordination/coordinator_cluster_state.cpp b/src/coordination/coordinator_cluster_state.cpp index 60f0ca622..2213a052f 100644 --- a/src/coordination/coordinator_cluster_state.cpp +++ b/src/coordination/coordinator_cluster_state.cpp @@ -18,78 +18,87 @@ namespace memgraph::coordination { -using replication_coordination_glue::ReplicationRole; +void to_json(nlohmann::json &j, InstanceState const &instance_state) { + j = nlohmann::json{{"config", instance_state.config}, {"status", instance_state.status}}; +} -CoordinatorClusterState::CoordinatorClusterState(CoordinatorClusterState const &other) - : instance_roles_{other.instance_roles_} {} +void from_json(nlohmann::json const &j, InstanceState &instance_state) { + j.at("config").get_to(instance_state.config); + j.at("status").get_to(instance_state.status); +} + +CoordinatorClusterState::CoordinatorClusterState(std::map> instances) + : instances_{std::move(instances)} {} + +CoordinatorClusterState::CoordinatorClusterState(CoordinatorClusterState const &other) : instances_{other.instances_} {} CoordinatorClusterState &CoordinatorClusterState::operator=(CoordinatorClusterState const &other) { if (this == &other) { return *this; } - instance_roles_ = other.instance_roles_; + instances_ = other.instances_; return *this; } CoordinatorClusterState::CoordinatorClusterState(CoordinatorClusterState &&other) noexcept - : instance_roles_{std::move(other.instance_roles_)} {} + : instances_{std::move(other.instances_)} {} CoordinatorClusterState &CoordinatorClusterState::operator=(CoordinatorClusterState &&other) noexcept { if (this == &other) { return *this; } - instance_roles_ = std::move(other.instance_roles_); + instances_ = std::move(other.instances_); return *this; } auto CoordinatorClusterState::MainExists() const -> bool { auto lock = std::shared_lock{log_lock_}; - return std::ranges::any_of(instance_roles_, - [](auto const &entry) { return entry.second.role == ReplicationRole::MAIN; }); + return std::ranges::any_of(instances_, + [](auto const &entry) { return entry.second.status == ReplicationRole::MAIN; }); } auto CoordinatorClusterState::IsMain(std::string_view instance_name) const -> bool { auto lock = std::shared_lock{log_lock_}; - auto const it = instance_roles_.find(instance_name); - return it != instance_roles_.end() && it->second.role == ReplicationRole::MAIN; + auto const it = instances_.find(instance_name); + return it != instances_.end() && it->second.status == ReplicationRole::MAIN; } auto CoordinatorClusterState::IsReplica(std::string_view instance_name) const -> bool { auto lock = std::shared_lock{log_lock_}; - auto const it = instance_roles_.find(instance_name); - return it != instance_roles_.end() && it->second.role == ReplicationRole::REPLICA; + auto const it = instances_.find(instance_name); + return it != instances_.end() && it->second.status == ReplicationRole::REPLICA; } -auto CoordinatorClusterState::InsertInstance(std::string_view instance_name, ReplicationRole role) -> void { - auto lock = std::unique_lock{log_lock_}; - instance_roles_[instance_name.data()].role = role; +auto CoordinatorClusterState::InsertInstance(std::string instance_name, InstanceState instance_state) -> void { + auto lock = std::lock_guard{log_lock_}; + instances_.insert_or_assign(std::move(instance_name), std::move(instance_state)); } auto CoordinatorClusterState::DoAction(TRaftLog log_entry, RaftLogAction log_action) -> void { - auto lock = std::unique_lock{log_lock_}; + auto lock = std::lock_guard{log_lock_}; switch (log_action) { case RaftLogAction::REGISTER_REPLICATION_INSTANCE: { auto const &config = std::get(log_entry); - instance_roles_[config.instance_name] = InstanceState{config, ReplicationRole::REPLICA}; + instances_[config.instance_name] = InstanceState{config, ReplicationRole::REPLICA}; break; } case RaftLogAction::UNREGISTER_REPLICATION_INSTANCE: { auto const instance_name = std::get(log_entry); - instance_roles_.erase(instance_name); + instances_.erase(instance_name); break; } case RaftLogAction::SET_INSTANCE_AS_MAIN: { auto const instance_name = std::get(log_entry); - auto it = instance_roles_.find(instance_name); - MG_ASSERT(it != instance_roles_.end(), "Instance does not exist as part of raft state!"); - it->second.role = ReplicationRole::MAIN; + auto it = instances_.find(instance_name); + MG_ASSERT(it != instances_.end(), "Instance does not exist as part of raft state!"); + it->second.status = ReplicationRole::MAIN; break; } case RaftLogAction::SET_INSTANCE_AS_REPLICA: { auto const instance_name = std::get(log_entry); - auto it = instance_roles_.find(instance_name); - MG_ASSERT(it != instance_roles_.end(), "Instance does not exist as part of raft state!"); - it->second.role = ReplicationRole::REPLICA; + auto it = instances_.find(instance_name); + MG_ASSERT(it != instances_.end(), "Instance does not exist as part of raft state!"); + it->second.status = ReplicationRole::REPLICA; break; } case RaftLogAction::UPDATE_UUID: { @@ -99,64 +108,37 @@ auto CoordinatorClusterState::DoAction(TRaftLog log_entry, RaftLogAction log_act } } -// TODO: (andi) Improve based on Gareth's comments auto CoordinatorClusterState::Serialize(ptr &data) -> void { auto lock = std::shared_lock{log_lock_}; - auto const role_to_string = [](auto const &role) -> std::string_view { - switch (role) { - case ReplicationRole::MAIN: - return "main"; - case ReplicationRole::REPLICA: - return "replica"; - } - }; - auto const entry_to_string = [&role_to_string](auto const &entry) { - return fmt::format("{}_{}", entry.first, role_to_string(entry.second.role)); - }; + // .at(0) is hack to solve the problem with json serialization of map + auto const log = nlohmann::json{instances_}.at(0).dump(); - auto instances_str_view = instance_roles_ | ranges::views::transform(entry_to_string); - uint32_t size = - std::accumulate(instances_str_view.begin(), instances_str_view.end(), 0, - [](uint32_t acc, auto const &entry) { return acc + sizeof(uint32_t) + entry.size(); }); - - data = buffer::alloc(size); + data = buffer::alloc(sizeof(uint32_t) + log.size()); buffer_serializer bs(data); - std::for_each(instances_str_view.begin(), instances_str_view.end(), [&bs](auto const &entry) { bs.put_str(entry); }); + bs.put_str(log); } auto CoordinatorClusterState::Deserialize(buffer &data) -> CoordinatorClusterState { - auto const str_to_role = [](auto const &str) -> ReplicationRole { - if (str == "main") { - return ReplicationRole::MAIN; - } - return ReplicationRole::REPLICA; - }; - - CoordinatorClusterState cluster_state; buffer_serializer bs(data); - while (bs.size() > 0) { - auto const entry = bs.get_str(); - auto const first_dash = entry.find('_'); - auto const instance_name = entry.substr(0, first_dash); - auto const role_str = entry.substr(first_dash + 1); - cluster_state.InsertInstance(instance_name, str_to_role(role_str)); - } - return cluster_state; + auto const j = nlohmann::json::parse(bs.get_str()); + auto instances = j.get>>(); + + return CoordinatorClusterState{std::move(instances)}; } auto CoordinatorClusterState::GetInstances() const -> std::vector { auto lock = std::shared_lock{log_lock_}; - return instance_roles_ | ranges::views::values | ranges::to>; + return instances_ | ranges::views::values | ranges::to>; } auto CoordinatorClusterState::GetUUID() const -> utils::UUID { return uuid_; } auto CoordinatorClusterState::FindCurrentMainInstanceName() const -> std::optional { auto lock = std::shared_lock{log_lock_}; - auto const it = std::ranges::find_if(instance_roles_, - [](auto const &entry) { return entry.second.role == ReplicationRole::MAIN; }); - if (it == instance_roles_.end()) { + auto const it = + std::ranges::find_if(instances_, [](auto const &entry) { return entry.second.status == ReplicationRole::MAIN; }); + if (it == instances_.end()) { return {}; } return it->first; diff --git a/src/coordination/coordinator_instance.cpp b/src/coordination/coordinator_instance.cpp index 9a00ca87c..920fea3cb 100644 --- a/src/coordination/coordinator_instance.cpp +++ b/src/coordination/coordinator_instance.cpp @@ -36,7 +36,7 @@ CoordinatorInstance::CoordinatorInstance() spdlog::info("Leader changed, starting all replication instances!"); auto const instances = raft_state_.GetInstances(); auto replicas = instances | ranges::views::filter([](auto const &instance) { - return instance.role == ReplicationRole::REPLICA; + return instance.status == ReplicationRole::REPLICA; }); std::ranges::for_each(replicas, [this](auto &replica) { @@ -47,10 +47,7 @@ CoordinatorInstance::CoordinatorInstance() }); auto main = instances | ranges::views::filter( - [](auto const &instance) { return instance.role == ReplicationRole::MAIN; }); - - // TODO: (andi) Add support for this - // MG_ASSERT(std::ranges::distance(main) == 1, "There should be exactly one main instance"); + [](auto const &instance) { return instance.status == ReplicationRole::MAIN; }); std::ranges::for_each(main, [this](auto &main_instance) { spdlog::info("Starting main instance {}", main_instance.config.instance_name); @@ -60,7 +57,7 @@ CoordinatorInstance::CoordinatorInstance() }); std::ranges::for_each(repl_instances_, [this](auto &instance) { - instance.SetNewMainUUID(raft_state_.GetUUID()); // TODO: (andi) Rename + instance.SetNewMainUUID(raft_state_.GetUUID()); instance.StartFrequentCheck(); }); }, @@ -69,13 +66,13 @@ CoordinatorInstance::CoordinatorInstance() repl_instances_.clear(); })) { client_succ_cb_ = [](CoordinatorInstance *self, std::string_view repl_instance_name) -> void { - auto lock = std::unique_lock{self->coord_instance_lock_}; + auto lock = std::lock_guard{self->coord_instance_lock_}; auto &repl_instance = self->FindReplicationInstance(repl_instance_name); std::invoke(repl_instance.GetSuccessCallback(), self, repl_instance_name); }; client_fail_cb_ = [](CoordinatorInstance *self, std::string_view repl_instance_name) -> void { - auto lock = std::unique_lock{self->coord_instance_lock_}; + auto lock = std::lock_guard{self->coord_instance_lock_}; auto &repl_instance = self->FindReplicationInstance(repl_instance_name); std::invoke(repl_instance.GetFailCallback(), self, repl_instance_name); }; @@ -98,7 +95,6 @@ auto CoordinatorInstance::ShowInstances() const -> std::vector { .raft_socket_address = instance->get_endpoint(), .cluster_role = "coordinator", .health = "unknown"}; // TODO: (andi) Get this info from RAFT and test it or when we will move - // CoordinatorState to every instance, we can be smarter about this using our RPC. }; auto instances_status = utils::fmap(raft_state_.GetAllCoordinators(), coord_instance_to_status); @@ -126,14 +122,14 @@ auto CoordinatorInstance::ShowInstances() const -> std::vector { std::ranges::transform(repl_instances_, std::back_inserter(instances_status), process_repl_instance_as_leader); } } else { - auto const stringify_repl_role = [](ReplicationRole role) -> std::string { - return role == ReplicationRole::MAIN ? "main" : "replica"; + auto const stringify_inst_status = [](ReplicationRole status) -> std::string { + return status == ReplicationRole::MAIN ? "main" : "replica"; }; // TODO: (andi) Add capability that followers can also return socket addresses - auto process_repl_instance_as_follower = [&stringify_repl_role](auto const &instance) -> InstanceStatus { + auto process_repl_instance_as_follower = [&stringify_inst_status](auto const &instance) -> InstanceStatus { return {.instance_name = instance.config.instance_name, - .cluster_role = stringify_repl_role(instance.role), + .cluster_role = stringify_inst_status(instance.status), .health = "unknown"}; }; @@ -355,11 +351,11 @@ auto CoordinatorInstance::UnregisterReplicationInstance(std::string_view instanc return UnregisterInstanceCoordinatorStatus::NO_INSTANCE_WITH_NAME; } - // TODO: (andi) Change so that RaftLogState is the central place for asking who is main... + auto const is_main = [this](ReplicationInstance const &instance) { + return IsMain(instance.InstanceName()) && instance.GetMainUUID() == raft_state_.GetUUID() && instance.IsAlive(); + }; - auto const is_main = [this](ReplicationInstance const &instance) { return IsMain(instance.InstanceName()); }; - - if (is_main(*inst_to_remove) && inst_to_remove->IsAlive()) { + if (is_main(*inst_to_remove)) { return UnregisterInstanceCoordinatorStatus::IS_MAIN; } diff --git a/src/coordination/include/coordination/coordinator_rpc.hpp b/src/coordination/include/coordination/coordinator_rpc.hpp index 2bf88fe46..d799b2955 100644 --- a/src/coordination/include/coordination/coordinator_rpc.hpp +++ b/src/coordination/include/coordination/coordinator_rpc.hpp @@ -90,7 +90,7 @@ struct UnregisterReplicaReq { static void Load(UnregisterReplicaReq *self, memgraph::slk::Reader *reader); static void Save(UnregisterReplicaReq const &self, memgraph::slk::Builder *builder); - explicit UnregisterReplicaReq(std::string instance_name) : instance_name(std::move(instance_name)) {} + explicit UnregisterReplicaReq(std::string_view inst_name) : instance_name(inst_name) {} UnregisterReplicaReq() = default; diff --git a/src/coordination/include/coordination/raft_state.hpp b/src/coordination/include/coordination/raft_state.hpp index d702697f1..34da3e2a6 100644 --- a/src/coordination/include/coordination/raft_state.hpp +++ b/src/coordination/include/coordination/raft_state.hpp @@ -14,6 +14,7 @@ #ifdef MG_ENTERPRISE #include +#include "io/network/endpoint.hpp" #include "nuraft/coordinator_state_machine.hpp" #include "nuraft/coordinator_state_manager.hpp" @@ -79,9 +80,8 @@ class RaftState { private: // TODO: (andi) I think variables below can be abstracted/clean them. + io::network::Endpoint raft_endpoint_; uint32_t raft_server_id_; - uint32_t raft_port_; - std::string raft_address_; ptr state_machine_; ptr state_manager_; diff --git a/src/coordination/include/nuraft/coordinator_cluster_state.hpp b/src/coordination/include/nuraft/coordinator_cluster_state.hpp index f38d00073..11d539a14 100644 --- a/src/coordination/include/nuraft/coordinator_cluster_state.hpp +++ b/src/coordination/include/nuraft/coordinator_cluster_state.hpp @@ -21,6 +21,7 @@ #include #include +#include "json/json.hpp" #include #include @@ -33,9 +34,16 @@ using replication_coordination_glue::ReplicationRole; struct InstanceState { CoordinatorClientConfig config; - ReplicationRole role; + ReplicationRole status; + + friend auto operator==(InstanceState const &lhs, InstanceState const &rhs) -> bool { + return lhs.config == rhs.config && lhs.status == rhs.status; + } }; +void to_json(nlohmann::json &j, InstanceState const &instance_state); +void from_json(nlohmann::json const &j, InstanceState &instance_state); + using TRaftLog = std::variant; using nuraft::buffer; @@ -45,6 +53,8 @@ using nuraft::ptr; class CoordinatorClusterState { public: CoordinatorClusterState() = default; + explicit CoordinatorClusterState(std::map> instances); + CoordinatorClusterState(CoordinatorClusterState const &); CoordinatorClusterState &operator=(CoordinatorClusterState const &); @@ -60,7 +70,7 @@ class CoordinatorClusterState { auto IsReplica(std::string_view instance_name) const -> bool; - auto InsertInstance(std::string_view instance_name, ReplicationRole role) -> void; + auto InsertInstance(std::string instance_name, InstanceState instance_state) -> void; auto DoAction(TRaftLog log_entry, RaftLogAction log_action) -> void; @@ -73,7 +83,7 @@ class CoordinatorClusterState { auto GetUUID() const -> utils::UUID; private: - std::map> instance_roles_; + std::map> instances_{}; utils::UUID uuid_{}; mutable utils::ResourceLock log_lock_{}; }; diff --git a/src/coordination/include/nuraft/raft_log_action.hpp b/src/coordination/include/nuraft/raft_log_action.hpp index 953049038..3f1b26dfa 100644 --- a/src/coordination/include/nuraft/raft_log_action.hpp +++ b/src/coordination/include/nuraft/raft_log_action.hpp @@ -38,26 +38,5 @@ NLOHMANN_JSON_SERIALIZE_ENUM(RaftLogAction, { {RaftLogAction::UPDATE_UUID, "update_uuid"}, }) -inline auto ParseRaftLogAction(std::string_view action) -> RaftLogAction { - if (action == "register") { - return RaftLogAction::REGISTER_REPLICATION_INSTANCE; - } - if (action == "unregister") { - return RaftLogAction::UNREGISTER_REPLICATION_INSTANCE; - } - if (action == "promote") { - return RaftLogAction::SET_INSTANCE_AS_MAIN; - } - if (action == "demote") { - return RaftLogAction::SET_INSTANCE_AS_REPLICA; - } - - if (action == "update_uuid") { - return RaftLogAction::UPDATE_UUID; - } - - throw InvalidRaftLogActionException("Invalid Raft log action: {}.", action); -} - } // namespace memgraph::coordination #endif diff --git a/src/coordination/raft_state.cpp b/src/coordination/raft_state.cpp index 365388b06..d4d65cc36 100644 --- a/src/coordination/raft_state.cpp +++ b/src/coordination/raft_state.cpp @@ -32,12 +32,10 @@ using raft_result = cmd_result>; RaftState::RaftState(BecomeLeaderCb become_leader_cb, BecomeFollowerCb become_follower_cb, uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) - : raft_server_id_(raft_server_id), - raft_port_(raft_port), - raft_address_(std::move(raft_address)), + : raft_endpoint_(raft_address, raft_port), + raft_server_id_(raft_server_id), state_machine_(cs_new()), - state_manager_( - cs_new(raft_server_id_, raft_address_ + ":" + std::to_string(raft_port_))), + state_manager_(cs_new(raft_server_id_, raft_endpoint_.SocketAddress())), logger_(nullptr), become_leader_cb_(std::move(become_leader_cb)), become_follower_cb_(std::move(become_follower_cb)) {} @@ -71,11 +69,11 @@ auto RaftState::InitRaftServer() -> void { raft_launcher launcher; - raft_server_ = launcher.init(state_machine_, state_manager_, logger_, static_cast(raft_port_), asio_opts, params, - init_opts); + raft_server_ = + launcher.init(state_machine_, state_manager_, logger_, raft_endpoint_.port, asio_opts, params, init_opts); if (!raft_server_) { - throw RaftServerStartException("Failed to launch raft server on {}:{}", raft_address_, raft_port_); + throw RaftServerStartException("Failed to launch raft server on {}", raft_endpoint_.SocketAddress()); } auto maybe_stop = utils::ResettableCounter<20>(); @@ -86,7 +84,7 @@ auto RaftState::InitRaftServer() -> void { std::this_thread::sleep_for(std::chrono::milliseconds(250)); } while (!maybe_stop()); - throw RaftServerStartException("Failed to initialize raft server on {}:{}", raft_address_, raft_port_); + throw RaftServerStartException("Failed to initialize raft server on {}", raft_endpoint_.SocketAddress()); } auto RaftState::MakeRaftState(BecomeLeaderCb &&become_leader_cb, BecomeFollowerCb &&become_follower_cb) -> RaftState { @@ -102,9 +100,11 @@ auto RaftState::MakeRaftState(BecomeLeaderCb &&become_leader_cb, BecomeFollowerC RaftState::~RaftState() { launcher_.shutdown(); } -auto RaftState::InstanceName() const -> std::string { return "coordinator_" + std::to_string(raft_server_id_); } +auto RaftState::InstanceName() const -> std::string { + return fmt::format("coordinator_{}", std::to_string(raft_server_id_)); +} -auto RaftState::RaftSocketAddress() const -> std::string { return raft_address_ + ":" + std::to_string(raft_port_); } +auto RaftState::RaftSocketAddress() const -> std::string { return raft_endpoint_.SocketAddress(); } auto RaftState::AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string_view raft_address) -> void { diff --git a/src/io/network/endpoint.cpp b/src/io/network/endpoint.cpp index bb6dcfd10..6ed4a6753 100644 --- a/src/io/network/endpoint.cpp +++ b/src/io/network/endpoint.cpp @@ -22,113 +22,15 @@ #include "utils/message.hpp" #include "utils/string.hpp" +namespace { +constexpr std::string_view delimiter = ":"; +} // namespace + namespace memgraph::io::network { -Endpoint::IpFamily Endpoint::GetIpFamily(std::string_view address) { - in_addr addr4; - in6_addr addr6; - int ipv4_result = inet_pton(AF_INET, address.data(), &addr4); - int ipv6_result = inet_pton(AF_INET6, address.data(), &addr6); - if (ipv4_result == 1) { - return IpFamily::IP4; - } - if (ipv6_result == 1) { - return IpFamily::IP6; - } - return IpFamily::NONE; -} - -std::optional> Endpoint::ParseSocketOrIpAddress( - std::string_view address, const std::optional default_port) { - /// expected address format: - /// - "ip_address:port_number" - /// - "ip_address" - /// We parse the address first. If it's an IP address, a default port must - // be given, or we return nullopt. If it's a socket address, we try to parse - // it into an ip address and a port number; even if a default port is given, - // it won't be used, as we expect that it is given in the address string. - const std::string delimiter = ":"; - std::string ip_address; - - std::vector parts = utils::Split(address, delimiter); - if (parts.size() == 1) { - if (default_port) { - if (GetIpFamily(address) == IpFamily::NONE) { - return std::nullopt; - } - return std::pair{std::string(address), *default_port}; // TODO: (andi) Optimize throughout the code - } - } else if (parts.size() == 2) { - ip_address = std::move(parts[0]); - if (GetIpFamily(ip_address) == IpFamily::NONE) { - return std::nullopt; - } - int64_t int_port{0}; - try { - int_port = utils::ParseInt(parts[1]); - } catch (utils::BasicException &e) { - spdlog::error(utils::MessageWithLink("Invalid port number {}.", parts[1], "https://memgr.ph/ports")); - return std::nullopt; - } - if (int_port < 0) { - spdlog::error(utils::MessageWithLink("Invalid port number {}. The port number must be a positive integer.", - int_port, "https://memgr.ph/ports")); - return std::nullopt; - } - if (int_port > std::numeric_limits::max()) { - spdlog::error(utils::MessageWithLink("Invalid port number. The port number exceedes the maximum possible size.", - "https://memgr.ph/ports")); - return std::nullopt; - } - - return std::pair{ip_address, static_cast(int_port)}; - } - - return std::nullopt; -} - -std::optional> Endpoint::ParseHostname( - std::string_view address, const std::optional default_port = {}) { - const std::string delimiter = ":"; - std::string ip_address; - std::vector parts = utils::Split(address, delimiter); - if (parts.size() == 1) { - if (default_port) { - if (!IsResolvableAddress(address, *default_port)) { - return std::nullopt; - } - return std::pair{std::string(address), *default_port}; // TODO: (andi) Optimize throughout the code - } - } else if (parts.size() == 2) { - int64_t int_port{0}; - auto hostname = std::move(parts[0]); - try { - int_port = utils::ParseInt(parts[1]); - } catch (utils::BasicException &e) { - spdlog::error(utils::MessageWithLink("Invalid port number {}.", parts[1], "https://memgr.ph/ports")); - return std::nullopt; - } - if (int_port < 0) { - spdlog::error(utils::MessageWithLink("Invalid port number {}. The port number must be a positive integer.", - int_port, "https://memgr.ph/ports")); - return std::nullopt; - } - if (int_port > std::numeric_limits::max()) { - spdlog::error(utils::MessageWithLink("Invalid port number. The port number exceedes the maximum possible size.", - "https://memgr.ph/ports")); - return std::nullopt; - } - if (IsResolvableAddress(hostname, static_cast(int_port))) { - return std::pair{hostname, static_cast(int_port)}; - } - } - return std::nullopt; -} - -std::string Endpoint::SocketAddress() const { - auto ip_address = address.empty() ? "EMPTY" : address; - return ip_address + ":" + std::to_string(port); -} +// NOLINTNEXTLINE +Endpoint::Endpoint(needs_resolving_t, std::string hostname, uint16_t port) + : address(std::move(hostname)), port(port), family{GetIpFamily(address)} {} Endpoint::Endpoint(std::string ip_address, uint16_t port) : address(std::move(ip_address)), port(port) { IpFamily ip_family = GetIpFamily(address); @@ -138,9 +40,23 @@ Endpoint::Endpoint(std::string ip_address, uint16_t port) : address(std::move(ip family = ip_family; } -// NOLINTNEXTLINE -Endpoint::Endpoint(needs_resolving_t, std::string hostname, uint16_t port) - : address(std::move(hostname)), port(port), family{GetIpFamily(address)} {} +std::string Endpoint::SocketAddress() const { return fmt::format("{}:{}", address, port); } + +Endpoint::IpFamily Endpoint::GetIpFamily(std::string_view address) { + // Ensure null-terminated + auto const tmp = std::string(address); + in_addr addr4; + in6_addr addr6; + int ipv4_result = inet_pton(AF_INET, tmp.c_str(), &addr4); + int ipv6_result = inet_pton(AF_INET6, tmp.c_str(), &addr6); + if (ipv4_result == 1) { + return IpFamily::IP4; + } + if (ipv6_result == 1) { + return IpFamily::IP6; + } + return IpFamily::NONE; +} std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint) { // no need to cover the IpFamily::NONE case, as you can't even construct an @@ -153,6 +69,7 @@ std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint) { return os << endpoint.address << ":" << endpoint.port; } +// NOTE: Intentional copy to ensure null-terminated string bool Endpoint::IsResolvableAddress(std::string_view address, uint16_t port) { addrinfo hints{ .ai_flags = AI_PASSIVE, @@ -160,28 +77,65 @@ bool Endpoint::IsResolvableAddress(std::string_view address, uint16_t port) { .ai_socktype = SOCK_STREAM // TCP socket }; addrinfo *info = nullptr; - auto status = getaddrinfo(address.data(), std::to_string(port).c_str(), &hints, &info); + auto status = getaddrinfo(std::string(address).c_str(), std::to_string(port).c_str(), &hints, &info); if (info) freeaddrinfo(info); return status == 0; } -std::optional> Endpoint::ParseSocketOrAddress( - std::string_view address, const std::optional default_port) { - const std::string delimiter = ":"; - std::vector parts = utils::Split(address, delimiter); - if (parts.size() == 1) { - if (GetIpFamily(address) == IpFamily::NONE) { - return ParseHostname(address, default_port); - } - return ParseSocketOrIpAddress(address, default_port); +std::optional Endpoint::ParseSocketOrAddress(std::string_view address, + std::optional default_port) { + auto const parts = utils::SplitView(address, delimiter); + + if (parts.size() > 2) { + return std::nullopt; } - if (parts.size() == 2) { - if (GetIpFamily(parts[0]) == IpFamily::NONE) { - return ParseHostname(address, default_port); + + auto const port = [default_port, &parts]() -> std::optional { + if (parts.size() == 2) { + return static_cast(utils::ParseInt(parts[1])); } - return ParseSocketOrIpAddress(address, default_port); + return default_port; + }(); + + if (!ValidatePort(port)) { + return std::nullopt; } - return std::nullopt; + + auto const addr = [address, &parts]() { + if (parts.size() == 2) { + return parts[0]; + } + return address; + }(); + + if (GetIpFamily(addr) == IpFamily::NONE) { + if (IsResolvableAddress(addr, *port)) { // NOLINT + return std::pair{addr, *port}; // NOLINT + } + return std::nullopt; + } + + return std::pair{addr, *port}; // NOLINT +} + +auto Endpoint::ValidatePort(std::optional port) -> bool { + if (!port) { + return false; + } + + if (port < 0) { + spdlog::error(utils::MessageWithLink("Invalid port number {}. The port number must be a positive integer.", *port, + "https://memgr.ph/ports")); + return false; + } + + if (port > std::numeric_limits::max()) { + spdlog::error(utils::MessageWithLink("Invalid port number. The port number exceedes the maximum possible size.", + "https://memgr.ph/ports")); + return false; + } + + return true; } } // namespace memgraph::io::network diff --git a/src/io/network/endpoint.hpp b/src/io/network/endpoint.hpp index b0201240b..f46d28ace 100644 --- a/src/io/network/endpoint.hpp +++ b/src/io/network/endpoint.hpp @@ -19,11 +19,8 @@ namespace memgraph::io::network { -/** - * This class represents a network endpoint that is used in Socket. - * It is used when connecting to an address and to get the current - * connection address. - */ +using ParsedAddress = std::pair; + struct Endpoint { static const struct needs_resolving_t { } needs_resolving; @@ -31,59 +28,35 @@ struct Endpoint { Endpoint() = default; Endpoint(std::string ip_address, uint16_t port); Endpoint(needs_resolving_t, std::string hostname, uint16_t port); + Endpoint(Endpoint const &) = default; Endpoint(Endpoint &&) noexcept = default; + Endpoint &operator=(Endpoint const &) = default; Endpoint &operator=(Endpoint &&) noexcept = default; + ~Endpoint() = default; enum class IpFamily : std::uint8_t { NONE, IP4, IP6 }; - std::string SocketAddress() const; + static std::optional ParseSocketOrAddress(std::string_view address, + std::optional default_port = {}); - bool operator==(const Endpoint &other) const = default; - friend std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint); + std::string SocketAddress() const; std::string address; uint16_t port{0}; IpFamily family{IpFamily::NONE}; - static std::optional> ParseSocketOrAddress( - std::string_view address, std::optional default_port = {}); - - /** - * Tries to parse the given string as either a socket address or ip address. - * Expected address format: - * - "ip_address:port_number" - * - "ip_address" - * We parse the address first. If it's an IP address, a default port must - * be given, or we return nullopt. If it's a socket address, we try to parse - * it into an ip address and a port number; even if a default port is given, - * it won't be used, as we expect that it is given in the address string. - */ - static std::optional> ParseSocketOrIpAddress( - std::string_view address, std::optional default_port = {}); - - /** - * Tries to parse given string as either socket address or hostname. - * Expected address format: - * - "hostname:port_number" - * - "hostname" - * After we parse hostname and port we try to resolve the hostname into an ip_address. - */ - static std::optional> ParseHostname(std::string_view address, - std::optional default_port); + bool operator==(const Endpoint &other) const = default; + friend std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint); + private: static IpFamily GetIpFamily(std::string_view address); static bool IsResolvableAddress(std::string_view address, uint16_t port); - /** - * Tries to resolve hostname to its corresponding IP address. - * Given a DNS hostname, this function performs resolution and returns - * the IP address associated with the hostname. - */ - static std::string ResolveHostnameIntoIpAddress(const std::string &address, uint16_t port); + static auto ValidatePort(std::optional port) -> bool; }; } // namespace memgraph::io::network diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index e6d39ab9a..e51620bf6 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -355,7 +355,7 @@ class ReplQueryHandler { const auto replication_config = replication::ReplicationClientConfig{.name = name, .mode = repl_mode, - .ip_address = ip, + .ip_address = std::string(ip), .port = port, .replica_check_frequency = replica_check_frequency, .ssl = std::nullopt}; @@ -454,12 +454,12 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { const auto repl_config = coordination::CoordinatorClientConfig::ReplicationClientInfo{ .instance_name = std::string(instance_name), .replication_mode = convertFromCoordinatorToReplicationMode(sync_mode), - .replication_ip_address = replication_ip, + .replication_ip_address = std::string(replication_ip), .replication_port = replication_port}; auto coordinator_client_config = coordination::CoordinatorClientConfig{.instance_name = std::string(instance_name), - .ip_address = coordinator_server_ip, + .ip_address = std::string(coordinator_server_ip), .port = coordinator_server_port, .instance_health_check_frequency_sec = instance_check_frequency, .instance_down_timeout_sec = instance_down_timeout, @@ -1212,7 +1212,7 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param }; notifications->emplace_back( - SeverityLevel::INFO, NotificationCode::REGISTER_COORDINATOR_SERVER, + SeverityLevel::INFO, NotificationCode::REGISTER_REPLICATION_INSTANCE, fmt::format("Coordinator has registered coordinator server on {} for instance {}.", coordinator_socket_address_tv.ValueString(), coordinator_query->instance_name_)); return callback; diff --git a/src/query/metadata.cpp b/src/query/metadata.cpp index e339aad57..af3b8d15f 100644 --- a/src/query/metadata.cpp +++ b/src/query/metadata.cpp @@ -67,8 +67,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { case NotificationCode::REGISTER_REPLICA: return "RegisterReplica"sv; #ifdef MG_ENTERPRISE - case NotificationCode::REGISTER_COORDINATOR_SERVER: - return "RegisterCoordinatorServer"sv; + case NotificationCode::REGISTER_REPLICATION_INSTANCE: + return "RegisterReplicationInstance"sv; case NotificationCode::ADD_COORDINATOR_INSTANCE: return "AddCoordinatorInstance"sv; case NotificationCode::UNREGISTER_INSTANCE: diff --git a/src/query/metadata.hpp b/src/query/metadata.hpp index dd8c2db07..fba672f4b 100644 --- a/src/query/metadata.hpp +++ b/src/query/metadata.hpp @@ -43,7 +43,7 @@ enum class NotificationCode : uint8_t { REPLICA_PORT_WARNING, REGISTER_REPLICA, #ifdef MG_ENTERPRISE - REGISTER_COORDINATOR_SERVER, // TODO: (andi) What is this? + REGISTER_REPLICATION_INSTANCE, ADD_COORDINATOR_INSTANCE, UNREGISTER_INSTANCE, #endif diff --git a/src/replication_coordination_glue/role.hpp b/src/replication_coordination_glue/role.hpp index d472cb454..3fbf522ba 100644 --- a/src/replication_coordination_glue/role.hpp +++ b/src/replication_coordination_glue/role.hpp @@ -12,8 +12,14 @@ #pragma once #include + +#include "json/json.hpp" + namespace memgraph::replication_coordination_glue { // TODO: figure out a way of ensuring that usage of this type is never uninitialed/defaulted incorrectly to MAIN enum class ReplicationRole : uint8_t { MAIN, REPLICA }; + +NLOHMANN_JSON_SERIALIZE_ENUM(ReplicationRole, {{ReplicationRole::MAIN, "main"}, {ReplicationRole::REPLICA, "replica"}}) + } // namespace memgraph::replication_coordination_glue diff --git a/tests/e2e/replication/common.hpp b/tests/e2e/replication/common.hpp index f5113ac37..1938eb0f3 100644 --- a/tests/e2e/replication/common.hpp +++ b/tests/e2e/replication/common.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -34,12 +34,13 @@ DEFINE_double(reads_duration_limit, 10.0, "How long should the client perform re namespace mg::e2e::replication { auto ParseDatabaseEndpoints(const std::string &database_endpoints_str) { - const auto db_endpoints_strs = memgraph::utils::Split(database_endpoints_str, ","); + const auto db_endpoints_strs = memgraph::utils::SplitView(database_endpoints_str, ","); std::vector database_endpoints; for (const auto &db_endpoint_str : db_endpoints_strs) { - const auto maybe_host_port = memgraph::io::network::Endpoint::ParseSocketOrIpAddress(db_endpoint_str, 7687); + const auto maybe_host_port = memgraph::io::network::Endpoint::ParseSocketOrAddress(db_endpoint_str, 7687); MG_ASSERT(maybe_host_port); - database_endpoints.emplace_back(maybe_host_port->first, maybe_host_port->second); + auto const [ip, port] = *maybe_host_port; + database_endpoints.emplace_back(std::string(ip), port); } return database_endpoints; } diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index f1afcdf15..44b24b6f6 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -445,3 +445,10 @@ add_unit_test(raft_log_serialization.cpp) target_link_libraries(${test_prefix}raft_log_serialization gflags mg-coordination mg-repl_coord_glue) target_include_directories(${test_prefix}raft_log_serialization PRIVATE ${CMAKE_SOURCE_DIR}/include) endif() + +# Test Raft log serialization +if(MG_ENTERPRISE) +add_unit_test(coordinator_cluster_state.cpp) +target_link_libraries(${test_prefix}coordinator_cluster_state gflags mg-coordination mg-repl_coord_glue) +target_include_directories(${test_prefix}coordinator_cluster_state PRIVATE ${CMAKE_SOURCE_DIR}/include) +endif() diff --git a/tests/unit/coordinator_cluster_state.cpp b/tests/unit/coordinator_cluster_state.cpp new file mode 100644 index 000000000..8df2797f2 --- /dev/null +++ b/tests/unit/coordinator_cluster_state.cpp @@ -0,0 +1,163 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "nuraft/coordinator_cluster_state.hpp" +#include "nuraft/coordinator_state_machine.hpp" +#include "replication_coordination_glue/role.hpp" + +#include "utils/file.hpp" + +#include +#include +#include "json/json.hpp" + +#include "libnuraft/nuraft.hxx" + +using memgraph::coordination::CoordinatorClientConfig; +using memgraph::coordination::CoordinatorClusterState; +using memgraph::coordination::CoordinatorStateMachine; +using memgraph::coordination::InstanceState; +using memgraph::coordination::RaftLogAction; +using memgraph::replication_coordination_glue::ReplicationMode; +using memgraph::replication_coordination_glue::ReplicationRole; +using nuraft::buffer; +using nuraft::buffer_serializer; +using nuraft::ptr; + +class CoordinatorClusterStateTest : public ::testing::Test { + protected: + void SetUp() override {} + + void TearDown() override {} + + std::filesystem::path test_folder_{std::filesystem::temp_directory_path() / + "MG_tests_unit_coordinator_cluster_state"}; +}; + +TEST_F(CoordinatorClusterStateTest, InstanceStateSerialization) { + InstanceState instance_state{ + CoordinatorClientConfig{"instance3", + "127.0.0.1", + 10112, + std::chrono::seconds{1}, + std::chrono::seconds{5}, + std::chrono::seconds{10}, + {"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10001}, + .ssl = std::nullopt}, + ReplicationRole::MAIN}; + + nlohmann::json j = instance_state; + InstanceState deserialized_instance_state = j.get(); + + EXPECT_EQ(instance_state.config, deserialized_instance_state.config); + EXPECT_EQ(instance_state.status, deserialized_instance_state.status); +} + +TEST_F(CoordinatorClusterStateTest, DoActionRegisterInstances) { + auto coordinator_cluster_state = memgraph::coordination::CoordinatorClusterState{}; + + { + CoordinatorClientConfig config{"instance1", + "127.0.0.1", + 10111, + std::chrono::seconds{1}, + std::chrono::seconds{5}, + std::chrono::seconds{10}, + {"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10001}, + .ssl = std::nullopt}; + + auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config); + auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer); + + coordinator_cluster_state.DoAction(payload, action); + } + { + CoordinatorClientConfig config{"instance2", + "127.0.0.1", + 10112, + std::chrono::seconds{1}, + std::chrono::seconds{5}, + std::chrono::seconds{10}, + {"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10002}, + .ssl = std::nullopt}; + + auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config); + auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer); + + coordinator_cluster_state.DoAction(payload, action); + } + { + CoordinatorClientConfig config{"instance3", + "127.0.0.1", + 10113, + std::chrono::seconds{1}, + std::chrono::seconds{5}, + std::chrono::seconds{10}, + {"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10003}, + .ssl = std::nullopt}; + + auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config); + auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer); + + coordinator_cluster_state.DoAction(payload, action); + } + { + CoordinatorClientConfig config{"instance4", + "127.0.0.1", + 10114, + std::chrono::seconds{1}, + std::chrono::seconds{5}, + std::chrono::seconds{10}, + {"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10004}, + .ssl = std::nullopt}; + + auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config); + auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer); + + coordinator_cluster_state.DoAction(payload, action); + } + { + CoordinatorClientConfig config{"instance5", + "127.0.0.1", + 10115, + std::chrono::seconds{1}, + std::chrono::seconds{5}, + std::chrono::seconds{10}, + {"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10005}, + .ssl = std::nullopt}; + + auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config); + auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer); + + coordinator_cluster_state.DoAction(payload, action); + } + { + CoordinatorClientConfig config{"instance6", + "127.0.0.1", + 10116, + std::chrono::seconds{1}, + std::chrono::seconds{5}, + std::chrono::seconds{10}, + {"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10006}, + .ssl = std::nullopt}; + + auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config); + auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer); + + coordinator_cluster_state.DoAction(payload, action); + } + + ptr data; + coordinator_cluster_state.Serialize(data); + + auto deserialized_coordinator_cluster_state = CoordinatorClusterState::Deserialize(*data); + ASSERT_EQ(coordinator_cluster_state.GetInstances(), deserialized_coordinator_cluster_state.GetInstances()); +}