From c8dbaf59797c7585bf9e7f2e46464625da62cfd6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sini=C5=A1a=20=C5=A0u=C5=A1njar?=
 <sinisa-susnjar@users.noreply.github.com>
Date: Fri, 8 Apr 2022 13:38:13 +0100
Subject: [PATCH] Small io network socket fixes (#360)

* Modernize AddrInfo

* Modernize Socket
---
 src/io/network/addrinfo.cpp |  56 +++++++++++------
 src/io/network/addrinfo.hpp |  42 +++++++++++--
 src/io/network/socket.cpp   | 118 +++++++++++++++---------------------
 src/io/network/socket.hpp   |  10 +--
 4 files changed, 128 insertions(+), 98 deletions(-)

diff --git a/src/io/network/addrinfo.cpp b/src/io/network/addrinfo.cpp
index 7b8ce9e9a..1623b834d 100644
--- a/src/io/network/addrinfo.cpp
+++ b/src/io/network/addrinfo.cpp
@@ -9,34 +9,52 @@
 // by the Apache License, Version 2.0, included in the file
 // licenses/APL.txt.
 
-#include <netdb.h>
-#include <cstring>
-
 #include "io/network/addrinfo.hpp"
 
+#include <concepts>
+#include <iterator>
+
 #include "io/network/network_error.hpp"
 
 namespace memgraph::io::network {
 
-AddrInfo::AddrInfo(struct addrinfo *info) : info(info) {}
+static_assert(std::forward_iterator<AddrInfo::Iterator> && std::equality_comparable<AddrInfo::Iterator>);
 
-AddrInfo::~AddrInfo() { freeaddrinfo(info); }
-
-AddrInfo AddrInfo::Get(const char *addr, const char *port) {
-  struct addrinfo hints;
-  memset(&hints, 0, sizeof(struct addrinfo));
-
-  hints.ai_family = AF_UNSPEC;      // IPv4 and IPv6
-  hints.ai_socktype = SOCK_STREAM;  // TCP socket
-  hints.ai_flags = AI_PASSIVE;
-
-  struct addrinfo *result;
-  auto status = getaddrinfo(addr, port, &hints, &result);
+AddrInfo::AddrInfo(const Endpoint &endpoint) : AddrInfo(endpoint.address, endpoint.port) {}
 
+AddrInfo::AddrInfo(const std::string &addr, uint16_t port) : info_{nullptr, nullptr} {
+  addrinfo hints{
+      .ai_flags = AI_PASSIVE,
+      .ai_family = AF_UNSPEC,     // IPv4 and IPv6
+      .ai_socktype = SOCK_STREAM  // TCP socket
+  };
+  addrinfo *info = nullptr;
+  auto status = getaddrinfo(addr.c_str(), std::to_string(port).c_str(), &hints, &info);
   if (status != 0) throw NetworkError(gai_strerror(status));
-
-  return AddrInfo(result);
+  info_ = std::unique_ptr<addrinfo, decltype(&freeaddrinfo)>(info, &freeaddrinfo);
 }
 
-AddrInfo::operator struct addrinfo *() { return info; }
+AddrInfo::Iterator::Iterator(addrinfo *p) noexcept : ptr_(p) {}
+
+AddrInfo::Iterator::reference AddrInfo::Iterator::operator*() const noexcept { return *ptr_; }
+
+AddrInfo::Iterator::pointer AddrInfo::Iterator::operator->() const noexcept { return ptr_; }
+
+// NOLINTNEXTLINE(cert-dcl21-cpp)
+AddrInfo::Iterator AddrInfo::Iterator::operator++(int) noexcept {
+  auto it = *this;
+  ++(*this);
+  return it;
+}
+AddrInfo::Iterator &AddrInfo::Iterator::operator++() noexcept {
+  ptr_ = ptr_->ai_next;
+  return *this;
+}
+
+bool operator==(const AddrInfo::Iterator &lhs, const AddrInfo::Iterator &rhs) noexcept { return lhs.ptr_ == rhs.ptr_; };
+
+bool operator!=(const AddrInfo::Iterator &lhs, const AddrInfo::Iterator &rhs) noexcept { return !(lhs == rhs); };
+
+void swap(AddrInfo::Iterator &lhs, AddrInfo::Iterator &rhs) noexcept { std::swap(lhs.ptr_, rhs.ptr_); };
+
 }  // namespace memgraph::io::network
diff --git a/src/io/network/addrinfo.hpp b/src/io/network/addrinfo.hpp
index 4828eea8b..8617c8697 100644
--- a/src/io/network/addrinfo.hpp
+++ b/src/io/network/addrinfo.hpp
@@ -11,6 +11,14 @@
 
 #pragma once
 
+#include <netdb.h>
+
+#include <iterator>
+#include <memory>
+#include <string>
+
+#include "io/network/endpoint.hpp"
+
 namespace memgraph::io::network {
 
 /**
@@ -18,16 +26,38 @@ namespace memgraph::io::network {
  * see: man 3 getaddrinfo
  */
 class AddrInfo {
-  explicit AddrInfo(struct addrinfo *info);
-
  public:
-  ~AddrInfo();
+  struct Iterator {
+    using iterator_category = std::forward_iterator_tag;
+    using value_type = addrinfo;
+    using difference_type = std::ptrdiff_t;
+    using pointer = addrinfo *;
+    using reference = addrinfo &;
 
-  static AddrInfo Get(const char *addr, const char *port);
+    Iterator() = default;
+    Iterator(const Iterator &) = default;
+    explicit Iterator(addrinfo *p) noexcept;
+    Iterator &operator=(const Iterator &) = default;
+    reference operator*() const noexcept;
+    pointer operator->() const noexcept;
+    Iterator operator++(int) noexcept;
+    Iterator &operator++() noexcept;
 
-  operator struct addrinfo *();
+    friend bool operator==(const Iterator &lhs, const Iterator &rhs) noexcept;
+    friend bool operator!=(const Iterator &lhs, const Iterator &rhs) noexcept;
+    friend void swap(Iterator &lhs, Iterator &rhs) noexcept;
+
+   private:
+    addrinfo *ptr_{nullptr};
+  };
+
+  AddrInfo(const std::string &addr, uint16_t port);
+  explicit AddrInfo(const Endpoint &endpoint);
+
+  auto begin() const noexcept { return Iterator(info_.get()); }
+  auto end() const noexcept { return Iterator{nullptr}; }
 
  private:
-  struct addrinfo *info;
+  std::unique_ptr<addrinfo, void (*)(addrinfo *)> info_;
 };
 }  // namespace memgraph::io::network
diff --git a/src/io/network/socket.cpp b/src/io/network/socket.cpp
index 4bc79da2c..df091c124 100644
--- a/src/io/network/socket.cpp
+++ b/src/io/network/socket.cpp
@@ -9,39 +9,25 @@
 // by the Apache License, Version 2.0, included in the file
 // licenses/APL.txt.
 
-#include "io/network/socket.hpp"
-
-#include <cstdio>
-#include <cstring>
-#include <iostream>
-#include <stdexcept>
-
 #include <arpa/inet.h>
-#include <errno.h>
 #include <fcntl.h>
-#include <netdb.h>
-#include <netinet/in.h>
 #include <netinet/tcp.h>
 #include <poll.h>
-#include <sys/epoll.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <unistd.h>
 
 #include "io/network/addrinfo.hpp"
+#include "io/network/socket.hpp"
 #include "utils/likely.hpp"
 #include "utils/logging.hpp"
 
 namespace memgraph::io::network {
 
-Socket::Socket(Socket &&other) {
-  socket_ = other.socket_;
-  endpoint_ = std::move(other.endpoint_);
+Socket::Socket(Socket &&other) noexcept : socket_(other.socket_), endpoint_(std::move(other.endpoint_)) {
   other.socket_ = -1;
 }
 
-Socket &Socket::operator=(Socket &&other) {
+Socket &Socket::operator=(Socket &&other) noexcept {
   if (this != &other) {
+    if (socket_ != -1) close(socket_);
     socket_ = other.socket_;
     endpoint_ = std::move(other.endpoint_);
     other.socket_ = -1;
@@ -49,9 +35,8 @@ Socket &Socket::operator=(Socket &&other) {
   return *this;
 }
 
-Socket::~Socket() {
-  if (socket_ == -1) return;
-  close(socket_);
+Socket::~Socket() noexcept {
+  if (socket_ != -1) close(socket_);
 }
 
 void Socket::Close() {
@@ -70,33 +55,27 @@ bool Socket::IsOpen() const { return socket_ != -1; }
 bool Socket::Connect(const Endpoint &endpoint) {
   if (socket_ != -1) return false;
 
-  auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
-
-  for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
-    int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
+  for (const auto &it : AddrInfo{endpoint}) {
+    int sfd = socket(it.ai_family, it.ai_socktype, it.ai_protocol);
     if (sfd == -1) continue;
-    if (connect(sfd, it->ai_addr, it->ai_addrlen) == 0) {
+    if (connect(sfd, it.ai_addr, it.ai_addrlen) == 0) {
       socket_ = sfd;
       endpoint_ = endpoint;
       break;
-    } else {
-      // If the connect failed close the file descriptor to prevent file
-      // descriptors being leaked
-      close(sfd);
     }
+    // If the connect failed close the file descriptor to prevent file
+    // descriptors being leaked
+    close(sfd);
   }
 
-  if (socket_ == -1) return false;
-  return true;
+  return !(socket_ == -1);
 }
 
 bool Socket::Bind(const Endpoint &endpoint) {
   if (socket_ != -1) return false;
 
-  auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
-
-  for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
-    int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
+  for (const auto &it : AddrInfo{endpoint}) {
+    int sfd = socket(it.ai_family, it.ai_socktype, it.ai_protocol);
     if (sfd == -1) continue;
 
     int on = 1;
@@ -107,14 +86,13 @@ bool Socket::Bind(const Endpoint &endpoint) {
       continue;
     }
 
-    if (bind(sfd, it->ai_addr, it->ai_addrlen) == 0) {
+    if (bind(sfd, it.ai_addr, it.ai_addrlen) == 0) {
       socket_ = sfd;
       break;
-    } else {
-      // If the bind failed close the file descriptor to prevent file
-      // descriptors being leaked
-      close(sfd);
     }
+    // If the bind failed close the file descriptor to prevent file
+    // descriptors being leaked
+    close(sfd);
   }
 
   if (socket_ == -1) return false;
@@ -122,7 +100,7 @@ bool Socket::Bind(const Endpoint &endpoint) {
   // detect bound port, used when the server binds to a random port
   struct sockaddr_in6 portdata;
   socklen_t portdatalen = sizeof(portdata);
-  if (getsockname(socket_, (struct sockaddr *)&portdata, &portdatalen) < 0) {
+  if (getsockname(socket_, reinterpret_cast<sockaddr *>(&portdata), &portdatalen) < 0) {
     // If the getsockname failed close the file descriptor to prevent file
     // descriptors being leaked
     close(socket_);
@@ -136,36 +114,35 @@ bool Socket::Bind(const Endpoint &endpoint) {
 }
 
 void Socket::SetNonBlocking() {
-  int flags = fcntl(socket_, F_GETFL, 0);
+  const unsigned flags = fcntl(socket_, F_GETFL);
+  constexpr unsigned o_nonblock = O_NONBLOCK;
   MG_ASSERT(flags != -1, "Can't get socket mode");
-  flags |= O_NONBLOCK;
-  MG_ASSERT(fcntl(socket_, F_SETFL, flags) != -1, "Can't set socket nonblocking");
+  MG_ASSERT(fcntl(socket_, F_SETFL, flags | o_nonblock) != -1, "Can't set socket nonblocking");
 }
 
 void Socket::SetKeepAlive() {
   int optval = 1;
-  socklen_t optlen = sizeof(optval);
-
-  MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen), "Can't set socket keep alive");
+  MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)), "Can't set socket keep alive");
 
   optval = 20;  // wait 20s before sending keep-alive packets
-  MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen), "Can't set socket keep alive");
+  MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, sizeof(optval)),
+            "Can't set socket keep alive");
 
   optval = 4;  // 4 keep-alive packets must fail to close
-  MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen), "Can't set socket keep alive");
+  MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, sizeof(optval)), "Can't set socket keep alive");
 
   optval = 15;  // send keep-alive packets every 15s
-  MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen), "Can't set socket keep alive");
+  MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, sizeof(optval)),
+            "Can't set socket keep alive");
 }
 
 void Socket::SetNoDelay() {
   int optval = 1;
-  socklen_t optlen = sizeof(optval);
-
-  MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen), "Can't set socket no delay");
+  MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, sizeof(optval)), "Can't set socket no delay");
 }
 
