Small io network socket fixes (#360)

* Modernize AddrInfo

* Modernize Socket
This commit is contained in:
Siniša Šušnjar 2022-04-08 13:38:13 +01:00 committed by GitHub
parent 17049ada09
commit c8dbaf5979
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 128 additions and 98 deletions

View File

@ -9,34 +9,52 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <netdb.h>
#include <cstring>
#include "io/network/addrinfo.hpp"
#include <concepts>
#include <iterator>
#include "io/network/network_error.hpp"
namespace memgraph::io::network {
AddrInfo::AddrInfo(struct addrinfo *info) : info(info) {}
static_assert(std::forward_iterator<AddrInfo::Iterator> && std::equality_comparable<AddrInfo::Iterator>);
AddrInfo::~AddrInfo() { freeaddrinfo(info); }
AddrInfo AddrInfo::Get(const char *addr, const char *port) {
struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_UNSPEC; // IPv4 and IPv6
hints.ai_socktype = SOCK_STREAM; // TCP socket
hints.ai_flags = AI_PASSIVE;
struct addrinfo *result;
auto status = getaddrinfo(addr, port, &hints, &result);
AddrInfo::AddrInfo(const Endpoint &endpoint) : AddrInfo(endpoint.address, endpoint.port) {}
AddrInfo::AddrInfo(const std::string &addr, uint16_t port) : info_{nullptr, nullptr} {
addrinfo hints{
.ai_flags = AI_PASSIVE,
.ai_family = AF_UNSPEC, // IPv4 and IPv6
.ai_socktype = SOCK_STREAM // TCP socket
};
addrinfo *info = nullptr;
auto status = getaddrinfo(addr.c_str(), std::to_string(port).c_str(), &hints, &info);
if (status != 0) throw NetworkError(gai_strerror(status));
return AddrInfo(result);
info_ = std::unique_ptr<addrinfo, decltype(&freeaddrinfo)>(info, &freeaddrinfo);
}
AddrInfo::operator struct addrinfo *() { return info; }
AddrInfo::Iterator::Iterator(addrinfo *p) noexcept : ptr_(p) {}
AddrInfo::Iterator::reference AddrInfo::Iterator::operator*() const noexcept { return *ptr_; }
AddrInfo::Iterator::pointer AddrInfo::Iterator::operator->() const noexcept { return ptr_; }
// NOLINTNEXTLINE(cert-dcl21-cpp)
AddrInfo::Iterator AddrInfo::Iterator::operator++(int) noexcept {
auto it = *this;
++(*this);
return it;
}
AddrInfo::Iterator &AddrInfo::Iterator::operator++() noexcept {
ptr_ = ptr_->ai_next;
return *this;
}
bool operator==(const AddrInfo::Iterator &lhs, const AddrInfo::Iterator &rhs) noexcept { return lhs.ptr_ == rhs.ptr_; };
bool operator!=(const AddrInfo::Iterator &lhs, const AddrInfo::Iterator &rhs) noexcept { return !(lhs == rhs); };
void swap(AddrInfo::Iterator &lhs, AddrInfo::Iterator &rhs) noexcept { std::swap(lhs.ptr_, rhs.ptr_); };
} // namespace memgraph::io::network

View File

@ -11,6 +11,14 @@
#pragma once
#include <netdb.h>
#include <iterator>
#include <memory>
#include <string>
#include "io/network/endpoint.hpp"
namespace memgraph::io::network {
/**
@ -18,16 +26,38 @@ namespace memgraph::io::network {
* see: man 3 getaddrinfo
*/
class AddrInfo {
explicit AddrInfo(struct addrinfo *info);
public:
~AddrInfo();
struct Iterator {
using iterator_category = std::forward_iterator_tag;
using value_type = addrinfo;
using difference_type = std::ptrdiff_t;
using pointer = addrinfo *;
using reference = addrinfo &;
static AddrInfo Get(const char *addr, const char *port);
Iterator() = default;
Iterator(const Iterator &) = default;
explicit Iterator(addrinfo *p) noexcept;
Iterator &operator=(const Iterator &) = default;
reference operator*() const noexcept;
pointer operator->() const noexcept;
Iterator operator++(int) noexcept;
Iterator &operator++() noexcept;
operator struct addrinfo *();
friend bool operator==(const Iterator &lhs, const Iterator &rhs) noexcept;
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) noexcept;
friend void swap(Iterator &lhs, Iterator &rhs) noexcept;
private:
addrinfo *ptr_{nullptr};
};
AddrInfo(const std::string &addr, uint16_t port);
explicit AddrInfo(const Endpoint &endpoint);
auto begin() const noexcept { return Iterator(info_.get()); }
auto end() const noexcept { return Iterator{nullptr}; }
private:
struct addrinfo *info;
std::unique_ptr<addrinfo, void (*)(addrinfo *)> info_;
};
} // namespace memgraph::io::network

View File

@ -9,39 +9,25 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "io/network/socket.hpp"
#include <cstdio>
#include <cstring>
#include <iostream>
#include <stdexcept>
#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include "io/network/addrinfo.hpp"
#include "io/network/socket.hpp"
#include "utils/likely.hpp"
#include "utils/logging.hpp"
namespace memgraph::io::network {
Socket::Socket(Socket &&other) {
socket_ = other.socket_;
endpoint_ = std::move(other.endpoint_);
Socket::Socket(Socket &&other) noexcept : socket_(other.socket_), endpoint_(std::move(other.endpoint_)) {
other.socket_ = -1;
}
Socket &Socket::operator=(Socket &&other) {
Socket &Socket::operator=(Socket &&other) noexcept {
if (this != &other) {
if (socket_ != -1) close(socket_);
socket_ = other.socket_;
endpoint_ = std::move(other.endpoint_);
other.socket_ = -1;
@ -49,9 +35,8 @@ Socket &Socket::operator=(Socket &&other) {
return *this;
}
Socket::~Socket() {
if (socket_ == -1) return;
close(socket_);
Socket::~Socket() noexcept {
if (socket_ != -1) close(socket_);
}
void Socket::Close() {
@ -70,33 +55,27 @@ bool Socket::IsOpen() const { return socket_ != -1; }
bool Socket::Connect(const Endpoint &endpoint) {
if (socket_ != -1) return false;
auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
for (const auto &it : AddrInfo{endpoint}) {
int sfd = socket(it.ai_family, it.ai_socktype, it.ai_protocol);
if (sfd == -1) continue;
if (connect(sfd, it->ai_addr, it->ai_addrlen) == 0) {
if (connect(sfd, it.ai_addr, it.ai_addrlen) == 0) {
socket_ = sfd;
endpoint_ = endpoint;
break;
} else {
// If the connect failed close the file descriptor to prevent file
// descriptors being leaked
close(sfd);
}
// If the connect failed close the file descriptor to prevent file
// descriptors being leaked
close(sfd);
}
if (socket_ == -1) return false;
return true;
return !(socket_ == -1);
}
bool Socket::Bind(const Endpoint &endpoint) {
if (socket_ != -1) return false;
auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
for (const auto &it : AddrInfo{endpoint}) {
int sfd = socket(it.ai_family, it.ai_socktype, it.ai_protocol);
if (sfd == -1) continue;
int on = 1;
@ -107,14 +86,13 @@ bool Socket::Bind(const Endpoint &endpoint) {
continue;
}
if (bind(sfd, it->ai_addr, it->ai_addrlen) == 0) {
if (bind(sfd, it.ai_addr, it.ai_addrlen) == 0) {
socket_ = sfd;
break;
} else {
// If the bind failed close the file descriptor to prevent file
// descriptors being leaked
close(sfd);
}
// If the bind failed close the file descriptor to prevent file
// descriptors being leaked
close(sfd);
}
if (socket_ == -1) return false;
@ -122,7 +100,7 @@ bool Socket::Bind(const Endpoint &endpoint) {
// detect bound port, used when the server binds to a random port
struct sockaddr_in6 portdata;
socklen_t portdatalen = sizeof(portdata);
if (getsockname(socket_, (struct sockaddr *)&portdata, &portdatalen) < 0) {
if (getsockname(socket_, reinterpret_cast<sockaddr *>(&portdata), &portdatalen) < 0) {
// If the getsockname failed close the file descriptor to prevent file
// descriptors being leaked
close(socket_);
@ -136,36 +114,35 @@ bool Socket::Bind(const Endpoint &endpoint) {
}
void Socket::SetNonBlocking() {
int flags = fcntl(socket_, F_GETFL, 0);
const unsigned flags = fcntl(socket_, F_GETFL);
constexpr unsigned o_nonblock = O_NONBLOCK;
MG_ASSERT(flags != -1, "Can't get socket mode");
flags |= O_NONBLOCK;
MG_ASSERT(fcntl(socket_, F_SETFL, flags) != -1, "Can't set socket nonblocking");
MG_ASSERT(fcntl(socket_, F_SETFL, flags | o_nonblock) != -1, "Can't set socket nonblocking");
}
void Socket::SetKeepAlive() {
int optval = 1;
socklen_t optlen = sizeof(optval);
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen), "Can't set socket keep alive");
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)), "Can't set socket keep alive");
optval = 20; // wait 20s before sending keep-alive packets
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen), "Can't set socket keep alive");
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, sizeof(optval)),
"Can't set socket keep alive");
optval = 4; // 4 keep-alive packets must fail to close
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen), "Can't set socket keep alive");
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, sizeof(optval)), "Can't set socket keep alive");
optval = 15; // send keep-alive packets every 15s
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen), "Can't set socket keep alive");
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, sizeof(optval)),
"Can't set socket keep alive");
}
void Socket::SetNoDelay() {
int optval = 1;
socklen_t optlen = sizeof(optval);
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen), "Can't set socket no delay");
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, sizeof(optval)), "Can't set socket no delay");
}
void Socket::SetTimeout(long sec, long usec) {
// NOLINTNEXTLINE(readability-make-member-function-const)
void Socket::SetTimeout(int64_t sec, int64_t usec) {
struct timeval tv;
tv.tv_sec = sec;
tv.tv_usec = usec;
@ -176,7 +153,7 @@ void Socket::SetTimeout(long sec, long usec) {
}
int Socket::ErrorStatus() const {
int optval;
int optval = 0;
socklen_t optlen = sizeof(optval);
auto status = getsockopt(socket_, SOL_SOCKET, SO_ERROR, &optval, &optlen);
MG_ASSERT(!status, "getsockopt failed");
@ -189,21 +166,22 @@ std::optional<Socket> Socket::Accept() {
sockaddr_storage addr;
socklen_t addr_size = sizeof addr;
char addr_decoded[INET6_ADDRSTRLEN];
void *addr_src;
unsigned short port;
int sfd = accept(socket_, (struct sockaddr *)&addr, &addr_size);
int sfd = accept(socket_, reinterpret_cast<sockaddr *>(&addr), &addr_size);
if (UNLIKELY(sfd == -1)) return std::nullopt;
void *addr_src = nullptr;
uint16_t port = 0;
if (addr.ss_family == AF_INET) {
addr_src = (void *)&(((sockaddr_in *)&addr)->sin_addr);
port = ntohs(((sockaddr_in *)&addr)->sin_port);
addr_src = &reinterpret_cast<sockaddr_in &>(addr).sin_addr;
port = ntohs(reinterpret_cast<sockaddr_in &>(addr).sin_port);
} else {
addr_src = (void *)&(((sockaddr_in6 *)&addr)->sin6_addr);
port = ntohs(((sockaddr_in6 *)&addr)->sin6_port);
addr_src = &reinterpret_cast<sockaddr_in6 &>(addr).sin6_addr;
port = ntohs(reinterpret_cast<sockaddr_in6 &>(addr).sin6_port);
}
inet_ntop(addr.ss_family, addr_src, addr_decoded, INET6_ADDRSTRLEN);
inet_ntop(addr.ss_family, addr_src, addr_decoded, sizeof(addr_decoded));
Endpoint endpoint(addr_decoded, port);
@ -213,9 +191,11 @@ std::optional<Socket> Socket::Accept() {
bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
// MSG_NOSIGNAL is here to disable raising a SIGPIPE signal when a
// connection dies mid-write, the socket will only return an EPIPE error.
int flags = MSG_NOSIGNAL | (have_more ? MSG_MORE : 0);
constexpr unsigned msg_nosignal = MSG_NOSIGNAL;
constexpr unsigned msg_more = MSG_MORE;
const unsigned flags = msg_nosignal | (have_more ? msg_more : 0);
while (len > 0) {
auto written = send(socket_, data, len, flags);
auto written = send(socket_, data, len, static_cast<int>(flags));
if (written == -1) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
// Terminal error, return failure.
@ -253,7 +233,8 @@ bool Socket::WaitForReadyRead() {
// event occurs.
int ret = poll(&p, 1, -1);
if (ret < 1) return false;
return p.revents & POLLIN;
constexpr unsigned pollin = POLLIN;
return static_cast<unsigned>(p.revents) & pollin;
}
bool Socket::WaitForReadyWrite() {
@ -265,7 +246,8 @@ bool Socket::WaitForReadyWrite() {
// event occurs.
int ret = poll(&p, 1, -1);
if (ret < 1) return false;
return p.revents & POLLOUT;
constexpr unsigned pollout = POLLOUT;
return static_cast<unsigned>(p.revents) & pollout;
}
} // namespace memgraph::io::network

View File

@ -27,12 +27,12 @@ namespace memgraph::io::network {
*/
class Socket {
public:
Socket() = default;
Socket() noexcept = default;
Socket(const Socket &) = delete;
Socket &operator=(const Socket &) = delete;
Socket(Socket &&);
Socket &operator=(Socket &&);
~Socket();
Socket(Socket &&) noexcept;
Socket &operator=(Socket &&) noexcept;
~Socket() noexcept;
/**
* Closes the socket if it is open.
@ -118,7 +118,7 @@ class Socket {
* @param sec timeout seconds value
* @param usec timeout microseconds value
*/
void SetTimeout(long sec, long usec);
void SetTimeout(int64_t sec, int64_t usec);
/**
* Checks if there are any errors on a socket. Returns 0 if there are none.