mirror of
https://github.com/mamoe/mirai.git
synced 2025-01-04 06:55:00 +08:00
fix: platform socket on Windows (#2121)
* fix: platform socket on Windows * fix(workflow): use win platform ssl provider instead of openssl for cURL openssl can't use system's native CA store by default
This commit is contained in:
parent
2022074007
commit
dc747ea438
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@ -350,7 +350,7 @@ jobs:
|
|||||||
uses: pat-s/always-upload-cache@v3
|
uses: pat-s/always-upload-cache@v3
|
||||||
with:
|
with:
|
||||||
path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}
|
path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}
|
||||||
key: ${{ runner.os }}-vcpkg-binary-cache-${{ github.job }}
|
key: ${{ runner.os }}-vcpkg-binary-cache-${{ github.job }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-vcpkg-binary-cache-
|
${{ runner.os }}-vcpkg-binary-cache-
|
||||||
|
|
||||||
@ -386,7 +386,7 @@ jobs:
|
|||||||
name: Install OpenSSL & cURL on Windows
|
name: Install OpenSSL & cURL on Windows
|
||||||
run: |
|
run: |
|
||||||
echo "set(VCPKG_BUILD_TYPE release)" | Out-File -FilePath "$env:VCPKG_INSTALLATION_ROOT\triplets\x64-windows.cmake" -Encoding utf8 -Append
|
echo "set(VCPKG_BUILD_TYPE release)" | Out-File -FilePath "$env:VCPKG_INSTALLATION_ROOT\triplets\x64-windows.cmake" -Encoding utf8 -Append
|
||||||
vcpkg install openssl:x64-windows curl[core,openssl]:x64-windows
|
vcpkg install openssl:x64-windows curl[core,ssl]:x64-windows
|
||||||
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\crypto.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcrypto.lib
|
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\crypto.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcrypto.lib
|
||||||
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\ssl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libssl.lib
|
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\ssl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libssl.lib
|
||||||
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\curl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcurl.lib
|
New-Item -Path $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\curl.lib -ItemType SymbolicLink -Value $env:VCPKG_INSTALLATION_ROOT\installed\x64-windows\lib\libcurl.lib
|
||||||
|
@ -1,31 +1,38 @@
|
|||||||
headers = winsock.h
|
headers = winsock.h
|
||||||
|
|
||||||
---
|
---
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#include <winsock2.h>
|
||||||
|
#include <Ws2tcpip.h>
|
||||||
|
#include <mswsock.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <winsock.h>
|
static SOCKET socket_create_connect(char *host, unsigned short port) {
|
||||||
|
SOCKADDR_STORAGE local_addr = {0};
|
||||||
static int socket_create_connect(char *host, unsigned short port) {
|
SOCKADDR_STORAGE remote_addr = {0};
|
||||||
struct hostent *he;
|
DWORD local_addr_size = sizeof(local_addr);
|
||||||
struct sockaddr_in their_addr; /* connector's address information */
|
DWORD remote_addr_size = sizeof(remote_addr);
|
||||||
if ((he = gethostbyname(host)) == NULL) { /* get the host info */
|
char port_name[6];
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
int sockfd;
|
int sockfd;
|
||||||
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
|
|
||||||
return -2;
|
sprintf(port_name, "%d", (int)port);
|
||||||
|
|
||||||
|
if ((sockfd = socket(AF_INET6, SOCK_STREAM, 0)) == INVALID_SOCKET) {
|
||||||
|
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {
|
||||||
|
return INVALID_SOCKET;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int ipv6only = 0;
|
||||||
|
setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&ipv6only, sizeof(ipv6only));
|
||||||
}
|
}
|
||||||
|
if (!WSAConnectByNameA(sockfd, host, port_name, &local_addr_size, (SOCKADDR*)&local_addr, &remote_addr_size, (SOCKADDR*)&remote_addr, NULL, NULL)) {
|
||||||
their_addr.sin_family = AF_INET; /* host byte order */
|
closesocket(sockfd);
|
||||||
their_addr.sin_port = htons(port); /* short, network byte order */
|
return INVALID_SOCKET;
|
||||||
their_addr.sin_addr = *((struct in_addr *) he->h_addr);
|
}
|
||||||
bzero(&(their_addr.sin_zero), 8); /* zero the rest of the struct */
|
if (setsockopt(sockfd, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0) == SOCKET_ERROR) {
|
||||||
|
closesocket(sockfd);
|
||||||
if (connect(sockfd, (struct sockaddr *) &their_addr, sizeof(struct sockaddr)) == -1) {
|
return INVALID_SOCKET;
|
||||||
return -3;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockfd;
|
return sockfd;
|
||||||
}
|
}
|
@ -10,21 +10,17 @@
|
|||||||
package net.mamoe.mirai.internal.utils
|
package net.mamoe.mirai.internal.utils
|
||||||
|
|
||||||
import io.ktor.utils.io.core.*
|
import io.ktor.utils.io.core.*
|
||||||
|
import io.ktor.utils.io.core.EOFException
|
||||||
import io.ktor.utils.io.errors.*
|
import io.ktor.utils.io.errors.*
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kotlinx.coroutines.CoroutineDispatcher
|
import kotlinx.coroutines.*
|
||||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
|
||||||
import kotlinx.coroutines.newSingleThreadContext
|
|
||||||
import kotlinx.coroutines.sync.Mutex
|
import kotlinx.coroutines.sync.Mutex
|
||||||
import kotlinx.coroutines.sync.withLock
|
import kotlinx.coroutines.sync.withLock
|
||||||
import kotlinx.coroutines.withContext
|
|
||||||
import net.mamoe.mirai.internal.network.highway.HighwayProtocolChannel
|
import net.mamoe.mirai.internal.network.highway.HighwayProtocolChannel
|
||||||
|
import net.mamoe.mirai.utils.ByteArrayPool
|
||||||
import net.mamoe.mirai.utils.DEFAULT_BUFFER_SIZE
|
import net.mamoe.mirai.utils.DEFAULT_BUFFER_SIZE
|
||||||
import net.mamoe.mirai.utils.toReadPacket
|
import net.mamoe.mirai.utils.toReadPacket
|
||||||
import net.mamoe.mirai.utils.wrapIO
|
import platform.posix.*
|
||||||
import platform.posix.close
|
|
||||||
import platform.posix.read
|
|
||||||
import platform.posix.write
|
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
@ -32,57 +28,73 @@ import kotlin.contracts.contract
|
|||||||
* TCP Socket.
|
* TCP Socket.
|
||||||
*/
|
*/
|
||||||
internal actual class PlatformSocket(
|
internal actual class PlatformSocket(
|
||||||
private val socket: Int
|
private val socket: SOCKET
|
||||||
) : Closeable, HighwayProtocolChannel {
|
) : Closeable, HighwayProtocolChannel {
|
||||||
@OptIn(ExperimentalCoroutinesApi::class)
|
@Suppress("UnnecessaryOptInAnnotation")
|
||||||
private val dispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher")
|
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
|
||||||
|
private val readDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.readDispatcher")
|
||||||
|
|
||||||
|
@Suppress("UnnecessaryOptInAnnotation")
|
||||||
|
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
|
||||||
|
private val sendDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.sendDispatcher")
|
||||||
|
|
||||||
private val readLock = Mutex()
|
|
||||||
private val readBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
|
|
||||||
private val writeLock = Mutex()
|
private val writeLock = Mutex()
|
||||||
private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
|
private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
|
||||||
|
|
||||||
actual val isOpen: Boolean
|
actual val isOpen: Boolean
|
||||||
get() = write(socket, null, 0) != 0
|
get() = send(socket, null, 0, 0).convert<Long>() != 0L
|
||||||
|
|
||||||
actual override fun close() {
|
actual override fun close() {
|
||||||
if (close(socket) != 0) {
|
closesocket(socket)
|
||||||
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
|
(readDispatcher as? CloseableCoroutineDispatcher)?.close()
|
||||||
}
|
(sendDispatcher as? CloseableCoroutineDispatcher)?.close()
|
||||||
|
writeBuffer.unpin()
|
||||||
}
|
}
|
||||||
|
|
||||||
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = readLock.withLock {
|
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = writeLock.withLock {
|
||||||
withContext(dispatcher) {
|
withContext(sendDispatcher) {
|
||||||
require(offset >= 0) { "offset must >= 0" }
|
require(offset >= 0) { "offset must >= 0" }
|
||||||
require(length >= 0) { "length must >= 0" }
|
require(length >= 0) { "length must >= 0" }
|
||||||
require(offset + length <= packet.size) { "It must follows offset + length <= packet.size" }
|
require(offset + length <= packet.size) { "It must follows offset + length <= packet.size" }
|
||||||
packet.usePinned { pin ->
|
packet.usePinned { pin ->
|
||||||
if (write(socket, pin.addressOf(offset), length.convert()) != 0) {
|
if (send(socket, pin.addressOf(offset), length.convert(), 0).convert<Long>() < 0L) {
|
||||||
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
|
@Suppress("INVISIBLE_MEMBER")
|
||||||
|
throw PosixException.forErrno(posixFunctionName = "send()").wrapIO()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
actual override suspend fun send(packet: ByteReadPacket): Unit = readLock.withLock {
|
actual override suspend fun send(packet: ByteReadPacket): Unit = writeLock.withLock {
|
||||||
withContext(dispatcher) {
|
withContext(sendDispatcher) {
|
||||||
val writeBuffer = writeBuffer
|
val writeBuffer = writeBuffer
|
||||||
val length = packet.readAvailable(writeBuffer.get())
|
while (packet.remaining != 0L) {
|
||||||
if (write(socket, writeBuffer.addressOf(0), length.convert()) != 0) {
|
val length = packet.readAvailable(writeBuffer.get())
|
||||||
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
|
if (send(socket, writeBuffer.addressOf(0), length.convert(), 0).convert<Long>() < 0L) {
|
||||||
|
@Suppress("INVISIBLE_MEMBER")
|
||||||
|
throw PosixException.forErrno(posixFunctionName = "send()").wrapIO()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @throws ReadPacketInternalException
|
* @throws EOFException
|
||||||
*/
|
*/
|
||||||
actual override suspend fun read(): ByteReadPacket = writeLock.withLock {
|
actual override suspend fun read(): ByteReadPacket = withContext(readDispatcher) {
|
||||||
withContext(dispatcher) {
|
val readBuffer = ByteArrayPool.borrow()
|
||||||
val readBuffer = readBuffer
|
try {
|
||||||
val length = read(socket, readBuffer.addressOf(0), readBuffer.get().size.convert())
|
val length = readBuffer.usePinned { pinned ->
|
||||||
readBuffer.get().toReadPacket(length = length)
|
recv(socket, pinned.addressOf(0), pinned.get().size.convert(), 0).convert<Long>()
|
||||||
|
}
|
||||||
|
if (length <= 0L) throw EOFException("recv: $length, errno=$errno")
|
||||||
|
// [toReadPacket] 后的 readBuffer 会被其内部直接使用,而不是 copy 一份
|
||||||
|
// 在 release 前不能被复用
|
||||||
|
readBuffer.toReadPacket(length = length.toInt()) { ByteArrayPool.recycle(it) }
|
||||||
|
} catch (e: Throwable) {
|
||||||
|
ByteArrayPool.recycle(readBuffer)
|
||||||
|
throw e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,32 +105,10 @@ internal actual class PlatformSocket(
|
|||||||
serverPort: Int
|
serverPort: Int
|
||||||
): PlatformSocket {
|
): PlatformSocket {
|
||||||
val r = sockets.socket_create_connect(serverIp.cstr, serverPort.toUShort())
|
val r = sockets.socket_create_connect(serverIp.cstr, serverPort.toUShort())
|
||||||
if (r < 0) error("Failed socket_create_connect: $r")
|
if (r == INVALID_SOCKET) error("Failed socket_create_connect: $r")
|
||||||
return PlatformSocket(r)
|
return PlatformSocket(r)
|
||||||
|
|
||||||
// val addr = memScoped {
|
|
||||||
// alloc<sockaddr_in>() {
|
|
||||||
// sin_family = AF_INET.convert()
|
|
||||||
// sin_port = htons(serverPort.toUShort())
|
|
||||||
// sin_addr.S_un
|
|
||||||
// sin_addr = resolveIpFromHost(serverIp).reinterpret<in_addr>().rawValue
|
|
||||||
// }
|
|
||||||
// }.reinterpret<sockaddr>()
|
|
||||||
//
|
|
||||||
// val id = socket(AF_INET, SOCK_STREAM, 0)
|
|
||||||
// if (id.toInt() == -1) throw PosixException.forErrno(posixFunctionName = "socket()")
|
|
||||||
//
|
|
||||||
// val conn = connect(id, addr.ptr, sizeOf<sockaddr_in>().convert())
|
|
||||||
// if (conn != 0) throw PosixException.forErrno(posixFunctionName = "connect()")
|
|
||||||
//
|
|
||||||
// return PlatformSocket(conn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// private fun resolveIpFromHost(serverIp: String): CPointer<hostent> {
|
|
||||||
// return gethostbyname(serverIp)
|
|
||||||
// ?: throw IllegalStateException("Failed to resolve IP from host. host=$serverIp")
|
|
||||||
// }
|
|
||||||
|
|
||||||
actual suspend inline fun <R> withConnection(
|
actual suspend inline fun <R> withConnection(
|
||||||
serverIp: String,
|
serverIp: String,
|
||||||
serverPort: Int,
|
serverPort: Int,
|
||||||
|
Loading…
Reference in New Issue
Block a user