1
0
mirror of https://github.com/mamoe/mirai.git synced 2025-04-25 04:50:26 +08:00

Fix LengthDelimitedPacketReader

This commit is contained in:
Him188 2022-06-02 14:45:03 +01:00
parent 8085d94dfe
commit 900c7feac7
No known key found for this signature in database
GPG Key ID: BA439CDDCF652375
5 changed files with 629 additions and 95 deletions
mirai-core/src

View File

@ -0,0 +1,119 @@
/*
* Copyright 2019-2022 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.network.handler
import io.ktor.utils.io.core.*
import net.mamoe.mirai.utils.*
private val debugLogger: MiraiLogger by lazy {
MiraiLogger.Factory.create(
LengthDelimitedPacketReader::class, "LengthDelimitedPacketReader"
).withSwitch(systemProp("mirai.network.handler.length.delimited.packet.reader.debug", false))
}
/**
* Not thread-safe
*/
internal class LengthDelimitedPacketReader(
private val sendDecode: (combined: ByteReadPacket) -> Unit
) : Closeable {
private var missingLength: Long = 0
set(value) {
field = value
debugLogger.info { "missingLength = $field" }
}
private val bufferedParts: MutableList<ByteReadPacket> = ArrayList(10)
@TestOnly
fun getMissingLength() = missingLength
@TestOnly
fun getBufferedPackets() = bufferedParts.toList()
fun offer(packet: ByteReadPacket) {
if (missingLength == 0L) {
// initial
debugLogger.info { "initial length == 0" }
missingLength = packet.readInt().toLongUnsigned() - 4
}
debugLogger.info { "Offering packet len = ${packet.remaining}" }
missingLength -= packet.remaining
bufferedParts.add(packet)
if (missingLength <= 0) {
emit()
}
}
private fun emit() {
debugLogger.info { "Emitting, buffered = ${bufferedParts.map { it.remaining }}" }
when (bufferedParts.size) {
0 -> {}
1 -> {
val part = bufferedParts.first()
if (missingLength == 0L) {
debugLogger.info { "Single packet length perfectly matched." }
sendDecode(part)
bufferedParts.clear()
} else {
check(missingLength < 0L) { "Failed check: remainingLength < 0L" }
val previousPacketLength = missingLength + part.remaining
debugLogger.info { "Got extra packets, previousPacketLength = $previousPacketLength" }
sendDecode(part.readPacketExact(previousPacketLength.toInt()))
bufferedParts.clear()
// now packet contain new part.
missingLength = part.readInt().toLongUnsigned() - 4
offer(part)
}
}
else -> {
if (missingLength == 0L) {
debugLogger.info { "Multiple packets length perfectly matched." }
sendDecode(buildPacket(bufferedParts.sumOf { it.remaining }.toInt()) {
bufferedParts.forEach { writePacket(it) }
})
bufferedParts.clear()
} else {
val lastPart = bufferedParts.last()
val previousPacketPartLength = missingLength + lastPart.remaining
debugLogger.debug { "previousPacketPartLength = $previousPacketPartLength" }
val combinedLength =
(bufferedParts.sumOf { it.remaining } - lastPart.remaining // buffered length without last part
+ previousPacketPartLength).toInt()
debugLogger.debug { "combinedLength = $combinedLength" }
if (combinedLength < 0) return // not enough, still more parts missing.
sendDecode(buildPacket(combinedLength) {
repeat(bufferedParts.size - 1) { i ->
writePacket(bufferedParts[i])
}
writePacket(lastPart, previousPacketPartLength)
})
bufferedParts.clear()
// now packet contain new part.
missingLength = lastPart.readInt().toLongUnsigned() - 4
offer(lastPart)
}
}
}
}
override fun close() {
bufferedParts.forEach { it.close() }
}
}

View File

