fix: platform socket on Windows (#2121)

* fix: platform socket on Windows

* fix(workflow): use win platform ssl provider instead of openssl for cURL
openssl can't use system's native CA store by default
This commit is contained in:
AdoptOSS 2022-06-27 20:54:44 +08:00 committed by Him188
parent 2022074007
commit dc747ea438
3 changed files with 74 additions and 77 deletions

View File

@ -386,7 +386,7 @@ jobs:
name: Install OpenSSL & cURL on Windows
run: |
echo "set(VCPKG_BUILD_TYPE release)" | Out-File -FilePath "$env:VCPKG_INSTALLATION_ROOT\triplets\x64-windows.cmake" -Encoding utf8 -Append
vcpkg install openssl:x64-windows curl[core,openssl]:x64-windows
vcpkg install openssl:x64-windows curl[core,ssl]:x64-windows
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\crypto.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcrypto.lib
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\ssl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libssl.lib
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\curl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcurl.lib

View File

@ -1,31 +1,38 @@
headers = winsock.h
---
#define WIN32_LEAN_AND_MEAN
#include <winsock2.h>
#include <Ws2tcpip.h>
#include <mswsock.h>
#include <stdlib.h>
#include <string.h>
#include <winsock.h>
static int socket_create_connect(char *host, unsigned short port) {
struct hostent *he;
struct sockaddr_in their_addr; /* connector's address information */
if ((he = gethostbyname(host)) == NULL) { /* get the host info */
return -1;
}
static SOCKET socket_create_connect(char *host, unsigned short port) {
SOCKADDR_STORAGE local_addr = {0};
SOCKADDR_STORAGE remote_addr = {0};
DWORD local_addr_size = sizeof(local_addr);
DWORD remote_addr_size = sizeof(remote_addr);
char port_name[6];
int sockfd;
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
return -2;
sprintf(port_name, "%d", (int)port);
if ((sockfd = socket(AF_INET6, SOCK_STREAM, 0)) == INVALID_SOCKET) {
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {
return INVALID_SOCKET;
}
their_addr.sin_family = AF_INET; /* host byte order */
their_addr.sin_port = htons(port); /* short, network byte order */
their_addr.sin_addr = *((struct in_addr *) he->h_addr);
bzero(&(their_addr.sin_zero), 8); /* zero the rest of the struct */
if (connect(sockfd, (struct sockaddr *) &their_addr, sizeof(struct sockaddr)) == -1) {
return -3;
} else {
int ipv6only = 0;
setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&ipv6only, sizeof(ipv6only));
}
if (!WSAConnectByNameA(sockfd, host, port_name, &local_addr_size, (SOCKADDR*)&local_addr, &remote_addr_size, (SOCKADDR*)&remote_addr, NULL, NULL)) {
closesocket(sockfd);
return INVALID_SOCKET;
}
if (setsockopt(sockfd, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0) == SOCKET_ERROR) {
closesocket(sockfd);
return INVALID_SOCKET;
}
return sockfd;
}

View File

@ -10,21 +10,17 @@
package net.mamoe.mirai.internal.utils
import io.ktor.utils.io.core.*
import io.ktor.utils.io.core.EOFException
import io.ktor.utils.io.errors.*
import kotlinx.cinterop.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import net.mamoe.mirai.internal.network.highway.HighwayProtocolChannel
import net.mamoe.mirai.utils.ByteArrayPool
import net.mamoe.mirai.utils.DEFAULT_BUFFER_SIZE
import net.mamoe.mirai.utils.toReadPacket
import net.mamoe.mirai.utils.wrapIO
import platform.posix.close
import platform.posix.read
import platform.posix.write
import platform.posix.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
@ -32,57 +28,73 @@ import kotlin.contracts.contract
* TCP Socket.
*/
internal actual class PlatformSocket(
private val socket: Int
private val socket: SOCKET
) : Closeable, HighwayProtocolChannel {
@OptIn(ExperimentalCoroutinesApi::class)
private val dispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher")
@Suppress("UnnecessaryOptInAnnotation")
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
private val readDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.readDispatcher")
@Suppress("UnnecessaryOptInAnnotation")
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
private val sendDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.sendDispatcher")
private val readLock = Mutex()
private val readBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
private val writeLock = Mutex()
private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
actual val isOpen: Boolean
get() = write(socket, null, 0) != 0
get() = send(socket, null, 0, 0).convert<Long>() != 0L
actual override fun close() {
if (close(socket) != 0) {
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
}
closesocket(socket)
(readDispatcher as? CloseableCoroutineDispatcher)?.close()
(sendDispatcher as? CloseableCoroutineDispatcher)?.close()
writeBuffer.unpin()
}
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = readLock.withLock {
withContext(dispatcher) {
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = writeLock.withLock {
withContext(sendDispatcher) {
require(offset >= 0) { "offset must >= 0" }
require(length >= 0) { "length must >= 0" }
require(offset + length <= packet.size) { "It must follows offset + length <= packet.size" }
packet.usePinned { pin ->
if (write(socket, pin.addressOf(offset), length.convert()) != 0) {
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
if (send(socket, pin.addressOf(offset), length.convert(), 0).convert<Long>() < 0L) {
@Suppress("INVISIBLE_MEMBER")
throw PosixException.forErrno(posixFunctionName = "send()").wrapIO()
}
}
}
}
actual override suspend fun send(packet: ByteReadPacket): Unit = readLock.withLock {
withContext(dispatcher) {
actual override suspend fun send(packet: ByteReadPacket): Unit = writeLock.withLock {
withContext(sendDispatcher) {
val writeBuffer = writeBuffer
while (packet.remaining != 0L) {
val length = packet.readAvailable(writeBuffer.get())
if (write(socket, writeBuffer.addressOf(0), length.convert()) != 0) {
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
if (send(socket, writeBuffer.addressOf(0), length.convert(), 0).convert<Long>() < 0L) {
@Suppress("INVISIBLE_MEMBER")
throw PosixException.forErrno(posixFunctionName = "send()").wrapIO()
}
}
}
}
/**
* @throws ReadPacketInternalException
* @throws EOFException
*/
actual override suspend fun read(): ByteReadPacket = writeLock.withLock {
withContext(dispatcher) {
val readBuffer = readBuffer
val length = read(socket, readBuffer.addressOf(0), readBuffer.get().size.convert())
readBuffer.get().toReadPacket(length = length)
actual override suspend fun read(): ByteReadPacket = withContext(readDispatcher) {
val readBuffer = ByteArrayPool.borrow()
try {
val length = readBuffer.usePinned { pinned ->
recv(socket, pinned.addressOf(0), pinned.get().size.convert(), 0).convert<Long>()
}
if (length <= 0L) throw EOFException("recv: $length, errno=$errno")
// [toReadPacket] 后的 readBuffer 会被其内部直接使用,而不是 copy 一份
// 在 release 前不能被复用
readBuffer.toReadPacket(length = length.toInt()) { ByteArrayPool.recycle(it) }
} catch (e: Throwable) {
ByteArrayPool.recycle(readBuffer)
throw e
}
}
@ -93,32 +105,10 @@ internal actual class PlatformSocket(
serverPort: Int
): PlatformSocket {
val r = sockets.socket_create_connect(serverIp.cstr, serverPort.toUShort())
if (r < 0) error("Failed socket_create_connect: $r")
if (r == INVALID_SOCKET) error("Failed socket_create_connect: $r")
return PlatformSocket(r)
// val addr = memScoped {
// alloc<sockaddr_in>() {
// sin_family = AF_INET.convert()
// sin_port = htons(serverPort.toUShort())
// sin_addr.S_un
// sin_addr = resolveIpFromHost(serverIp).reinterpret<in_addr>().rawValue
// }
// }.reinterpret<sockaddr>()
//
// val id = socket(AF_INET, SOCK_STREAM, 0)
// if (id.toInt() == -1) throw PosixException.forErrno(posixFunctionName = "socket()")
//
// val conn = connect(id, addr.ptr, sizeOf<sockaddr_in>().convert())
// if (conn != 0) throw PosixException.forErrno(posixFunctionName = "connect()")
//
// return PlatformSocket(conn)
}
// private fun resolveIpFromHost(serverIp: String): CPointer<hostent> {
// return gethostbyname(serverIp)
// ?: throw IllegalStateException("Failed to resolve IP from host. host=$serverIp")
// }
actual suspend inline fun <R> withConnection(
serverIp: String,
serverPort: Int,