From 37437a5c5b72130a423227a79d7ebecba3b88c3c Mon Sep 17 00:00:00 2001 From: Sam Hocevar Date: Fri, 8 Sep 2017 13:28:09 +0200 Subject: [PATCH] Improve UDP connection handling. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Send a zero-sized UDP packet when first accessing the server; this makes TCP→UDP→TCP redirects work in both directions. Properly close UDP connections when the other end disconnects (TCP) or timeouts (UDP). --- configure.ac | 2 +- rinetd.c | 83 ++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/configure.ac b/configure.ac index 2a46188..b4a0f62 100644 --- a/configure.ac +++ b/configure.ac @@ -1,6 +1,6 @@ # Process this file with autoconf to produce a configure script. AC_PREREQ(2.52) -AC_INIT(rinetd, 0.63, sam@hocevar.net) +AC_INIT(rinetd, 0.63.test, sam@hocevar.net) AC_CONFIG_AUX_DIR(.auto) AC_CONFIG_SRCDIR([getopt.h]) AC_CONFIG_HEADER([config.h]) diff --git a/rinetd.c b/rinetd.c index 04d165b..167788f 100644 --- a/rinetd.c +++ b/rinetd.c @@ -67,6 +67,9 @@ int const maxfd = 0; int maxfd = 0; #endif +/* Global static buffer for UDP data. */ +static char globalUdpBuffer[65536]; + char *logFileName = NULL; char *pidLogFileName = NULL; int logFormatCommon = 0; @@ -107,6 +110,7 @@ RinetdOptions options = { static void selectPass(void); static void handleWrite(ConnectionInfo *cnx, Socket *socket, Socket *other_socket); static void handleRead(ConnectionInfo *cnx, Socket *socket, Socket *other_socket); +static void handleUdpRead(ConnectionInfo *cnx, char const *buffer, int bytes); static void handleClose(ConnectionInfo *cnx, Socket *socket, Socket *other_socket); static void handleAccept(ServerInfo const *srv); static ConnectionInfo *findAvailableConnection(void); @@ -246,7 +250,8 @@ static void readConfiguration(char const *file) { void addServer(char *bindAddress, int bindPort, int bindProto, char *connectAddress, int connectPort, int connectProto, - int serverTimeout) { + int serverTimeout) +{ /* Turn all of this stuff into reasonable addresses */ struct in_addr iaddr; if (getAddress(bindAddress, &iaddr) < 0) { @@ -271,8 +276,7 @@ void addServer(char *bindAddress, int bindPort, int bindProto, setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (const char *) &tmp, sizeof(tmp)); if (bind(fd, (struct sockaddr *) - &saddr, sizeof(saddr)) == SOCKET_ERROR) - { + &saddr, sizeof(saddr)) == SOCKET_ERROR) { /* Warn -- don't exit. */ syslog(LOG_ERR, "couldn't bind to " "address %s port %d (%m)\n", @@ -290,6 +294,8 @@ void addServer(char *bindAddress, int bindPort, int bindProto, closesocket(fd); } + /* Make socket nonblocking in TCP mode only, otherwise + we may miss some data. */ setSocketDefaults(fd); } @@ -535,6 +541,18 @@ static void handleRead(ConnectionInfo *cnx, Socket *socket, Socket *other_socket socket->recvPos += got; } +static void handleUdpRead(ConnectionInfo *cnx, char const *buffer, int bytes) +{ + Socket *socket = &cnx->remote; + int got = bytes < RINETD_BUFFER_SIZE - socket->recvPos + ? bytes : RINETD_BUFFER_SIZE - socket->recvPos; + if (got > 0) { + memcpy(socket->buffer + socket->recvPos, buffer, got); + socket->recvBytes += got; + socket->recvPos += got; + } +} + static void handleWrite(ConnectionInfo *cnx, Socket *socket, Socket *other_socket) { if (cnx->coClosing && (socket->sentPos == other_socket->recvPos)) { @@ -583,17 +601,25 @@ static void handleClose(ConnectionInfo *cnx, Socket *socket, Socket *other_socke /* Nothing to do in UDP mode */ } socket->fd = INVALID_SOCKET; - if (other_socket->fd != INVALID_SOCKET) { -#if !defined __linux__ && !defined _WIN32 - /* Now set up the other end for a polite closing */ - /* Request a low-water mark equal to the entire - output buffer, so the next write notification - tells us for sure that we can close the socket. */ - int arg = 1024; - setsockopt(other_socket->fd, SOL_SOCKET, SO_SNDLOWAT, - &arg, sizeof(arg)); + if (other_socket->fd != INVALID_SOCKET) { + if (other_socket->proto == protoTcp) { +#if !defined __linux__ && !defined _WIN32 + /* Now set up the other end for a polite closing */ + + /* Request a low-water mark equal to the entire + output buffer, so the next write notification + tells us for sure that we can close the socket. */ + int arg = 1024; + setsockopt(other_socket->fd, SOL_SOCKET, SO_SNDLOWAT, + &arg, sizeof(arg)); #endif + } else /* if (other_socket->proto == protoUdp) */ { + if (other_socket == &cnx->local) + closesocket(other_socket->fd); + other_socket->fd = INVALID_SOCKET; + } + cnx->coLog = socket == &cnx->local ? logLocalClosedFirst : logRemoteClosedFirst; } @@ -601,6 +627,8 @@ static void handleClose(ConnectionInfo *cnx, Socket *socket, Socket *other_socke static void handleAccept(ServerInfo const *srv) { + int udpBytes = 0; + struct sockaddr addr; SOCKLEN_T addrlen = sizeof(addr); @@ -617,10 +645,12 @@ static void handleAccept(ServerInfo const *srv) setSocketDefaults(nfd); } else /* if (srv->fromProto == protoUdp) */ { /* In UDP mode, get remote address using recvfrom() and check - for an existing connection from this client. */ + for an existing connection from this client. We need + to read a lot of data otherwise the datagram contents + may be lost later. */ nfd = srv->fd; - ssize_t ret = recvfrom(srv->fd, NULL, 0, MSG_PEEK, - &addr, &addrlen); + ssize_t ret = recvfrom(nfd, globalUdpBuffer, + sizeof(globalUdpBuffer), 0, &addr, &addrlen); if (ret < 0) { if (GetLastError() == WSAEWOULDBLOCK) { return; @@ -633,6 +663,8 @@ static void handleAccept(ServerInfo const *srv) return; } + udpBytes = (int)ret; + for (int i = 0; i < coTotal; ++i) { ConnectionInfo *cnx = &coInfo[i]; struct sockaddr_in *addr_in = (struct sockaddr_in *)&addr; @@ -641,7 +673,7 @@ static void handleAccept(ServerInfo const *srv) && cnx->remoteAddress.sin_port == addr_in->sin_port && cnx->remoteAddress.sin_addr.s_addr == addr_in->sin_addr.s_addr) { cnx->remoteTimeout = time(NULL) + srv->serverTimeout; - handleRead(cnx, &cnx->remote, &cnx->local); + handleUdpRead(cnx, globalUdpBuffer, udpBytes); return; } } @@ -696,6 +728,9 @@ static void handleAccept(ServerInfo const *srv) return; } + if (srv->toProto == protoTcp) + setSocketDefaults(cnx->local.fd); + #if 0 // You don't need bind(2) on a socket you'll use for connect(2). /* Bind the local socket */ saddr.sin_family = AF_INET; @@ -717,9 +752,6 @@ static void handleAccept(ServerInfo const *srv) memcpy(&saddr.sin_addr, &srv->localAddr, sizeof(struct in_addr)); saddr.sin_port = srv->localPort; - if (srv->toProto == protoTcp) - setSocketDefaults(cnx->local.fd); - if (connect(cnx->local.fd, (struct sockaddr *)&saddr, sizeof(struct sockaddr_in)) == SOCKET_ERROR) { @@ -737,6 +769,19 @@ static void handleAccept(ServerInfo const *srv) } } + /* Send a zero-size UDP packet to simulate a connection */ + if (srv->toProto == protoUdp) { + int got = sendto(cnx->local.fd, NULL, 0, 0, + &saddr, (SOCKLEN_T)sizeof(saddr)); + /* FIXME: we ignore errors here... is it safe? */ + (void)got; + } + + /* Send UDP data to the other socket */ + if (srv->fromProto == protoUdp) { + handleUdpRead(cnx, globalUdpBuffer, udpBytes); + } + #ifndef _WIN32 if (cnx->local.fd > maxfd) { maxfd = cnx->local.fd;