1
0
mirror of https://github.com/mamoe/mirai.git synced 2025-04-02 05:00:35 +08:00

Fix socket issues

This commit is contained in:
Him188 2022-06-01 23:35:22 +01:00
parent 06616bac35
commit 8085d94dfe
No known key found for this signature in database
GPG Key ID: BA439CDDCF652375
11 changed files with 167 additions and 94 deletions
mirai-core-utils/src/nativeMain/kotlin
mirai-core
.gitignorebuild.gradle.kts
src
commonMain/kotlin/network
nativeMain/kotlin
unixMain
cinterop
kotlin/utils

View File

@ -46,7 +46,7 @@ public actual fun currentTimeFormatted(format: String?): String = timeLock.withL
val tm = localtime(timeT.ptr) // localtime is not thread-safe
val bb = allocArray<ByteVar>(40)
strftime(bb, 40, "%Y-%M-%d %H:%M:%S", tm);
strftime(bb, 40, "%Y-%m-%d %H:%M:%S", tm);
bb.toKString()
}

View File

@ -1 +1,2 @@
src/jvmTest/kotlin/local
src/jvmTest/kotlin/local
test-sandbox/

View File

@ -114,6 +114,14 @@ kotlin {
}
}
UNIX_LIKE_TARGETS.forEach { target ->
(targets.getByName(target) as KotlinNativeTarget).compilations.getByName("main").cinterops.create("Socket")
.apply {
defFile = projectDir.resolve("src/unixMain/cinterop/Socket.def")
packageName("sockets")
}
}
configure(WIN_TARGETS.map { getByName(it + "Main") }) {
dependencies {
implementation(`ktor-client-curl`)

View File

@ -10,6 +10,7 @@
package net.mamoe.mirai.internal.network.components
import io.ktor.client.request.*
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
@ -22,6 +23,7 @@ import net.mamoe.mirai.internal.utils.crypto.ECDHWithPublicKey
import net.mamoe.mirai.internal.utils.crypto.defaultInitialPublicKey
import net.mamoe.mirai.utils.MiraiLogger
import net.mamoe.mirai.utils.currentTimeSeconds
import kotlin.time.Duration.Companion.seconds
/**
@ -88,16 +90,19 @@ internal class EcdhInitialPublicKeyUpdaterImpl(
logger.info("ECDH key is invalid, start to fetch ecdh public key from server.")
val respStr =
@Suppress("DEPRECATION", "DEPRECATION_ERROR")
Mirai.Http.get<String>("https://keyrotate.qq.com/rotate_key?cipher_suite_ver=305&uin=${bot.client.uin}")
withTimeout(10.seconds) { Mirai.Http.get<String>("https://keyrotate.qq.com/rotate_key?cipher_suite_ver=305&uin=${bot.client.uin}") }
val resp = json.decodeFromString(ServerRespPOJO.serializer(), respStr)
println("check2")
resp.pubKeyMeta.let { meta ->
val isValid = ECDH.verifyPublicKey(
version = meta.keyVer,
publicKey = meta.pubKey,
publicKeySign = meta.pubKeySign
)
println("check1")
check(isValid) { "Ecdh public key which from server is invalid" }
logger.info("Successfully fetched ecdh public key from server.")
println("check3")
ECDHInitialPublicKey(meta.keyVer, meta.pubKey, currentTimeSeconds() + resp.querySpan)
}
}

View File

@ -145,8 +145,6 @@ internal abstract class CommonNetworkHandler<Conn>(
this.setState { StateConnecting(ExceptionCollector()) }
?.resumeConnection()
?: this@CommonNetworkHandler.resumeConnection() // concurrently closed by other thread.
println("INITIALIZED RETURN")
}
override fun toString(): String = "StateInitialized"
@ -214,9 +212,6 @@ internal abstract class CommonNetworkHandler<Conn>(
connectResult.await() // propagates exceptions
val connection = connection.await()
this.setState { StateLoading(connection) }
.also {
println(" this.setState { StateLoading(connection) }: " + it)
}
?.resumeConnection()
?: this@CommonNetworkHandler.resumeConnection() // concurrently closed by other thread.
}

View File

@ -422,5 +422,5 @@ internal fun String.toIpV4Long(): Long {
if (isEmpty()) return 0
val split = split('.')
if (split.size != 4) return 0
return split.mapToByteArray { it.toByte() }.toInt().toLongUnsigned()
return split.mapToByteArray { it.toUByte().toByte() }.toInt().toLongUnsigned()
}

View File

@ -22,9 +22,7 @@ import net.mamoe.mirai.internal.network.components.SsoProcessor
import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacket
import net.mamoe.mirai.internal.utils.PlatformSocket
import net.mamoe.mirai.internal.utils.connect
import net.mamoe.mirai.utils.childScope
import net.mamoe.mirai.utils.readPacketExact
import net.mamoe.mirai.utils.toLongUnsigned
import net.mamoe.mirai.utils.*
internal class NativeNetworkHandler(
context: NetworkHandlerContext,
@ -58,7 +56,12 @@ internal class NativeNetworkHandler(
private val bufferedPackets: MutableList<ByteReadPacket> = ArrayList(10)
fun offer(packet: ByteReadPacket) {
if (missingLength == 0L) {
// initial
missingLength = packet.readInt().toLongUnsigned() - 4
}
missingLength -= packet.remaining
bufferedPackets.add(packet)
if (missingLength <= 0) {
emit()
}
@ -70,17 +73,13 @@ internal class NativeNetworkHandler(
1 -> {
val packet = bufferedPackets.first()
if (missingLength == 0L) {
packetCodec.decodeRaw(ssoProcessor.ssoSession, packet)
sendDecode(packet)
bufferedPackets.clear()
} else {
check(missingLength < 0L) { "Failed check: remainingLength < 0L" }
val previousPacketLength = missingLength + packet.remaining
decodePipeline.send(
packetCodec.decodeRaw(
ssoProcessor.ssoSession,
packet.readPacketExact(previousPacketLength.toInt())
)
)
sendDecode(packet.readPacketExact(previousPacketLength.toInt()))
// now packet contain new packet.
missingLength = packet.readInt().toLongUnsigned() - 4
@ -115,16 +114,23 @@ internal class NativeNetworkHandler(
bufferedPackets.add(lastPacket)
}
decodePipeline.send(
packetCodec.decodeRaw(
ssoProcessor.ssoSession,
combined
)
)
sendDecode(combined)
}
}
}
private fun sendDecode(combined: ByteReadPacket) {
packetLogger.verbose { "Decoding: len=${combined.remaining}" }
val raw = packetCodec.decodeRaw(
ssoProcessor.ssoSession,
combined
)
packetLogger.verbose { "Decoded: ${raw.commandName}" }
decodePipeline.send(
raw
)
}
override fun close() {
bufferedPackets.forEach { it.close() }
}
@ -133,6 +139,7 @@ internal class NativeNetworkHandler(
private val sender = launch {
while (isActive) {
val result = sendQueue.receiveCatching()
logger.info { "Native sender: $result" }
result.onFailure { if (it is CancellationException) return@launch }
result.getOrNull()?.let { packet ->
@ -150,6 +157,7 @@ internal class NativeNetworkHandler(
while (isActive) {
try {
val packet = socket.read()
lengthDelimitedPacketReader.offer(packet)
} catch (e: Throwable) {
if (e is CancellationException) return@launch
@ -171,7 +179,10 @@ internal class NativeNetworkHandler(
}
override suspend fun createConnection(): NativeConn {
return NativeConn(PlatformSocket.connect(address))
logger.info { "Connecting to $address" }
return NativeConn(PlatformSocket.connect(address)).also {
logger.info { "Connected to server $address" }
}
}
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")

View File

@ -19,7 +19,11 @@ internal actual fun SocketAddress.getHost(): String = host
internal actual fun SocketAddress.getPort(): Int = port
internal class SocketAddressImpl(host: String, port: Int) : SocketAddress(host, port, null)
internal class SocketAddressImpl(host: String, port: Int) : SocketAddress(host, port, null) {
override fun toString(): String {
return "$host:$port"
}
}
internal actual fun createSocketAddress(host: String, port: Int): SocketAddress {
return SocketAddressImpl(host, port)

View File

@ -94,7 +94,7 @@ internal class OpenSslPublicKey(override val hex: String) : ECDHPublicKey {
internal actual class ECDHKeyPairImpl(
override val privateKey: OpenSslPrivateKey,
override val publicKey: OpenSslPublicKey,
private val initialPublicKey: ECDHPublicKey
initialPublicKey: ECDHPublicKey
) : ECDHKeyPair {
override val maskedPublicKey: ByteArray by lazy { publicKey.encoded }
@ -136,19 +136,12 @@ internal actual class ECDH actual constructor(actual val keyPair: ECDHKeyPair) {
* 由完整的 publicKey ByteArray 得到 [ECDHPublicKey]
*/
actual fun constructPublicKey(key: ByteArray): ECDHPublicKey {
memScoped {
key.usePinned { pin ->
val group = EC_GROUP_new_by_curve_name(curveId)
?: error("Failed to create EC_GROUP")
val p = EC_POINT_new(group) ?: error("Failed to create EC_POINT")
val p = EC_POINT_new(group) ?: error("Failed to create EC_POINT")
EC_POINT_hex2point(group, pin.get().toUHexString("").lowercase(), p, bnCtx)
return OpenSslPublicKey.fromPoint(p)
}
}
// TODO: 2022/6/1 native: check memory
EC_POINT_hex2point(group, key.toUHexString("").lowercase(), p, bnCtx)
return OpenSslPublicKey.fromPoint(p)
}
/**
@ -186,7 +179,7 @@ internal actual class ECDH actual constructor(actual val keyPair: ECDHKeyPair) {
try {
val privateBignum = privateKey.toBignum()
try {
EC_KEY_set_private_key(k, privateKey.toBignum()).let { r ->
EC_KEY_set_private_key(k, privateBignum).let { r ->
if (r != 1) error("Failed EC_KEY_set_private_key: $r")
}
@ -228,5 +221,6 @@ internal actual class ECDH actual constructor(actual val keyPair: ECDHKeyPair) {
}
internal actual fun ByteArray.adjustToPublicKey(): ECDHPublicKey {
println("adjustToPublicKey: ${this.toUHexString()}")
return ECDH.constructPublicKey(this)
}

View File

@ -0,0 +1,30 @@
headers = netdb.h
---
#include <stdlib.h>
#include <string.h>
#include <netdb.h>
static int socket_create_connect(char *host, ushort port) {
struct hostent *he;
struct sockaddr_in their_addr; /* connector's address information */
if ((he = gethostbyname(host)) == NULL) { /* get the host info */
return -1;
}
int sockfd;
if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
return -2;
}
their_addr.sin_family = AF_INET; /* host byte order */
their_addr.sin_port = htons(port); /* short, network byte order */
their_addr.sin_addr = *((struct in_addr *) he->h_addr);
bzero(&(their_addr.sin_zero), 8); /* zero the rest of the struct */
if (connect(sockfd, (struct sockaddr *) &their_addr, sizeof(struct sockaddr)) == -1) {
return -3;
}
return sockfd;
}

