mirror of
https://github.com/mamoe/mirai.git
synced 2025-04-25 04:50:26 +08:00
Fix LengthDelimitedPacketReader
This commit is contained in:
parent
8085d94dfe
commit
900c7feac7
mirai-core/src
nativeMain/kotlin/network/handler
nativeTest/kotlin
unixMain/kotlin/utils
@ -0,0 +1,119 @@
|
||||
/*
|
||||
* Copyright 2019-2022 Mamoe Technologies and contributors.
|
||||
*
|
||||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
|
||||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
|
||||
*
|
||||
* https://github.com/mamoe/mirai/blob/dev/LICENSE
|
||||
*/
|
||||
|
||||
package net.mamoe.mirai.internal.network.handler
|
||||
|
||||
import io.ktor.utils.io.core.*
|
||||
import net.mamoe.mirai.utils.*
|
||||
|
||||
private val debugLogger: MiraiLogger by lazy {
|
||||
MiraiLogger.Factory.create(
|
||||
LengthDelimitedPacketReader::class, "LengthDelimitedPacketReader"
|
||||
).withSwitch(systemProp("mirai.network.handler.length.delimited.packet.reader.debug", false))
|
||||
}
|
||||
|
||||
/**
|
||||
* Not thread-safe
|
||||
*/
|
||||
internal class LengthDelimitedPacketReader(
|
||||
private val sendDecode: (combined: ByteReadPacket) -> Unit
|
||||
) : Closeable {
|
||||
private var missingLength: Long = 0
|
||||
set(value) {
|
||||
field = value
|
||||
debugLogger.info { "missingLength = $field" }
|
||||
}
|
||||
private val bufferedParts: MutableList<ByteReadPacket> = ArrayList(10)
|
||||
|
||||
@TestOnly
|
||||
fun getMissingLength() = missingLength
|
||||
|
||||
@TestOnly
|
||||
fun getBufferedPackets() = bufferedParts.toList()
|
||||
|
||||
fun offer(packet: ByteReadPacket) {
|
||||
if (missingLength == 0L) {
|
||||
// initial
|
||||
debugLogger.info { "initial length == 0" }
|
||||
missingLength = packet.readInt().toLongUnsigned() - 4
|
||||
}
|
||||
debugLogger.info { "Offering packet len = ${packet.remaining}" }
|
||||
missingLength -= packet.remaining
|
||||
bufferedParts.add(packet)
|
||||
if (missingLength <= 0) {
|
||||
emit()
|
||||
}
|
||||
}
|
||||
|
||||
private fun emit() {
|
||||
debugLogger.info { "Emitting, buffered = ${bufferedParts.map { it.remaining }}" }
|
||||
when (bufferedParts.size) {
|
||||
0 -> {}
|
||||
1 -> {
|
||||
val part = bufferedParts.first()
|
||||
if (missingLength == 0L) {
|
||||
debugLogger.info { "Single packet length perfectly matched." }
|
||||
sendDecode(part)
|
||||
|
||||
bufferedParts.clear()
|
||||
} else {
|
||||
check(missingLength < 0L) { "Failed check: remainingLength < 0L" }
|
||||
|
||||
val previousPacketLength = missingLength + part.remaining
|
||||
debugLogger.info { "Got extra packets, previousPacketLength = $previousPacketLength" }
|
||||
sendDecode(part.readPacketExact(previousPacketLength.toInt()))
|
||||
|
||||
bufferedParts.clear()
|
||||
|
||||
// now packet contain new part.
|
||||
missingLength = part.readInt().toLongUnsigned() - 4
|
||||
offer(part)
|
||||
}
|
||||
}
|
||||
else -> {
|
||||
if (missingLength == 0L) {
|
||||
debugLogger.info { "Multiple packets length perfectly matched." }
|
||||
sendDecode(buildPacket(bufferedParts.sumOf { it.remaining }.toInt()) {
|
||||
bufferedParts.forEach { writePacket(it) }
|
||||
})
|
||||
|
||||
bufferedParts.clear()
|
||||
} else {
|
||||
val lastPart = bufferedParts.last()
|
||||
val previousPacketPartLength = missingLength + lastPart.remaining
|
||||
debugLogger.debug { "previousPacketPartLength = $previousPacketPartLength" }
|
||||
val combinedLength =
|
||||
(bufferedParts.sumOf { it.remaining } - lastPart.remaining // buffered length without last part
|
||||
+ previousPacketPartLength).toInt()
|
||||
debugLogger.debug { "combinedLength = $combinedLength" }
|
||||
|
||||
if (combinedLength < 0) return // not enough, still more parts missing.
|
||||
|
||||
sendDecode(buildPacket(combinedLength) {
|
||||
repeat(bufferedParts.size - 1) { i ->
|
||||
writePacket(bufferedParts[i])
|
||||
}
|
||||
writePacket(lastPart, previousPacketPartLength)
|
||||
})
|
||||
|
||||
bufferedParts.clear()
|
||||
|
||||
// now packet contain new part.
|
||||
missingLength = lastPart.readInt().toLongUnsigned() - 4
|
||||
offer(lastPart)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
bufferedParts.forEach { it.close() }
|
||||
}
|
||||
}
|
@ -22,7 +22,9 @@ import net.mamoe.mirai.internal.network.components.SsoProcessor
|
||||
import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacket
|
||||
import net.mamoe.mirai.internal.utils.PlatformSocket
|
||||
import net.mamoe.mirai.internal.utils.connect
|
||||
import net.mamoe.mirai.utils.*
|
||||
import net.mamoe.mirai.utils.childScope
|
||||
import net.mamoe.mirai.utils.info
|
||||
import net.mamoe.mirai.utils.verbose
|
||||
|
||||
internal class NativeNetworkHandler(
|
||||
context: NetworkHandlerContext,
|
||||
@ -46,95 +48,15 @@ internal class NativeNetworkHandler(
|
||||
launch { write(undelivered) }
|
||||
}
|
||||
|
||||
private val lengthDelimitedPacketReader = LengthDelimitedPacketReader()
|
||||
|
||||
/**
|
||||
* Not thread-safe
|
||||
*/
|
||||
private inner class LengthDelimitedPacketReader : Closeable {
|
||||
private var missingLength: Long = 0
|
||||
private val bufferedPackets: MutableList<ByteReadPacket> = ArrayList(10)
|
||||
|
||||
fun offer(packet: ByteReadPacket) {
|
||||
if (missingLength == 0L) {
|
||||
// initial
|
||||
missingLength = packet.readInt().toLongUnsigned() - 4
|
||||
}
|
||||
missingLength -= packet.remaining
|
||||
bufferedPackets.add(packet)
|
||||
if (missingLength <= 0) {
|
||||
emit()
|
||||
}
|
||||
}
|
||||
|
||||
fun emit() {
|
||||
when (bufferedPackets.size) {
|
||||
0 -> {}
|
||||
1 -> {
|
||||
val packet = bufferedPackets.first()
|
||||
if (missingLength == 0L) {
|
||||
sendDecode(packet)
|
||||
bufferedPackets.clear()
|
||||
} else {
|
||||
check(missingLength < 0L) { "Failed check: remainingLength < 0L" }
|
||||
|
||||
val previousPacketLength = missingLength + packet.remaining
|
||||
sendDecode(packet.readPacketExact(previousPacketLength.toInt()))
|
||||
|
||||
// now packet contain new packet.
|
||||
missingLength = packet.readInt().toLongUnsigned() - 4
|
||||
bufferedPackets[0] = packet
|
||||
}
|
||||
}
|
||||
else -> {
|
||||
val combined: ByteReadPacket
|
||||
if (missingLength == 0L) {
|
||||
combined = buildPacket(bufferedPackets.sumOf { it.remaining }.toInt()) {
|
||||
bufferedPackets.forEach { writePacket(it) }
|
||||
}
|
||||
|
||||
bufferedPackets.clear()
|
||||
} else {
|
||||
val lastPacket = bufferedPackets.last()
|
||||
val previousPacketPartLength = missingLength + lastPacket.remaining
|
||||
val combinedLength =
|
||||
(bufferedPackets.sumOf { it.remaining } - lastPacket.remaining + previousPacketPartLength).toInt()
|
||||
|
||||
combined = buildPacket(combinedLength) {
|
||||
repeat(bufferedPackets.size - 1) { i ->
|
||||
writePacket(bufferedPackets[i])
|
||||
}
|
||||
writePacket(lastPacket, previousPacketPartLength)
|
||||
}
|
||||
|
||||
bufferedPackets.clear()
|
||||
|
||||
// now packet contain new packet.
|
||||
missingLength = lastPacket.readInt().toLongUnsigned() - 4
|
||||
bufferedPackets.add(lastPacket)
|
||||
}
|
||||
|
||||
sendDecode(combined)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun sendDecode(combined: ByteReadPacket) {
|
||||
packetLogger.verbose { "Decoding: len=${combined.remaining}" }
|
||||
val raw = packetCodec.decodeRaw(
|
||||
ssoProcessor.ssoSession,
|
||||
combined
|
||||
)
|
||||
packetLogger.verbose { "Decoded: ${raw.commandName}" }
|
||||
decodePipeline.send(
|
||||
raw
|
||||
)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
bufferedPackets.forEach { it.close() }
|
||||
}
|
||||
}
|
||||
private val lengthDelimitedPacketReader = LengthDelimitedPacketReader(fun(combined: ByteReadPacket) {
|
||||
logger.verbose { "Decoding: len=${combined.remaining}" }
|
||||
val raw = packetCodec.decodeRaw(
|
||||
ssoProcessor.ssoSession,
|
||||
combined
|
||||
)
|
||||
logger.verbose { "Decoded: ${raw.commandName}" }
|
||||
decodePipeline.send(raw)
|
||||
})
|
||||
|
||||
private val sender = launch {
|
||||
while (isActive) {
|
||||
|
@ -0,0 +1,490 @@
|
||||
/*
|
||||
* Copyright 2019-2022 Mamoe Technologies and contributors.
|
||||
*
|
||||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
|
||||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
|
||||
*
|
||||
* https://github.com/mamoe/mirai/blob/dev/LICENSE
|
||||
*/
|
||||
|
||||
@file:OptIn(TestOnly::class)
|
||||
|
||||
package net.mamoe.mirai.internal.network
|
||||
|
||||
import io.ktor.utils.io.core.*
|
||||
import net.mamoe.mirai.internal.network.handler.LengthDelimitedPacketReader
|
||||
import net.mamoe.mirai.internal.test.AbstractTest
|
||||
import net.mamoe.mirai.internal.utils.io.writeIntLVPacket
|
||||
import net.mamoe.mirai.internal.utils.io.writeShortLVString
|
||||
import net.mamoe.mirai.utils.*
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class LengthDelimitedPacketReaderTest : AbstractTest() {
|
||||
init {
|
||||
setSystemProp("mirai.network.handler.length.delimited.packet.reader.debug", "true")
|
||||
}
|
||||
|
||||
private val received = mutableListOf<ByteArray>()
|
||||
private val reader = LengthDelimitedPacketReader { received.add(it.readBytes()) }
|
||||
|
||||
/*
|
||||
* All these tests cases can happen in the real time, and even before logon is complete.
|
||||
*/
|
||||
|
||||
@Test
|
||||
fun `can read exact packet`() {
|
||||
val original = buildLVPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(123)
|
||||
}
|
||||
val originalLength = original.remaining
|
||||
reader.offer(original)
|
||||
|
||||
assertEquals(1, received.size)
|
||||
received.single().read {
|
||||
assertEquals(originalLength - 4, this.remaining)
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(123, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
assertEquals(0, reader.getBufferedPackets().size)
|
||||
assertEquals(0, reader.getMissingLength())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `can read 2 part packets`() {
|
||||
val part1 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(123)
|
||||
}.readBytes()
|
||||
|
||||
val part2 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(123)
|
||||
}.readBytes()
|
||||
|
||||
reader.offer(buildPacket {
|
||||
writeInt(part1.size + part2.size + 4)
|
||||
writeFully(part1)
|
||||
})
|
||||
assertEquals(0, received.size)
|
||||
|
||||
reader.offer(part2.toReadPacket())
|
||||
assertEquals(1, received.size)
|
||||
|
||||
received.single().read {
|
||||
assertEquals(part1.size + part2.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(123, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(123, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
assertEquals(0, reader.getBufferedPackets().size)
|
||||
assertEquals(0, reader.getMissingLength())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `can read 3 part packets`() {
|
||||
val part1 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(111)
|
||||
}.readBytes()
|
||||
|
||||
val part2 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(222)
|
||||
}.readBytes()
|
||||
|
||||
val part3 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(333)
|
||||
}.readBytes()
|
||||
|
||||
reader.offer(buildPacket {
|
||||
writeInt(part1.size + part2.size + part3.size + 4)
|
||||
writeFully(part1)
|
||||
|
||||
// part2 and part3 missing
|
||||
})
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(1, reader.getBufferedPackets().size)
|
||||
assertEquals(part2.size + part3.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(part2.toReadPacket())
|
||||
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(2, reader.getBufferedPackets().size)
|
||||
assertEquals(part3.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(part3.toReadPacket())
|
||||
|
||||
assertEquals(1, received.size)
|
||||
assertEquals(0, reader.getBufferedPackets().size)
|
||||
assertEquals(0, reader.getMissingLength())
|
||||
|
||||
received.single().read {
|
||||
assertEquals(part1.size + part2.size + part3.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(111, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(222, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(333, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `can read 3 part packets with a combined`() {
|
||||
val part1 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(111)
|
||||
}.readBytes()
|
||||
|
||||
val part2 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(222)
|
||||
}.readBytes()
|
||||
|
||||
val part3 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(333)
|
||||
}.readBytes()
|
||||
|
||||
val part4 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(444)
|
||||
}.readBytes()
|
||||
|
||||
reader.offer(buildPacket {
|
||||
writeInt(part1.size + part2.size + part3.size + 4)
|
||||
writeFully(part1)
|
||||
|
||||
// part2 and part3 missing
|
||||
})
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(1, reader.getBufferedPackets().size)
|
||||
assertEquals(part2.size + part3.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(part2.toReadPacket())
|
||||
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(2, reader.getBufferedPackets().size)
|
||||
assertEquals(part3.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(buildPacket {
|
||||
writeFully(part3)
|
||||
writePacket(buildLVPacket { writeFully(part4) })
|
||||
})
|
||||
|
||||
assertEquals(2, received.size)
|
||||
assertEquals(0, reader.getBufferedPackets().size)
|
||||
assertEquals(0, reader.getMissingLength())
|
||||
|
||||
received[0].read {
|
||||
assertEquals(part1.size + part2.size + part3.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(111, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(222, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(333, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
received[1].read {
|
||||
assertEquals(part4.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(444, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `can read 3 part packets from combined with a combined`() {
|
||||
val part1 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(111)
|
||||
}.readBytes()
|
||||
|
||||
val part2 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(222)
|
||||
}.readBytes()
|
||||
|
||||
val part3 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(333)
|
||||
}.readBytes()
|
||||
|
||||
val part4 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(444)
|
||||
}.readBytes()
|
||||
|
||||
val part5 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(555)
|
||||
}.readBytes()
|
||||
|
||||
reader.offer(buildPacket {
|
||||
writeInt(part1.size + part2.size + part3.size + 4)
|
||||
writeFully(part1)
|
||||
|
||||
// part2 and part3 missing
|
||||
})
|
||||
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(1, reader.getBufferedPackets().size)
|
||||
assertEquals(part2.size + part3.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(part2.toReadPacket())
|
||||
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(2, reader.getBufferedPackets().size)
|
||||
assertEquals(part3.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(buildPacket {
|
||||
writeFully(part3)
|
||||
writePacket(buildPacket {
|
||||
writeInt(part4.size + part5.size + 4)
|
||||
writeFully(part4)
|
||||
|
||||
// part5 missing
|
||||
})
|
||||
})
|
||||
|
||||
assertEquals(1, received.size)
|
||||
assertEquals(1, reader.getBufferedPackets().size)
|
||||
assertEquals(part5.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(part5.toReadPacket())
|
||||
|
||||
assertEquals(2, received.size)
|
||||
assertEquals(0, reader.getBufferedPackets().size)
|
||||
assertEquals(0, reader.getMissingLength().toInt())
|
||||
|
||||
received[0].read {
|
||||
assertEquals(part1.size + part2.size + part3.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(111, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(222, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(333, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
received[1].read {
|
||||
assertEquals(part4.size + part5.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(444, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(555, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensures it will not emit without any length check if received a missing part.
|
||||
@Test
|
||||
fun `can read 4 part packets`() {
|
||||
val part1 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(111)
|
||||
}.readBytes()
|
||||
|
||||
val part2 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(222)
|
||||
}.readBytes()
|
||||
|
||||
val part3 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(333)
|
||||
}.readBytes()
|
||||
|
||||
val part4 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(444)
|
||||
}.readBytes()
|
||||
|
||||
reader.offer(buildPacket {
|
||||
writeInt(part1.size + part2.size + part3.size + part4.size + 4)
|
||||
writeFully(part1)
|
||||
|
||||
// part2, part3 and part4 missing
|
||||
})
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(1, reader.getBufferedPackets().size)
|
||||
assertEquals(part2.size + part3.size + part4.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(part2.toReadPacket())
|
||||
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(2, reader.getBufferedPackets().size)
|
||||
assertEquals(part3.size + part4.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(part3.toReadPacket())
|
||||
|
||||
assertEquals(0, received.size)
|
||||
assertEquals(3, reader.getBufferedPackets().size)
|
||||
assertEquals(part4.size, reader.getMissingLength().toInt())
|
||||
|
||||
reader.offer(part4.toReadPacket())
|
||||
|
||||
assertEquals(1, received.size)
|
||||
assertEquals(0, reader.getBufferedPackets().size)
|
||||
assertEquals(0, reader.getMissingLength())
|
||||
|
||||
received.single().read {
|
||||
assertEquals(part1.size + part2.size + part3.size + part4.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(111, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(222, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(333, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(444, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `can read 2 combined packets`() {
|
||||
val part1 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(123)
|
||||
}.readBytes()
|
||||
|
||||
println("part1.size = ${part1.size}")
|
||||
|
||||
val part2 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(222)
|
||||
}.readBytes()
|
||||
|
||||
println("part2.size = ${part2.size}")
|
||||
|
||||
reader.offer(buildPacket {
|
||||
writePacket(buildLVPacket { writeFully(part1) })
|
||||
writePacket(buildLVPacket { writeFully(part2) })
|
||||
})
|
||||
|
||||
assertEquals(2, received.size)
|
||||
|
||||
received[0].read {
|
||||
assertEquals(part1.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(123, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
received[1].read {
|
||||
assertEquals(part2.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(222, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
assertEquals(0, reader.getBufferedPackets().size)
|
||||
assertEquals(0, reader.getMissingLength())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `can emit 2 combined packets with another part`() {
|
||||
val part1 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(111)
|
||||
}.readBytes()
|
||||
|
||||
val part2 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(222)
|
||||
}.readBytes()
|
||||
|
||||
val part3 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(333)
|
||||
}.readBytes()
|
||||
|
||||
val part4 = buildPacket {
|
||||
writeShortLVString("some strings")
|
||||
writeInt(444)
|
||||
}.readBytes()
|
||||
|
||||
|
||||
println("part1.size = ${part1.size}")
|
||||
println("part2.size = ${part2.size}")
|
||||
println("part3.size = ${part3.size}")
|
||||
println("part4.size = ${part4.size}")
|
||||
|
||||
reader.offer(buildPacket {
|
||||
// should emit two packets
|
||||
writePacket(buildLVPacket { writeFully(part1) })
|
||||
writePacket(buildLVPacket { writeFully(part2) })
|
||||
|
||||
// and process this part
|
||||
writePacket(buildPacket {
|
||||
writeInt(part3.size + part4.size + 4)
|
||||
writeFully(part3)
|
||||
// part4 missing
|
||||
})
|
||||
})
|
||||
|
||||
assertEquals(2, received.size)
|
||||
assertEquals(1, reader.getBufferedPackets().size)
|
||||
assertEquals(part4.size, reader.getMissingLength().toInt())
|
||||
|
||||
received[0].read {
|
||||
assertEquals(part1.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(111, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
received[1].read {
|
||||
assertEquals(part2.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(222, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
// part4, here you are
|
||||
reader.offer(buildPacket {
|
||||
writePacket(buildPacket { writeFully(part4) })
|
||||
})
|
||||
|
||||
received[2].read {
|
||||
assertEquals(part3.size + part4.size, this.remaining.toInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(333, readInt())
|
||||
assertEquals("some strings", readUShortLVString())
|
||||
assertEquals(444, readInt())
|
||||
assertEquals(0, remaining)
|
||||
}
|
||||
|
||||
assertEquals(0, reader.getBufferedPackets().size)
|
||||
assertEquals(0, reader.getMissingLength())
|
||||
}
|
||||
|
||||
private inline fun buildLVPacket(block: BytePacketBuilder.() -> Unit): ByteReadPacket {
|
||||
contract {
|
||||
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
|
||||
}
|
||||
|
||||
return buildPacket {
|
||||
writeIntLVPacket(lengthOffset = { it + 4 }) {
|
||||
block()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -27,6 +27,9 @@ internal actual class PlatformInitializationTest actual constructor() : Abstract
|
||||
|
||||
/**
|
||||
* All test classes should inherit from [AbstractTest]
|
||||
*
|
||||
* Note: To run a test in native sourceSets, use IDEA key shortcut 'control + shift + R' on macOS and 'Ctrl + Shift + R' on Windows.
|
||||
* Or you can right-click the function name of the test case and invoke 'Run ...'. You should not expect to see a button icon around the line numbers.
|
||||
*/
|
||||
internal actual abstract class AbstractTest actual constructor() : CommonAbstractTest() {
|
||||
init {
|
||||
|
@ -28,9 +28,9 @@ import kotlin.contracts.contract
|
||||
/**
|
||||
* TCP Socket.
|
||||
*/
|
||||
@OptIn(UnsafeNumber::class)
|
||||
internal actual class PlatformSocket(
|
||||
private val socket: Int
|
||||
private val socket: Int,
|
||||
bufferSize: Int = DEFAULT_BUFFER_SIZE * 2 // improve performance for some big packets
|
||||
) : Closeable, HighwayProtocolChannel {
|
||||
@Suppress("UnnecessaryOptInAnnotation")
|
||||
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
|
||||
@ -42,9 +42,9 @@ internal actual class PlatformSocket(
|
||||
private val sendDispatcher: CoroutineDispatcher = newSingleThreadContext("PlatformSocket#$socket.dispatcher")
|
||||
|
||||
private val readLock = Mutex()
|
||||
private val readBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
|
||||
private val readBuffer = ByteArray(bufferSize).pin()
|
||||
private val writeLock = Mutex()
|
||||
private val writeBuffer = ByteArray(DEFAULT_BUFFER_SIZE).pin()
|
||||
private val writeBuffer = ByteArray(bufferSize).pin()
|
||||
|
||||
actual val isOpen: Boolean
|
||||
get() = write(socket, null, 0).convert<Long>() != 0L
|
||||
@ -95,7 +95,7 @@ internal actual class PlatformSocket(
|
||||
logger.info { "Native socket reading." }
|
||||
val readBuffer = readBuffer
|
||||
val length = recv(socket, readBuffer.addressOf(0), readBuffer.get().size.convert(), 0).convert<Long>()
|
||||
if (length < 0L) {
|
||||
if (length <= 0L) {
|
||||
throw EOFException("recv: $length, errno=$errno")
|
||||
}
|
||||
logger.info {
|
||||
|
Loading…
Reference in New Issue
Block a user