diff --git a/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NettyNetworkHandler.kt b/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NettyNetworkHandler.kt index b61c55d08..8c0ff9abb 100644 --- a/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NettyNetworkHandler.kt +++ b/mirai-core/src/commonMain/kotlin/network/net/impl/netty/NettyNetworkHandler.kt @@ -208,13 +208,17 @@ internal class NettyNetworkHandler( override fun initialState(): BaseStateImpl = StateInitialized() } -private suspend fun ChannelFuture.awaitKt(): ChannelFuture { +internal suspend fun ChannelFuture.awaitKt(): ChannelFuture { suspendCancellableCoroutine { cont -> cont.invokeOnCancellation { channel().close() } - addListener { - cont.resumeWith(Result.success(Unit)) + addListener { f -> + if (f.isSuccess) { + cont.resumeWith(Result.success(Unit)) + } else { + cont.resumeWith(Result.failure(f.cause())) + } } } return this diff --git a/mirai-core/src/commonTest/kotlin/network/NettyTestUnit.kt b/mirai-core/src/commonTest/kotlin/network/NettyTestUnit.kt new file mode 100644 index 000000000..3a1ec03b7 --- /dev/null +++ b/mirai-core/src/commonTest/kotlin/network/NettyTestUnit.kt @@ -0,0 +1,69 @@ +/* + * Copyright 2019-2021 Mamoe Technologies and contributors. + * + * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证. + * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link. + * + * https://github.com/mamoe/mirai/blob/master/LICENSE + */ + +package net.mamoe.mirai.internal.network + +import io.netty.channel.DefaultChannelPromise +import io.netty.channel.embedded.EmbeddedChannel +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import net.mamoe.mirai.internal.network.net.impl.netty.awaitKt +import net.mamoe.mirai.internal.test.runBlockingUnit +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.Test +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import kotlin.time.seconds + +internal class NettyTestUnit { + companion object { + private val channel = EmbeddedChannel() + + @JvmStatic + @AfterAll + fun afterAll() { + channel.close() + } + } + + @Test + fun canAwait() = runBlockingUnit(timeout = 5.seconds) { + val future = DefaultChannelPromise(channel) + launch { + delay(2000) + future.setSuccess() + } + future.awaitKt() + } + + @Test + fun returnsImmediatelyIfCompleted() = runBlockingUnit(timeout = 5.seconds) { + val future = DefaultChannelPromise(channel) + future.setSuccess() + future.awaitKt() + } + + @Test + fun testAwait() { + class MyError : AssertionError("My") // coroutine debugger will modify the exception if inside coroutine + + runBlockingUnit(timeout = 5.seconds) { + val future = DefaultChannelPromise(channel) + launch { + delay(2000) + future.setFailure(MyError()) + } + assertFailsWith { + future.awaitKt() + }.let { actual -> + assertTrue { actual is MyError } + } + } + } +} diff --git a/mirai-core/src/commonTest/kotlin/test/utils.kt b/mirai-core/src/commonTest/kotlin/test/utils.kt index 70fa773bf..ef765f5b1 100644 --- a/mirai-core/src/commonTest/kotlin/test/utils.kt +++ b/mirai-core/src/commonTest/kotlin/test/utils.kt @@ -12,12 +12,26 @@ package net.mamoe.mirai.internal.test import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext +import kotlin.time.Duration -internal fun runBlockingUnit( +fun runBlockingUnit( context: CoroutineContext = EmptyCoroutineContext, block: suspend CoroutineScope.() -> Unit -): Unit { +) { return runBlocking(context, block) +} + +fun runBlockingUnit( + context: CoroutineContext = EmptyCoroutineContext, + timeout: Duration, + block: suspend CoroutineScope.() -> Unit +) { + runBlockingUnit(context) { + withTimeout(timeout) { + block() + } + } } \ No newline at end of file