diff --git a/src/net.c b/src/net.c index dea9c6d..2bb6936 100644 --- a/src/net.c +++ b/src/net.c @@ -32,8 +32,13 @@ void setSocketDefaults(SOCKET fd) { #endif } -int getSocketType(int protocol) { - return protocol == IPPROTO_UDP ? SOCK_DGRAM : SOCK_STREAM; +struct addrinfo getAddrInfoHint(int protocol) { + return (struct addrinfo) { + .ai_family = AF_UNSPEC, + .ai_protocol = protocol, + .ai_socktype = protocol == IPPROTO_UDP ? SOCK_DGRAM : SOCK_STREAM, + .ai_flags = AI_PASSIVE, + }; } int sameSocketAddress(struct sockaddr_storage *a, struct sockaddr_storage *b) { diff --git a/src/net.h b/src/net.h index bcb0cf2..7cf67d9 100644 --- a/src/net.h +++ b/src/net.h @@ -85,6 +85,6 @@ static inline int GetLastError(void) { #endif /* _WIN32 */ void setSocketDefaults(SOCKET fd); -int getSocketType(int protocol); +struct addrinfo getAddrInfoHint(int protocol); int sameSocketAddress(struct sockaddr_storage *a, struct sockaddr_storage *b); uint16_t getPort(struct addrinfo* ai); diff --git a/src/rinetd.c b/src/rinetd.c index 08b0672..d418eea 100644 --- a/src/rinetd.c +++ b/src/rinetd.c @@ -267,87 +267,68 @@ void addServer(char *bindAddress, char *bindPort, int bindProtocol, }; /* Make a server socket */ - struct addrinfo hints = - { - .ai_family = AF_UNSPEC, - .ai_protocol = bindProtocol, - .ai_socktype = getSocketType(bindProtocol), - .ai_flags = AI_PASSIVE, - }; - struct addrinfo *servinfo; - int ret = getaddrinfo(bindAddress, bindPort, &hints, &servinfo); + struct addrinfo hints = getAddrInfoHint(bindProtocol), *ai; + int ret = getaddrinfo(bindAddress, bindPort, &hints, &ai); if (ret != 0) { fprintf(stderr, "rinetd: getaddrinfo error: %s\n", gai_strerror(ret)); exit(1); } - for (struct addrinfo *it = servinfo; it; it = it->ai_next) { - si.fd = socket(it->ai_family, it->ai_socktype, it->ai_protocol); - if (si.fd == INVALID_SOCKET) { - syslog(LOG_ERR, "couldn't create server socket! (%m)\n"); - freeaddrinfo(servinfo); - exit(1); - } - - int tmp = 1; - setsockopt(si.fd, SOL_SOCKET, SO_REUSEADDR, (const char *)&tmp, sizeof(tmp)); - - if (bind(si.fd, it->ai_addr, it->ai_addrlen) == SOCKET_ERROR) { - syslog(LOG_ERR, "couldn't bind to address %s port %s (%m)\n", - bindAddress, bindPort); - closesocket(si.fd); - freeaddrinfo(servinfo); - exit(1); - } - - if (bindProtocol == IPPROTO_TCP) { - if (listen(si.fd, RINETD_LISTEN_BACKLOG) == SOCKET_ERROR) { - /* Warn -- don't exit. */ - syslog(LOG_ERR, "couldn't listen to address %s port %s (%m)\n", - bindAddress, bindPort); - /* XXX: check whether this is correct */ - closesocket(si.fd); - } - - /* Make socket nonblocking in TCP mode only, otherwise - we may miss some data. */ - setSocketDefaults(si.fd); - } - - si.fromAddrInfo = it; - break; + si.fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); + if (si.fd == INVALID_SOCKET) { + syslog(LOG_ERR, "couldn't create server socket! (%m)\n"); + freeaddrinfo(ai); + exit(1); } + int tmp = 1; + setsockopt(si.fd, SOL_SOCKET, SO_REUSEADDR, (const char *)&tmp, sizeof(tmp)); + + if (bind(si.fd, ai->ai_addr, ai->ai_addrlen) == SOCKET_ERROR) { + syslog(LOG_ERR, "couldn't bind to address %s port %s (%m)\n", + bindAddress, bindPort); + closesocket(si.fd); + freeaddrinfo(ai); + exit(1); + } + + if (bindProtocol == IPPROTO_TCP) { + if (listen(si.fd, RINETD_LISTEN_BACKLOG) == SOCKET_ERROR) { + /* Warn -- don't exit. */ + syslog(LOG_ERR, "couldn't listen to address %s port %s (%m)\n", + bindAddress, bindPort); + /* XXX: check whether this is correct */ + closesocket(si.fd); + } + + /* Make socket nonblocking in TCP mode only, otherwise + we may miss some data. */ + setSocketDefaults(si.fd); + } + si.fromAddrInfo = ai; + /* Resolve destination address. */ - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_protocol = connectProtocol; - hints.ai_socktype = getSocketType(connectProtocol); - hints.ai_flags = AI_PASSIVE; - ret = getaddrinfo(connectAddress, connectPort, &hints, &servinfo); + hints = getAddrInfoHint(connectProtocol); + ret = getaddrinfo(connectAddress, connectPort, &hints, &ai); if (ret != 0) { fprintf(stderr, "rinetd: getaddrinfo error: %s\n", gai_strerror(ret)); freeaddrinfo(si.fromAddrInfo); closesocket(si.fd); exit(1); } - si.toAddrInfo = servinfo; + si.toAddrInfo = ai; /* Resolve source address if applicable. */ if (sourceAddress) { - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC, - hints.ai_protocol = connectProtocol, - hints.ai_socktype = getSocketType(connectProtocol), - - ret = getaddrinfo(sourceAddress, NULL, &hints, &servinfo); + hints = getAddrInfoHint(connectProtocol); + ret = getaddrinfo(sourceAddress, NULL, &hints, &ai); if (ret != 0) { fprintf(stderr, "rinetd: getaddrinfo error: %s\n", gai_strerror(ret)); freeaddrinfo(si.fromAddrInfo); freeaddrinfo(si.toAddrInfo); exit(1); } - si.sourceAddrInfo = servinfo; + si.sourceAddrInfo = ai; } #ifndef _WIN32