From 853bcc22a4e32abb212dc358adab663ad340d150 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Thu, 2 Jun 2022 17:50:01 +0100
Subject: [PATCH] Commonize PacketDecodePipeline for PacketCodec for all
 platforms

---
 .../network/handler/CommonNetworkHandler.kt   | 67 +++++++++++++++++--
 .../network/impl/netty/NettyNetworkHandler.kt | 30 ++-------
 .../kotlin/netinternalkit/NetReplayHelper.kt  |  5 +-
 .../network/handler/NativeNetworkHandler.kt   | 57 +++++++---------
 4 files changed, 96 insertions(+), 63 deletions(-)

diff --git a/mirai-core/src/commonMain/kotlin/network/handler/CommonNetworkHandler.kt b/mirai-core/src/commonMain/kotlin/network/handler/CommonNetworkHandler.kt
index 1eb114916..ed0a38d24 100644
--- a/mirai-core/src/commonMain/kotlin/network/handler/CommonNetworkHandler.kt
+++ b/mirai-core/src/commonMain/kotlin/network/handler/CommonNetworkHandler.kt
@@ -11,6 +11,8 @@ package net.mamoe.mirai.internal.network.handler
 
 import io.ktor.utils.io.core.*
 import kotlinx.coroutines.*
+import kotlinx.coroutines.channels.Channel
+import kotlinx.coroutines.channels.onFailure
 import net.mamoe.mirai.internal.network.components.*
 import net.mamoe.mirai.internal.network.handler.selector.NetworkException
 import net.mamoe.mirai.internal.network.handler.selector.NetworkHandlerSelector
