mirror of
https://github.com/mamoe/mirai.git
synced 2025-01-19 10:34:44 +08:00
fix: platform socket on Windows (#2121)
* fix: platform socket on Windows * fix(workflow): use win platform ssl provider instead of openssl for cURL openssl can't use system's native CA store by default
This commit is contained in:
parent
2022074007
commit
dc747ea438
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -386,7 +386,7 @@ jobs:
|
||||
name: Install OpenSSL & cURL on Windows
|
||||
run: |
|
||||
echo "set(VCPKG_BUILD_TYPE release)" | Out-File -FilePath "$env:VCPKG_INSTALLATION_ROOT\triplets\x64-windows.cmake" -Encoding utf8 -Append
|
||||
vcpkg install openssl:x64-windows curl[core,openssl]:x64-windows
|
||||
vcpkg install openssl:x64-windows curl[core,ssl]:x64-windows
|
||||
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\crypto.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcrypto.lib
|
||||
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\ssl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libssl.lib
|
||||
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\curl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcurl.lib
|
||||
|
@ -1,31 +1,38 @@
|
||||
headers = winsock.h
|
||||
|
||||
---
|
||||
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#include <winsock2.h>
|
||||
#include <Ws2tcpip.h>
|
||||
#include <mswsock.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <winsock.h>
|
||||
|
||||
static int socket_create_connect(char *host, unsigned short port) {
|
||||
struct hostent *he;
|
||||
struct sockaddr_in their_addr; /* connector's address information */
|
||||
if ((he = gethostbyname(host)) == NULL) { /* get the host info */
|
||||
return -1;
|
||||
}
|
||||
static SOCKET socket_create_connect(char *host, unsigned short port) {
|
||||
SOCKADDR_STORAGE local_addr = {0};
|
||||
SOCKADDR_STORAGE remote_addr = {0};
|
||||
DWORD local_addr_size = sizeof(local_addr);
|
||||
DWORD remote_addr_size = sizeof(remote_addr);
|
||||
char port_name[6];
|
||||
int sockfd;
|
||||
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
|
||||
return -2;
|
||||
|
||||
sprintf(port_name, "%d", (int)port);
|
||||
|
||||
if ((sockfd = socket(AF_INET6, SOCK_STREAM, 0)) == INVALID_SOCKET) {
|
||||
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {
|
||||
return INVALID_SOCKET;
|
||||
}
|
||||
|
||||
their_addr.sin_family = AF_INET; /* host byte order */
|
||||
their_addr.sin_port = htons(port); /* short, network byte order */
|
||||
their_addr.sin_addr = *((struct in_addr *) he->h_addr);
|
||||
bzero(&(their_addr.sin_zero), 8); /* zero the rest of the struct */
|
||||
|
||||
if (connect(sockfd, (struct sockaddr *) &their_addr, sizeof(struct sockaddr)) == -1) {
|
||||
return -3;
|
||||
} else {
|
||||
int ipv6only = 0;
|
||||
setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&ipv6only, sizeof(ipv6only));
|
||||
}
|
||||
if (!WSAConnectByNameA(sockfd, host, port_name, &local_addr_size, (SOCKADDR*)&local_addr, &remote_addr_size, (SOCKADDR*)&remote_addr, NULL, NULL)) {
|
||||
closesocket(sockfd);
|
||||
return INVALID_SOCKET;
|
||||
}
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0) == SOCKET_ERROR) {
|
||||
closesocket(sockfd);
|
||||
return INVALID_SOCKET;
|
||||
}
|
||||
|
||||
return sockfd;
|
||||
}
|
@ -10,21 +10,17 @@
|
||||
package net.mamoe.mirai.internal.utils
|
||||
|
||||
import io.ktor.utils.io.core.*
|
||||
import io.ktor.utils.io.core.EOFException
|
||||
import io.ktor.utils.io.errors.*
|
||||
import kotlinx.cinterop.*
|
||||
import kotlinx.coroutines.CoroutineDispatcher
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import kotlinx.coroutines.withContext
|
||||
import net.mamoe.mirai.internal.network.highway.HighwayProtocolChannel
|
||||
import net.mamoe.mirai.utils.ByteArrayPool
|
||||
import net.mamoe.mirai.utils.DEFAULT_BUFFER_SIZE
|
||||
import net.mamoe.mirai.utils.toReadPacket
|
||||
import net.mamoe.mirai.utils.wrapIO
|
||||
import platform.posix.close
|
||||
import platform.posix.read
|
||||
import platform.posix.write
|
||||
import platform.posix.*
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
@ -32,57 +28,73 @@ import kotlin.contracts.contract
|
||||
* TCP Socket.
|
||||
*/
|
||||
internal actual class PlatformSocket(
|
||||
private val socket: Int
|
||||
private val socket: SOCKET
|
||||
) : Closeable, HighwayProtocolChannel {
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
private val dispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher")
|
||||
@Suppress("UnnecessaryOptInAnnotation")
|
||||
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
|
||||
private val readDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.readDispatcher")
|
||||
|
||||
@Suppress("UnnecessaryOptInAnnotation")
|
||||
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
|
||||
private val sendDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.sendDispatcher")
|
||||
|
||||
private val readLock = Mutex()
|
||||
private val readBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
|
||||
private val writeLock = Mutex()
|
||||
private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
|
||||
|
||||
actual val isOpen: Boolean
|
||||
get() = write(socket, null, 0) != 0
|
||||
get() = send(socket, null, 0, 0).convert<Long>() != 0L
|
||||
|
||||
actual override fun close() {
|
||||
if (close(socket) != 0) {
|
||||
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
|
||||
}
|
||||
closesocket(socket)
|
||||
(readDispatcher as? CloseableCoroutineDispatcher)?.close()
|
||||
(sendDispatcher as? CloseableCoroutineDispatcher)?.close()
|
||||
writeBuffer.unpin()
|
||||
}
|
||||
|
||||
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = readLock.withLock {
|
||||
withContext(dispatcher) {
|
||||
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = writeLock.withLock {
|
||||
withContext(sendDispatcher) {
|
||||
require(offset >= 0) { "offset must >= 0" }
|
||||
require(length >= 0) { "length must >= 0" }
|
||||
require(offset + length <= packet.size) { "It must follows offset + length <= packet.size" }
|
||||
packet.usePinned { pin ->
|
||||
if (write(socket, pin.addressOf(offset), length.convert()) != 0) {
|
||||
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
|
||||
if (send(socket, pin.addressOf(offset), length.convert(), 0).convert<Long>() < 0L) {
|
||||
@Suppress("INVISIBLE_MEMBER")
|
||||
throw PosixException.forErrno(posixFunctionName = "send()").wrapIO()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
actual override suspend fun send(packet: ByteReadPacket): Unit = readLock.withLock {
|
||||
withContext(dispatcher) {
|
||||
actual override suspend fun send(packet: ByteReadPacket): Unit = writeLock.withLock {
|
||||
withContext(sendDispatcher) {
|
||||
val writeBuffer = writeBuffer
|
||||
while (packet.remaining != 0L) {
|
||||
val length = packet.readAvailable(writeBuffer.get())
|
||||
if (write(socket, writeBuffer.addressOf(0), length.convert()) != 0) {
|
||||
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
|
||||
if (send(socket, writeBuffer.addressOf(0), length.convert(), 0).convert<Long>() < 0L) {
|
||||
@Suppress("INVISIBLE_MEMBER")
|
||||
throw PosixException.forErrno(posixFunctionName = "send()").wrapIO()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @throws ReadPacketInternalException
|
||||
* @throws EOFException
|
||||
*/
|
||||
actual override suspend fun read(): ByteReadPacket = writeLock.withLock {
|
||||
withContext(dispatcher) {
|
||||
val readBuffer = readBuffer
|
||||
val length = read(socket, readBuffer.addressOf(0), readBuffer.get().size.convert())
|
||||
readBuffer.get().toReadPacket(length = length)
|
||||
actual override suspend fun read(): ByteReadPacket = withContext(readDispatcher) {
|
||||
val readBuffer = ByteArrayPool.borrow()
|
||||
try {
|
||||
val length = readBuffer.usePinned { pinned ->
|
||||
recv(socket, pinned.addressOf(0), pinned.get().size.convert(), 0).convert<Long>()
|
||||
}
|
||||
if (length <= 0L) throw EOFException("recv: $length, errno=$errno")
|
||||
// [toReadPacket] 后的 readBuffer 会被其内部直接使用,而不是 copy 一份
|
||||
// 在 release 前不能被复用
|
||||
readBuffer.toReadPacket(length = length.toInt()) { ByteArrayPool.recycle(it) }
|
||||
} catch (e: Throwable) {
|
||||
ByteArrayPool.recycle(readBuffer)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
@ -93,32 +105,10 @@ internal actual class PlatformSocket(
|
||||
serverPort: Int
|
||||
): PlatformSocket {
|
||||
val r = sockets.socket_create_connect(serverIp.cstr, serverPort.toUShort())
|
||||
if (r < 0) error("Failed socket_create_connect: $r")
|
||||
if (r == INVALID_SOCKET) error("Failed socket_create_connect: $r")
|
||||
return PlatformSocket(r)
|
||||
|
||||
// val addr = memScoped {
|
||||
// alloc<sockaddr_in>() {
|
||||
// sin_family = AF_INET.convert()
|
||||
// sin_port = htons(serverPort.toUShort())
|
||||
// sin_addr.S_un
|
||||
// sin_addr = resolveIpFromHost(serverIp).reinterpret<in_addr>().rawValue
|
||||
// }
|
||||
// }.reinterpret<sockaddr>()
|
||||
//
|
||||
// val id = socket(AF_INET, SOCK_STREAM, 0)
|
||||
// if (id.toInt() == -1) throw PosixException.forErrno(posixFunctionName = "socket()")
|
||||
//
|
||||
// val conn = connect(id, addr.ptr, sizeOf<sockaddr_in>().convert())
|
||||
// if (conn != 0) throw PosixException.forErrno(posixFunctionName = "connect()")
|
||||
//
|
||||
// return PlatformSocket(conn)
|
||||
}
|
||||
|
||||
// private fun resolveIpFromHost(serverIp: String): CPointer<hostent> {
|
||||
// return gethostbyname(serverIp)
|
||||
// ?: throw IllegalStateException("Failed to resolve IP from host. host=$serverIp")
|
||||
// }
|
||||
|
||||
actual suspend inline fun <R> withConnection(
|
||||
serverIp: String,
|
||||
serverPort: Int,
|
||||
|
Loading…
Reference in New Issue
Block a user