From c8dbaf59797c7585bf9e7f2e46464625da62cfd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sini=C5=A1a=20=C5=A0u=C5=A1njar?= <sinisa-susnjar@users.noreply.github.com> Date: Fri, 8 Apr 2022 13:38:13 +0100 Subject: [PATCH] Small io network socket fixes (#360) * Modernize AddrInfo * Modernize Socket --- src/io/network/addrinfo.cpp | 56 +++++++++++------ src/io/network/addrinfo.hpp | 42 +++++++++++-- src/io/network/socket.cpp | 118 +++++++++++++++--------------------- src/io/network/socket.hpp | 10 +-- 4 files changed, 128 insertions(+), 98 deletions(-) diff --git a/src/io/network/addrinfo.cpp b/src/io/network/addrinfo.cpp index 7b8ce9e9a..1623b834d 100644 --- a/src/io/network/addrinfo.cpp +++ b/src/io/network/addrinfo.cpp @@ -9,34 +9,52 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include <netdb.h> -#include <cstring> - #include "io/network/addrinfo.hpp" +#include <concepts> +#include <iterator> + #include "io/network/network_error.hpp" namespace memgraph::io::network { -AddrInfo::AddrInfo(struct addrinfo *info) : info(info) {} +static_assert(std::forward_iterator<AddrInfo::Iterator> && std::equality_comparable<AddrInfo::Iterator>); -AddrInfo::~AddrInfo() { freeaddrinfo(info); } - -AddrInfo AddrInfo::Get(const char *addr, const char *port) { - struct addrinfo hints; - memset(&hints, 0, sizeof(struct addrinfo)); - - hints.ai_family = AF_UNSPEC; // IPv4 and IPv6 - hints.ai_socktype = SOCK_STREAM; // TCP socket - hints.ai_flags = AI_PASSIVE; - - struct addrinfo *result; - auto status = getaddrinfo(addr, port, &hints, &result); +AddrInfo::AddrInfo(const Endpoint &endpoint) : AddrInfo(endpoint.address, endpoint.port) {} +AddrInfo::AddrInfo(const std::string &addr, uint16_t port) : info_{nullptr, nullptr} { + addrinfo hints{ + .ai_flags = AI_PASSIVE, + .ai_family = AF_UNSPEC, // IPv4 and IPv6 + .ai_socktype = SOCK_STREAM // TCP socket + }; + addrinfo *info = nullptr; + auto status = getaddrinfo(addr.c_str(), std::to_string(port).c_str(), &hints, &info); if (status != 0) throw NetworkError(gai_strerror(status)); - - return AddrInfo(result); + info_ = std::unique_ptr<addrinfo, decltype(&freeaddrinfo)>(info, &freeaddrinfo); } -AddrInfo::operator struct addrinfo *() { return info; } +AddrInfo::Iterator::Iterator(addrinfo *p) noexcept : ptr_(p) {} + +AddrInfo::Iterator::reference AddrInfo::Iterator::operator*() const noexcept { return *ptr_; } + +AddrInfo::Iterator::pointer AddrInfo::Iterator::operator->() const noexcept { return ptr_; } + +// NOLINTNEXTLINE(cert-dcl21-cpp) +AddrInfo::Iterator AddrInfo::Iterator::operator++(int) noexcept { + auto it = *this; + ++(*this); + return it; +} +AddrInfo::Iterator &AddrInfo::Iterator::operator++() noexcept { + ptr_ = ptr_->ai_next; + return *this; +} + +bool operator==(const AddrInfo::Iterator &lhs, const AddrInfo::Iterator &rhs) noexcept { return lhs.ptr_ == rhs.ptr_; }; + +bool operator!=(const AddrInfo::Iterator &lhs, const AddrInfo::Iterator &rhs) noexcept { return !(lhs == rhs); }; + +void swap(AddrInfo::Iterator &lhs, AddrInfo::Iterator &rhs) noexcept { std::swap(lhs.ptr_, rhs.ptr_); }; + } // namespace memgraph::io::network diff --git a/src/io/network/addrinfo.hpp b/src/io/network/addrinfo.hpp index 4828eea8b..8617c8697 100644 --- a/src/io/network/addrinfo.hpp +++ b/src/io/network/addrinfo.hpp @@ -11,6 +11,14 @@ #pragma once +#include <netdb.h> + +#include <iterator> +#include <memory> +#include <string> + +#include "io/network/endpoint.hpp" + namespace memgraph::io::network { /** @@ -18,16 +26,38 @@ namespace memgraph::io::network { * see: man 3 getaddrinfo */ class AddrInfo { - explicit AddrInfo(struct addrinfo *info); - public: - ~AddrInfo(); + struct Iterator { + using iterator_category = std::forward_iterator_tag; + using value_type = addrinfo; + using difference_type = std::ptrdiff_t; + using pointer = addrinfo *; + using reference = addrinfo &; - static AddrInfo Get(const char *addr, const char *port); + Iterator() = default; + Iterator(const Iterator &) = default; + explicit Iterator(addrinfo *p) noexcept; + Iterator &operator=(const Iterator &) = default; + reference operator*() const noexcept; + pointer operator->() const noexcept; + Iterator operator++(int) noexcept; + Iterator &operator++() noexcept; - operator struct addrinfo *(); + friend bool operator==(const Iterator &lhs, const Iterator &rhs) noexcept; + friend bool operator!=(const Iterator &lhs, const Iterator &rhs) noexcept; + friend void swap(Iterator &lhs, Iterator &rhs) noexcept; + + private: + addrinfo *ptr_{nullptr}; + }; + + AddrInfo(const std::string &addr, uint16_t port); + explicit AddrInfo(const Endpoint &endpoint); + + auto begin() const noexcept { return Iterator(info_.get()); } + auto end() const noexcept { return Iterator{nullptr}; } private: - struct addrinfo *info; + std::unique_ptr<addrinfo, void (*)(addrinfo *)> info_; }; } // namespace memgraph::io::network diff --git a/src/io/network/socket.cpp b/src/io/network/socket.cpp index 4bc79da2c..df091c124 100644 --- a/src/io/network/socket.cpp +++ b/src/io/network/socket.cpp @@ -9,39 +9,25 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include "io/network/socket.hpp" - -#include <cstdio> -#include <cstring> -#include <iostream> -#include <stdexcept> - #include <arpa/inet.h> -#include <errno.h> #include <fcntl.h> -#include <netdb.h> -#include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> -#include <sys/epoll.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> #include "io/network/addrinfo.hpp" +#include "io/network/socket.hpp" #include "utils/likely.hpp" #include "utils/logging.hpp" namespace memgraph::io::network { -Socket::Socket(Socket &&other) { - socket_ = other.socket_; - endpoint_ = std::move(other.endpoint_); +Socket::Socket(Socket &&other) noexcept : socket_(other.socket_), endpoint_(std::move(other.endpoint_)) { other.socket_ = -1; } -Socket &Socket::operator=(Socket &&other) { +Socket &Socket::operator=(Socket &&other) noexcept { if (this != &other) { + if (socket_ != -1) close(socket_); socket_ = other.socket_; endpoint_ = std::move(other.endpoint_); other.socket_ = -1; @@ -49,9 +35,8 @@ Socket &Socket::operator=(Socket &&other) { return *this; } -Socket::~Socket() { - if (socket_ == -1) return; - close(socket_); +Socket::~Socket() noexcept { + if (socket_ != -1) close(socket_); } void Socket::Close() { @@ -70,33 +55,27 @@ bool Socket::IsOpen() const { return socket_ != -1; } bool Socket::Connect(const Endpoint &endpoint) { if (socket_ != -1) return false; - auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str()); - - for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) { - int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol); + for (const auto &it : AddrInfo{endpoint}) { + int sfd = socket(it.ai_family, it.ai_socktype, it.ai_protocol); if (sfd == -1) continue; - if (connect(sfd, it->ai_addr, it->ai_addrlen) == 0) { + if (connect(sfd, it.ai_addr, it.ai_addrlen) == 0) { socket_ = sfd; endpoint_ = endpoint; break; - } else { - // If the connect failed close the file descriptor to prevent file - // descriptors being leaked - close(sfd); } + // If the connect failed close the file descriptor to prevent file + // descriptors being leaked + close(sfd); } - if (socket_ == -1) return false; - return true; + return !(socket_ == -1); } bool Socket::Bind(const Endpoint &endpoint) { if (socket_ != -1) return false; - auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str()); - - for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) { - int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol); + for (const auto &it : AddrInfo{endpoint}) { + int sfd = socket(it.ai_family, it.ai_socktype, it.ai_protocol); if (sfd == -1) continue; int on = 1; @@ -107,14 +86,13 @@ bool Socket::Bind(const Endpoint &endpoint) { continue; } - if (bind(sfd, it->ai_addr, it->ai_addrlen) == 0) { + if (bind(sfd, it.ai_addr, it.ai_addrlen) == 0) { socket_ = sfd; break; - } else { - // If the bind failed close the file descriptor to prevent file - // descriptors being leaked - close(sfd); } + // If the bind failed close the file descriptor to prevent file + // descriptors being leaked + close(sfd); } if (socket_ == -1) return false; @@ -122,7 +100,7 @@ bool Socket::Bind(const Endpoint &endpoint) { // detect bound port, used when the server binds to a random port struct sockaddr_in6 portdata; socklen_t portdatalen = sizeof(portdata); - if (getsockname(socket_, (struct sockaddr *)&portdata, &portdatalen) < 0) { + if (getsockname(socket_, reinterpret_cast<sockaddr *>(&portdata), &portdatalen) < 0) { // If the getsockname failed close the file descriptor to prevent file // descriptors being leaked close(socket_); @@ -136,36 +114,35 @@ bool Socket::Bind(const Endpoint &endpoint) { } void Socket::SetNonBlocking() { - int flags = fcntl(socket_, F_GETFL, 0); + const unsigned flags = fcntl(socket_, F_GETFL); + constexpr unsigned o_nonblock = O_NONBLOCK; MG_ASSERT(flags != -1, "Can't get socket mode"); - flags |= O_NONBLOCK; - MG_ASSERT(fcntl(socket_, F_SETFL, flags) != -1, "Can't set socket nonblocking"); + MG_ASSERT(fcntl(socket_, F_SETFL, flags | o_nonblock) != -1, "Can't set socket nonblocking"); } void Socket::SetKeepAlive() { int optval = 1; - socklen_t optlen = sizeof(optval); - - MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen), "Can't set socket keep alive"); + MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)), "Can't set socket keep alive"); optval = 20; // wait 20s before sending keep-alive packets - MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen), "Can't set socket keep alive"); + MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, sizeof(optval)), + "Can't set socket keep alive"); optval = 4; // 4 keep-alive packets must fail to close - MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen), "Can't set socket keep alive"); + MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, sizeof(optval)), "Can't set socket keep alive"); optval = 15; // send keep-alive packets every 15s - MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen), "Can't set socket keep alive"); + MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, sizeof(optval)), + "Can't set socket keep alive"); } void Socket::SetNoDelay() { int optval = 1; - socklen_t optlen = sizeof(optval); - - MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen), "Can't set socket no delay"); + MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, sizeof(optval)), "Can't set socket no delay"); } -void Socket::SetTimeout(long sec, long usec) { +// NOLINTNEXTLINE(readability-make-member-function-const) +void Socket::SetTimeout(int64_t sec, int64_t usec) { struct timeval tv; tv.tv_sec = sec; tv.tv_usec = usec; @@ -176,7 +153,7 @@ void Socket::SetTimeout(long sec, long usec) { } int Socket::ErrorStatus() const { - int optval; + int optval = 0; socklen_t optlen = sizeof(optval); auto status = getsockopt(socket_, SOL_SOCKET, SO_ERROR, &optval, &optlen); MG_ASSERT(!status, "getsockopt failed"); @@ -189,21 +166,22 @@ std::optional<Socket> Socket::Accept() { sockaddr_storage addr; socklen_t addr_size = sizeof addr; char addr_decoded[INET6_ADDRSTRLEN]; - void *addr_src; - unsigned short port; - int sfd = accept(socket_, (struct sockaddr *)&addr, &addr_size); + int sfd = accept(socket_, reinterpret_cast<sockaddr *>(&addr), &addr_size); if (UNLIKELY(sfd == -1)) return std::nullopt; + void *addr_src = nullptr; + uint16_t port = 0; + if (addr.ss_family == AF_INET) { - addr_src = (void *)&(((sockaddr_in *)&addr)->sin_addr); - port = ntohs(((sockaddr_in *)&addr)->sin_port); + addr_src = &reinterpret_cast<sockaddr_in &>(addr).sin_addr; + port = ntohs(reinterpret_cast<sockaddr_in &>(addr).sin_port); } else { - addr_src = (void *)&(((sockaddr_in6 *)&addr)->sin6_addr); - port = ntohs(((sockaddr_in6 *)&addr)->sin6_port); + addr_src = &reinterpret_cast<sockaddr_in6 &>(addr).sin6_addr; + port = ntohs(reinterpret_cast<sockaddr_in6 &>(addr).sin6_port); } - inet_ntop(addr.ss_family, addr_src, addr_decoded, INET6_ADDRSTRLEN); + inet_ntop(addr.ss_family, addr_src, addr_decoded, sizeof(addr_decoded)); Endpoint endpoint(addr_decoded, port); @@ -213,9 +191,11 @@ std::optional<Socket> Socket::Accept() { bool Socket::Write(const uint8_t *data, size_t len, bool have_more) { // MSG_NOSIGNAL is here to disable raising a SIGPIPE signal when a // connection dies mid-write, the socket will only return an EPIPE error. - int flags = MSG_NOSIGNAL | (have_more ? MSG_MORE : 0); + constexpr unsigned msg_nosignal = MSG_NOSIGNAL; + constexpr unsigned msg_more = MSG_MORE; + const unsigned flags = msg_nosignal | (have_more ? msg_more : 0); while (len > 0) { - auto written = send(socket_, data, len, flags); + auto written = send(socket_, data, len, static_cast<int>(flags)); if (written == -1) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { // Terminal error, return failure. @@ -253,7 +233,8 @@ bool Socket::WaitForReadyRead() { // event occurs. int ret = poll(&p, 1, -1); if (ret < 1) return false; - return p.revents & POLLIN; + constexpr unsigned pollin = POLLIN; + return static_cast<unsigned>(p.revents) & pollin; } bool Socket::WaitForReadyWrite() { @@ -265,7 +246,8 @@ bool Socket::WaitForReadyWrite() { // event occurs. int ret = poll(&p, 1, -1); if (ret < 1) return false; - return p.revents & POLLOUT; + constexpr unsigned pollout = POLLOUT; + return static_cast<unsigned>(p.revents) & pollout; } } // namespace memgraph::io::network diff --git a/src/io/network/socket.hpp b/src/io/network/socket.hpp index faa127399..bc6a98bd3 100644 --- a/src/io/network/socket.hpp +++ b/src/io/network/socket.hpp @@ -27,12 +27,12 @@ namespace memgraph::io::network { */ class Socket { public: - Socket() = default; + Socket() noexcept = default; Socket(const Socket &) = delete; Socket &operator=(const Socket &) = delete; - Socket(Socket &&); - Socket &operator=(Socket &&); - ~Socket(); + Socket(Socket &&) noexcept; + Socket &operator=(Socket &&) noexcept; + ~Socket() noexcept; /** * Closes the socket if it is open. @@ -118,7 +118,7 @@ class Socket { * @param sec timeout seconds value * @param usec timeout microseconds value */ - void SetTimeout(long sec, long usec); + void SetTimeout(int64_t sec, int64_t usec); /** * Checks if there are any errors on a socket. Returns 0 if there are none.