fix: platform socket on Windows (#2121)

* fix: platform socket on Windows

* fix(workflow): use win platform ssl provider instead of openssl for cURL
openssl can't use system's native CA store by default
This commit is contained in:
AdoptOSS 2022-06-27 20:54:44 +08:00 committed by Him188
parent 2022074007
commit dc747ea438
3 changed files with 74 additions and 77 deletions

View File

@ -350,7 +350,7 @@ jobs:
uses: pat-s/always-upload-cache@v3 uses: pat-s/always-upload-cache@v3
with: with:
path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }} path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}
key: ${{ runner.os }}-vcpkg-binary-cache-${{ github.job }} key: ${{ runner.os }}-vcpkg-binary-cache-${{ github.job }}
restore-keys: | restore-keys: |
${{ runner.os }}-vcpkg-binary-cache- ${{ runner.os }}-vcpkg-binary-cache-
@ -386,7 +386,7 @@ jobs:
name: Install OpenSSL & cURL on Windows name: Install OpenSSL & cURL on Windows
run: | run: |
echo "set(VCPKG_BUILD_TYPE release)" | Out-File -FilePath "$env:VCPKG_INSTALLATION_ROOT\triplets\x64-windows.cmake" -Encoding utf8 -Append echo "set(VCPKG_BUILD_TYPE release)" | Out-File -FilePath "$env:VCPKG_INSTALLATION_ROOT\triplets\x64-windows.cmake" -Encoding utf8 -Append
vcpkg install openssl:x64-windows curl[core,openssl]:x64-windows vcpkg install openssl:x64-windows curl[core,ssl]:x64-windows
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\crypto.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcrypto.lib New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\crypto.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcrypto.lib
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\ssl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libssl.lib New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\ssl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libssl.lib
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\curl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcurl.lib New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\curl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcurl.lib

View File

@ -1,31 +1,38 @@
headers = winsock.h headers = winsock.h
--- ---
#define WIN32_LEAN_AND_MEAN
#include <winsock2.h>
#include <Ws2tcpip.h>
#include <mswsock.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <winsock.h> static SOCKET socket_create_connect(char *host, unsigned short port) {
SOCKADDR_STORAGE local_addr = {0};
static int socket_create_connect(char *host, unsigned short port) { SOCKADDR_STORAGE remote_addr = {0};
struct hostent *he; DWORD local_addr_size = sizeof(local_addr);
struct sockaddr_in their_addr; /* connector's address information */ DWORD remote_addr_size = sizeof(remote_addr);
if ((he = gethostbyname(host)) == NULL) { /* get the host info */ char port_name[6];
return -1;
}
int sockfd; int sockfd;
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
return -2; sprintf(port_name, "%d", (int)port);
if ((sockfd = socket(AF_INET6, SOCK_STREAM, 0)) == INVALID_SOCKET) {
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {
return INVALID_SOCKET;
}
} else {
int ipv6only = 0;
setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&ipv6only, sizeof(ipv6only));
} }
if (!WSAConnectByNameA(sockfd, host, port_name, &local_addr_size, (SOCKADDR*)&local_addr, &remote_addr_size, (SOCKADDR*)&remote_addr, NULL, NULL)) {
their_addr.sin_family = AF_INET; /* host byte order */ closesocket(sockfd);
their_addr.sin_port = htons(port); /* short, network byte order */ return INVALID_SOCKET;
their_addr.sin_addr = *((struct in_addr *) he->h_addr); }
bzero(&(their_addr.sin_zero), 8); /* zero the rest of the struct */ if (setsockopt(sockfd, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0) == SOCKET_ERROR) {
closesocket(sockfd);
if (connect(sockfd, (struct sockaddr *) &their_addr, sizeof(struct sockaddr)) == -1) { return INVALID_SOCKET;
return -3;
} }
return sockfd; return sockfd;
} }

View File

