Add DNS support for cluster replica address (#1323)

This commit is contained in:
DavIvek 2023-10-24 13:11:36 +02:00 committed by GitHub
parent 1d45016217
commit 98680b04c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 150 additions and 8 deletions

View File

@ -15,6 +15,7 @@
#include <algorithm>
#include "endpoint.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/network_error.hpp"
#include "utils/logging.hpp"
@ -23,11 +24,11 @@
namespace memgraph::io::network {
Endpoint::IpFamily Endpoint::GetIpFamily(const std::string &ip_address) {
Endpoint::IpFamily Endpoint::GetIpFamily(const std::string &address) {
in_addr addr4;
in6_addr addr6;
int ipv4_result = inet_pton(AF_INET, ip_address.c_str(), &addr4);
int ipv6_result = inet_pton(AF_INET6, ip_address.c_str(), &addr6);
int ipv4_result = inet_pton(AF_INET, address.c_str(), &addr4);
int ipv6_result = inet_pton(AF_INET6, address.c_str(), &addr6);
if (ipv4_result == 1) {
return IpFamily::IP4;
} else if (ipv6_result == 1) {
@ -86,6 +87,44 @@ std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseSocketOrIpAddress
return std::nullopt;
}
std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseHostname(
const std::string &address, const std::optional<uint16_t> default_port = {}) {
const std::string delimiter = ":";
std::string ip_address;
std::vector<std::string> parts = utils::Split(address, delimiter);
if (parts.size() == 1) {
if (default_port) {
if (!IsResolvableAddress(address, *default_port)) {
return std::nullopt;
}
return std::pair{address, *default_port};
}
} else if (parts.size() == 2) {
int64_t int_port{0};
auto hostname = std::move(parts[0]);
try {
int_port = utils::ParseInt(parts[1]);
} catch (utils::BasicException &e) {
spdlog::error(utils::MessageWithLink("Invalid port number {}.", parts[1], "https://memgr.ph/ports"));
return std::nullopt;
}
if (int_port < 0) {
spdlog::error(utils::MessageWithLink("Invalid port number {}. The port number must be a positive integer.",
int_port, "https://memgr.ph/ports"));
return std::nullopt;
}
if (int_port > std::numeric_limits<uint16_t>::max()) {
spdlog::error(utils::MessageWithLink("Invalid port number. The port number exceedes the maximum possible size.",
"https://memgr.ph/ports"));
return std::nullopt;
}
if (IsResolvableAddress(hostname, static_cast<uint16_t>(int_port))) {
return std::pair{hostname, static_cast<u_int16_t>(int_port)};
}
}
return std::nullopt;
}
std::string Endpoint::SocketAddress() const {
auto ip_address = address.empty() ? "EMPTY" : address;
return ip_address + ":" + std::to_string(port);
@ -99,6 +138,16 @@ Endpoint::Endpoint(std::string ip_address, uint16_t port) : address(std::move(ip
family = ip_family;
}
// NOLINTNEXTLINE
Endpoint::Endpoint(needs_resolving_t, std::string hostname, uint16_t port) : port(port) {
address = ResolveHostnameIntoIpAddress(hostname, port);
IpFamily ip_family = GetIpFamily(address);
if (ip_family == IpFamily::NONE) {
throw NetworkError("Not a valid IPv4 or IPv6 address: {}", address);
}
family = ip_family;
}
std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint) {
// no need to cover the IpFamily::NONE case, as you can't even construct an
// Endpoint object if the IpFamily is NONE (i.e. the IP address is invalid)
@ -109,4 +158,65 @@ std::ostream &operator<<(std::ostream &os, const Endpoint &endpoint) {
return os << endpoint.address << ":" << endpoint.port;
}
bool Endpoint::IsResolvableAddress(const std::string &address, uint16_t port) {
addrinfo hints{
.ai_flags = AI_PASSIVE,
.ai_family = AF_UNSPEC, // IPv4 and IPv6
.ai_socktype = SOCK_STREAM // TCP socket
};
addrinfo *info = nullptr;
auto status = getaddrinfo(address.c_str(), std::to_string(port).c_str(), &hints, &info);
freeaddrinfo(info);
return status == 0;
}
std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseSocketOrAddress(
const std::string &address, const std::optional<uint16_t> default_port = {}) {
const std::string delimiter = ":";
std::vector<std::string> parts = utils::Split(address, delimiter);
if (parts.size() == 1) {
if (GetIpFamily(address) == IpFamily::NONE) {
return ParseHostname(address, default_port);
}
return ParseSocketOrIpAddress(address, default_port);
}
if (parts.size() == 2) {
if (GetIpFamily(parts[0]) == IpFamily::NONE) {
return ParseHostname(address, default_port);
}
return ParseSocketOrIpAddress(address, default_port);
}
return std::nullopt;
}
std::string Endpoint::ResolveHostnameIntoIpAddress(const std::string &address, uint16_t port) {
addrinfo hints{
.ai_flags = AI_PASSIVE,
.ai_family = AF_UNSPEC, // IPv4 and IPv6
.ai_socktype = SOCK_STREAM // TCP socket
};
addrinfo *info = nullptr;
auto status = getaddrinfo(address.c_str(), std::to_string(port).c_str(), &hints, &info);
if (status != 0) throw NetworkError(gai_strerror(status));
for (auto *result = info; result != nullptr; result = result->ai_next) {
if (result->ai_family == AF_INET) {
char ipstr[INET_ADDRSTRLEN];
auto *ipv4 = reinterpret_cast<struct sockaddr_in *>(result->ai_addr);
inet_ntop(AF_INET, &(ipv4->sin_addr), ipstr, sizeof(ipstr));
freeaddrinfo(info);
return ipstr;
}
if (result->ai_family == AF_INET6) {
char ipstr[INET6_ADDRSTRLEN];
auto *ipv6 = reinterpret_cast<struct sockaddr_in6 *>(result->ai_addr);
inet_ntop(AF_INET6, &(ipv6->sin6_addr), ipstr, sizeof(ipstr));
freeaddrinfo(info);
return ipstr;
}
}
freeaddrinfo(info);
throw NetworkError("Not a valid address: {}", address);
}
} // namespace memgraph::io::network

View File

@ -25,8 +25,12 @@ namespace memgraph::io::network {
* connection address.
*/
struct Endpoint {
static const struct needs_resolving_t {
} needs_resolving;
Endpoint() = default;
Endpoint(std::string ip_address, uint16_t port);
Endpoint(needs_resolving_t, std::string hostname, uint16_t port);
Endpoint(Endpoint const &) = default;
Endpoint(Endpoint &&) noexcept = default;
Endpoint &operator=(Endpoint const &) = default;
@ -44,6 +48,9 @@ struct Endpoint {
uint16_t port{0};
IpFamily family{IpFamily::NONE};
static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrAddress(
const std::string &address, const std::optional<uint16_t> default_port);
/**
* Tries to parse the given string as either a socket address or ip address.
* Expected address format:
@ -57,7 +64,26 @@ struct Endpoint {
static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrIpAddress(
const std::string &address, const std::optional<uint16_t> default_port);
static IpFamily GetIpFamily(const std::string &ip_address);
/**
* Tries to parse given string as either socket address or hostname.
* Expected address format:
* - "hostname:port_number"
* - "hostname"
* After we parse hostname and port we try to resolve the hostname into an ip_address.
*/
static std::optional<std::pair<std::string, uint16_t>> ParseHostname(const std::string &address,
const std::optional<uint16_t> default_port);
static IpFamily GetIpFamily(const std::string &address);
static bool IsResolvableAddress(const std::string &address, uint16_t port);
/**
* Tries to resolve hostname to its corresponding IP address.
* Given a DNS hostname, this function performs resolution and returns
* the IP address associated with the hostname.
*/
static std::string ResolveHostnameIntoIpAddress(const std::string &address, uint16_t port);
};
} // namespace memgraph::io::network

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source

View File

@ -318,7 +318,7 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
auto repl_mode = convertToReplicationMode(sync_mode);
auto maybe_ip_and_port =
io::network::Endpoint::ParseSocketOrIpAddress(socket_address, memgraph::replication::kDefaultReplicationPort);
io::network::Endpoint::ParseSocketOrAddress(socket_address, memgraph::replication::kDefaultReplicationPort);
if (maybe_ip_and_port) {
auto [ip, port] = *maybe_ip_and_port;
auto config = replication::ReplicationClientConfig{.name = name,

View File

@ -29,7 +29,8 @@ ReplicationClient::ReplicationClient(Storage *storage, const memgraph::replicati
const memgraph::replication::ReplicationEpoch *epoch)
: name_{config.name},
rpc_context_{CreateClientContext(config)},
rpc_client_{io::network::Endpoint(config.ip_address, config.port), &rpc_context_},
rpc_client_{io::network::Endpoint(io::network::Endpoint::needs_resolving, config.ip_address, config.port),
&rpc_context_},
replica_check_frequency_{config.replica_check_frequency},
mode_{config.mode},
storage_{storage},

View File

@ -68,7 +68,12 @@ memgraph::utils::BasicResult<RegisterReplicaError> ReplicationHandler::RegisterR
return std::any_of(clients.begin(), clients.end(), name_matches);
};
auto desired_endpoint = io::network::Endpoint{config.ip_address, config.port};
io::network::Endpoint desired_endpoint;
if (io::network::Endpoint::GetIpFamily(config.ip_address) == io::network::Endpoint::IpFamily::NONE) {
desired_endpoint = io::network::Endpoint{io::network::Endpoint::needs_resolving, config.ip_address, config.port};
} else {
desired_endpoint = io::network::Endpoint{config.ip_address, config.port};
}
auto endpoint_check = [&](auto &clients) {
auto endpoint_matches = [&](const auto &client) { return client->Endpoint() == desired_endpoint; };
return std::any_of(clients.begin(), clients.end(), endpoint_matches);