View File

@ -10,20 +10,18 @@
package net.mamoe.mirai.internal.utils
import io.ktor.utils.io.core.*
import io.ktor.utils.io.core.EOFException
import io.ktor.utils.io.errors.*
import kotlinx.cinterop.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import net.mamoe.mirai.internal.network.highway.HighwayProtocolChannel
import net.mamoe.mirai.internal.network.protocol.packet.login.toIpV4Long
import net.mamoe.mirai.utils.DEFAULT_BUFFER_SIZE
import net.mamoe.mirai.utils.toReadPacket
import net.mamoe.mirai.utils.wrapIO
import platform.posix.*
import net.mamoe.mirai.utils.*
import platform.posix.close
import platform.posix.errno
import platform.posix.recv
import platform.posix.write
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
@ -34,8 +32,14 @@ import kotlin.contracts.contract
internal actual class PlatformSocket(
private val socket: Int
) : Closeable, HighwayProtocolChannel {
@OptIn(ExperimentalCoroutinesApi::class)
private val dispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher")
@Suppress("UnnecessaryOptInAnnotation")
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
private val readDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher")
// Native send and read are blocking. Using a dedicated thread(dispatcher) to do the job.
@Suppress("UnnecessaryOptInAnnotation")
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
private val sendDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher")
private val readLock = Mutex()
private val readBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
@ -45,22 +49,21 @@ internal actual class PlatformSocket(
actual val isOpen: Boolean
get() = write(socket, null, 0).convert<Long>() != 0L
@OptIn(ExperimentalIoApi::class)
actual override fun close() {
if (close(socket) != 0) {
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
}
close(socket)
(readDispatcher as CloseableCoroutineDispatcher).close()
(sendDispatcher as CloseableCoroutineDispatcher).close()
}
@OptIn(ExperimentalIoApi::class)
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = readLock.withLock {
withContext(dispatcher) {
actual suspend fun send(packet: ByteArray, offset: Int, length: Int): Unit = writeLock.withLock {
withContext(sendDispatcher) {
require(offset >= 0) { "offset must >= 0" }
require(length >= 0) { "length must >= 0" }
require(offset + length <= packet.size) { "It must follows offset + length <= packet.size" }
packet.usePinned { pin ->
if (write(socket, pin.addressOf(offset), length.convert()).convert<Long>() != 0L) {
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
if (write(socket, pin.addressOf(offset), length.convert()).convert<Long>() < 0L) {
throw PosixException.forErrno(posixFunctionName = "write()").wrapIO()
}
}
}
@ -70,12 +73,16 @@ internal actual class PlatformSocket(
* @throws SendPacketInternalException
*/
@OptIn(ExperimentalIoApi::class)
actual override suspend fun send(packet: ByteReadPacket): Unit = readLock.withLock {
withContext(dispatcher) {
actual override suspend fun send(packet: ByteReadPacket): Unit = writeLock.withLock {
withContext(sendDispatcher) {
logger.info { "Native socket sending: len=${packet.remaining}" }
val writeBuffer = writeBuffer
val length = packet.readAvailable(writeBuffer.get())
if (write(socket, writeBuffer.addressOf(0), length.convert()).convert<Long>() != 0L) {
throw PosixException.forErrno(posixFunctionName = "close()").wrapIO()
while (packet.remaining != 0L) {
val length = packet.readAvailable(writeBuffer.get())
if (write(socket, writeBuffer.addressOf(0), length.convert()).convert<Long>() < 0L) {
throw PosixException.forErrno(posixFunctionName = "write()").wrapIO()
}
logger.info { "Native socket sent $length bytes." }
}
}
}
@ -83,49 +90,67 @@ internal actual class PlatformSocket(
/**
* @throws ReadPacketInternalException
*/
actual override suspend fun read(): ByteReadPacket = writeLock.withLock {
withContext(dispatcher) {
actual override suspend fun read(): ByteReadPacket = readLock.withLock {
withContext(readDispatcher) {
logger.info { "Native socket reading." }
val readBuffer = readBuffer
val length = read(socket, readBuffer.addressOf(0), readBuffer.get().size.convert()).convert<Long>()
val length = recv(socket, readBuffer.addressOf(0), readBuffer.get().size.convert(), 0).convert<Long>()
if (length < 0L) {
throw EOFException("recv: $length, errno=$errno")
}
logger.info {
"Native socket read $length bytes: ${
readBuffer.get().copyOf(length.toInt()).toUHexString()
}"
}
readBuffer.get().toReadPacket(length = length.toInt())
}
}
actual companion object {
private val logger: MiraiLogger = MiraiLogger.Factory.create(PlatformSocket::class)
@OptIn(UnsafeNumber::class, ExperimentalIoApi::class)
actual suspend fun connect(
serverIp: String,
serverPort: Int
): PlatformSocket {
val addr = memScoped {
alloc<sockaddr_in>() {
sin_family = AF_INET.convert()
resolveIpFromHost(serverIp)
sin_addr.s_addr = resolveIpFromHost(serverIp)
}
}.reinterpret<sockaddr>()
val id = socket(AF_INET, 1 /* SOCKET_STREAM */, IPPROTO_TCP)
if (id != 0) throw PosixException.forErrno(posixFunctionName = "socket()")
val conn = connect(id, addr.ptr, sizeOf<sockaddr_in>().convert())
if (conn != 0) throw PosixException.forErrno(posixFunctionName = "connect()")
return PlatformSocket(conn)
val r = sockets.socket_create_connect(serverIp.cstr, serverPort.toUShort())
if (r < 0) error("Failed socket_create_connect: $r")
return PlatformSocket(r)
// val addr = nativeHeap.alloc<sockaddr_in>() {
// sin_family = AF_INET.toUByte()
// sin_addr.s_addr = resolveIpFromHost(serverIp).pointed.s_addr
// sin_port = serverPort.toUInt().toUShort()
// }
//
// val id = socket(AF_INET, SOCK_STREAM, 0)
// if (id == -1) throw PosixException.forErrno(posixFunctionName = "socket()")
//
// println("connect")
// val conn = connect(id, addr.ptr.reinterpret(), sizeOf<sockaddr_in>().toUInt())
// println("connect: $conn, $errno")
// if (conn < 0) throw PosixException.forErrno(posixFunctionName = "connect()")
//
// return PlatformSocket(conn)
}
private fun resolveIpFromHost(serverIp: String): UInt {
val host = gethostbyname(serverIp) // points to static data, don't free
?: throw IllegalStateException("Failed to resolve IP from host. host=$serverIp")
val hAddrList = host.pointed.h_addr_list
?: throw IllegalStateException("Empty IP list resolved from host. host=$serverIp")
val str = hAddrList[0]!!.toKString()
return str.toIpV4Long().toUInt()
}
// private fun resolveIpFromHost(serverIp: String): CPointer<in_addr> {
// val host = gethostbyname(serverIp) // points to static data, don't free
// ?: throw IllegalStateException("Failed to resolve IP from host. host=$serverIp")
// println(host.pointed.h_addr_list?.get(1)?.reinterpret<in_addr>()?.pointed?.s_addr)
// return host.pointed.h_addr_list?.get(1)?.reinterpret<in_addr>() ?: error("Failed to get ip")
//// val hAddrList = host.pointed.h_addr_list
//// ?: throw IllegalStateException("Empty IP list resolved from host. host=$serverIp")
////
////
//// val str = hAddrList[0]!!.reinterpret<UIntVar>()
////
//// try {
//// return str.pointed.value
//// } finally {
////// free()
//// }
// }
actual suspend inline fun <R> withConnection(
serverIp: String,