Refactor code to avoid duplication.

This commit is contained in:
Sam Hocevar 2021-02-14 11:54:48 +01:00
parent 3b33f45925
commit bb0d64ed5d
3 changed files with 47 additions and 61 deletions

View File

@ -32,8 +32,13 @@ void setSocketDefaults(SOCKET fd) {
#endif
}
int getSocketType(int protocol) {
return protocol == IPPROTO_UDP ? SOCK_DGRAM : SOCK_STREAM;
struct addrinfo getAddrInfoHint(int protocol) {
return (struct addrinfo) {
.ai_family = AF_UNSPEC,
.ai_protocol = protocol,
.ai_socktype = protocol == IPPROTO_UDP ? SOCK_DGRAM : SOCK_STREAM,
.ai_flags = AI_PASSIVE,
};
}
int sameSocketAddress(struct sockaddr_storage *a, struct sockaddr_storage *b) {

View File

@ -85,6 +85,6 @@ static inline int GetLastError(void) {
#endif /* _WIN32 */
void setSocketDefaults(SOCKET fd);
int getSocketType(int protocol);
struct addrinfo getAddrInfoHint(int protocol);
int sameSocketAddress(struct sockaddr_storage *a, struct sockaddr_storage *b);
uint16_t getPort(struct addrinfo* ai);

View File

@ -267,87 +267,68 @@ void addServer(char *bindAddress, char *bindPort, int bindProtocol,
};
/* Make a server socket */
struct addrinfo hints =
{
.ai_family = AF_UNSPEC,
.ai_protocol = bindProtocol,
.ai_socktype = getSocketType(bindProtocol),
.ai_flags = AI_PASSIVE,
};
struct addrinfo *servinfo;
int ret = getaddrinfo(bindAddress, bindPort, &hints, &servinfo);
struct addrinfo hints = getAddrInfoHint(bindProtocol), *ai;
int ret = getaddrinfo(bindAddress, bindPort, &hints, &ai);
if (ret != 0) {
fprintf(stderr, "rinetd: getaddrinfo error: %s\n", gai_strerror(ret));
exit(1);
}
for (struct addrinfo *it = servinfo; it; it = it->ai_next) {
si.fd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
if (si.fd == INVALID_SOCKET) {
syslog(LOG_ERR, "couldn't create server socket! (%m)\n");
freeaddrinfo(servinfo);
exit(1);
}
int tmp = 1;
setsockopt(si.fd, SOL_SOCKET, SO_REUSEADDR, (const char *)&tmp, sizeof(tmp));
if (bind(si.fd, it->ai_addr, it->ai_addrlen) == SOCKET_ERROR) {
syslog(LOG_ERR, "couldn't bind to address %s port %s (%m)\n",
bindAddress, bindPort);
closesocket(si.fd);
freeaddrinfo(servinfo);
exit(1);
}
if (bindProtocol == IPPROTO_TCP) {
if (listen(si.fd, RINETD_LISTEN_BACKLOG) == SOCKET_ERROR) {
/* Warn -- don't exit. */
syslog(LOG_ERR, "couldn't listen to address %s port %s (%m)\n",
bindAddress, bindPort);
/* XXX: check whether this is correct */
closesocket(si.fd);
}
/* Make socket nonblocking in TCP mode only, otherwise
we may miss some data. */
setSocketDefaults(si.fd);
}
si.fromAddrInfo = it;
break;
si.fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
if (si.fd == INVALID_SOCKET) {
syslog(LOG_ERR, "couldn't create server socket! (%m)\n");
freeaddrinfo(ai);
exit(1);
}
int tmp = 1;
setsockopt(si.fd, SOL_SOCKET, SO_REUSEADDR, (const char *)&tmp, sizeof(tmp));
if (bind(si.fd, ai->ai_addr, ai->ai_addrlen) == SOCKET_ERROR) {
syslog(LOG_ERR, "couldn't bind to address %s port %s (%m)\n",
bindAddress, bindPort);
closesocket(si.fd);
freeaddrinfo(ai);
exit(1);
}
if (bindProtocol == IPPROTO_TCP) {
if (listen(si.fd, RINETD_LISTEN_BACKLOG) == SOCKET_ERROR) {
/* Warn -- don't exit. */
syslog(LOG_ERR, "couldn't listen to address %s port %s (%m)\n",
bindAddress, bindPort);
/* XXX: check whether this is correct */
closesocket(si.fd);
}
/* Make socket nonblocking in TCP mode only, otherwise
we may miss some data. */
setSocketDefaults(si.fd);
}
si.fromAddrInfo = ai;
/* Resolve destination address. */
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_protocol = connectProtocol;
hints.ai_socktype = getSocketType(connectProtocol);
hints.ai_flags = AI_PASSIVE;
ret = getaddrinfo(connectAddress, connectPort, &hints, &servinfo);
hints = getAddrInfoHint(connectProtocol);
ret = getaddrinfo(connectAddress, connectPort, &hints, &ai);
if (ret != 0) {
fprintf(stderr, "rinetd: getaddrinfo error: %s\n", gai_strerror(ret));
freeaddrinfo(si.fromAddrInfo);
closesocket(si.fd);
exit(1);
}
si.toAddrInfo = servinfo;
si.toAddrInfo = ai;
/* Resolve source address if applicable. */
if (sourceAddress) {
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC,
hints.ai_protocol = connectProtocol,
hints.ai_socktype = getSocketType(connectProtocol),
ret = getaddrinfo(sourceAddress, NULL, &hints, &servinfo);
hints = getAddrInfoHint(connectProtocol);
ret = getaddrinfo(sourceAddress, NULL, &hints, &ai);
if (ret != 0) {
fprintf(stderr, "rinetd: getaddrinfo error: %s\n", gai_strerror(ret));
freeaddrinfo(si.fromAddrInfo);
freeaddrinfo(si.toAddrInfo);
exit(1);
}
si.sourceAddrInfo = servinfo;
si.sourceAddrInfo = ai;
}
#ifndef _WIN32