diff --git a/mirai-core/src/commonMain/kotlin/network/net/NetworkHandler.kt b/mirai-core/src/commonMain/kotlin/network/net/NetworkHandler.kt index eb426fa4e..3dd725cec 100644 --- a/mirai-core/src/commonMain/kotlin/network/net/NetworkHandler.kt +++ b/mirai-core/src/commonMain/kotlin/network/net/NetworkHandler.kt @@ -14,7 +14,7 @@ import net.mamoe.mirai.Bot import net.mamoe.mirai.internal.QQAndroidBot import net.mamoe.mirai.internal.network.Packet import net.mamoe.mirai.internal.network.net.NetworkHandler.State -import net.mamoe.mirai.internal.network.net.protocol.SsoController +import net.mamoe.mirai.internal.network.net.protocol.SsoContext import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacket import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacketWithRespType import net.mamoe.mirai.utils.BotConfiguration @@ -32,7 +32,7 @@ internal interface NetworkHandlerContext { val bot: QQAndroidBot val logger: MiraiLogger - val ssoController: SsoController + val ssoContext: SsoContext val configuration: BotConfiguration fun getNextAddress(): SocketAddress // FIXME: 2021/4/14 @@ -40,7 +40,7 @@ internal interface NetworkHandlerContext { internal class NetworkHandlerContextImpl( override val bot: QQAndroidBot, - override val ssoController: SsoController, + override val ssoContext: SsoContext, ) : NetworkHandlerContext { override val configuration: BotConfiguration get() = bot.configuration diff --git a/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NettyNetworkHandler.kt b/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NettyNetworkHandler.kt index a58dcb760..b61c55d08 100644 --- a/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NettyNetworkHandler.kt +++ b/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NettyNetworkHandler.kt @@ -12,10 +12,7 @@ package net.mamoe.mirai.internal.network.net.impl.netty import io.netty.bootstrap.Bootstrap import io.netty.buffer.ByteBuf import io.netty.buffer.ByteBufInputStream -import io.netty.channel.ChannelHandlerContext -import io.netty.channel.ChannelInboundHandlerAdapter -import io.netty.channel.ChannelInitializer -import io.netty.channel.SimpleChannelInboundHandler +import io.netty.channel.* import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel @@ -30,6 +27,7 @@ import net.mamoe.mirai.internal.network.net.NetworkHandler import net.mamoe.mirai.internal.network.net.NetworkHandlerContext import net.mamoe.mirai.internal.network.net.protocol.PacketCodec import net.mamoe.mirai.internal.network.net.protocol.RawIncomingPacket +import net.mamoe.mirai.internal.network.net.protocol.SsoController import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacket import net.mamoe.mirai.utils.childScope import java.net.SocketAddress @@ -40,10 +38,11 @@ internal class NettyNetworkHandler( private val address: SocketAddress, ) : NetworkHandlerSupport(context) { override fun close() { - super.close() - setState(StateClosed()) + setState { StateClosed(null) } } + private fun closeSuper() = super.close() + override suspend fun sendPacketImpl(packet: OutgoingPacket) { val state = _state as NettyState state.sendPacketImpl(packet) @@ -71,24 +70,33 @@ internal class NettyNetworkHandler( private suspend fun createConnection(decodePipeline: PacketDecodePipeline): ChannelHandlerContext { val contextResult = CompletableDeferred<ChannelHandlerContext>() - val eventLoopGroup = NioEventLoopGroup() - Bootstrap().group(eventLoopGroup) + + val future = Bootstrap().group(eventLoopGroup) .channel(NioSocketChannel::class.java) .handler(object : ChannelInitializer<SocketChannel>() { override fun initChannel(ch: SocketChannel) { - ch.pipeline().addLast(object : ChannelInboundHandlerAdapter() { - override fun channelActive(ctx: ChannelHandlerContext) { - contextResult.complete(ctx) - } - }) + ch.pipeline() + .addLast(object : ChannelInboundHandlerAdapter() { + override fun channelActive(ctx: ChannelHandlerContext) { + contextResult.complete(ctx) + } + + override fun channelInactive(ctx: ChannelHandlerContext?) { + eventLoopGroup.shutdownGracefully() + } + }) .addLast(LengthFieldBasedFrameDecoder(Int.MAX_VALUE, 0, 4, -4, 0)) .addLast(ByteBufToIncomingPacketDecoder()) .addLast(RawIncomingPacketCollector(decodePipeline)) } }) - .connect(address).runBIO { await() } - // TODO: 2021/4/14 eventLoopGroup 关闭 + .connect(address) + .awaitKt() + + future.channel().closeFuture().addListener { + setState { StateConnectionLost(it.cause()) } + } return contextResult.await() } @@ -117,6 +125,9 @@ internal class NettyNetworkHandler( // states /////////////////////////////////////////////////////////////////////////// + /** + * When state is initialized, it must be set to [_state]. (inside [setState]) + */ private abstract inner class NettyState( correspondingState: NetworkHandler.State ) : BaseStateImpl(correspondingState) { @@ -129,24 +140,25 @@ internal class NettyNetworkHandler( } override suspend fun resumeConnection() { - setState(StateConnecting(PacketDecodePipeline(this@NettyNetworkHandler.coroutineContext))) + setState { StateConnecting(PacketDecodePipeline(this@NettyNetworkHandler.coroutineContext)) } } } private inner class StateConnecting( val decodePipeline: PacketDecodePipeline, ) : NettyState(NetworkHandler.State.CONNECTING) { - private val connection = async { - createConnection(decodePipeline) - } + private val ssoController = SsoController(context.ssoContext, this@NettyNetworkHandler) + + private val connection = async { createConnection(decodePipeline) } private val connectResult = async { val connection = connection.await() - context.ssoController.login() - setState(StateOK(connection)) + ssoController.login() + setState { StateOK(connection) } }.apply { invokeOnCompletion { error -> - if (error != null) setState(StateClosed()) // logon failure closes the network handler. + if (error != null) setState { StateClosed(error) } // logon failure closes the network handler. + // and this error will also be thrown by `StateConnecting.resumeConnection` } } @@ -169,14 +181,45 @@ internal class NettyNetworkHandler( override suspend fun resumeConnection() {} // noop } - private inner class StateClosed : NettyState(NetworkHandler.State.OK) { + private inner class StateConnectionLost(private val cause: Throwable) : + NettyState(NetworkHandler.State.CONNECTION_LOST) { + override suspend fun sendPacketImpl(packet: OutgoingPacket) { + throw IllegalStateException("Connection is lost so cannot send packet. Call resumeConnection first.", cause) + } + + override suspend fun resumeConnection() { + setState { StateConnecting(PacketDecodePipeline(this@NettyNetworkHandler.coroutineContext)) } + } // noop + } + + private inner class StateClosed( + val exception: Throwable? + ) : NettyState(NetworkHandler.State.OK) { + init { + closeSuper() + } + override suspend fun sendPacketImpl(packet: OutgoingPacket) = error("NetworkHandler is already closed.") - override suspend fun resumeConnection() {} // noop + override suspend fun resumeConnection() { + exception?.let { throw it } + } // noop } override fun initialState(): BaseStateImpl = StateInitialized() } +private suspend fun ChannelFuture.awaitKt(): ChannelFuture { + suspendCancellableCoroutine<Unit> { cont -> + cont.invokeOnCancellation { + channel().close() + } + addListener { + cont.resumeWith(Result.success(Unit)) + } + } + return this +} + // TODO: 2021/4/14 Add test for toReadPacket private fun ByteBuf.toReadPacket(): ByteReadPacket { val buf = this diff --git a/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NetworkHandlerSupport.kt b/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NetworkHandlerSupport.kt index 9ca96a9b9..af48436e6 100644 --- a/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NetworkHandlerSupport.kt +++ b/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NetworkHandlerSupport.kt @@ -116,6 +116,8 @@ internal abstract class NetworkHandlerSupport( * A **scoped** state corresponding to [NetworkHandler.State]. * * CoroutineScope is cancelled when switched to another state. + * + * State can only be changed inside [setState]. */ protected abstract inner class BaseStateImpl( val correspondingState: NetworkHandler.State, @@ -132,7 +134,11 @@ internal abstract class NetworkHandlerSupport( private set final override val state: NetworkHandler.State get() = _state.correspondingState - protected fun setState(impl: BaseStateImpl) { // we can add monitor here for debug. + protected inline fun setState(crossinline new: () -> BaseStateImpl) = synchronized(this) { + // we can add hooks here for debug. + + val impl = new() + val old = _state check(old !== impl) { "Old and new states cannot be the same." } old.cancel()