@ -22,7 +22,9 @@ import net.mamoe.mirai.internal.network.components.SsoProcessor
import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacket
import net.mamoe.mirai.internal.utils.PlatformSocket
import net.mamoe.mirai.internal.utils.connect
import net.mamoe.mirai.utils.*
import net.mamoe.mirai.utils.childScope
import net.mamoe.mirai.utils.info
import net.mamoe.mirai.utils.verbose
internal class NativeNetworkHandler(
context: NetworkHandlerContext,
@ -46,95 +48,15 @@ internal class NativeNetworkHandler(
launch { write(undelivered) }
}
private val lengthDelimitedPacketReader = LengthDelimitedPacketReader()
/**
* Not thread-safe
*/
private inner class LengthDelimitedPacketReader : Closeable {
private var missingLength: Long = 0
private val bufferedPackets: MutableList<ByteReadPacket> = ArrayList(10)
fun offer(packet: ByteReadPacket) {
if (missingLength == 0L) {
// initial
missingLength = packet.readInt().toLongUnsigned() - 4
}
missingLength -= packet.remaining
bufferedPackets.add(packet)
if (missingLength <= 0) {
emit()
}
}
fun emit() {
when (bufferedPackets.size) {
0 -> {}
1 -> {
val packet = bufferedPackets.first()
if (missingLength == 0L) {
sendDecode(packet)
bufferedPackets.clear()
} else {
check(missingLength < 0L) { "Failed check: remainingLength < 0L" }
val previousPacketLength = missingLength + packet.remaining
sendDecode(packet.readPacketExact(previousPacketLength.toInt()))
// now packet contain new packet.
missingLength = packet.readInt().toLongUnsigned() - 4
bufferedPackets[0] = packet
}
}
else -> {
val combined: ByteReadPacket
if (missingLength == 0L) {
combined = buildPacket(bufferedPackets.sumOf { it.remaining }.toInt()) {
bufferedPackets.forEach { writePacket(it) }
}
bufferedPackets.clear()
} else {
val lastPacket = bufferedPackets.last()
val previousPacketPartLength = missingLength + lastPacket.remaining
val combinedLength =
(bufferedPackets.sumOf { it.remaining } - lastPacket.remaining + previousPacketPartLength).toInt()
combined = buildPacket(combinedLength) {
repeat(bufferedPackets.size - 1) { i ->
writePacket(bufferedPackets[i])
}
writePacket(lastPacket, previousPacketPartLength)
}
bufferedPackets.clear()
// now packet contain new packet.
missingLength = lastPacket.readInt().toLongUnsigned() - 4
bufferedPackets.add(lastPacket)
}
sendDecode(combined)
}
}
}
private fun sendDecode(combined: ByteReadPacket) {
packetLogger.verbose { "Decoding: len=${combined.remaining}" }
val raw = packetCodec.decodeRaw(
ssoProcessor.ssoSession,
combined
)
packetLogger.verbose { "Decoded: ${raw.commandName}" }
decodePipeline.send(
raw
)
}
override fun close() {
bufferedPackets.forEach { it.close() }
}
}
private val lengthDelimitedPacketReader = LengthDelimitedPacketReader(fun(combined: ByteReadPacket) {
logger.verbose { "Decoding: len=${combined.remaining}" }
val raw = packetCodec.decodeRaw(
ssoProcessor.ssoSession,
combined
)
logger.verbose { "Decoded: ${raw.commandName}" }
decodePipeline.send(raw)
})
private val sender = launch {
while (isActive) {

View File

@ -0,0 +1,490 @@
/*
* Copyright 2019-2022 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
@file:OptIn(TestOnly::class)
package net.mamoe.mirai.internal.network
import io.ktor.utils.io.core.*
import net.mamoe.mirai.internal.network.handler.LengthDelimitedPacketReader
import net.mamoe.mirai.internal.test.AbstractTest
import net.mamoe.mirai.internal.utils.io.writeIntLVPacket
import net.mamoe.mirai.internal.utils.io.writeShortLVString
import net.mamoe.mirai.utils.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.test.Test
import kotlin.test.assertEquals
internal class LengthDelimitedPacketReaderTest : AbstractTest() {
init {
setSystemProp("mirai.network.handler.length.delimited.packet.reader.debug", "true")
}
private val received = mutableListOf<ByteArray>()
private val reader = LengthDelimitedPacketReader { received.add(it.readBytes()) }
/*
* All these tests cases can happen in the real time, and even before logon is complete.
*/
@Test
fun `can read exact packet`() {
val original = buildLVPacket {
writeShortLVString("some strings")
writeInt(123)
}
val originalLength = original.remaining
reader.offer(original)
assertEquals(1, received.size)
received.single().read {
assertEquals(originalLength - 4, this.remaining)
assertEquals("some strings", readUShortLVString())
assertEquals(123, readInt())
assertEquals(0, remaining)
}
assertEquals(0, reader.getBufferedPackets().size)
assertEquals(0, reader.getMissingLength())
}
@Test
fun `can read 2 part packets`() {
val part1 = buildPacket {
writeShortLVString("some strings")
writeInt(123)
}.readBytes()
val part2 = buildPacket {
writeShortLVString("some strings")
writeInt(123)
}.readBytes()
reader.offer(buildPacket {
writeInt(part1.size + part2.size + 4)
writeFully(part1)
})
assertEquals(0, received.size)
reader.offer(part2.toReadPacket())
assertEquals(1, received.size)
received.single().read {
assertEquals(part1.size + part2.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(123, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(123, readInt())
assertEquals(0, remaining)
}
assertEquals(0, reader.getBufferedPackets().size)
assertEquals(0, reader.getMissingLength())
}
@Test
fun `can read 3 part packets`() {
val part1 = buildPacket {
writeShortLVString("some strings")
writeInt(111)
}.readBytes()
val part2 = buildPacket {
writeShortLVString("some strings")
writeInt(222)
}.readBytes()
val part3 = buildPacket {
writeShortLVString("some strings")
writeInt(333)
}.readBytes()
reader.offer(buildPacket {
writeInt(part1.size + part2.size + part3.size + 4)
writeFully(part1)
// part2 and part3 missing
})
assertEquals(0, received.size)
assertEquals(1, reader.getBufferedPackets().size)
assertEquals(part2.size + part3.size, reader.getMissingLength().toInt())
reader.offer(part2.toReadPacket())
assertEquals(0, received.size)
assertEquals(2, reader.getBufferedPackets().size)
assertEquals(part3.size, reader.getMissingLength().toInt())
reader.offer(part3.toReadPacket())
assertEquals(1, received.size)
assertEquals(0, reader.getBufferedPackets().size)
assertEquals(0, reader.getMissingLength())
received.single().read {
assertEquals(part1.size + part2.size + part3.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(111, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(222, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(333, readInt())
assertEquals(0, remaining)
}
}
@Test
fun `can read 3 part packets with a combined`() {
val part1 = buildPacket {
writeShortLVString("some strings")
writeInt(111)
}.readBytes()
val part2 = buildPacket {
writeShortLVString("some strings")
writeInt(222)
}.readBytes()
val part3 = buildPacket {
writeShortLVString("some strings")
writeInt(333)
}.readBytes()
val part4 = buildPacket {
writeShortLVString("some strings")
writeInt(444)
}.readBytes()
reader.offer(buildPacket {
writeInt(part1.size + part2.size + part3.size + 4)
writeFully(part1)
// part2 and part3 missing
})
assertEquals(0, received.size)
assertEquals(1, reader.getBufferedPackets().size)
assertEquals(part2.size + part3.size, reader.getMissingLength().toInt())
reader.offer(part2.toReadPacket())
assertEquals(0, received.size)
assertEquals(2, reader.getBufferedPackets().size)
assertEquals(part3.size, reader.getMissingLength().toInt())
reader.offer(buildPacket {
writeFully(part3)
writePacket(buildLVPacket { writeFully(part4) })
})
assertEquals(2, received.size)
assertEquals(0, reader.getBufferedPackets().size)
assertEquals(0, reader.getMissingLength())
received[0].read {
assertEquals(part1.size + part2.size + part3.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(111, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(222, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(333, readInt())
assertEquals(0, remaining)
}
received[1].read {
assertEquals(part4.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(444, readInt())
assertEquals(0, remaining)
}
}
@Test
fun `can read 3 part packets from combined with a combined`() {
val part1 = buildPacket {
writeShortLVString("some strings")
writeInt(111)
}.readBytes()
val part2 = buildPacket {
writeShortLVString("some strings")
writeInt(222)
}.readBytes()
val part3 = buildPacket {
writeShortLVString("some strings")
writeInt(333)
}.readBytes()
val part4 = buildPacket {
writeShortLVString("some strings")
writeInt(444)
}.readBytes()
val part5 = buildPacket {
writeShortLVString("some strings")
writeInt(555)
}.readBytes()
reader.offer(buildPacket {
writeInt(part1.size + part2.size + part3.size + 4)
writeFully(part1)
// part2 and part3 missing
})
assertEquals(0, received.size)
assertEquals(1, reader.getBufferedPackets().size)
assertEquals(part2.size + part3.size, reader.getMissingLength().toInt())
reader.offer(part2.toReadPacket())
assertEquals(0, received.size)
assertEquals(2, reader.getBufferedPackets().size)
assertEquals(part3.size, reader.getMissingLength().toInt())
reader.offer(buildPacket {
writeFully(part3)
writePacket(buildPacket {
writeInt(part4.size + part5.size + 4)
writeFully(part4)
// part5 missing
})
})
assertEquals(1, received.size)
assertEquals(1, reader.getBufferedPackets().size)
assertEquals(part5.size, reader.getMissingLength().toInt())
reader.offer(part5.toReadPacket())
assertEquals(2, received.size)
assertEquals(0, reader.getBufferedPackets().size)
assertEquals(0, reader.getMissingLength().toInt())
received[0].read {
assertEquals(part1.size + part2.size + part3.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(111, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(222, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(333, readInt())
assertEquals(0, remaining)
}
received[1].read {
assertEquals(part4.size + part5.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(444, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(555, readInt())
assertEquals(0, remaining)
}
}
// Ensures it will not emit without any length check if received a missing part.
@Test
fun `can read 4 part packets`() {
val part1 = buildPacket {
writeShortLVString("some strings")
writeInt(111)
}.readBytes()
val part2 = buildPacket {
writeShortLVString("some strings")
writeInt(222)
}.readBytes()
val part3 = buildPacket {
writeShortLVString("some strings")
writeInt(333)
}.readBytes()
val part4 = buildPacket {
writeShortLVString("some strings")
writeInt(444)
}.readBytes()
reader.offer(buildPacket {
writeInt(part1.size + part2.size + part3.size + part4.size + 4)
writeFully(part1)
// part2, part3 and part4 missing
})
assertEquals(0, received.size)
assertEquals(1, reader.getBufferedPackets().size)
assertEquals(part2.size + part3.size + part4.size, reader.getMissingLength().toInt())
reader.offer(part2.toReadPacket())
assertEquals(0, received.size)
assertEquals(2, reader.getBufferedPackets().size)
assertEquals(part3.size + part4.size, reader.getMissingLength().toInt())
reader.offer(part3.toReadPacket())
assertEquals(0, received.size)
assertEquals(3, reader.getBufferedPackets().size)
assertEquals(part4.size, reader.getMissingLength().toInt())
reader.offer(part4.toReadPacket())
assertEquals(1, received.size)
assertEquals(0, reader.getBufferedPackets().size)
assertEquals(0, reader.getMissingLength())
received.single().read {
assertEquals(part1.size + part2.size + part3.size + part4.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(111, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(222, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(333, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(444, readInt())
assertEquals(0, remaining)
}
}
@Test
fun `can read 2 combined packets`() {
val part1 = buildPacket {
writeShortLVString("some strings")
writeInt(123)
}.readBytes()
println("part1.size = ${part1.size}")
val part2 = buildPacket {
writeShortLVString("some strings")
writeInt(222)
}.readBytes()
println("part2.size = ${part2.size}")
reader.offer(buildPacket {
writePacket(buildLVPacket { writeFully(part1) })
writePacket(buildLVPacket { writeFully(part2) })
})
assertEquals(2, received.size)
received[0].read {
assertEquals(part1.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(123, readInt())
assertEquals(0, remaining)
}
received[1].read {
assertEquals(part2.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(222, readInt())
assertEquals(0, remaining)
}
assertEquals(0, reader.getBufferedPackets().size)
assertEquals(0, reader.getMissingLength())
}
@Test
fun `can emit 2 combined packets with another part`() {
val part1 = buildPacket {
writeShortLVString("some strings")
writeInt(111)
}.readBytes()
val part2 = buildPacket {
writeShortLVString("some strings")
writeInt(222)
}.readBytes()
val part3 = buildPacket {
writeShortLVString("some strings")
writeInt(333)
}.readBytes()
val part4 = buildPacket {
writeShortLVString("some strings")
writeInt(444)
}.readBytes()
println("part1.size = ${part1.size}")
println("part2.size = ${part2.size}")
println("part3.size = ${part3.size}")
println("part4.size = ${part4.size}")
reader.offer(buildPacket {
// should emit two packets
writePacket(buildLVPacket { writeFully(part1) })
writePacket(buildLVPacket { writeFully(part2) })
// and process this part
writePacket(buildPacket {
writeInt(part3.size + part4.size + 4)
writeFully(part3)
// part4 missing
})
})
assertEquals(2, received.size)
assertEquals(1, reader.getBufferedPackets().size)
assertEquals(part4.size, reader.getMissingLength().toInt())
received[0].read {
assertEquals(part1.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(111, readInt())
assertEquals(0, remaining)
}
received[1].read {
assertEquals(part2.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(222, readInt())
assertEquals(0, remaining)
}
// part4, here you are
reader.offer(buildPacket {
writePacket(buildPacket { writeFully(part4) })
})
received[2].read {
assertEquals(part3.size + part4.size, this.remaining.toInt())
assertEquals("some strings", readUShortLVString())
assertEquals(333, readInt())
assertEquals("some strings", readUShortLVString())
assertEquals(444, readInt())
assertEquals(0, remaining)
}
assertEquals(0, reader.getBufferedPackets().size)
assertEquals(0, reader.getMissingLength())
}
private inline fun buildLVPacket(block: BytePacketBuilder.() -> Unit): ByteReadPacket {
contract {
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
}
return buildPacket {
writeIntLVPacket(lengthOffset = { it + 4 }) {
block()
}
}
}
}

View File

@ -27,6 +27,9 @@ internal actual class PlatformInitializationTest actual constructor() : Abstract
/**
* All test classes should inherit from [AbstractTest]
*
* Note: To run a test in native sourceSets, use IDEA key shortcut 'control + shift + R' on macOS and 'Ctrl + Shift + R' on Windows.
* Or you can right-click the function name of the test case and invoke 'Run ...'. You should not expect to see a button icon around the line numbers.
*/
internal actual abstract class AbstractTest actual constructor() : CommonAbstractTest() {
init {

View File

@ -28,9 +28,9 @@ import kotlin.contracts.contract
/**
* TCP Socket.
*/
@OptIn(UnsafeNumber::class)
internal actual class PlatformSocket(
private val socket: Int
private val socket: Int,
bufferSize: Int = DEFAULT_BUFFER_SIZE * 2 // improve performance for some big packets
) : Closeable, HighwayProtocolChannel {
@Suppress("UnnecessaryOptInAnnotation")
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
@ -42,9 +42,9 @@ internal actual class PlatformSocket(
private val sendDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher")
private val readLock = Mutex()
private val readBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
private val readBuffer = ByteArray(bufferSize).pin()
private val writeLock = Mutex()
private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
private val writeBuffer = ByteArray(bufferSize).pin()
actual val isOpen: Boolean
get() = write(socket, null, 0).convert<Long>() != 0L
@ -95,7 +95,7 @@ internal actual class PlatformSocket(
logger.info { "Native socket reading." }
val readBuffer = readBuffer
val length = recv(socket, readBuffer.addressOf(0), readBuffer.get().size.convert(), 0).convert<Long>()
if (length < 0L) {
if (length <= 0L) {
throw EOFException("recv: $length, errno=$errno")
}
logger.info {