From b31ef37c8d25ac2e34390b641d3f797d1d86ed48 Mon Sep 17 00:00:00 2001 From: Karlatemp Date: Wed, 5 May 2021 00:30:33 +0800 Subject: [PATCH] Exception Handling; Normal login tests --- .../kotlin/network/components/PacketCodec.kt | 20 +++++- .../network/impl/netty/NettyNetworkHandler.kt | 39 ++++++++++- .../network/impl/netty/AbstractNettyNHTest.kt | 31 ++++++++- .../impl/netty/NettyBotNormalLoginTest.kt | 69 +++++++++++++++++++ .../src/commonTest/kotlin/test/events.kt | 2 + 5 files changed, 155 insertions(+), 6 deletions(-) create mode 100644 mirai-core/src/commonTest/kotlin/network/impl/netty/NettyBotNormalLoginTest.kt diff --git a/mirai-core/src/commonMain/kotlin/network/components/PacketCodec.kt b/mirai-core/src/commonMain/kotlin/network/components/PacketCodec.kt index 66d47ee1c..b13cef78e 100644 --- a/mirai-core/src/commonMain/kotlin/network/components/PacketCodec.kt +++ b/mirai-core/src/commonMain/kotlin/network/components/PacketCodec.kt @@ -50,6 +50,18 @@ internal interface PacketCodec { } } +internal class OicqDecodingException( + val targetException: Throwable +) : RuntimeException( + null, targetException, + true, // enableSuppression + false, // writableStackTrace +) { + override fun getStackTrace(): Array { + return targetException.stackTrace + } +} + internal class PacketCodecImpl : PacketCodec { override fun decodeRaw(client: SsoSession, input: ByteReadPacket): RawIncomingPacket = input.run { @@ -87,7 +99,13 @@ internal class PacketCodecImpl : PacketCodec { 2 -> RawIncomingPacket( raw.commandName, raw.sequenceId, - raw.body.withUse { parseOicqResponse(client) } + raw.body.withUse { + try { + parseOicqResponse(client) + } catch (e: Throwable) { + throw OicqDecodingException(e) + } + } ) else -> error("Unknown flag2=$flag2") } diff --git a/mirai-core/src/commonMain/kotlin/network/impl/netty/NettyNetworkHandler.kt b/mirai-core/src/commonMain/kotlin/network/impl/netty/NettyNetworkHandler.kt index e1104924c..21c10e572 100644 --- a/mirai-core/src/commonMain/kotlin/network/impl/netty/NettyNetworkHandler.kt +++ b/mirai-core/src/commonMain/kotlin/network/impl/netty/NettyNetworkHandler.kt @@ -31,6 +31,7 @@ import net.mamoe.mirai.internal.network.handler.logger import net.mamoe.mirai.internal.network.handler.state.StateObserver import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacket import net.mamoe.mirai.utils.* +import java.io.EOFException import java.net.SocketAddress import kotlin.coroutines.CoroutineContext import io.netty.channel.Channel as NettyChannel @@ -58,6 +59,29 @@ internal open class NettyNetworkHandler( return "NettyNetworkHandler(context=$context, address=$address)" } + + /////////////////////////////////////////////////////////////////////////// + // exception handling + /////////////////////////////////////////////////////////////////////////// + protected open fun handleExceptionInDecoding(error: Throwable) { + if (error is OicqDecodingException) { + if (error.targetException is EOFException) return + throw error.targetException + } + throw error + } + + protected open fun handlePipelineException(ctx: ChannelHandlerContext, error: Throwable) { + context.bot.logger.error(error) + synchronized(this) { + if (_state !is StateConnecting) { + setState { StateConnecting(ExceptionCollector(error)) } + } else { + close(error) + } + } + } + /////////////////////////////////////////////////////////////////////////// // netty conn. /////////////////////////////////////////////////////////////////////////// @@ -67,9 +91,13 @@ internal open class NettyNetworkHandler( private val ssoProcessor: SsoProcessor by lazy { context[SsoProcessor] } override fun channelRead0(ctx: ChannelHandlerContext, msg: ByteBuf) { - ctx.fireChannelRead(msg.toReadPacket().use { packet -> - packetCodec.decodeRaw(ssoProcessor.ssoSession, packet) - }) + kotlin.runCatching { + ctx.fireChannelRead(msg.toReadPacket().use { packet -> + packetCodec.decodeRaw(ssoProcessor.ssoSession, packet) + }) + }.onFailure { error -> + handleExceptionInDecoding(error) + } } } @@ -90,6 +118,11 @@ internal open class NettyNetworkHandler( protected open fun setupChannelPipeline(pipeline: ChannelPipeline, decodePipeline: PacketDecodePipeline) { pipeline + .addLast(object : ChannelInboundHandlerAdapter() { + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + handlePipelineException(ctx, cause) + } + }) .addLast(OutgoingPacketEncoder()) .addLast(LengthFieldBasedFrameDecoder(Int.MAX_VALUE, 0, 4, -4, 4)) .addLast(ByteBufToIncomingPacketDecoder()) diff --git a/mirai-core/src/commonTest/kotlin/network/impl/netty/AbstractNettyNHTest.kt b/mirai-core/src/commonTest/kotlin/network/impl/netty/AbstractNettyNHTest.kt index b8fc2838e..8c8f703a4 100644 --- a/mirai-core/src/commonTest/kotlin/network/impl/netty/AbstractNettyNHTest.kt +++ b/mirai-core/src/commonTest/kotlin/network/impl/netty/AbstractNettyNHTest.kt @@ -11,6 +11,7 @@ package net.mamoe.mirai.internal.network.impl.netty import io.netty.channel.Channel import io.netty.channel.embedded.EmbeddedChannel +import io.netty.util.ReferenceCountUtil import kotlinx.coroutines.CompletableDeferred import net.mamoe.mirai.internal.network.framework.AbstractRealNetworkHandlerTest import net.mamoe.mirai.internal.network.handler.NetworkHandlerContext @@ -42,7 +43,30 @@ internal open class TestNettyNH( } internal abstract class AbstractNettyNHTest : AbstractRealNetworkHandlerTest() { - val channel = EmbeddedChannel() + var fakeServer: (NettyNHTestChannel.(msg: Any?) -> Unit)? = null + + internal inner class NettyNHTestChannel : EmbeddedChannel() { + public /*internal*/ override fun doRegister() { + super.doRegister() // Set channel state to ACTIVE + // Drop old handlers + pipeline().let { p -> + while (p.first() != null) { + p.removeFirst() + } + } + } + + override fun handleInboundMessage(msg: Any?) { + ReferenceCountUtil.release(msg) // Not handled, Drop + } + + override fun handleOutboundMessage(msg: Any?) { + fakeServer?.invoke(this, msg) ?: ReferenceCountUtil.release(msg) + } + } + + val channel = NettyNHTestChannel() + override val network: TestNettyNH get() = bot.network as TestNettyNH override val factory: NetworkHandlerFactory = @@ -50,7 +74,10 @@ internal abstract class AbstractNettyNHTest : AbstractRealNetworkHandlerTest("A") { bot.login() } + } + + @Test + fun `test network broken`() = runBlockingUnit { + withSsoProcessor { + delay(1000) + channel.pipeline().fireExceptionCaught(IOException("TestNetworkBroken")) + delay(100000) // receive bits from "network" + } + assertFailsWith("TestNetworkBroken") { + bot.login() + } + } + + @Test + fun `test errors after logon`() = runBlockingUnit { + bot.login() + delay(1000) + assertEventBroadcasts(-1) { + launch { + delay(1000) + channel.pipeline().fireExceptionCaught(CusLoginException("Net error")) + } + assertNotNull( + nextEvent(5000) { it.bot === bot } + ) + }.let { events -> + assertFailsWith("Net error") { + throw events.firstIsInstanceOrNull()!!.cause!! + } + } + assertState(NetworkHandler.State.OK) + } +} diff --git a/mirai-core/src/commonTest/kotlin/test/events.kt b/mirai-core/src/commonTest/kotlin/test/events.kt index cb4a44740..98484b597 100644 --- a/mirai-core/src/commonTest/kotlin/test/events.kt +++ b/mirai-core/src/commonTest/kotlin/test/events.kt @@ -40,6 +40,8 @@ internal inline fun assertEventBroadcasts(times: Int = 1, bl listener.complete() } + if (times < 0) return receivedEvents.filterIsInstance().cast() + val actual = receivedEvents.filterIsInstance().count() assertEquals( times,