From dc747ea43837e6892b8c5644bc605637796315b2 Mon Sep 17 00:00:00 2001 From: AdoptOSS Date: Mon, 27 Jun 2022 20:54:44 +0800 Subject: [PATCH] fix: platform socket on Windows (#2121) * fix: platform socket on Windows * fix(workflow): use win platform ssl provider instead of openssl for cURL openssl can't use system's native CA store by default --- .github/workflows/build.yml | 4 +- .../src/mingwX64Main/cinterop/Socket.def | 47 ++++---- .../kotlin/utils/PlatformSocket.kt | 100 ++++++++---------- 3 files changed, 74 insertions(+), 77 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cd383b4d1..88b7be217 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -350,7 +350,7 @@ jobs: uses: pat-s/always-upload-cache@v3 with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }} - key: ${{ runner.os }}-vcpkg-binary-cache-${{ github.job }} + key: ${{ runner.os }}-vcpkg-binary-cache-${{ github.job }} restore-keys: | ${{ runner.os }}-vcpkg-binary-cache- @@ -386,7 +386,7 @@ jobs: name: Install OpenSSL & cURL on Windows run: | echo "set(VCPKG_BUILD_TYPE release)" | Out-File -FilePath "$env:VCPKG_INSTALLATION_ROOT\triplets\x64-windows.cmake" -Encoding utf8 -Append - vcpkg install openssl:x64-windows curl[core,openssl]:x64-windows + vcpkg install openssl:x64-windows curl[core,ssl]:x64-windows New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\crypto.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcrypto.lib New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\ssl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libssl.lib New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\curl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcurl.lib diff --git a/mirai-core/src/mingwX64Main/cinterop/Socket.def b/mirai-core/src/mingwX64Main/cinterop/Socket.def index eadcc2599..debd8e44c 100644 --- a/mirai-core/src/mingwX64Main/cinterop/Socket.def +++ b/mirai-core/src/mingwX64Main/cinterop/Socket.def @@ -1,31 +1,38 @@ headers = winsock.h --- - +#define WIN32_LEAN_AND_MEAN +#include +#include +#include #include #include -#include - -static int socket_create_connect(char *host, unsigned short port) { - struct hostent *he; - struct sockaddr_in their_addr; /* connector's address information */ - if ((he = gethostbyname(host)) == NULL) { /* get the host info */ - return -1; - } +static SOCKET socket_create_connect(char *host, unsigned short port) { + SOCKADDR_STORAGE local_addr = {0}; + SOCKADDR_STORAGE remote_addr = {0}; + DWORD local_addr_size = sizeof(local_addr); + DWORD remote_addr_size = sizeof(remote_addr); + char port_name[6]; int sockfd; - if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == -1) { - return -2; + + sprintf(port_name, "%d", (int)port); + + if ((sockfd = socket(AF_INET6, SOCK_STREAM, 0)) == INVALID_SOCKET) { + if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) { + return INVALID_SOCKET; + } + } else { + int ipv6only = 0; + setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&ipv6only, sizeof(ipv6only)); } - - their_addr.sin_family = AF_INET; /* host byte order */ - their_addr.sin_port = htons(port); /* short, network byte order */ - their_addr.sin_addr = *((struct in_addr *) he->h_addr); - bzero(&(their_addr.sin_zero), 8); /* zero the rest of the struct */ - - if (connect(sockfd, (struct sockaddr *) &their_addr, sizeof(struct sockaddr)) == -1) { - return -3; + if (!WSAConnectByNameA(sockfd, host, port_name, &local_addr_size, (SOCKADDR*)&local_addr, &remote_addr_size, (SOCKADDR*)&remote_addr, NULL, NULL)) { + closesocket(sockfd); + return INVALID_SOCKET; + } + if (setsockopt(sockfd, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0) == SOCKET_ERROR) { + closesocket(sockfd); + return INVALID_SOCKET; } - return sockfd; } \ No newline at end of file diff --git a/mirai-core/src/mingwX64Main/kotlin/utils/PlatformSocket.kt b/mirai-core/src/mingwX64Main/kotlin/utils/PlatformSocket.kt index 34873f035..c58420297 100644 --- a/mirai-core/src/mingwX64Main/kotlin/utils/PlatformSocket.kt +++ b/mirai-core/src/mingwX64Main/kotlin/utils/PlatformSocket.kt @@ -10,21 +10,17 @@ package net.mamoe.mirai.internal.utils import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.EOFException import io.ktor.utils.io.errors.* import kotlinx.cinterop.* -import kotlinx.coroutines.CoroutineDispatcher -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.newSingleThreadContext +import kotlinx.coroutines.* import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import kotlinx.coroutines.withContext import net.mamoe.mirai.internal.network.highway.HighwayProtocolChannel +import net.mamoe.mirai.utils.ByteArrayPool import net.mamoe.mirai.utils.DEFAULT_BUFFER_SIZE import net.mamoe.mirai.utils.toReadPacket -import net.mamoe.mirai.utils.wrapIO -import platform.posix.close -import platform.posix.read -import platform.posix.write +import platform.posix.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -32,57 +28,73 @@ import kotlin.contracts.contract * TCP Socket. */ internal actual class PlatformSocket( - private val socket: Int + private val socket: SOCKET ) : Closeable, HighwayProtocolChannel { - @OptIn(ExperimentalCoroutinesApi::class) - private val dispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher") + @Suppress("UnnecessaryOptInAnnotation") + @OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class) + private val readDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.readDispatcher") + + @Suppress("UnnecessaryOptInAnnotation") + @OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class) + private val sendDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.sendDispatcher") - private val readLock = Mutex() - private val readBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin() private val writeLock = Mutex() private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin() actual val isOpen: Boolean - get() = write(socket, null, 0) != 0 + get() = send(socket, null, 0, 0).convert() != 0L actual override fun close() { - if (close(socket) != 0) { - throw PosixException.forErrno(posixFunctionName = "close()").wrapIO() - } + closesocket(socket) + (readDispatcher as? CloseableCoroutineDispatcher)?.close() + (sendDispatcher as? CloseableCoroutineDispatcher)?.close() + writeBuffer.unpin() } - actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = readLock.withLock { - withContext(dispatcher) { + actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = writeLock.withLock { + withContext(sendDispatcher) { require(offset >= 0) { "offset must >= 0" } require(length >= 0) { "length must >= 0" } require(offset + length <= packet.size) { "It must follows offset + length <= packet.size" } packet.usePinned { pin -> - if (write(socket, pin.addressOf(offset), length.convert()) != 0) { - throw PosixException.forErrno(posixFunctionName = "close()").wrapIO() + if (send(socket, pin.addressOf(offset), length.convert(), 0).convert() < 0L) { + @Suppress("INVISIBLE_MEMBER") + throw PosixException.forErrno(posixFunctionName = "send()").wrapIO() } } } } - actual override suspend fun send(packet: ByteReadPacket): Unit = readLock.withLock { - withContext(dispatcher) { + actual override suspend fun send(packet: ByteReadPacket): Unit = writeLock.withLock { + withContext(sendDispatcher) { val writeBuffer = writeBuffer - val length = packet.readAvailable(writeBuffer.get()) - if (write(socket, writeBuffer.addressOf(0), length.convert()) != 0) { - throw PosixException.forErrno(posixFunctionName = "close()").wrapIO() + while (packet.remaining != 0L) { + val length = packet.readAvailable(writeBuffer.get()) + if (send(socket, writeBuffer.addressOf(0), length.convert(), 0).convert() < 0L) { + @Suppress("INVISIBLE_MEMBER") + throw PosixException.forErrno(posixFunctionName = "send()").wrapIO() + } } } } /** - * @throws ReadPacketInternalException + * @throws EOFException */ - actual override suspend fun read(): ByteReadPacket = writeLock.withLock { - withContext(dispatcher) { - val readBuffer = readBuffer - val length = read(socket, readBuffer.addressOf(0), readBuffer.get().size.convert()) - readBuffer.get().toReadPacket(length = length) + actual override suspend fun read(): ByteReadPacket = withContext(readDispatcher) { + val readBuffer = ByteArrayPool.borrow() + try { + val length = readBuffer.usePinned { pinned -> + recv(socket, pinned.addressOf(0), pinned.get().size.convert(), 0).convert() + } + if (length <= 0L) throw EOFException("recv: $length, errno=$errno") + // [toReadPacket] 后的 readBuffer 会被其内部直接使用,而不是 copy 一份 + // 在 release 前不能被复用 + readBuffer.toReadPacket(length = length.toInt()) { ByteArrayPool.recycle(it) } + } catch (e: Throwable) { + ByteArrayPool.recycle(readBuffer) + throw e } } @@ -93,32 +105,10 @@ internal actual class PlatformSocket( serverPort: Int ): PlatformSocket { val r = sockets.socket_create_connect(serverIp.cstr, serverPort.toUShort()) - if (r < 0) error("Failed socket_create_connect: $r") + if (r == INVALID_SOCKET) error("Failed socket_create_connect: $r") return PlatformSocket(r) - -// val addr = memScoped { -// alloc() { -// sin_family = AF_INET.convert() -// sin_port = htons(serverPort.toUShort()) -// sin_addr.S_un -// sin_addr = resolveIpFromHost(serverIp).reinterpret().rawValue -// } -// }.reinterpret() -// -// val id = socket(AF_INET, SOCK_STREAM, 0) -// if (id.toInt() == -1) throw PosixException.forErrno(posixFunctionName = "socket()") -// -// val conn = connect(id, addr.ptr, sizeOf().convert()) -// if (conn != 0) throw PosixException.forErrno(posixFunctionName = "connect()") -// -// return PlatformSocket(conn) } -// private fun resolveIpFromHost(serverIp: String): CPointer { -// return gethostbyname(serverIp) -// ?: throw IllegalStateException("Failed to resolve IP from host. host=$serverIp") -// } - actual suspend inline fun withConnection( serverIp: String, serverPort: Int,