Support multiple UDP clients with connection tracking.

If an incoming UDP packet shares the same host and source port as
an existing connection, we consider it part of the input stream.
This commit is contained in:
Sam Hocevar 2017-09-07 16:29:23 +02:00
parent b5206eab4e
commit 6bb27c18a6
2 changed files with 47 additions and 22 deletions

View File

@ -338,7 +338,8 @@ static void setConnectionCount(int newCount)
closesocket(coInfo[i].local.fd); closesocket(coInfo[i].local.fd);
} }
if (coInfo[i].remote.fd != INVALID_SOCKET) { if (coInfo[i].remote.fd != INVALID_SOCKET) {
closesocket(coInfo[i].remote.fd); if (coInfo[i].remote.proto == protoTcp)
closesocket(coInfo[i].remote.fd);
} }
free(coInfo[i].local.buffer); free(coInfo[i].local.buffer);
} }
@ -453,8 +454,11 @@ static void selectPass(void)
for (int i = 0; i < coTotal; ++i) { for (int i = 0; i < coTotal; ++i) {
ConnectionInfo *cnx = &coInfo[i]; ConnectionInfo *cnx = &coInfo[i];
if (cnx->remote.fd != INVALID_SOCKET) { if (cnx->remote.fd != INVALID_SOCKET) {
if (FD_ISSET_EXT(cnx->remote.fd, readfds)) { /* Do not read on remote UDP sockets, the server does it. */
handleRead(cnx, &cnx->remote, &cnx->local); if (cnx->remote.proto != protoUdp) {
if (FD_ISSET_EXT(cnx->remote.fd, readfds)) {
handleRead(cnx, &cnx->remote, &cnx->local);
}
} }
} }
if (cnx->remote.fd != INVALID_SOCKET) { if (cnx->remote.fd != INVALID_SOCKET) {
@ -513,7 +517,8 @@ static void handleWrite(ConnectionInfo *cnx, Socket *socket, Socket *other_socke
if (cnx->coClosing && (socket->sentPos == other_socket->recvPos)) { if (cnx->coClosing && (socket->sentPos == other_socket->recvPos)) {
PERROR("rinetd: local closed and no more output"); PERROR("rinetd: local closed and no more output");
logEvent(cnx, cnx->server, cnx->coLog); logEvent(cnx, cnx->server, cnx->coLog);
closesocket(socket->fd); if (socket->proto == protoTcp)
closesocket(socket->fd);
socket->fd = INVALID_SOCKET; socket->fd = INVALID_SOCKET;
return; return;
} }
@ -564,18 +569,12 @@ static void handleClose(ConnectionInfo *cnx, Socket *socket, Socket *other_socke
static void handleAccept(ServerInfo const *srv) static void handleAccept(ServerInfo const *srv)
{ {
ConnectionInfo *cnx = findAvailableConnection();
if (!cnx) {
return;
}
struct sockaddr addr; struct sockaddr addr;
SOCKLEN_T addrlen = sizeof(addr); SOCKLEN_T addrlen = sizeof(addr);
SOCKET nfd; SOCKET nfd;
/* Get remote address using accept() in TCP mode, recvfrom()
in UDP mode. */
if (srv->fromProto == protoTcp) { if (srv->fromProto == protoTcp) {
/* In TCP mode, get remote address using accept(). */
nfd = accept(srv->fd, &addr, &addrlen); nfd = accept(srv->fd, &addr, &addrlen);
if (nfd == INVALID_SOCKET) { if (nfd == INVALID_SOCKET) {
syslog(LOG_ERR, "accept(%d): %m\n", srv->fd); syslog(LOG_ERR, "accept(%d): %m\n", srv->fd);
@ -590,14 +589,39 @@ static void handleAccept(ServerInfo const *srv)
setsockopt(nfd, SOL_SOCKET, SO_LINGER, &tmp, sizeof(tmp)); setsockopt(nfd, SOL_SOCKET, SO_LINGER, &tmp, sizeof(tmp));
#endif #endif
} else /* if (srv->fromProto == protoUdp) */ { } else /* if (srv->fromProto == protoUdp) */ {
/* In UDP mode, get remote address using recvfrom() and check
for an existing connection from this client. */
nfd = srv->fd; nfd = srv->fd;
ssize_t ret = recvfrom(srv->fd, NULL, 0, MSG_PEEK, ssize_t ret = recvfrom(srv->fd, NULL, 0, MSG_PEEK,
&addr, &addrlen); &addr, &addrlen);
if (ret < 0) { if (ret < 0) {
if (GetLastError() == WSAEWOULDBLOCK) {
return;
}
if (GetLastError() == WSAEINPROGRESS) {
return;
}
syslog(LOG_ERR, "recvfrom(%d): %m\n", srv->fd); syslog(LOG_ERR, "recvfrom(%d): %m\n", srv->fd);
logEvent(NULL, srv, logAcceptFailed); logEvent(NULL, srv, logAcceptFailed);
return; return;
} }
for (int i = 0; i < coTotal; ++i) {
ConnectionInfo *cnx = &coInfo[i];
struct sockaddr_in *addr_in = (struct sockaddr_in *)&addr;
if (cnx->remote.fd == nfd
&& cnx->remoteAddress.sin_family == addr_in->sin_family
&& cnx->remoteAddress.sin_port == addr_in->sin_port
&& cnx->remoteAddress.sin_addr.s_addr == addr_in->sin_addr.s_addr) {
handleRead(cnx, &cnx->remote, &cnx->local);
return;
}
}
}
ConnectionInfo *cnx = findAvailableConnection();
if (!cnx) {
return;
} }
cnx->local.fd = INVALID_SOCKET; cnx->local.fd = INVALID_SOCKET;
@ -608,7 +632,7 @@ static void handleAccept(ServerInfo const *srv)
cnx->remote.proto = srv->fromProto; cnx->remote.proto = srv->fromProto;
cnx->remote.recvPos = cnx->remote.sentPos = 0; cnx->remote.recvPos = cnx->remote.sentPos = 0;
cnx->remote.recvBytes = cnx->remote.sentBytes = 0; cnx->remote.recvBytes = cnx->remote.sentBytes = 0;
cnx->reAddresses.s_addr = ((struct sockaddr_in *)&addr)->sin_addr.s_addr; cnx->remoteAddress = *(struct sockaddr_in *)&addr;
cnx->coClosing = 0; cnx->coClosing = 0;
cnx->coLog = logUnknownError; cnx->coLog = logUnknownError;
cnx->server = srv; cnx->server = srv;
@ -617,7 +641,8 @@ static void handleAccept(ServerInfo const *srv)
if (logCode != logAllowed) { if (logCode != logAllowed) {
/* Local fd is not open yet, so only /* Local fd is not open yet, so only
close the remote socket. */ close the remote socket. */
closesocket(cnx->remote.fd); if (cnx->remote.proto == protoTcp)
closesocket(cnx->remote.fd);
cnx->remote.fd = INVALID_SOCKET; cnx->remote.fd = INVALID_SOCKET;
logEvent(cnx, cnx->server, logCode); logEvent(cnx, cnx->server, logCode);
return; return;
@ -632,7 +657,8 @@ static void handleAccept(ServerInfo const *srv)
: socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP); : socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP);
if (cnx->local.fd == INVALID_SOCKET) { if (cnx->local.fd == INVALID_SOCKET) {
syslog(LOG_ERR, "socket(): %m\n"); syslog(LOG_ERR, "socket(): %m\n");
closesocket(cnx->remote.fd); if (cnx->remote.proto == protoTcp)
closesocket(cnx->remote.fd);
cnx->remote.fd = INVALID_SOCKET; cnx->remote.fd = INVALID_SOCKET;
logEvent(cnx, srv, logLocalSocketFailed); logEvent(cnx, srv, logLocalSocketFailed);
return; return;
@ -645,7 +671,8 @@ static void handleAccept(ServerInfo const *srv)
saddr.sin_addr.s_addr = 0; saddr.sin_addr.s_addr = 0;
if (bind(cnx->local.fd, (struct sockaddr *) &saddr, sizeof(saddr)) == SOCKET_ERROR) { if (bind(cnx->local.fd, (struct sockaddr *) &saddr, sizeof(saddr)) == SOCKET_ERROR) {
closesocket(cnx->local.fd); closesocket(cnx->local.fd);
closesocket(cnx->remote.fd); if (cnx->remote.proto == protoTcp)
closesocket(cnx->remote.fd);
cnx->remote.fd = INVALID_SOCKET; cnx->remote.fd = INVALID_SOCKET;
cnx->local.fd = INVALID_SOCKET; cnx->local.fd = INVALID_SOCKET;
logEvent(cnx, srv, logLocalBindFailed); logEvent(cnx, srv, logLocalBindFailed);
@ -677,7 +704,8 @@ static void handleAccept(ServerInfo const *srv)
{ {
PERROR("rinetd: connect"); PERROR("rinetd: connect");
closesocket(cnx->local.fd); closesocket(cnx->local.fd);
closesocket(cnx->remote.fd); if (cnx->remote.proto == protoTcp)
closesocket(cnx->remote.fd);
cnx->remote.fd = INVALID_SOCKET; cnx->remote.fd = INVALID_SOCKET;
cnx->local.fd = INVALID_SOCKET; cnx->local.fd = INVALID_SOCKET;
logEvent(cnx, srv, logLocalConnectFailed); logEvent(cnx, srv, logLocalConnectFailed);
@ -700,9 +728,7 @@ static void handleAccept(ServerInfo const *srv)
static int checkConnectionAllowed(ConnectionInfo const *cnx) static int checkConnectionAllowed(ConnectionInfo const *cnx)
{ {
ServerInfo const *srv = cnx->server; ServerInfo const *srv = cnx->server;
struct in_addr address; char const *addressText = inet_ntoa(cnx->remoteAddress.sin_addr);
address.s_addr = cnx->reAddresses.s_addr;
char const *addressText = inet_ntoa(address);
/* 1. Check global allow rules. If there are no /* 1. Check global allow rules. If there are no
global allow rules, it's presumed OK at global allow rules, it's presumed OK at
@ -890,8 +916,7 @@ static void logEvent(ConnectionInfo const *cnx, ServerInfo const *srv, int resul
int bytesOutput = 0; int bytesOutput = 0;
int bytesInput = 0; int bytesInput = 0;
if (cnx != NULL) { if (cnx != NULL) {
struct in_addr const *reAddress = &cnx->reAddresses; addressText = inet_ntoa(cnx->remoteAddress.sin_addr);
addressText = inet_ntoa(*reAddress);
bytesOutput = cnx->remote.sentBytes; bytesOutput = cnx->remote.sentBytes;
bytesInput = cnx->remote.recvBytes; bytesInput = cnx->remote.recvBytes;
} }

View File

@ -58,7 +58,7 @@ typedef struct _connection_info ConnectionInfo;
struct _connection_info struct _connection_info
{ {
Socket remote, local; Socket remote, local;
struct in_addr reAddresses; struct sockaddr_in remoteAddress;
int coClosing; int coClosing;
int coLog; int coLog;
ServerInfo const *server; // only useful for logEvent ServerInfo const *server; // only useful for logEvent