@@ -89,14 +91,67 @@ internal abstract class CommonNetworkHandler<Conn>(
     internal inner class PacketDecodePipeline(parentContext: CoroutineContext) :
         CoroutineScope by parentContext.childScope() {
         private val packetCodec: PacketCodec by lazy { context[PacketCodec] }
+        private val ssoProcessor: SsoProcessor by lazy { context[SsoProcessor] }
 
-        fun send(raw: RawIncomingPacket) {
+
+        private val queue: Channel<ByteReadPacket> = Channel<ByteReadPacket>(Channel.BUFFERED) { undelivered ->
+            launch { sendQueue(undelivered) }
+        }.also { channel -> coroutineContext[Job]!!.invokeOnCompletion { channel.close(it) } }
+
+        private suspend inline fun sendQueue(packet: ByteReadPacket) {
+            queue.send(packet)
+        }
+
+        init {
             launch {
-                packetLogger.debug { "Packet Handling Processor: receive packet ${raw.commandName}" }
-                val result = packetCodec.processBody(context.bot, raw)
-                if (result == null) {
-                    collectUnknownPacket(raw)
-                } else collectReceived(result)
+                while (isActive) {
+                    val result = queue.receiveCatching()
+                    packetLogger.verbose { "Decoding packet: $result" }
+                    result.onFailure { if (it is CancellationException) return@launch }
+
+                    result.getOrNull()?.let { packet ->
+                        try {
+                            val decoded = decodePacket(packet)
+                            processBody(decoded)
+                        } catch (e: Throwable) {
+                            if (e is CancellationException) return@launch
+                            handleExceptionInDecoding(e)
+                            logger.error("Error while decoding packet '${packet}'", e)
+                        }
+                    }
+                }
+            }
+        }
+
+        private fun decodePacket(packet: ByteReadPacket): RawIncomingPacket {
+            return if (packetLogger.isDebugEnabled) {
+                val bytes = packet.readBytes()
+                logger.verbose { "Decoding: len=${bytes.size}, value=${bytes.toUHexString()}" }
+                val raw = packetCodec.decodeRaw(
+                    ssoProcessor.ssoSession,
+                    bytes.toReadPacket()
+                )
+                logger.verbose { "Decoded: ${raw.commandName}" }
+                raw
+            } else {
+                packetCodec.decodeRaw(
+                    ssoProcessor.ssoSession,
+                    packet
+                )
+            }
+        }
+
+        private suspend fun processBody(raw: RawIncomingPacket) {
+            packetLogger.debug { "Packet Handling Processor: receive packet ${raw.commandName}" }
+            val result = packetCodec.processBody(context.bot, raw)
+            if (result == null) {
+                collectUnknownPacket(raw)
+            } else collectReceived(result)
+        }
+
+        fun send(packet: ByteReadPacket) {
+            queue.trySend(packet).onFailure {
+                throw it ?: throw IllegalStateException("Internal error: Failed to decode '$packet' without reason.")
             }
         }
     }
diff --git a/mirai-core/src/jvmBaseMain/kotlin/network/impl/netty/NettyNetworkHandler.kt b/mirai-core/src/jvmBaseMain/kotlin/network/impl/netty/NettyNetworkHandler.kt
index e1baad02b..4790f352c 100644
--- a/mirai-core/src/jvmBaseMain/kotlin/network/impl/netty/NettyNetworkHandler.kt
+++ b/mirai-core/src/jvmBaseMain/kotlin/network/impl/netty/NettyNetworkHandler.kt
@@ -9,6 +9,7 @@
 
 package net.mamoe.mirai.internal.network.impl.netty
 
+import io.ktor.utils.io.core.*
 import io.netty.bootstrap.Bootstrap
 import io.netty.buffer.ByteBuf
 import io.netty.channel.*
@@ -20,9 +21,6 @@ import io.netty.handler.codec.MessageToByteEncoder
 import kotlinx.coroutines.CompletableDeferred
 import kotlinx.coroutines.asCoroutineDispatcher
 import kotlinx.coroutines.job
-import net.mamoe.mirai.internal.network.components.PacketCodec
-import net.mamoe.mirai.internal.network.components.RawIncomingPacket
-import net.mamoe.mirai.internal.network.components.SsoProcessor
 import net.mamoe.mirai.internal.network.handler.CommonNetworkHandler
 import net.mamoe.mirai.internal.network.handler.NetworkHandler.State
 import net.mamoe.mirai.internal.network.handler.NetworkHandlerContext
@@ -53,26 +51,11 @@ internal open class NettyNetworkHandler(
     // netty conn.
     ///////////////////////////////////////////////////////////////////////////
 
-    private inner class ByteBufToIncomingPacketDecoder : SimpleChannelInboundHandler<ByteBuf>(ByteBuf::class.java) {
-        private val packetCodec: PacketCodec by lazy { context[PacketCodec] }
-        private val ssoProcessor: SsoProcessor by lazy { context[SsoProcessor] }
-
-        override fun channelRead0(ctx: ChannelHandlerContext, msg: ByteBuf) {
-            kotlin.runCatching {
-                ctx.fireChannelRead(msg.toReadPacket().use { packet ->
-                    packetCodec.decodeRaw(ssoProcessor.ssoSession, packet)
-                })
-            }.onFailure { error ->
-                handleExceptionInDecoding(error)
-            }
-        }
-    }
-
-    private inner class RawIncomingPacketCollector(
+    private inner class IncomingPacketDecoder(
         private val decodePipeline: PacketDecodePipeline,
-    ) : SimpleChannelInboundHandler<RawIncomingPacket>(RawIncomingPacket::class.java) {
-        override fun channelRead0(ctx: ChannelHandlerContext, msg: RawIncomingPacket) {
-            decodePipeline.send(msg)
+    ) : SimpleChannelInboundHandler<ByteBuf>(ByteBuf::class.java) {
+        override fun channelRead0(ctx: ChannelHandlerContext, msg: ByteBuf) {
+            decodePipeline.send(msg.toReadPacket())
         }
     }
 
@@ -93,8 +76,7 @@ internal open class NettyNetworkHandler(
             })
             .addLast("outgoing-packet-encoder", OutgoingPacketEncoder())
             .addLast(LengthFieldBasedFrameDecoder(Int.MAX_VALUE, 0, 4, -4, 4))
-            .addLast(ByteBufToIncomingPacketDecoder())
-            .addLast("raw-packet-collector", RawIncomingPacketCollector(decodePipeline))
+            .addLast(IncomingPacketDecoder(decodePipeline))
     }
 
     protected open fun createDummyDecodePipeline() = PacketDecodePipeline(this@NettyNetworkHandler.coroutineContext)
diff --git a/mirai-core/src/jvmTest/kotlin/netinternalkit/NetReplayHelper.kt b/mirai-core/src/jvmTest/kotlin/netinternalkit/NetReplayHelper.kt
index 178342a66..8dfbf51cb 100644
--- a/mirai-core/src/jvmTest/kotlin/netinternalkit/NetReplayHelper.kt
+++ b/mirai-core/src/jvmTest/kotlin/netinternalkit/NetReplayHelper.kt
@@ -76,7 +76,7 @@ private fun NetReplayHelperClass(): Class<*> {
 
 
 private fun attachNetReplayHelper(channel: Channel) {
-    channel.pipeline()
+    channel.pipeline() // TODO: 2022/6/2 will not work since "raw-packet-collector" has been removed
         .addBefore("raw-packet-collector", "raw-packet-dumper", newRawPacketDumper())
 
     attachNetReplayWView(channel)
@@ -280,5 +280,6 @@ fun Bot.attachNetReplayHelper() {
 
 fun main() {
     val bot = BotFactory.newBot(0, "")
-    bot.attachNetReplayHelper()
+    bot.attachNetReplayHelper() //
+    // TODO: 2022/6/2 will not work since "raw-packet-collector" has been removed, see net.mamoe.mirai.internal.netinternalkit.NetReplayHelper.attachNetReplayHelper(io.netty.channel.Channel)
 }
diff --git a/mirai-core/src/nativeMain/kotlin/network/handler/NativeNetworkHandler.kt b/mirai-core/src/nativeMain/kotlin/network/handler/NativeNetworkHandler.kt
index 6bb9d65db..599d40711 100644
--- a/mirai-core/src/nativeMain/kotlin/network/handler/NativeNetworkHandler.kt
+++ b/mirai-core/src/nativeMain/kotlin/network/handler/NativeNetworkHandler.kt
@@ -24,7 +24,6 @@ import net.mamoe.mirai.internal.utils.PlatformSocket
 import net.mamoe.mirai.internal.utils.connect
 import net.mamoe.mirai.utils.childScope
 import net.mamoe.mirai.utils.info
-import net.mamoe.mirai.utils.verbose
 
 internal class NativeNetworkHandler(
     context: NetworkHandlerContext,
@@ -48,42 +47,37 @@ internal class NativeNetworkHandler(
             launch { write(undelivered) }
         }
 
-        private val lengthDelimitedPacketReader = LengthDelimitedPacketReader(fun(combined: ByteReadPacket) {
-            logger.verbose { "Decoding: len=${combined.remaining}" }
-            val raw = packetCodec.decodeRaw(
-                ssoProcessor.ssoSession,
-                combined
-            )
-            logger.verbose { "Decoded: ${raw.commandName}" }
-            decodePipeline.send(raw)
-        })
+        private val lengthDelimitedPacketReader = LengthDelimitedPacketReader(decodePipeline::send)
 
-        private val sender = launch {
-            while (isActive) {
-                val result = sendQueue.receiveCatching()
-                logger.info { "Native sender: $result" }
-                result.onFailure { if (it is CancellationException) return@launch }
+        init {
+            launch {
+                while (isActive) {
+                    val result = sendQueue.receiveCatching()
+                    logger.info { "Native sender: $result" }
+                    result.onFailure { if (it is CancellationException) return@launch }
 
-                result.getOrNull()?.let { packet ->
-                    try {
-                        socket.send(packet.delegate, 0, packet.delegate.size)
-                    } catch (e: Throwable) {
-                        if (e is CancellationException) return@launch
-                        logger.error("Error while sending packet '${packet.commandName}'", e)
+                    result.getOrNull()?.let { packet ->
+                        try {
+                            socket.send(packet.delegate, 0, packet.delegate.size)
+                        } catch (e: Throwable) {
+                            if (e is CancellationException) return@launch
+                            logger.error("Error while sending packet '${packet.commandName}'", e)
+                        }
                     }
                 }
             }
-        }
 
-        private val receiver = launch {
-            while (isActive) {
-                try {
-                    val packet = socket.read()
+            launch {
+                while (isActive) {
+                    try {
+                        val packet = socket.read()
 
-                    lengthDelimitedPacketReader.offer(packet)
-                } catch (e: Throwable) {
-                    if (e is CancellationException) return@launch
-                    logger.error("Error while reading packet.", e)
+                        lengthDelimitedPacketReader.offer(packet)
+                    } catch (e: Throwable) {
+                        if (e is CancellationException) return@launch
+                        logger.error("Error while reading packet.", e)
+                        setState { StateClosed(e) }
+                    }
                 }
             }
         }
@@ -91,12 +85,13 @@ internal class NativeNetworkHandler(
         fun write(packet: OutgoingPacket) {
             sendQueue.trySend(packet).onFailure {
                 throw it
-                    ?: throw IllegalStateException("Failed to send packet '${packet.commandName}' without reason.")
+                    ?: throw IllegalStateException("Internal error: Failed to send packet '${packet.commandName}' without reason.")
             }
         }
 
         override fun close() {
             cancel()
+            sendQueue.close()
         }
     }