Improve in-memory RAFT state (#1782)

This commit is contained in:
Andi 2024-03-06 09:16:46 +01:00 committed by GitHub
parent d4d4660af0
commit 75aad72984
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 363 additions and 293 deletions

View File

@ -135,8 +135,7 @@ auto CoordinatorClient::SendSwapMainUUIDRpc(utils::UUID const &uuid) const -> bo
auto CoordinatorClient::SendUnregisterReplicaRpc(std::string_view instance_name) const -> bool {
try {
auto stream{rpc_client_.Stream<UnregisterReplicaRpc>(
std::string(instance_name))}; // TODO: (andi) Try to change to stream string_view and do just one copy later
auto stream{rpc_client_.Stream<UnregisterReplicaRpc>(instance_name)};
if (!stream.AwaitResponse().success) {
spdlog::error("Failed to receive successful RPC response for unregistering replica!");
return false;

View File

@ -18,78 +18,87 @@
namespace memgraph::coordination {
using replication_coordination_glue::ReplicationRole;
void to_json(nlohmann::json &j, InstanceState const &instance_state) {
j = nlohmann::json{{"config", instance_state.config}, {"status", instance_state.status}};
}
CoordinatorClusterState::CoordinatorClusterState(CoordinatorClusterState const &other)
: instance_roles_{other.instance_roles_} {}
void from_json(nlohmann::json const &j, InstanceState &instance_state) {
j.at("config").get_to(instance_state.config);
j.at("status").get_to(instance_state.status);
}
CoordinatorClusterState::CoordinatorClusterState(std::map<std::string, InstanceState, std::less<>> instances)
: instances_{std::move(instances)} {}
CoordinatorClusterState::CoordinatorClusterState(CoordinatorClusterState const &other) : instances_{other.instances_} {}
CoordinatorClusterState &CoordinatorClusterState::operator=(CoordinatorClusterState const &other) {
if (this == &other) {
return *this;
}
instance_roles_ = other.instance_roles_;
instances_ = other.instances_;
return *this;
}
CoordinatorClusterState::CoordinatorClusterState(CoordinatorClusterState &&other) noexcept
: instance_roles_{std::move(other.instance_roles_)} {}
: instances_{std::move(other.instances_)} {}
CoordinatorClusterState &CoordinatorClusterState::operator=(CoordinatorClusterState &&other) noexcept {
if (this == &other) {
return *this;
}
instance_roles_ = std::move(other.instance_roles_);
instances_ = std::move(other.instances_);
return *this;
}
auto CoordinatorClusterState::MainExists() const -> bool {
auto lock = std::shared_lock{log_lock_};
return std::ranges::any_of(instance_roles_,
[](auto const &entry) { return entry.second.role == ReplicationRole::MAIN; });
return std::ranges::any_of(instances_,
[](auto const &entry) { return entry.second.status == ReplicationRole::MAIN; });
}
auto CoordinatorClusterState::IsMain(std::string_view instance_name) const -> bool {
auto lock = std::shared_lock{log_lock_};
auto const it = instance_roles_.find(instance_name);
return it != instance_roles_.end() && it->second.role == ReplicationRole::MAIN;
auto const it = instances_.find(instance_name);
return it != instances_.end() && it->second.status == ReplicationRole::MAIN;
}
auto CoordinatorClusterState::IsReplica(std::string_view instance_name) const -> bool {
auto lock = std::shared_lock{log_lock_};
auto const it = instance_roles_.find(instance_name);
return it != instance_roles_.end() && it->second.role == ReplicationRole::REPLICA;
auto const it = instances_.find(instance_name);
return it != instances_.end() && it->second.status == ReplicationRole::REPLICA;
}
auto CoordinatorClusterState::InsertInstance(std::string_view instance_name, ReplicationRole role) -> void {
auto lock = std::unique_lock{log_lock_};
instance_roles_[instance_name.data()].role = role;
auto CoordinatorClusterState::InsertInstance(std::string instance_name, InstanceState instance_state) -> void {
auto lock = std::lock_guard{log_lock_};
instances_.insert_or_assign(std::move(instance_name), std::move(instance_state));
}
auto CoordinatorClusterState::DoAction(TRaftLog log_entry, RaftLogAction log_action) -> void {
auto lock = std::unique_lock{log_lock_};
auto lock = std::lock_guard{log_lock_};
switch (log_action) {
case RaftLogAction::REGISTER_REPLICATION_INSTANCE: {
auto const &config = std::get<CoordinatorClientConfig>(log_entry);
instance_roles_[config.instance_name] = InstanceState{config, ReplicationRole::REPLICA};
instances_[config.instance_name] = InstanceState{config, ReplicationRole::REPLICA};
break;
}
case RaftLogAction::UNREGISTER_REPLICATION_INSTANCE: {
auto const instance_name = std::get<std::string>(log_entry);
instance_roles_.erase(instance_name);
instances_.erase(instance_name);
break;
}
case RaftLogAction::SET_INSTANCE_AS_MAIN: {
auto const instance_name = std::get<std::string>(log_entry);
auto it = instance_roles_.find(instance_name);
MG_ASSERT(it != instance_roles_.end(), "Instance does not exist as part of raft state!");
it->second.role = ReplicationRole::MAIN;
auto it = instances_.find(instance_name);
MG_ASSERT(it != instances_.end(), "Instance does not exist as part of raft state!");
it->second.status = ReplicationRole::MAIN;
break;
}
case RaftLogAction::SET_INSTANCE_AS_REPLICA: {
auto const instance_name = std::get<std::string>(log_entry);
auto it = instance_roles_.find(instance_name);
MG_ASSERT(it != instance_roles_.end(), "Instance does not exist as part of raft state!");
it->second.role = ReplicationRole::REPLICA;
auto it = instances_.find(instance_name);
MG_ASSERT(it != instances_.end(), "Instance does not exist as part of raft state!");
it->second.status = ReplicationRole::REPLICA;
break;
}
case RaftLogAction::UPDATE_UUID: {
@ -99,64 +108,37 @@ auto CoordinatorClusterState::DoAction(TRaftLog log_entry, RaftLogAction log_act
}
}
// TODO: (andi) Improve based on Gareth's comments
auto CoordinatorClusterState::Serialize(ptr<buffer> &data) -> void {
auto lock = std::shared_lock{log_lock_};
auto const role_to_string = [](auto const &role) -> std::string_view {
switch (role) {
case ReplicationRole::MAIN:
return "main";
case ReplicationRole::REPLICA:
return "replica";
}
};
auto const entry_to_string = [&role_to_string](auto const &entry) {
return fmt::format("{}_{}", entry.first, role_to_string(entry.second.role));
};
// .at(0) is hack to solve the problem with json serialization of map
auto const log = nlohmann::json{instances_}.at(0).dump();
auto instances_str_view = instance_roles_ | ranges::views::transform(entry_to_string);
uint32_t size =
std::accumulate(instances_str_view.begin(), instances_str_view.end(), 0,
[](uint32_t acc, auto const &entry) { return acc + sizeof(uint32_t) + entry.size(); });
data = buffer::alloc(size);
data = buffer::alloc(sizeof(uint32_t) + log.size());
buffer_serializer bs(data);
std::for_each(instances_str_view.begin(), instances_str_view.end(), [&bs](auto const &entry) { bs.put_str(entry); });
bs.put_str(log);
}
auto CoordinatorClusterState::Deserialize(buffer &data) -> CoordinatorClusterState {
auto const str_to_role = [](auto const &str) -> ReplicationRole {
if (str == "main") {
return ReplicationRole::MAIN;
}
return ReplicationRole::REPLICA;
};
CoordinatorClusterState cluster_state;
buffer_serializer bs(data);
while (bs.size() > 0) {
auto const entry = bs.get_str();
auto const first_dash = entry.find('_');
auto const instance_name = entry.substr(0, first_dash);
auto const role_str = entry.substr(first_dash + 1);
cluster_state.InsertInstance(instance_name, str_to_role(role_str));
}
return cluster_state;
auto const j = nlohmann::json::parse(bs.get_str());
auto instances = j.get<std::map<std::string, InstanceState, std::less<>>>();
return CoordinatorClusterState{std::move(instances)};
}
auto CoordinatorClusterState::GetInstances() const -> std::vector<InstanceState> {
auto lock = std::shared_lock{log_lock_};
return instance_roles_ | ranges::views::values | ranges::to<std::vector<InstanceState>>;
return instances_ | ranges::views::values | ranges::to<std::vector<InstanceState>>;
}
auto CoordinatorClusterState::GetUUID() const -> utils::UUID { return uuid_; }
auto CoordinatorClusterState::FindCurrentMainInstanceName() const -> std::optional<std::string> {
auto lock = std::shared_lock{log_lock_};
auto const it = std::ranges::find_if(instance_roles_,
[](auto const &entry) { return entry.second.role == ReplicationRole::MAIN; });
if (it == instance_roles_.end()) {
auto const it =
std::ranges::find_if(instances_, [](auto const &entry) { return entry.second.status == ReplicationRole::MAIN; });
if (it == instances_.end()) {
return {};
}
return it->first;

View File

@ -36,7 +36,7 @@ CoordinatorInstance::CoordinatorInstance()
spdlog::info("Leader changed, starting all replication instances!");
auto const instances = raft_state_.GetInstances();
auto replicas = instances | ranges::views::filter([](auto const &instance) {
return instance.role == ReplicationRole::REPLICA;
return instance.status == ReplicationRole::REPLICA;
});
std::ranges::for_each(replicas, [this](auto &replica) {
@ -47,10 +47,7 @@ CoordinatorInstance::CoordinatorInstance()
});
auto main = instances | ranges::views::filter(
[](auto const &instance) { return instance.role == ReplicationRole::MAIN; });
// TODO: (andi) Add support for this
// MG_ASSERT(std::ranges::distance(main) == 1, "There should be exactly one main instance");
[](auto const &instance) { return instance.status == ReplicationRole::MAIN; });
std::ranges::for_each(main, [this](auto &main_instance) {
spdlog::info("Starting main instance {}", main_instance.config.instance_name);
@ -60,7 +57,7 @@ CoordinatorInstance::CoordinatorInstance()
});
std::ranges::for_each(repl_instances_, [this](auto &instance) {
instance.SetNewMainUUID(raft_state_.GetUUID()); // TODO: (andi) Rename
instance.SetNewMainUUID(raft_state_.GetUUID());
instance.StartFrequentCheck();
});
},
@ -69,13 +66,13 @@ CoordinatorInstance::CoordinatorInstance()
repl_instances_.clear();
})) {
client_succ_cb_ = [](CoordinatorInstance *self, std::string_view repl_instance_name) -> void {
auto lock = std::unique_lock{self->coord_instance_lock_};
auto lock = std::lock_guard{self->coord_instance_lock_};
auto &repl_instance = self->FindReplicationInstance(repl_instance_name);
std::invoke(repl_instance.GetSuccessCallback(), self, repl_instance_name);
};
client_fail_cb_ = [](CoordinatorInstance *self, std::string_view repl_instance_name) -> void {
auto lock = std::unique_lock{self->coord_instance_lock_};
auto lock = std::lock_guard{self->coord_instance_lock_};
auto &repl_instance = self->FindReplicationInstance(repl_instance_name);
std::invoke(repl_instance.GetFailCallback(), self, repl_instance_name);
};
@ -98,7 +95,6 @@ auto CoordinatorInstance::ShowInstances() const -> std::vector<InstanceStatus> {
.raft_socket_address = instance->get_endpoint(),
.cluster_role = "coordinator",
.health = "unknown"}; // TODO: (andi) Get this info from RAFT and test it or when we will move
// CoordinatorState to every instance, we can be smarter about this using our RPC.
};
auto instances_status = utils::fmap(raft_state_.GetAllCoordinators(), coord_instance_to_status);
@ -126,14 +122,14 @@ auto CoordinatorInstance::ShowInstances() const -> std::vector<InstanceStatus> {
std::ranges::transform(repl_instances_, std::back_inserter(instances_status), process_repl_instance_as_leader);
}
} else {
auto const stringify_repl_role = [](ReplicationRole role) -> std::string {
return role == ReplicationRole::MAIN ? "main" : "replica";
auto const stringify_inst_status = [](ReplicationRole status) -> std::string {
return status == ReplicationRole::MAIN ? "main" : "replica";
};
// TODO: (andi) Add capability that followers can also return socket addresses
auto process_repl_instance_as_follower = [&stringify_repl_role](auto const &instance) -> InstanceStatus {
auto process_repl_instance_as_follower = [&stringify_inst_status](auto const &instance) -> InstanceStatus {
return {.instance_name = instance.config.instance_name,
.cluster_role = stringify_repl_role(instance.role),
.cluster_role = stringify_inst_status(instance.status),
.health = "unknown"};
};
@ -355,11 +351,11 @@ auto CoordinatorInstance::UnregisterReplicationInstance(std::string_view instanc
return UnregisterInstanceCoordinatorStatus::NO_INSTANCE_WITH_NAME;
}
// TODO: (andi) Change so that RaftLogState is the central place for asking who is main...
auto const is_main = [this](ReplicationInstance const &instance) {
return IsMain(instance.InstanceName()) && instance.GetMainUUID() == raft_state_.GetUUID() && instance.IsAlive();
};
auto const is_main = [this](ReplicationInstance const &instance) { return IsMain(instance.InstanceName()); };
if (is_main(*inst_to_remove) && inst_to_remove->IsAlive()) {
if (is_main(*inst_to_remove)) {
return UnregisterInstanceCoordinatorStatus::IS_MAIN;
}

View File

@ -90,7 +90,7 @@ struct UnregisterReplicaReq {
static void Load(UnregisterReplicaReq *self, memgraph::slk::Reader *reader);
static void Save(UnregisterReplicaReq const &self, memgraph::slk::Builder *builder);
explicit UnregisterReplicaReq(std::string instance_name) : instance_name(std::move(instance_name)) {}
explicit UnregisterReplicaReq(std::string_view inst_name) : instance_name(inst_name) {}
UnregisterReplicaReq() = default;

View File

@ -14,6 +14,7 @@
#ifdef MG_ENTERPRISE
#include <flags/replication.hpp>
#include "io/network/endpoint.hpp"
#include "nuraft/coordinator_state_machine.hpp"
#include "nuraft/coordinator_state_manager.hpp"
@ -79,9 +80,8 @@ class RaftState {
private:
// TODO: (andi) I think variables below can be abstracted/clean them.
io::network::Endpoint raft_endpoint_;
uint32_t raft_server_id_;
uint32_t raft_port_;
std::string raft_address_;
ptr<CoordinatorStateMachine> state_machine_;
ptr<CoordinatorStateManager> state_manager_;

View File

@ -21,6 +21,7 @@
#include <libnuraft/nuraft.hxx>
#include <range/v3/view.hpp>
#include "json/json.hpp"
#include <map>
#include <numeric>
@ -33,9 +34,16 @@ using replication_coordination_glue::ReplicationRole;
struct InstanceState {
CoordinatorClientConfig config;
ReplicationRole role;
ReplicationRole status;
friend auto operator==(InstanceState const &lhs, InstanceState const &rhs) -> bool {
return lhs.config == rhs.config && lhs.status == rhs.status;
}
};
void to_json(nlohmann::json &j, InstanceState const &instance_state);
void from_json(nlohmann::json const &j, InstanceState &instance_state);
using TRaftLog = std::variant<CoordinatorClientConfig, std::string, utils::UUID>;
using nuraft::buffer;
@ -45,6 +53,8 @@ using nuraft::ptr;
class CoordinatorClusterState {
public:
CoordinatorClusterState() = default;
explicit CoordinatorClusterState(std::map<std::string, InstanceState, std::less<>> instances);
CoordinatorClusterState(CoordinatorClusterState const &);
CoordinatorClusterState &operator=(CoordinatorClusterState const &);
@ -60,7 +70,7 @@ class CoordinatorClusterState {
auto IsReplica(std::string_view instance_name) const -> bool;
auto InsertInstance(std::string_view instance_name, ReplicationRole role) -> void;
auto InsertInstance(std::string instance_name, InstanceState instance_state) -> void;
auto DoAction(TRaftLog log_entry, RaftLogAction log_action) -> void;
@ -73,7 +83,7 @@ class CoordinatorClusterState {
auto GetUUID() const -> utils::UUID;
private:
std::map<std::string, InstanceState, std::less<>> instance_roles_;
std::map<std::string, InstanceState, std::less<>> instances_{};
utils::UUID uuid_{};
mutable utils::ResourceLock log_lock_{};
};

View File

@ -38,26 +38,5 @@ NLOHMANN_JSON_SERIALIZE_ENUM(RaftLogAction, {
{RaftLogAction::UPDATE_UUID, "update_uuid"},
})
inline auto ParseRaftLogAction(std::string_view action) -> RaftLogAction {
if (action == "register") {
return RaftLogAction::REGISTER_REPLICATION_INSTANCE;
}
if (action == "unregister") {
return RaftLogAction::UNREGISTER_REPLICATION_INSTANCE;
}
if (action == "promote") {
return RaftLogAction::SET_INSTANCE_AS_MAIN;
}
if (action == "demote") {
return RaftLogAction::SET_INSTANCE_AS_REPLICA;
}
if (action == "update_uuid") {
return RaftLogAction::UPDATE_UUID;
}
throw InvalidRaftLogActionException("Invalid Raft log action: {}.", action);
}
} // namespace memgraph::coordination
#endif

View File

@ -32,12 +32,10 @@ using raft_result = cmd_result<ptr<buffer>>;
RaftState::RaftState(BecomeLeaderCb become_leader_cb, BecomeFollowerCb become_follower_cb, uint32_t raft_server_id,
uint32_t raft_port, std::string raft_address)
: raft_server_id_(raft_server_id),
raft_port_(raft_port),
raft_address_(std::move(raft_address)),
: raft_endpoint_(raft_address, raft_port),
raft_server_id_(raft_server_id),
state_machine_(cs_new<CoordinatorStateMachine>()),
state_manager_(
cs_new<CoordinatorStateManager>(raft_server_id_, raft_address_ + ":" + std::to_string(raft_port_))),
state_manager_(cs_new<CoordinatorStateManager>(raft_server_id_, raft_endpoint_.SocketAddress())),
logger_(nullptr),
become_leader_cb_(std::move(become_leader_cb)),
become_follower_cb_(std::move(become_follower_cb)) {}
@ -71,11 +69,11 @@ auto RaftState::InitRaftServer() -> void {
raft_launcher launcher;
raft_server_ = launcher.init(state_machine_, state_manager_, logger_, static_cast<int>(raft_port_), asio_opts, params,
init_opts);
raft_server_ =
launcher.init(state_machine_, state_manager_, logger_, raft_endpoint_.port, asio_opts, params, init_opts);
if (!raft_server_) {
throw RaftServerStartException("Failed to launch raft server on {}:{}", raft_address_, raft_port_);
throw RaftServerStartException("Failed to launch raft server on {}", raft_endpoint_.SocketAddress());
}
auto maybe_stop = utils::ResettableCounter<20>();
@ -86,7 +84,7 @@ auto RaftState::InitRaftServer() -> void {
std::this_thread::sleep_for(std::chrono::milliseconds(250));
} while (!maybe_stop());
throw RaftServerStartException("Failed to initialize raft server on {}:{}", raft_address_, raft_port_);
throw RaftServerStartException("Failed to initialize raft server on {}", raft_endpoint_.SocketAddress());
}
auto RaftState::MakeRaftState(BecomeLeaderCb &&become_leader_cb, BecomeFollowerCb &&become_follower_cb) -> RaftState {
@ -102,9 +100,11 @@ auto RaftState::MakeRaftState(BecomeLeaderCb &&become_leader_cb, BecomeFollowerC
RaftState::~RaftState() { launcher_.shutdown(); }
auto RaftState::InstanceName() const -> std::string { return "coordinator_" + std::to_string(raft_server_id_); }
auto RaftState::InstanceName() const -> std::string {
return fmt::format("coordinator_{}", std::to_string(raft_server_id_));
}
auto RaftState::RaftSocketAddress() const -> std::string { return raft_address_ + ":" + std::to_string(raft_port_); }
auto RaftState::RaftSocketAddress() const -> std::string { return raft_endpoint_.SocketAddress(); }
auto RaftState::AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string_view raft_address)
-> void {

View File

@ -22,113 +22,15 @@
#include "utils/message.hpp"
#include "utils/string.hpp"
namespace {
constexpr std::string_view delimiter = ":";
} // namespace
namespace memgraph::io::network {
Endpoint::IpFamily Endpoint::GetIpFamily(std::string_view address) {
in_addr addr4;
in6_addr addr6;
int ipv4_result = inet_pton(AF_INET, address.data(), &addr4);
int ipv6_result = inet_pton(AF_INET6, address.data(), &addr6);
if (ipv4_result == 1) {
return IpFamily::IP4;
}
if (ipv6_result == 1) {
return IpFamily::IP6;
}
return IpFamily::NONE;
}
std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseSocketOrIpAddress(
std::string_view address, const std::optional<uint16_t> default_port) {
/// expected address format:
/// - "ip_address:port_number"
/// - "ip_address"
/// We parse the address first. If it's an IP address, a default port must
// be given, or we return nullopt. If it's a socket address, we try to parse
// it into an ip address and a port number; even if a default port is given,
// it won't be used, as we expect that it is given in the address string.
const std::string delimiter = ":";
std::string ip_address;
std::vector<std::string> parts = utils::Split(address, delimiter);
if (parts.size() == 1) {
if (default_port) {
if (GetIpFamily(address) == IpFamily::NONE) {
return std::nullopt;
}
return std::pair{std::string(address), *default_port}; // TODO: (andi) Optimize throughout the code
}
} else if (parts.size() == 2) {
ip_address = std::move(parts[0]);
if (GetIpFamily(ip_address) == IpFamily::NONE) {
return std::nullopt;
}
int64_t int_port{0};
try {
int_port = utils::ParseInt(parts[1]);
} catch (utils::BasicException &e) {
spdlog::error(utils::MessageWithLink("Invalid port number {}.", parts[1], "https://memgr.ph/ports"));
return std::nullopt;
}
if (int_port < 0) {
spdlog::error(utils::MessageWithLink("Invalid port number {}. The port number must be a positive integer.",
int_port, "https://memgr.ph/ports"));
return std::nullopt;
}
if (int_port > std::numeric_limits<uint16_t>::max()) {
spdlog::error(utils::MessageWithLink("Invalid port number. The port number exceedes the maximum possible size.",
"https://memgr.ph/ports"));
return std::nullopt;
}
return std::pair{ip_address, static_cast<uint16_t>(int_port)};
}
return std::nullopt;
}
std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseHostname(
std::string_view address, const std::optional<uint16_t> default_port = {}) {
const std::string delimiter = ":";
std::string ip_address;
std::vector<std::string> parts = utils::Split(address, delimiter);
if (parts.size() == 1) {
if (default_port) {
if (!IsResolvableAddress(address, *default_port)) {
return std::nullopt;
}
return std::pair{std::string(address), *default_port}; // TODO: (andi) Optimize throughout the code
}
} else if (parts.size() == 2) {
int64_t int_port{0};
auto hostname = std::move(parts[0]);
try {
int_port = utils::ParseInt(parts[1]);
} catch (utils::BasicException &e) {
spdlog::error(utils::MessageWithLink("Invalid port number {}.", parts[1], "https://memgr.ph/ports"));
return std::nullopt;
}
if (int_port < 0) {
spdlog::error(utils::MessageWithLink("Invalid port number {}. The port number must be a positive integer.",
int_port, "https://memgr.ph/ports"));
return std::nullopt;
}
if (int_port > std::numeric_limits<uint16_t>::max()) {
spdlog::error(utils::MessageWithLink("Invalid port number. The port number exceedes the maximum possible size.",
"https://memgr.ph/ports"));
return std::nullopt;
}
if (IsResolvableAddress(hostname, static_cast<uint16_t>(int_port))) {
return std::pair{hostname, static_cast<u_int16_t>(int_port)};
}
}
return std::nullopt;
}
std::string Endpoint::SocketAddress() const {
auto ip_address = address.empty() ? "EMPTY" : address;
return ip_address + ":" + std::to_string(port);
}
// NOLINTNEXTLINE
Endpoint::Endpoint(needs_resolving_t, std::string hostname, uint16_t port)
: address(std::move(hostname)), port(port), family{GetIpFamily(address)} {}
Endpoint::Endpoint(std::string ip_address, uint16_t port) : address(std::move(ip_address)), port(port) {
IpFamily ip_family = GetIpFamily(address);
@ -138,9 +40,23 @@ Endpoint::Endpoint(std::string ip_address, uint16_t port) : address(std::move(ip
family = ip_family;
}
// NOLINTNEXTLINE
Endpoint::Endpoint(needs_resolving_t, std::string hostname, uint16_t port)
: address(std::move(hostname)), port(port), family{GetIpFamily(address)} {}
std::string Endpoint::SocketAddress() const { return fmt::format("{}:{}", address, port); }
Endpoint::IpFamily Endpoint::GetIpFamily(std::string_view address) {
// Ensure null-terminated
auto const tmp = std::string(address);
in_addr addr4;
in6_addr addr6;
int ipv4_result = inet_pton(AF_INET, tmp.c_str(), &addr4);
int ipv6_result = inet_pton(AF_INET6, tmp.c_str(), &addr6);
if (ipv4_result == 1) {
return IpFamily::IP4;
}
if (ipv6_result == 1) {
return IpFamily::IP6;
}
return IpFamily::NONE;
}
std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint) {
// no need to cover the IpFamily::NONE case, as you can't even construct an
@ -153,6 +69,7 @@ std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint) {
return os << endpoint.address << ":" << endpoint.port;
}
// NOTE: Intentional copy to ensure null-terminated string
bool Endpoint::IsResolvableAddress(std::string_view address, uint16_t port) {
addrinfo hints{
.ai_flags = AI_PASSIVE,
@ -160,28 +77,65 @@ bool Endpoint::IsResolvableAddress(std::string_view address, uint16_t port) {
.ai_socktype = SOCK_STREAM // TCP socket
};
addrinfo *info = nullptr;
auto status = getaddrinfo(address.data(), std::to_string(port).c_str(), &hints, &info);
auto status = getaddrinfo(std::string(address).c_str(), std::to_string(port).c_str(), &hints, &info);
if (info) freeaddrinfo(info);
return status == 0;
}
std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseSocketOrAddress(
std::string_view address, const std::optional<uint16_t> default_port) {
const std::string delimiter = ":";
std::vector<std::string> parts = utils::Split(address, delimiter);
if (parts.size() == 1) {
if (GetIpFamily(address) == IpFamily::NONE) {
return ParseHostname(address, default_port);
}
return ParseSocketOrIpAddress(address, default_port);
std::optional<ParsedAddress> Endpoint::ParseSocketOrAddress(std::string_view address,
std::optional<uint16_t> default_port) {
auto const parts = utils::SplitView(address, delimiter);
if (parts.size() > 2) {
return std::nullopt;
}
if (parts.size() == 2) {
if (GetIpFamily(parts[0]) == IpFamily::NONE) {
return ParseHostname(address, default_port);
auto const port = [default_port, &parts]() -> std::optional<uint16_t> {
if (parts.size() == 2) {
return static_cast<uint16_t>(utils::ParseInt(parts[1]));
}
return ParseSocketOrIpAddress(address, default_port);
return default_port;
}();
if (!ValidatePort(port)) {
return std::nullopt;
}
return std::nullopt;
auto const addr = [address, &parts]() {
if (parts.size() == 2) {
return parts[0];
}
return address;
}();
if (GetIpFamily(addr) == IpFamily::NONE) {
if (IsResolvableAddress(addr, *port)) { // NOLINT
return std::pair{addr, *port}; // NOLINT
}
return std::nullopt;
}
return std::pair{addr, *port}; // NOLINT
}
auto Endpoint::ValidatePort(std::optional<uint16_t> port) -> bool {
if (!port) {
return false;
}
if (port < 0) {
spdlog::error(utils::MessageWithLink("Invalid port number {}. The port number must be a positive integer.", *port,
"https://memgr.ph/ports"));
return false;
}
if (port > std::numeric_limits<uint16_t>::max()) {
spdlog::error(utils::MessageWithLink("Invalid port number. The port number exceedes the maximum possible size.",
"https://memgr.ph/ports"));
return false;
}
return true;
}
} // namespace memgraph::io::network

View File

@ -19,11 +19,8 @@
namespace memgraph::io::network {
/**
* This class represents a network endpoint that is used in Socket.
* It is used when connecting to an address and to get the current
* connection address.
*/
using ParsedAddress = std::pair<std::string_view, uint16_t>;
struct Endpoint {
static const struct needs_resolving_t {
} needs_resolving;
@ -31,59 +28,35 @@ struct Endpoint {
Endpoint() = default;
Endpoint(std::string ip_address, uint16_t port);
Endpoint(needs_resolving_t, std::string hostname, uint16_t port);
Endpoint(Endpoint const &) = default;
Endpoint(Endpoint &&) noexcept = default;
Endpoint &operator=(Endpoint const &) = default;
Endpoint &operator=(Endpoint &&) noexcept = default;
~Endpoint() = default;
enum class IpFamily : std::uint8_t { NONE, IP4, IP6 };
std::string SocketAddress() const;
static std::optional<ParsedAddress> ParseSocketOrAddress(std::string_view address,
std::optional<uint16_t> default_port = {});
bool operator==(const Endpoint &other) const = default;
friend std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint);
std::string SocketAddress() const;
std::string address;
uint16_t port{0};
IpFamily family{IpFamily::NONE};
static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrAddress(
std::string_view address, std::optional<uint16_t> default_port = {});
/**
* Tries to parse the given string as either a socket address or ip address.
* Expected address format:
* - "ip_address:port_number"
* - "ip_address"
* We parse the address first. If it's an IP address, a default port must
* be given, or we return nullopt. If it's a socket address, we try to parse
* it into an ip address and a port number; even if a default port is given,
* it won't be used, as we expect that it is given in the address string.
*/
static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrIpAddress(
std::string_view address, std::optional<uint16_t> default_port = {});
/**
* Tries to parse given string as either socket address or hostname.
* Expected address format:
* - "hostname:port_number"
* - "hostname"
* After we parse hostname and port we try to resolve the hostname into an ip_address.
*/
static std::optional<std::pair<std::string, uint16_t>> ParseHostname(std::string_view address,
std::optional<uint16_t> default_port);
bool operator==(const Endpoint &other) const = default;
friend std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint);
private:
static IpFamily GetIpFamily(std::string_view address);
static bool IsResolvableAddress(std::string_view address, uint16_t port);
/**
* Tries to resolve hostname to its corresponding IP address.
* Given a DNS hostname, this function performs resolution and returns
* the IP address associated with the hostname.
*/
static std::string ResolveHostnameIntoIpAddress(const std::string &address, uint16_t port);
static auto ValidatePort(std::optional<uint16_t> port) -> bool;
};
} // namespace memgraph::io::network

View File

@ -355,7 +355,7 @@ class ReplQueryHandler {
const auto replication_config =
replication::ReplicationClientConfig{.name = name,
.mode = repl_mode,
.ip_address = ip,
.ip_address = std::string(ip),
.port = port,
.replica_check_frequency = replica_check_frequency,
.ssl = std::nullopt};
@ -454,12 +454,12 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler {
const auto repl_config = coordination::CoordinatorClientConfig::ReplicationClientInfo{
.instance_name = std::string(instance_name),
.replication_mode = convertFromCoordinatorToReplicationMode(sync_mode),
.replication_ip_address = replication_ip,
.replication_ip_address = std::string(replication_ip),
.replication_port = replication_port};
auto coordinator_client_config =
coordination::CoordinatorClientConfig{.instance_name = std::string(instance_name),
.ip_address = coordinator_server_ip,
.ip_address = std::string(coordinator_server_ip),
.port = coordinator_server_port,
.instance_health_check_frequency_sec = instance_check_frequency,
.instance_down_timeout_sec = instance_down_timeout,
@ -1212,7 +1212,7 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
};
notifications->emplace_back(
SeverityLevel::INFO, NotificationCode::REGISTER_COORDINATOR_SERVER,
SeverityLevel::INFO, NotificationCode::REGISTER_REPLICATION_INSTANCE,
fmt::format("Coordinator has registered coordinator server on {} for instance {}.",
coordinator_socket_address_tv.ValueString(), coordinator_query->instance_name_));
return callback;

View File

@ -67,8 +67,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) {
case NotificationCode::REGISTER_REPLICA:
return "RegisterReplica"sv;
#ifdef MG_ENTERPRISE
case NotificationCode::REGISTER_COORDINATOR_SERVER:
return "RegisterCoordinatorServer"sv;
case NotificationCode::REGISTER_REPLICATION_INSTANCE:
return "RegisterReplicationInstance"sv;
case NotificationCode::ADD_COORDINATOR_INSTANCE:
return "AddCoordinatorInstance"sv;
case NotificationCode::UNREGISTER_INSTANCE:

View File

@ -43,7 +43,7 @@ enum class NotificationCode : uint8_t {
REPLICA_PORT_WARNING,
REGISTER_REPLICA,
#ifdef MG_ENTERPRISE
REGISTER_COORDINATOR_SERVER, // TODO: (andi) What is this?
REGISTER_REPLICATION_INSTANCE,
ADD_COORDINATOR_INSTANCE,
UNREGISTER_INSTANCE,
#endif

View File

@ -12,8 +12,14 @@
#pragma once
#include <cstdint>
#include "json/json.hpp"
namespace memgraph::replication_coordination_glue {
// TODO: figure out a way of ensuring that usage of this type is never uninitialed/defaulted incorrectly to MAIN
enum class ReplicationRole : uint8_t { MAIN, REPLICA };
NLOHMANN_JSON_SERIALIZE_ENUM(ReplicationRole, {{ReplicationRole::MAIN, "main"}, {ReplicationRole::REPLICA, "replica"}})
} // namespace memgraph::replication_coordination_glue

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -34,12 +34,13 @@ DEFINE_double(reads_duration_limit, 10.0, "How long should the client perform re
namespace mg::e2e::replication {
auto ParseDatabaseEndpoints(const std::string &database_endpoints_str) {
const auto db_endpoints_strs = memgraph::utils::Split(database_endpoints_str, ",");
const auto db_endpoints_strs = memgraph::utils::SplitView(database_endpoints_str, ",");
std::vector<memgraph::io::network::Endpoint> database_endpoints;
for (const auto &db_endpoint_str : db_endpoints_strs) {
const auto maybe_host_port = memgraph::io::network::Endpoint::ParseSocketOrIpAddress(db_endpoint_str, 7687);
const auto maybe_host_port = memgraph::io::network::Endpoint::ParseSocketOrAddress(db_endpoint_str, 7687);
MG_ASSERT(maybe_host_port);
database_endpoints.emplace_back(maybe_host_port->first, maybe_host_port->second);
auto const [ip, port] = *maybe_host_port;
database_endpoints.emplace_back(std::string(ip), port);
}
return database_endpoints;
}

View File

@ -445,3 +445,10 @@ add_unit_test(raft_log_serialization.cpp)
target_link_libraries(${test_prefix}raft_log_serialization gflags mg-coordination mg-repl_coord_glue)
target_include_directories(${test_prefix}raft_log_serialization PRIVATE ${CMAKE_SOURCE_DIR}/include)
endif()
# Test Raft log serialization
if(MG_ENTERPRISE)
add_unit_test(coordinator_cluster_state.cpp)
target_link_libraries(${test_prefix}coordinator_cluster_state gflags mg-coordination mg-repl_coord_glue)
target_include_directories(${test_prefix}coordinator_cluster_state PRIVATE ${CMAKE_SOURCE_DIR}/include)
endif()

View File

@ -0,0 +1,163 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "nuraft/coordinator_cluster_state.hpp"
#include "nuraft/coordinator_state_machine.hpp"
#include "replication_coordination_glue/role.hpp"
#include "utils/file.hpp"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "json/json.hpp"
#include "libnuraft/nuraft.hxx"
using memgraph::coordination::CoordinatorClientConfig;
using memgraph::coordination::CoordinatorClusterState;
using memgraph::coordination::CoordinatorStateMachine;
using memgraph::coordination::InstanceState;
using memgraph::coordination::RaftLogAction;
using memgraph::replication_coordination_glue::ReplicationMode;
using memgraph::replication_coordination_glue::ReplicationRole;
using nuraft::buffer;
using nuraft::buffer_serializer;
using nuraft::ptr;
class CoordinatorClusterStateTest : public ::testing::Test {
protected:
void SetUp() override {}
void TearDown() override {}
std::filesystem::path test_folder_{std::filesystem::temp_directory_path() /
"MG_tests_unit_coordinator_cluster_state"};
};
TEST_F(CoordinatorClusterStateTest, InstanceStateSerialization) {
InstanceState instance_state{
CoordinatorClientConfig{"instance3",
"127.0.0.1",
10112,
std::chrono::seconds{1},
std::chrono::seconds{5},
std::chrono::seconds{10},
{"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10001},
.ssl = std::nullopt},
ReplicationRole::MAIN};
nlohmann::json j = instance_state;
InstanceState deserialized_instance_state = j.get<InstanceState>();
EXPECT_EQ(instance_state.config, deserialized_instance_state.config);
EXPECT_EQ(instance_state.status, deserialized_instance_state.status);
}
TEST_F(CoordinatorClusterStateTest, DoActionRegisterInstances) {
auto coordinator_cluster_state = memgraph::coordination::CoordinatorClusterState{};
{
CoordinatorClientConfig config{"instance1",
"127.0.0.1",
10111,
std::chrono::seconds{1},
std::chrono::seconds{5},
std::chrono::seconds{10},
{"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10001},
.ssl = std::nullopt};
auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config);
auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer);
coordinator_cluster_state.DoAction(payload, action);
}
{
CoordinatorClientConfig config{"instance2",
"127.0.0.1",
10112,
std::chrono::seconds{1},
std::chrono::seconds{5},
std::chrono::seconds{10},
{"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10002},
.ssl = std::nullopt};
auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config);
auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer);
coordinator_cluster_state.DoAction(payload, action);
}
{
CoordinatorClientConfig config{"instance3",
"127.0.0.1",
10113,
std::chrono::seconds{1},
std::chrono::seconds{5},
std::chrono::seconds{10},
{"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10003},
.ssl = std::nullopt};
auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config);
auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer);
coordinator_cluster_state.DoAction(payload, action);
}
{
CoordinatorClientConfig config{"instance4",
"127.0.0.1",
10114,
std::chrono::seconds{1},
std::chrono::seconds{5},
std::chrono::seconds{10},
{"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10004},
.ssl = std::nullopt};
auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config);
auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer);
coordinator_cluster_state.DoAction(payload, action);
}
{
CoordinatorClientConfig config{"instance5",
"127.0.0.1",
10115,
std::chrono::seconds{1},
std::chrono::seconds{5},
std::chrono::seconds{10},
{"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10005},
.ssl = std::nullopt};
auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config);
auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer);
coordinator_cluster_state.DoAction(payload, action);
}
{
CoordinatorClientConfig config{"instance6",
"127.0.0.1",
10116,
std::chrono::seconds{1},
std::chrono::seconds{5},
std::chrono::seconds{10},
{"instance_name", ReplicationMode::ASYNC, "replication_ip_address", 10006},
.ssl = std::nullopt};
auto buffer = CoordinatorStateMachine::SerializeRegisterInstance(config);
auto [payload, action] = CoordinatorStateMachine::DecodeLog(*buffer);
coordinator_cluster_state.DoAction(payload, action);
}
ptr<buffer> data;
coordinator_cluster_state.Serialize(data);
auto deserialized_coordinator_cluster_state = CoordinatorClusterState::Deserialize(*data);
ASSERT_EQ(coordinator_cluster_state.GetInstances(), deserialized_coordinator_cluster_state.GetInstances());
}