-void Socket::SetTimeout(long sec, long usec) {
+// NOLINTNEXTLINE(readability-make-member-function-const)
+void Socket::SetTimeout(int64_t sec, int64_t usec) {
   struct timeval tv;
   tv.tv_sec = sec;
   tv.tv_usec = usec;
@@ -176,7 +153,7 @@ void Socket::SetTimeout(long sec, long usec) {
 }
 
 int Socket::ErrorStatus() const {
-  int optval;
+  int optval = 0;
   socklen_t optlen = sizeof(optval);
   auto status = getsockopt(socket_, SOL_SOCKET, SO_ERROR, &optval, &optlen);
   MG_ASSERT(!status, "getsockopt failed");
@@ -189,21 +166,22 @@ std::optional<Socket> Socket::Accept() {
   sockaddr_storage addr;
   socklen_t addr_size = sizeof addr;
   char addr_decoded[INET6_ADDRSTRLEN];
-  void *addr_src;
-  unsigned short port;
 
-  int sfd = accept(socket_, (struct sockaddr *)&addr, &addr_size);
+  int sfd = accept(socket_, reinterpret_cast<sockaddr *>(&addr), &addr_size);
   if (UNLIKELY(sfd == -1)) return std::nullopt;
 
+  void *addr_src = nullptr;
+  uint16_t port = 0;
+
   if (addr.ss_family == AF_INET) {
-    addr_src = (void *)&(((sockaddr_in *)&addr)->sin_addr);
-    port = ntohs(((sockaddr_in *)&addr)->sin_port);
+    addr_src = &reinterpret_cast<sockaddr_in &>(addr).sin_addr;
+    port = ntohs(reinterpret_cast<sockaddr_in &>(addr).sin_port);
   } else {
-    addr_src = (void *)&(((sockaddr_in6 *)&addr)->sin6_addr);
-    port = ntohs(((sockaddr_in6 *)&addr)->sin6_port);
+    addr_src = &reinterpret_cast<sockaddr_in6 &>(addr).sin6_addr;
+    port = ntohs(reinterpret_cast<sockaddr_in6 &>(addr).sin6_port);
   }
 
-  inet_ntop(addr.ss_family, addr_src, addr_decoded, INET6_ADDRSTRLEN);
+  inet_ntop(addr.ss_family, addr_src, addr_decoded, sizeof(addr_decoded));
 
   Endpoint endpoint(addr_decoded, port);
 
@@ -213,9 +191,11 @@ std::optional<Socket> Socket::Accept() {
 bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
   // MSG_NOSIGNAL is here to disable raising a SIGPIPE signal when a
   // connection dies mid-write, the socket will only return an EPIPE error.
-  int flags = MSG_NOSIGNAL | (have_more ? MSG_MORE : 0);
+  constexpr unsigned msg_nosignal = MSG_NOSIGNAL;
+  constexpr unsigned msg_more = MSG_MORE;
+  const unsigned flags = msg_nosignal | (have_more ? msg_more : 0);
   while (len > 0) {
-    auto written = send(socket_, data, len, flags);
+    auto written = send(socket_, data, len, static_cast<int>(flags));
     if (written == -1) {
       if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
         // Terminal error, return failure.
@@ -253,7 +233,8 @@ bool Socket::WaitForReadyRead() {
   // event occurs.
   int ret = poll(&p, 1, -1);
   if (ret < 1) return false;
-  return p.revents & POLLIN;
+  constexpr unsigned pollin = POLLIN;
+  return static_cast<unsigned>(p.revents) & pollin;
 }
 
 bool Socket::WaitForReadyWrite() {
@@ -265,7 +246,8 @@ bool Socket::WaitForReadyWrite() {
   // event occurs.
   int ret = poll(&p, 1, -1);
   if (ret < 1) return false;
-  return p.revents & POLLOUT;
+  constexpr unsigned pollout = POLLOUT;
+  return static_cast<unsigned>(p.revents) & pollout;
 }
 
 }  // namespace memgraph::io::network
diff --git a/src/io/network/socket.hpp b/src/io/network/socket.hpp
index faa127399..bc6a98bd3 100644
--- a/src/io/network/socket.hpp
+++ b/src/io/network/socket.hpp
@@ -27,12 +27,12 @@ namespace memgraph::io::network {
  */
 class Socket {
  public:
-  Socket() = default;
+  Socket() noexcept = default;
   Socket(const Socket &) = delete;
   Socket &operator=(const Socket &) = delete;
-  Socket(Socket &&);
-  Socket &operator=(Socket &&);
-  ~Socket();
+  Socket(Socket &&) noexcept;
+  Socket &operator=(Socket &&) noexcept;
+  ~Socket() noexcept;
 
   /**
    * Closes the socket if it is open.
@@ -118,7 +118,7 @@ class Socket {
    * @param sec timeout seconds value
    * @param usec timeout microseconds value
    */
-  void SetTimeout(long sec, long usec);
+  void SetTimeout(int64_t sec, int64_t usec);
 
   /**
    * Checks if there are any errors on a socket. Returns 0 if there are none.