Improve UDP connection handling.

Send a zero-sized UDP packet when first accessing the server; this makes
TCP→UDP→TCP redirects work in both directions.

Properly close UDP connections when the other end disconnects (TCP) or
timeouts (UDP).
This commit is contained in:
Sam Hocevar 2017-09-08 13:28:09 +02:00
parent b833d456dd
commit 37437a5c5b
2 changed files with 65 additions and 20 deletions

View File

@ -1,6 +1,6 @@
# Process this file with autoconf to produce a configure script.
AC_PREREQ(2.52)
AC_INIT(rinetd, 0.63, sam@hocevar.net)
AC_INIT(rinetd, 0.63.test, sam@hocevar.net)
AC_CONFIG_AUX_DIR(.auto)
AC_CONFIG_SRCDIR([getopt.h])
AC_CONFIG_HEADER([config.h])

View File

@ -67,6 +67,9 @@ int const maxfd = 0;
int maxfd = 0;
#endif
/* Global static buffer for UDP data. */
static char globalUdpBuffer[65536];
char *logFileName = NULL;
char *pidLogFileName = NULL;
int logFormatCommon = 0;
@ -107,6 +110,7 @@ RinetdOptions options = {
static void selectPass(void);
static void handleWrite(ConnectionInfo *cnx, Socket *socket, Socket *other_socket);
static void handleRead(ConnectionInfo *cnx, Socket *socket, Socket *other_socket);
static void handleUdpRead(ConnectionInfo *cnx, char const *buffer, int bytes);
static void handleClose(ConnectionInfo *cnx, Socket *socket, Socket *other_socket);
static void handleAccept(ServerInfo const *srv);
static ConnectionInfo *findAvailableConnection(void);
@ -246,7 +250,8 @@ static void readConfiguration(char const *file) {
void addServer(char *bindAddress, int bindPort, int bindProto,
char *connectAddress, int connectPort, int connectProto,
int serverTimeout) {
int serverTimeout)
{
/* Turn all of this stuff into reasonable addresses */
struct in_addr iaddr;
if (getAddress(bindAddress, &iaddr) < 0) {
@ -271,8 +276,7 @@ void addServer(char *bindAddress, int bindPort, int bindProto,
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
(const char *) &tmp, sizeof(tmp));
if (bind(fd, (struct sockaddr *)
&saddr, sizeof(saddr)) == SOCKET_ERROR)
{
&saddr, sizeof(saddr)) == SOCKET_ERROR) {
/* Warn -- don't exit. */
syslog(LOG_ERR, "couldn't bind to "
"address %s port %d (%m)\n",
@ -290,6 +294,8 @@ void addServer(char *bindAddress, int bindPort, int bindProto,
closesocket(fd);
}
/* Make socket nonblocking in TCP mode only, otherwise
we may miss some data. */
setSocketDefaults(fd);
}
@ -535,6 +541,18 @@ static void handleRead(ConnectionInfo *cnx, Socket *socket, Socket *other_socket
socket->recvPos += got;
}
static void handleUdpRead(ConnectionInfo *cnx, char const *buffer, int bytes)
{
Socket *socket = &cnx->remote;
int got = bytes < RINETD_BUFFER_SIZE - socket->recvPos
? bytes : RINETD_BUFFER_SIZE - socket->recvPos;
if (got > 0) {
memcpy(socket->buffer + socket->recvPos, buffer, got);
socket->recvBytes += got;
socket->recvPos += got;
}
}
static void handleWrite(ConnectionInfo *cnx, Socket *socket, Socket *other_socket)
{
if (cnx->coClosing && (socket->sentPos == other_socket->recvPos)) {
@ -583,17 +601,25 @@ static void handleClose(ConnectionInfo *cnx, Socket *socket, Socket *other_socke
/* Nothing to do in UDP mode */
}
socket->fd = INVALID_SOCKET;
if (other_socket->fd != INVALID_SOCKET) {
#if !defined __linux__ && !defined _WIN32
/* Now set up the other end for a polite closing */
/* Request a low-water mark equal to the entire
output buffer, so the next write notification
tells us for sure that we can close the socket. */
int arg = 1024;
setsockopt(other_socket->fd, SOL_SOCKET, SO_SNDLOWAT,
&arg, sizeof(arg));
if (other_socket->fd != INVALID_SOCKET) {
if (other_socket->proto == protoTcp) {
#if !defined __linux__ && !defined _WIN32
/* Now set up the other end for a polite closing */
/* Request a low-water mark equal to the entire
output buffer, so the next write notification
tells us for sure that we can close the socket. */
int arg = 1024;
setsockopt(other_socket->fd, SOL_SOCKET, SO_SNDLOWAT,
&arg, sizeof(arg));
#endif
} else /* if (other_socket->proto == protoUdp) */ {
if (other_socket == &cnx->local)
closesocket(other_socket->fd);
other_socket->fd = INVALID_SOCKET;
}
cnx->coLog = socket == &cnx->local ?
logLocalClosedFirst : logRemoteClosedFirst;
}
@ -601,6 +627,8 @@ static void handleClose(ConnectionInfo *cnx, Socket *socket, Socket *other_socke
static void handleAccept(ServerInfo const *srv)
{
int udpBytes = 0;
struct sockaddr addr;
SOCKLEN_T addrlen = sizeof(addr);
@ -617,10 +645,12 @@ static void handleAccept(ServerInfo const *srv)
setSocketDefaults(nfd);
} else /* if (srv->fromProto == protoUdp) */ {
/* In UDP mode, get remote address using recvfrom() and check
for an existing connection from this client. */
for an existing connection from this client. We need
to read a lot of data otherwise the datagram contents
may be lost later. */
nfd = srv->fd;
ssize_t ret = recvfrom(srv->fd, NULL, 0, MSG_PEEK,
&addr, &addrlen);
ssize_t ret = recvfrom(nfd, globalUdpBuffer,
sizeof(globalUdpBuffer), 0, &addr, &addrlen);
if (ret < 0) {
if (GetLastError() == WSAEWOULDBLOCK) {
return;
@ -633,6 +663,8 @@ static void handleAccept(ServerInfo const *srv)
return;
}
udpBytes = (int)ret;
for (int i = 0; i < coTotal; ++i) {
ConnectionInfo *cnx = &coInfo[i];
struct sockaddr_in *addr_in = (struct sockaddr_in *)&addr;
@ -641,7 +673,7 @@ static void handleAccept(ServerInfo const *srv)
&& cnx->remoteAddress.sin_port == addr_in->sin_port
&& cnx->remoteAddress.sin_addr.s_addr == addr_in->sin_addr.s_addr) {
cnx->remoteTimeout = time(NULL) + srv->serverTimeout;
handleRead(cnx, &cnx->remote, &cnx->local);
handleUdpRead(cnx, globalUdpBuffer, udpBytes);
return;
}
}
@ -696,6 +728,9 @@ static void handleAccept(ServerInfo const *srv)
return;
}
if (srv->toProto == protoTcp)
setSocketDefaults(cnx->local.fd);
#if 0 // You don't need bind(2) on a socket you'll use for connect(2).
/* Bind the local socket */
saddr.sin_family = AF_INET;
@ -717,9 +752,6 @@ static void handleAccept(ServerInfo const *srv)
memcpy(&saddr.sin_addr, &srv->localAddr, sizeof(struct in_addr));
saddr.sin_port = srv->localPort;
if (srv->toProto == protoTcp)
setSocketDefaults(cnx->local.fd);
if (connect(cnx->local.fd, (struct sockaddr *)&saddr,
sizeof(struct sockaddr_in)) == SOCKET_ERROR)
{
@ -737,6 +769,19 @@ static void handleAccept(ServerInfo const *srv)
}
}
/* Send a zero-size UDP packet to simulate a connection */
if (srv->toProto == protoUdp) {
int got = sendto(cnx->local.fd, NULL, 0, 0,
&saddr, (SOCKLEN_T)sizeof(saddr));
/* FIXME: we ignore errors here... is it safe? */
(void)got;
}
/* Send UDP data to the other socket */
if (srv->fromProto == protoUdp) {
handleUdpRead(cnx, globalUdpBuffer, udpBytes);
}
#ifndef _WIN32
if (cnx->local.fd > maxfd) {
maxfd = cnx->local.fd;