Assert on endpoint failures

Summary: .

Reviewers: mferencevic, florijan

Reviewed By: mferencevic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1045
This commit is contained in:
Mislav Bradac 2017-12-12 12:25:15 +01:00
parent 0d40f6a759
commit eb272f0b67
14 changed files with 54 additions and 129 deletions

View File

@ -41,12 +41,7 @@ void System::StartServer(int worker_count) {
}
// Initialize endpoint.
Endpoint endpoint;
try {
endpoint = Endpoint(address_.c_str(), port_);
} catch (io::network::NetworkEndpointException &e) {
LOG(FATAL) << e.what();
}
Endpoint endpoint(address_.c_str(), port_);
// Initialize server.
server_ = std::make_unique<ServerT>(endpoint, protocol_data_);

View File

@ -67,15 +67,8 @@ void SendMessage(const std::string &address, uint16_t port,
CHECK(message) << "Trying to send nullptr instead of message";
// Initialize endpoint.
Endpoint endpoint;
try {
endpoint = Endpoint(address.c_str(), port);
} catch (io::network::NetworkEndpointException &e) {
LOG(ERROR) << "Address {} is invalid!";
return;
}
Endpoint endpoint(address.c_str(), port);
// Initialize socket.
Socket socket;
if (!socket.Connect(endpoint)) {
LOG(INFO) << "Couldn't connect to remote address: " << address << ":"

View File

@ -98,13 +98,7 @@ bool SendLength(Socket &socket, SizeT length) {
void SendMessage(std::string address, uint16_t port, std::string reactor,
std::string channel, std::unique_ptr<Message> message) {
// Initialize endpoint.
Endpoint endpoint;
try {
endpoint = Endpoint(address.c_str(), port);
} catch (io::network::NetworkEndpointException &e) {
LOG(INFO) << "Address is invalid!";
return;
}
Endpoint endpoint(address.c_str(), port);
// Initialize socket.
Socket socket;

View File

@ -152,12 +152,7 @@ class Network {
}
// Initialize endpoint.
Endpoint endpoint;
try {
endpoint = Endpoint(FLAGS_reactor_address.c_str(), FLAGS_reactor_port);
} catch (io::network::NetworkEndpointException &e) {
LOG(FATAL) << e.what();
}
Endpoint endpoint(FLAGS_reactor_address.c_str(), FLAGS_reactor_port);
// Initialize server
server_ = std::make_unique<ServerT>(endpoint, protocol_data_);

View File

@ -49,9 +49,7 @@ class Server {
LOG(FATAL) << "Cannot bind to socket on " << endpoint.address() << " at "
<< endpoint.port();
}
if (!socket_.SetNonBlocking()) {
LOG(FATAL) << "Cannot set socket to non blocking!";
}
socket_.SetNonBlocking();
if (!socket_.Listen(1024)) {
LOG(FATAL) << "Cannot listen on socket!";
}
@ -126,9 +124,9 @@ class Server {
s->fd(), s->endpoint().address(), s->endpoint().family(),
s->endpoint().port());
if (!s->SetTimeout(1, 0)) return std::experimental::nullopt;
if (!s->SetKeepAlive()) return std::experimental::nullopt;
if (!s->SetNoDelay()) return std::experimental::nullopt;
s->SetTimeout(1, 0);
s->SetKeepAlive();
s->SetNoDelay();
return s;
}

View File

@ -1,5 +1,6 @@
#include "io/network/network_endpoint.hpp"
#include "io/network/network_error.hpp"
#include "glog/logging.h"
#include <arpa/inet.h>
#include <netdb.h>
@ -12,8 +13,7 @@ NetworkEndpoint::NetworkEndpoint() : port_(0), family_(0) {
}
NetworkEndpoint::NetworkEndpoint(const char *addr, const char *port) {
if (addr == nullptr) throw NetworkEndpointException("Address can't be null!");
if (port == nullptr) throw NetworkEndpointException("Port can't be null!");
if (!addr || !port) LOG(FATAL) << "Address or port is nullptr";
// strncpy isn't used because it does not guarantee an ending null terminator
snprintf(address_, sizeof address_, "%s", addr);
@ -24,16 +24,16 @@ NetworkEndpoint::NetworkEndpoint(const char *addr, const char *port) {
int ret = inet_pton(AF_INET, address_, &addr4);
if (ret != 1) {
ret = inet_pton(AF_INET6, address_, &addr6);
if (ret != 1)
throw NetworkEndpointException(
"Address isn't a valid IPv4 or IPv6 address!");
else
family_ = 6;
} else
if (ret != 1) {
LOG(FATAL) << "Not a valid IPv4 or IPv6 address: " << *addr;
}
family_ = 6;
} else {
family_ = 4;
}
ret = sscanf(port, "%hu", &port_);
if (ret != 1) throw NetworkEndpointException("Port isn't valid!");
if (ret != 1) LOG(FATAL) << "Not a valid port: " << *port;
}
NetworkEndpoint::NetworkEndpoint(const std::string &addr,

View File

@ -7,11 +7,6 @@
namespace io::network {
class NetworkEndpointException : public utils::BasicException {
public:
using utils::BasicException::BasicException;
};
/**
* This class represents a network endpoint that is used in Socket.
* It is used when connecting to an address and to get the current

View File

@ -16,6 +16,8 @@
#include <sys/types.h>
#include <unistd.h>
#include "glog/logging.h"
#include "io/network/addrinfo.hpp"
#include "threading/sync/cpu_relax.hpp"
#include "utils/likely.hpp"
@ -102,62 +104,51 @@ bool Socket::Bind(const NetworkEndpoint &endpoint) {
return true;
}
bool Socket::SetNonBlocking() {
void Socket::SetNonBlocking() {
int flags = fcntl(socket_, F_GETFL, 0);
if (UNLIKELY(flags == -1)) return false;
CHECK(flags != -1) << "Can't get socket mode";
flags |= O_NONBLOCK;
int ret = fcntl(socket_, F_SETFL, flags);
if (UNLIKELY(ret == -1)) return false;
return true;
CHECK(fcntl(socket_, F_SETFL, flags) != -1) << "Can't set socket nonblocking";
}
bool Socket::SetKeepAlive() {
void Socket::SetKeepAlive() {
int optval = 1;
socklen_t optlen = sizeof(optval);
if (setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen) < 0)
return false;
CHECK(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen))
<< "Can't set socket keep alive";
optval = 20; // wait 120s before seding keep-alive packets
if (setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen) < 0)
return false;
CHECK(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen))
<< "Can't set socket keep alive";
optval = 4; // 4 keep-alive packets must fail to close
if (setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen) < 0)
return false;
CHECK(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen))
<< "Can't set socket keep alive";
optval = 15; // send keep-alive packets every 15s
if (setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen) < 0)
return false;
return true;
CHECK(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen))
<< "Can't set socket keep alive";
}
bool Socket::SetNoDelay() {
void Socket::SetNoDelay() {
int optval = 1;
socklen_t optlen = sizeof(optval);
if (setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen) < 0)
return false;
return true;
CHECK(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen))
<< "Can't set socket no delay";
}
bool Socket::SetTimeout(long sec, long usec) {
void Socket::SetTimeout(long sec, long usec) {
struct timeval tv;
tv.tv_sec = sec;
tv.tv_usec = usec;
if (setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
return false;
CHECK(!setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
<< "Can't set socket timeout";
if (setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) < 0)
return false;
return true;
CHECK(!setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)))
<< "Can't set socket timeout";
}
bool Socket::Listen(int backlog) { return listen(socket_, backlog) == 0; }
@ -182,12 +173,7 @@ std::experimental::optional<Socket> Socket::Accept() {
inet_ntop(addr.ss_family, addr_src, addr_decoded, INET6_ADDRSTRLEN);
NetworkEndpoint endpoint;
try {
endpoint = NetworkEndpoint(addr_decoded, port);
} catch (NetworkEndpointException &e) {
return std::experimental::nullopt;
}
NetworkEndpoint endpoint(addr_decoded, port);
return Socket(sfd, endpoint);
}

View File

@ -81,44 +81,28 @@ class Socket {
/**
* Sets the socket to non-blocking.
*
* @return set non-blocking success status:
* true if the socket was successfully set to non-blocking
* false if the socket was not set to non-blocking
*/
bool SetNonBlocking();
void SetNonBlocking();
/**
* Enables TCP keep-alive on the socket.
*
* @return enable keep-alive success status:
* true if keep-alive was successfully enabled on the socket
* false if keep-alive was not enabled
*/
bool SetKeepAlive();
void SetKeepAlive();
/**
* Enables TCP no_delay on the socket.
* When enabled, the socket doesn't wait for an ACK of every data packet
* before sending the next packet.
*
* @return enable no_delay success status:
* true if no_delay was successfully enabled on the socket
* false if no_delay was not enabled
*/
bool SetNoDelay();
void SetNoDelay();
/**
* Sets the socket timeout.
*
* @param sec timeout seconds value
* @param usec timeout microseconds value
* @return set socket timeout status:
* true if the timeout was successfully set to
* sec seconds + usec microseconds
* false if the timeout was not set
*/
bool SetTimeout(long sec, long usec);
void SetTimeout(long sec, long usec);
/**
* Returns the socket file descriptor.

View File

@ -69,13 +69,7 @@ int main(int argc, char **argv) {
SessionData session_data;
// Initialize endpoint.
NetworkEndpoint endpoint = [&] {
try {
return NetworkEndpoint(FLAGS_interface, FLAGS_port);
} catch (io::network::NetworkEndpointException &e) {
LOG(FATAL) << e.what();
}
}();
NetworkEndpoint endpoint(FLAGS_interface, FLAGS_port);
// Initialize server.
ServerT server(endpoint, session_data);

View File

@ -75,7 +75,7 @@ void client_run(int num, const char *interface, const char *port,
NetworkEndpoint endpoint(interface, port);
Socket socket;
ASSERT_TRUE(socket.Connect(endpoint));
ASSERT_TRUE(socket.SetTimeout(2, 0));
socket.SetTimeout(2, 0);
DLOG(INFO) << "Socket create: " << socket.fd();
for (int len = lo; len <= hi; len += 100) {
have = 0;

View File

@ -58,7 +58,7 @@ void client_run(int num, const char *interface, const char *port) {
Socket socket;
uint8_t data = 0x00;
ASSERT_TRUE(socket.Connect(endpoint));
ASSERT_TRUE(socket.SetTimeout(1, 0));
socket.SetTimeout(1, 0);
// set socket timeout to 1s
ASSERT_TRUE(socket.Write((uint8_t *)"\xAA", 1));
ASSERT_TRUE(socket.Read(&data, 1));

View File

@ -20,13 +20,8 @@ class BoltClient {
const std::string &username, const std::string &password,
const std::string & = "") {
SocketT socket;
EndpointT endpoint;
EndpointT endpoint(address, port);
try {
endpoint = EndpointT(address, port);
} catch (const io::network::NetworkEndpointException &e) {
LOG(FATAL) << "Invalid address or port: " << address << ":" << port;
}
if (!socket.Connect(endpoint)) {
LOG(FATAL) << "Could not connect to: " << address << ":" << port;
}

View File

@ -6,7 +6,6 @@
#include "io/network/network_error.hpp"
using endpoint_t = io::network::NetworkEndpoint;
using exception_t = io::network::NetworkEndpointException;
TEST(NetworkEndpoint, IPv4) {
endpoint_t endpoint;
@ -34,13 +33,13 @@ TEST(NetworkEndpoint, IPv4) {
EXPECT_EQ(endpoint.family(), 4);
// test address null
EXPECT_THROW(endpoint_t(nullptr, nullptr), exception_t);
EXPECT_DEATH(endpoint_t(nullptr, nullptr), "null");
// test address invalid
EXPECT_THROW(endpoint_t("invalid", "12345"), exception_t);
EXPECT_DEATH(endpoint_t("invalid", "12345"), "addres");
// test port invalid
EXPECT_THROW(endpoint_t("127.0.0.1", "invalid"), exception_t);
EXPECT_DEATH(endpoint_t("127.0.0.1", "invalid"), "port");
}
TEST(NetworkEndpoint, IPv6) {
@ -68,14 +67,11 @@ TEST(NetworkEndpoint, IPv6) {
EXPECT_EQ(endpoint.port(), 12347);
EXPECT_EQ(endpoint.family(), 6);
// test address null
EXPECT_THROW(endpoint_t(nullptr, nullptr), exception_t);
// test address invalid
EXPECT_THROW(endpoint_t("::g", "12345"), exception_t);
EXPECT_DEATH(endpoint_t("::g", "12345"), "address");
// test port invalid
EXPECT_THROW(endpoint_t("::1", "invalid"), exception_t);
EXPECT_DEATH(endpoint_t("::1", "invalid"), "port");
}
int main(int argc, char **argv) {