@ -10,21 +10,17 @@
package net.mamoe.mirai.internal.utils package net.mamoe.mirai.internal.utils
import io.ktor.utils.io.core.* import io.ktor.utils.io.core.*
import io.ktor.utils.io.core.EOFException
import io.ktor.utils.io.errors.* import io.ktor.utils.io.errors.*
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.*
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import net.mamoe.mirai.internal.network.highway.HighwayProtocolChannel import net.mamoe.mirai.internal.network.highway.HighwayProtocolChannel
import net.mamoe.mirai.utils.ByteArrayPool
import net.mamoe.mirai.utils.DEFAULT_BUFFER_SIZE import net.mamoe.mirai.utils.DEFAULT_BUFFER_SIZE
import net.mamoe.mirai.utils.toReadPacket import net.mamoe.mirai.utils.toReadPacket
import net.mamoe.mirai.utils.wrapIO import platform.posix.*
import platform.posix.close
import platform.posix.read
import platform.posix.write
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -32,57 +28,73 @@ import kotlin.contracts.contract
* TCP Socket. * TCP Socket.
*/ */
internal actual class PlatformSocket( internal actual class PlatformSocket(
private val socket: Int private val socket: SOCKET
) : Closeable, HighwayProtocolChannel { ) : Closeable, HighwayProtocolChannel {
@OptIn(ExperimentalCoroutinesApi::class) @Suppress("UnnecessaryOptInAnnotation")
private val dispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher") @OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
private val readDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.readDispatcher")
@Suppress("UnnecessaryOptInAnnotation")
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
private val sendDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.sendDispatcher")
private val readLock = Mutex()
private val readBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
private val writeLock = Mutex() private val writeLock = Mutex()
private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin() private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
actual val isOpen: Boolean actual val isOpen: Boolean
get() = write(socket, null, 0) != 0 get() = send(socket, null, 0, 0).convert<Long>() != 0L
actual override fun close() { actual override fun close() {
if (close(socket) != 0) { closesocket(socket)
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO() (readDispatcher as? CloseableCoroutineDispatcher)?.close()
} (sendDispatcher as? CloseableCoroutineDispatcher)?.close()
writeBuffer.unpin()
} }
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = readLock.withLock { actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = writeLock.withLock {
withContext(dispatcher) { withContext(sendDispatcher) {
require(offset >= 0) { "offset must >= 0" } require(offset >= 0) { "offset must >= 0" }
require(length >= 0) { "length must >= 0" } require(length >= 0) { "length must >= 0" }
require(offset + length <= packet.size) { "It must follows offset + length <= packet.size" } require(offset + length <= packet.size) { "It must follows offset + length <= packet.size" }
packet.usePinned { pin -> packet.usePinned { pin ->
if (write(socket, pin.addressOf(offset), length.convert()) != 0) { if (send(socket, pin.addressOf(offset), length.convert(), 0).convert<Long>() < 0L) {
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO() @Suppress("INVISIBLE_MEMBER")
throw PosixException.forErrno(posixFunctionName = "send()").wrapIO()
} }
} }
} }
} }
actual override suspend fun send(packet: ByteReadPacket): Unit = readLock.withLock { actual override suspend fun send(packet: ByteReadPacket): Unit = writeLock.withLock {
withContext(dispatcher) { withContext(sendDispatcher) {
val writeBuffer = writeBuffer val writeBuffer = writeBuffer
val length = packet.readAvailable(writeBuffer.get()) while (packet.remaining != 0L) {
if (write(socket, writeBuffer.addressOf(0), length.convert()) != 0) { val length = packet.readAvailable(writeBuffer.get())
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO() if (send(socket, writeBuffer.addressOf(0), length.convert(), 0).convert<Long>() < 0L) {
@Suppress("INVISIBLE_MEMBER")
throw PosixException.forErrno(posixFunctionName = "send()").wrapIO()
}
} }
} }
} }
/** /**
* @throws ReadPacketInternalException * @throws EOFException
*/ */
actual override suspend fun read(): ByteReadPacket = writeLock.withLock { actual override suspend fun read(): ByteReadPacket = withContext(readDispatcher) {
withContext(dispatcher) { val readBuffer = ByteArrayPool.borrow()
val readBuffer = readBuffer try {
val length = read(socket, readBuffer.addressOf(0), readBuffer.get().size.convert()) val length = readBuffer.usePinned { pinned ->
readBuffer.get().toReadPacket(length = length) recv(socket, pinned.addressOf(0), pinned.get().size.convert(), 0).convert<Long>()
}
if (length <= 0L) throw EOFException("recv: $length, errno=$errno")
// [toReadPacket] 后的 readBuffer 会被其内部直接使用,而不是 copy 一份
// 在 release 前不能被复用
readBuffer.toReadPacket(length = length.toInt()) { ByteArrayPool.recycle(it) }
} catch (e: Throwable) {
ByteArrayPool.recycle(readBuffer)
throw e
} }
} }
@ -93,32 +105,10 @@ internal actual class PlatformSocket(
serverPort: Int serverPort: Int
): PlatformSocket { ): PlatformSocket {
val r = sockets.socket_create_connect(serverIp.cstr, serverPort.toUShort()) val r = sockets.socket_create_connect(serverIp.cstr, serverPort.toUShort())
if (r < 0) error("Failed socket_create_connect: $r") if (r == INVALID_SOCKET) error("Failed socket_create_connect: $r")
return PlatformSocket(r) return PlatformSocket(r)
// val addr = memScoped {
// alloc<sockaddr_in>() {
// sin_family = AF_INET.convert()
// sin_port = htons(serverPort.toUShort())
// sin_addr.S_un
// sin_addr = resolveIpFromHost(serverIp).reinterpret<in_addr>().rawValue
// }
// }.reinterpret<sockaddr>()
//
// val id = socket(AF_INET, SOCK_STREAM, 0)
// if (id.toInt() == -1) throw PosixException.forErrno(posixFunctionName = "socket()")
//
// val conn = connect(id, addr.ptr, sizeOf<sockaddr_in>().convert())
// if (conn != 0) throw PosixException.forErrno(posixFunctionName = "connect()")
//
// return PlatformSocket(conn)
} }
// private fun resolveIpFromHost(serverIp: String): CPointer<hostent> {
// return gethostbyname(serverIp)
// ?: throw IllegalStateException("Failed to resolve IP from host. host=$serverIp")
// }
actual suspend inline fun <R> withConnection( actual suspend inline fun <R> withConnection(
serverIp: String, serverIp: String,
serverPort: Int, serverPort: Int,