Small io network socket fixes (#360)
* Modernize AddrInfo * Modernize Socket
This commit is contained in:
parent
17049ada09
commit
c8dbaf5979
@ -9,34 +9,52 @@
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <netdb.h>
|
||||
#include <cstring>
|
||||
|
||||
#include "io/network/addrinfo.hpp"
|
||||
|
||||
#include <concepts>
|
||||
#include <iterator>
|
||||
|
||||
#include "io/network/network_error.hpp"
|
||||
|
||||
namespace memgraph::io::network {
|
||||
|
||||
AddrInfo::AddrInfo(struct addrinfo *info) : info(info) {}
|
||||
static_assert(std::forward_iterator<AddrInfo::Iterator> && std::equality_comparable<AddrInfo::Iterator>);
|
||||
|
||||
AddrInfo::~AddrInfo() { freeaddrinfo(info); }
|
||||
|
||||
AddrInfo AddrInfo::Get(const char *addr, const char *port) {
|
||||
struct addrinfo hints;
|
||||
memset(&hints, 0, sizeof(struct addrinfo));
|
||||
|
||||
hints.ai_family = AF_UNSPEC; // IPv4 and IPv6
|
||||
hints.ai_socktype = SOCK_STREAM; // TCP socket
|
||||
hints.ai_flags = AI_PASSIVE;
|
||||
|
||||
struct addrinfo *result;
|
||||
auto status = getaddrinfo(addr, port, &hints, &result);
|
||||
AddrInfo::AddrInfo(const Endpoint &endpoint) : AddrInfo(endpoint.address, endpoint.port) {}
|
||||
|
||||
AddrInfo::AddrInfo(const std::string &addr, uint16_t port) : info_{nullptr, nullptr} {
|
||||
addrinfo hints{
|
||||
.ai_flags = AI_PASSIVE,
|
||||
.ai_family = AF_UNSPEC, // IPv4 and IPv6
|
||||
.ai_socktype = SOCK_STREAM // TCP socket
|
||||
};
|
||||
addrinfo *info = nullptr;
|
||||
auto status = getaddrinfo(addr.c_str(), std::to_string(port).c_str(), &hints, &info);
|
||||
if (status != 0) throw NetworkError(gai_strerror(status));
|
||||
|
||||
return AddrInfo(result);
|
||||
info_ = std::unique_ptr<addrinfo, decltype(&freeaddrinfo)>(info, &freeaddrinfo);
|
||||
}
|
||||
|
||||
AddrInfo::operator struct addrinfo *() { return info; }
|
||||
AddrInfo::Iterator::Iterator(addrinfo *p) noexcept : ptr_(p) {}
|
||||
|
||||
AddrInfo::Iterator::reference AddrInfo::Iterator::operator*() const noexcept { return *ptr_; }
|
||||
|
||||
AddrInfo::Iterator::pointer AddrInfo::Iterator::operator->() const noexcept { return ptr_; }
|
||||
|
||||
// NOLINTNEXTLINE(cert-dcl21-cpp)
|
||||
AddrInfo::Iterator AddrInfo::Iterator::operator++(int) noexcept {
|
||||
auto it = *this;
|
||||
++(*this);
|
||||
return it;
|
||||
}
|
||||
AddrInfo::Iterator &AddrInfo::Iterator::operator++() noexcept {
|
||||
ptr_ = ptr_->ai_next;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(const AddrInfo::Iterator &lhs, const AddrInfo::Iterator &rhs) noexcept { return lhs.ptr_ == rhs.ptr_; };
|
||||
|
||||
bool operator!=(const AddrInfo::Iterator &lhs, const AddrInfo::Iterator &rhs) noexcept { return !(lhs == rhs); };
|
||||
|
||||
void swap(AddrInfo::Iterator &lhs, AddrInfo::Iterator &rhs) noexcept { std::swap(lhs.ptr_, rhs.ptr_); };
|
||||
|
||||
} // namespace memgraph::io::network
|
||||
|
@ -11,6 +11,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <netdb.h>
|
||||
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "io/network/endpoint.hpp"
|
||||
|
||||
namespace memgraph::io::network {
|
||||
|
||||
/**
|
||||
@ -18,16 +26,38 @@ namespace memgraph::io::network {
|
||||
* see: man 3 getaddrinfo
|
||||
*/
|
||||
class AddrInfo {
|
||||
explicit AddrInfo(struct addrinfo *info);
|
||||
|
||||
public:
|
||||
~AddrInfo();
|
||||
struct Iterator {
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
using value_type = addrinfo;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = addrinfo *;
|
||||
using reference = addrinfo &;
|
||||
|
||||
static AddrInfo Get(const char *addr, const char *port);
|
||||
Iterator() = default;
|
||||
Iterator(const Iterator &) = default;
|
||||
explicit Iterator(addrinfo *p) noexcept;
|
||||
Iterator &operator=(const Iterator &) = default;
|
||||
reference operator*() const noexcept;
|
||||
pointer operator->() const noexcept;
|
||||
Iterator operator++(int) noexcept;
|
||||
Iterator &operator++() noexcept;
|
||||
|
||||
operator struct addrinfo *();
|
||||
friend bool operator==(const Iterator &lhs, const Iterator &rhs) noexcept;
|
||||
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) noexcept;
|
||||
friend void swap(Iterator &lhs, Iterator &rhs) noexcept;
|
||||
|
||||
private:
|
||||
addrinfo *ptr_{nullptr};
|
||||
};
|
||||
|
||||
AddrInfo(const std::string &addr, uint16_t port);
|
||||
explicit AddrInfo(const Endpoint &endpoint);
|
||||
|
||||
auto begin() const noexcept { return Iterator(info_.get()); }
|
||||
auto end() const noexcept { return Iterator{nullptr}; }
|
||||
|
||||
private:
|
||||
struct addrinfo *info;
|
||||
std::unique_ptr<addrinfo, void (*)(addrinfo *)> info_;
|
||||
};
|
||||
} // namespace memgraph::io::network
|
||||
|
@ -9,39 +9,25 @@
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <poll.h>
|
||||
#include <sys/epoll.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "io/network/addrinfo.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
#include "utils/likely.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
namespace memgraph::io::network {
|
||||
|
||||
Socket::Socket(Socket &&other) {
|
||||
socket_ = other.socket_;
|
||||
endpoint_ = std::move(other.endpoint_);
|
||||
Socket::Socket(Socket &&other) noexcept : socket_(other.socket_), endpoint_(std::move(other.endpoint_)) {
|
||||
other.socket_ = -1;
|
||||
}
|
||||
|
||||
Socket &Socket::operator=(Socket &&other) {
|
||||
Socket &Socket::operator=(Socket &&other) noexcept {
|
||||
if (this != &other) {
|
||||
if (socket_ != -1) close(socket_);
|
||||
socket_ = other.socket_;
|
||||
endpoint_ = std::move(other.endpoint_);
|
||||
other.socket_ = -1;
|
||||
@ -49,9 +35,8 @@ Socket &Socket::operator=(Socket &&other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
Socket::~Socket() {
|
||||
if (socket_ == -1) return;
|
||||
close(socket_);
|
||||
Socket::~Socket() noexcept {
|
||||
if (socket_ != -1) close(socket_);
|
||||
}
|
||||
|
||||
void Socket::Close() {
|
||||
@ -70,33 +55,27 @@ bool Socket::IsOpen() const { return socket_ != -1; }
|
||||
bool Socket::Connect(const Endpoint &endpoint) {
|
||||
if (socket_ != -1) return false;
|
||||
|
||||
auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
|
||||
|
||||
for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
|
||||
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
|
||||
for (const auto &it : AddrInfo{endpoint}) {
|
||||
int sfd = socket(it.ai_family, it.ai_socktype, it.ai_protocol);
|
||||
if (sfd == -1) continue;
|
||||
if (connect(sfd, it->ai_addr, it->ai_addrlen) == 0) {
|
||||
if (connect(sfd, it.ai_addr, it.ai_addrlen) == 0) {
|
||||
socket_ = sfd;
|
||||
endpoint_ = endpoint;
|
||||
break;
|
||||
} else {
|
||||
// If the connect failed close the file descriptor to prevent file
|
||||
// descriptors being leaked
|
||||
close(sfd);
|
||||
}
|
||||
// If the connect failed close the file descriptor to prevent file
|
||||
// descriptors being leaked
|
||||
close(sfd);
|
||||
}
|
||||
|
||||
if (socket_ == -1) return false;
|
||||
return true;
|
||||
return !(socket_ == -1);
|
||||
}
|
||||
|
||||
bool Socket::Bind(const Endpoint &endpoint) {
|
||||
if (socket_ != -1) return false;
|
||||
|
||||
auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
|
||||
|
||||
for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
|
||||
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
|
||||
for (const auto &it : AddrInfo{endpoint}) {
|
||||
int sfd = socket(it.ai_family, it.ai_socktype, it.ai_protocol);
|
||||
if (sfd == -1) continue;
|
||||
|
||||
int on = 1;
|
||||
@ -107,14 +86,13 @@ bool Socket::Bind(const Endpoint &endpoint) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (bind(sfd, it->ai_addr, it->ai_addrlen) == 0) {
|
||||
if (bind(sfd, it.ai_addr, it.ai_addrlen) == 0) {
|
||||
socket_ = sfd;
|
||||
break;
|
||||
} else {
|
||||
// If the bind failed close the file descriptor to prevent file
|
||||
// descriptors being leaked
|
||||
close(sfd);
|
||||
}
|
||||
// If the bind failed close the file descriptor to prevent file
|
||||
// descriptors being leaked
|
||||
close(sfd);
|
||||
}
|
||||
|
||||
if (socket_ == -1) return false;
|
||||
@ -122,7 +100,7 @@ bool Socket::Bind(const Endpoint &endpoint) {
|
||||
// detect bound port, used when the server binds to a random port
|
||||
struct sockaddr_in6 portdata;
|
||||
socklen_t portdatalen = sizeof(portdata);
|
||||
if (getsockname(socket_, (struct sockaddr *)&portdata, &portdatalen) < 0) {
|
||||
if (getsockname(socket_, reinterpret_cast<sockaddr *>(&portdata), &portdatalen) < 0) {
|
||||
// If the getsockname failed close the file descriptor to prevent file
|
||||
// descriptors being leaked
|
||||
close(socket_);
|
||||
@ -136,36 +114,35 @@ bool Socket::Bind(const Endpoint &endpoint) {
|
||||
}
|
||||
|
||||
void Socket::SetNonBlocking() {
|
||||
int flags = fcntl(socket_, F_GETFL, 0);
|
||||
const unsigned flags = fcntl(socket_, F_GETFL);
|
||||
constexpr unsigned o_nonblock = O_NONBLOCK;
|
||||
MG_ASSERT(flags != -1, "Can't get socket mode");
|
||||
flags |= O_NONBLOCK;
|
||||
MG_ASSERT(fcntl(socket_, F_SETFL, flags) != -1, "Can't set socket nonblocking");
|
||||
MG_ASSERT(fcntl(socket_, F_SETFL, flags | o_nonblock) != -1, "Can't set socket nonblocking");
|
||||
}
|
||||
|
||||
void Socket::SetKeepAlive() {
|
||||
int optval = 1;
|
||||
socklen_t optlen = sizeof(optval);
|
||||
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen), "Can't set socket keep alive");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)), "Can't set socket keep alive");
|
||||
|
||||
optval = 20; // wait 20s before sending keep-alive packets
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen), "Can't set socket keep alive");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, sizeof(optval)),
|
||||
"Can't set socket keep alive");
|
||||
|
||||
optval = 4; // 4 keep-alive packets must fail to close
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen), "Can't set socket keep alive");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, sizeof(optval)), "Can't set socket keep alive");
|
||||
|
||||
optval = 15; // send keep-alive packets every 15s
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen), "Can't set socket keep alive");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, sizeof(optval)),
|
||||
"Can't set socket keep alive");
|
||||
}
|
||||
|
||||
void Socket::SetNoDelay() {
|
||||
int optval = 1;
|
||||
socklen_t optlen = sizeof(optval);
|
||||
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen), "Can't set socket no delay");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, sizeof(optval)), "Can't set socket no delay");
|
||||
}
|
||||
|
||||
void Socket::SetTimeout(long sec, long usec) {
|
||||
// NOLINTNEXTLINE(readability-make-member-function-const)
|
||||
void Socket::SetTimeout(int64_t sec, int64_t usec) {
|
||||
struct timeval tv;
|
||||
tv.tv_sec = sec;
|
||||
tv.tv_usec = usec;
|
||||
@ -176,7 +153,7 @@ void Socket::SetTimeout(long sec, long usec) {
|
||||
}
|
||||
|
||||
int Socket::ErrorStatus() const {
|
||||
int optval;
|
||||
int optval = 0;
|
||||
socklen_t optlen = sizeof(optval);
|
||||
auto status = getsockopt(socket_, SOL_SOCKET, SO_ERROR, &optval, &optlen);
|
||||
MG_ASSERT(!status, "getsockopt failed");
|
||||
@ -189,21 +166,22 @@ std::optional<Socket> Socket::Accept() {
|
||||
sockaddr_storage addr;
|
||||
socklen_t addr_size = sizeof addr;
|
||||
char addr_decoded[INET6_ADDRSTRLEN];
|
||||
void *addr_src;
|
||||
unsigned short port;
|
||||
|
||||
int sfd = accept(socket_, (struct sockaddr *)&addr, &addr_size);
|
||||
int sfd = accept(socket_, reinterpret_cast<sockaddr *>(&addr), &addr_size);
|
||||
if (UNLIKELY(sfd == -1)) return std::nullopt;
|
||||
|
||||
void *addr_src = nullptr;
|
||||
uint16_t port = 0;
|
||||
|
||||
if (addr.ss_family == AF_INET) {
|
||||
addr_src = (void *)&(((sockaddr_in *)&addr)->sin_addr);
|
||||
port = ntohs(((sockaddr_in *)&addr)->sin_port);
|
||||
addr_src = &reinterpret_cast<sockaddr_in &>(addr).sin_addr;
|
||||
port = ntohs(reinterpret_cast<sockaddr_in &>(addr).sin_port);
|
||||
} else {
|
||||
addr_src = (void *)&(((sockaddr_in6 *)&addr)->sin6_addr);
|
||||
port = ntohs(((sockaddr_in6 *)&addr)->sin6_port);
|
||||
addr_src = &reinterpret_cast<sockaddr_in6 &>(addr).sin6_addr;
|
||||
port = ntohs(reinterpret_cast<sockaddr_in6 &>(addr).sin6_port);
|
||||
}
|
||||
|
||||
inet_ntop(addr.ss_family, addr_src, addr_decoded, INET6_ADDRSTRLEN);
|
||||
inet_ntop(addr.ss_family, addr_src, addr_decoded, sizeof(addr_decoded));
|
||||
|
||||
Endpoint endpoint(addr_decoded, port);
|
||||
|
||||
@ -213,9 +191,11 @@ std::optional<Socket> Socket::Accept() {
|
||||
bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
|
||||
// MSG_NOSIGNAL is here to disable raising a SIGPIPE signal when a
|
||||
// connection dies mid-write, the socket will only return an EPIPE error.
|
||||
int flags = MSG_NOSIGNAL | (have_more ? MSG_MORE : 0);
|
||||
constexpr unsigned msg_nosignal = MSG_NOSIGNAL;
|
||||
constexpr unsigned msg_more = MSG_MORE;
|
||||
const unsigned flags = msg_nosignal | (have_more ? msg_more : 0);
|
||||
while (len > 0) {
|
||||
auto written = send(socket_, data, len, flags);
|
||||
auto written = send(socket_, data, len, static_cast<int>(flags));
|
||||
if (written == -1) {
|
||||
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
|
||||
// Terminal error, return failure.
|
||||
@ -253,7 +233,8 @@ bool Socket::WaitForReadyRead() {
|
||||
// event occurs.
|
||||
int ret = poll(&p, 1, -1);
|
||||
if (ret < 1) return false;
|
||||
return p.revents & POLLIN;
|
||||
constexpr unsigned pollin = POLLIN;
|
||||
return static_cast<unsigned>(p.revents) & pollin;
|
||||
}
|
||||
|
||||
bool Socket::WaitForReadyWrite() {
|
||||
@ -265,7 +246,8 @@ bool Socket::WaitForReadyWrite() {
|
||||
// event occurs.
|
||||
int ret = poll(&p, 1, -1);
|
||||
if (ret < 1) return false;
|
||||
return p.revents & POLLOUT;
|
||||
constexpr unsigned pollout = POLLOUT;
|
||||
return static_cast<unsigned>(p.revents) & pollout;
|
||||
}
|
||||
|
||||
} // namespace memgraph::io::network
|
||||
|
@ -27,12 +27,12 @@ namespace memgraph::io::network {
|
||||
*/
|
||||
class Socket {
|
||||
public:
|
||||
Socket() = default;
|
||||
Socket() noexcept = default;
|
||||
Socket(const Socket &) = delete;
|
||||
Socket &operator=(const Socket &) = delete;
|
||||
Socket(Socket &&);
|
||||
Socket &operator=(Socket &&);
|
||||
~Socket();
|
||||
Socket(Socket &&) noexcept;
|
||||
Socket &operator=(Socket &&) noexcept;
|
||||
~Socket() noexcept;
|
||||
|
||||
/**
|
||||
* Closes the socket if it is open.
|
||||
@ -118,7 +118,7 @@ class Socket {
|
||||
* @param sec timeout seconds value
|
||||
* @param usec timeout microseconds value
|
||||
*/
|
||||
void SetTimeout(long sec, long usec);
|
||||
void SetTimeout(int64_t sec, int64_t usec);
|
||||
|
||||
/**
|
||||
* Checks if there are any errors on a socket. Returns 0 if there are none.
|
||||
|
Loading…
Reference in New Issue
Block a user