mirror of
https://github.com/mamoe/mirai.git
synced 2025-01-30 02:30:12 +08:00
[core] OnDemandChannel: Ensure scope is closed on exceptional cases:
[core] OnDemandChannel: Ensure channel close and emit shares same behavior [core] OnDemandChannel: Ensure channel is closed properly when producer throws an exception [core] OnDemandChannel: Ensure coroutine scope cancelled when producer coroutine throws exception, also added more tests
This commit is contained in:
parent
87fbaa4fb2
commit
0e625d3368
@ -101,28 +101,27 @@ internal sealed interface ChannelState<T, V> {
|
||||
val producer: OnDemandSendChannel<T, V>
|
||||
}
|
||||
|
||||
// Producer is not running until `expectMore`. `emit` and `receiveOrNull` not allowed.
|
||||
class ProducerReady<T, V>(
|
||||
launchProducer: () -> OnDemandSendChannel<T, V>,
|
||||
) : HasProducer<T, V> {
|
||||
// Lazily start the producer job since it's on-demand
|
||||
override val producer: OnDemandSendChannel<T, V> by lazy(launchProducer) // `lazy` is synchronized
|
||||
|
||||
fun startProducerIfNotYet() {
|
||||
producer
|
||||
}
|
||||
|
||||
override fun toString(): String = "ProducerReady"
|
||||
}
|
||||
|
||||
// Producer is running. `emit` and `receiveOrNull` both allowed.
|
||||
class Producing<T, V>(
|
||||
override val producer: OnDemandSendChannel<T, V>,
|
||||
parentJob: Job,
|
||||
) : HasProducer<T, V> {
|
||||
val deferred: CompletableDeferred<V> by lazy { CompletableDeferred<V>(parentJob) }
|
||||
|
||||
|
||||
override fun toString(): String = "Producing(deferred.completed=${deferred.isCompleted})"
|
||||
}
|
||||
|
||||
// Producer is suspended because it called `emit`. Expecting `receiveOrNull`.
|
||||
class Consuming<T, V>(
|
||||
override val producer: OnDemandSendChannel<T, V>,
|
||||
val value: Deferred<V>,
|
||||
@ -138,6 +137,7 @@ internal sealed interface ChannelState<T, V> {
|
||||
}
|
||||
}
|
||||
|
||||
// Producer is suspended. `expectMore` will resume producer with a ticket.
|
||||
class Consumed<T, V>(
|
||||
override val producer: OnDemandSendChannel<T, V>,
|
||||
val producerLatch: CompletableDeferred<T>
|
||||
@ -151,7 +151,7 @@ internal sealed interface ChannelState<T, V> {
|
||||
) : ChannelState<T, V> {
|
||||
val isSuccess: Boolean get() = exception == null
|
||||
|
||||
fun createAlreadyFinishedException(cause: Throwable?): IllegalProducerStateException {
|
||||
fun createAlreadyFinishedException(cause: Throwable?): IllegalChannelStateException {
|
||||
val exception = exception
|
||||
val causeMessage = if (cause == null) {
|
||||
""
|
||||
@ -159,13 +159,13 @@ internal sealed interface ChannelState<T, V> {
|
||||
", but attempting to finish with the cause $cause"
|
||||
}
|
||||
return if (exception == null) {
|
||||
IllegalProducerStateException(
|
||||
IllegalChannelStateException(
|
||||
this,
|
||||
"Producer has already finished normally$causeMessage. Previous state was: $previousState",
|
||||
cause = cause
|
||||
)
|
||||
} else {
|
||||
IllegalProducerStateException(
|
||||
IllegalChannelStateException(
|
||||
this,
|
||||
"Producer has already finished with the suppressed exception$causeMessage. Previous state was: $previousState",
|
||||
cause = cause
|
||||
|
@ -9,7 +9,8 @@
|
||||
|
||||
package net.mamoe.mirai.utils.channels
|
||||
|
||||
public class IllegalProducerStateException internal constructor(
|
||||
// An internal error exception
|
||||
public class IllegalChannelStateException internal constructor(
|
||||
private val state: ChannelState<*, *>,
|
||||
message: String? = state.toString(),
|
||||
cause: Throwable? = null,
|
@ -13,11 +13,11 @@ import kotlinx.atomicfu.AtomicRef
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.atomicfu.loop
|
||||
import kotlinx.coroutines.*
|
||||
import net.mamoe.mirai.utils.TestOnly
|
||||
import net.mamoe.mirai.utils.UtilsLogger
|
||||
import net.mamoe.mirai.utils.childScope
|
||||
import net.mamoe.mirai.utils.debug
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.coroutines.cancellation.CancellationException
|
||||
|
||||
|
||||
internal class CoroutineOnDemandReceiveChannel<T, V>(
|
||||
@ -27,8 +27,14 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
||||
) : OnDemandReceiveChannel<T, V> {
|
||||
private val coroutineScope = parentCoroutineContext.childScope("CoroutineOnDemandReceiveChannel")
|
||||
|
||||
@TestOnly
|
||||
internal fun getScope() = coroutineScope
|
||||
|
||||
private val state: AtomicRef<ChannelState<T, V>> = atomic(ChannelState.JustInitialized())
|
||||
|
||||
@TestOnly
|
||||
internal fun getState() = state.value
|
||||
|
||||
|
||||
inner class Producer(
|
||||
private val initialTicket: T,
|
||||
@ -38,20 +44,33 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
||||
// attaching Job to the coroutineScope, then `yield` the thread back, to complete `launch`.
|
||||
coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
|
||||
yield()
|
||||
|
||||
try {
|
||||
producerCoroutine(initialTicket)
|
||||
} catch (_: CancellationException) {
|
||||
// ignored
|
||||
} catch (e: Exception) {
|
||||
finishExceptionally(e)
|
||||
} catch (e: Throwable) {
|
||||
// close exceptionally
|
||||
val r = emitImpl(Result.failure(e))
|
||||
check(r == null) // assertion
|
||||
return@launch
|
||||
}
|
||||
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun emit(value: V): T {
|
||||
override suspend fun emit(value: V): T = emitImpl(Result.success(value))!!
|
||||
|
||||
private suspend inline fun emitImpl(value: Result<V>): T? {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
|
||||
is ChannelState.Finished -> {
|
||||
if (value.isFailure) {
|
||||
return null
|
||||
} else {
|
||||
throw state.createAlreadyFinishedException(null)
|
||||
}
|
||||
}
|
||||
|
||||
is ChannelState.Producing -> {
|
||||
val deferred = state.deferred
|
||||
val consumingState = ChannelState.Consuming(
|
||||
@ -60,55 +79,46 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
||||
coroutineScope.coroutineContext
|
||||
)
|
||||
if (compareAndSetState(state, consumingState)) {
|
||||
deferred.complete(value) // produce a value
|
||||
deferred.completeWith(value) // produce a value
|
||||
return consumingState.producerLatch.await() // wait for producer to consume the previous value.
|
||||
}
|
||||
// failed race, try again
|
||||
}
|
||||
|
||||
is ChannelState.ProducerReady -> {
|
||||
// This implies another coroutine is running `expectMore`,
|
||||
// and we are a bit faster than it!
|
||||
setStateProducing(state)
|
||||
}
|
||||
|
||||
else -> throw IllegalProducerStateException(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
else -> throw IllegalChannelStateException(
|
||||
state,
|
||||
if (value.isFailure)
|
||||
"Producer threw an exception (see cause), so completing with the exception, but current state is not Producing"
|
||||
else "Producer is emitting an value, but current state is not Producing",
|
||||
value.exceptionOrNull()
|
||||
)
|
||||
|
||||
override fun finishExceptionally(exception: Throwable) {
|
||||
finishImpl(exception)
|
||||
}
|
||||
|
||||
override fun finish() {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
|
||||
else -> {
|
||||
if (compareAndSetState(state, ChannelState.Finished(state, null))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun setStateProducing(state: ChannelState.ProducerReady<T, V>) {
|
||||
compareAndSetState(state, ChannelState.Producing(state.producer, coroutineScope.coroutineContext.job))
|
||||
private fun setStateProducing(state: ChannelState.ProducerReady<T, V>): Boolean {
|
||||
return compareAndSetState(state, ChannelState.Producing(state.producer, coroutineScope.coroutineContext.job))
|
||||
}
|
||||
|
||||
private fun finishImpl(exception: Throwable?) {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ChannelState.Finished -> {} // ignore
|
||||
else -> {
|
||||
if (compareAndSetState(state, ChannelState.Finished(state, exception))) {
|
||||
val cancellationException = kotlinx.coroutines.CancellationException("Finished", exception)
|
||||
coroutineScope.cancel(cancellationException)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
private fun setStateFinished(
|
||||
currState: ChannelState<T, V>,
|
||||
message: String,
|
||||
exception: ProducerFailureException?
|
||||
): Boolean {
|
||||
if (compareAndSetState(currState, ChannelState.Finished(currState, exception))) {
|
||||
val cancellationException = CancellationException(message, exception)
|
||||
coroutineScope.cancel(cancellationException)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private fun compareAndSetState(state: ChannelState<T, V>, newState: ChannelState<T, V>): Boolean {
|
||||
@ -117,27 +127,27 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
||||
}
|
||||
}
|
||||
|
||||
override val isClosed: Boolean
|
||||
get() = state.value is ChannelState.Finished
|
||||
|
||||
override suspend fun receiveOrNull(): V? {
|
||||
// don't use `.loop`:
|
||||
// don't use atomicfu `.loop`:
|
||||
// java.lang.VerifyError: Bad type on operand stack
|
||||
// net/mamoe/mirai/utils/channels/CoroutineOnDemandReceiveChannel.receiveOrNull(Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @103: getfield
|
||||
|
||||
while (true) {
|
||||
when (val state = state.value) {
|
||||
is ChannelState.Consuming -> {
|
||||
// value is ready, now we consume the value
|
||||
// value is ready, now we try to consume the value
|
||||
|
||||
if (compareAndSetState(state, ChannelState.Consumed(state.producer, state.producerLatch))) {
|
||||
// value is consumed, no contention, safe to retrieve
|
||||
// value is now reserved for us, no contention is possible, safe to retrieve
|
||||
|
||||
return try {
|
||||
// This actually won't suspend, since the value is already completed
|
||||
// Just to be error-tolerating and re-throwing exceptions.
|
||||
state.value.await()
|
||||
} catch (e: Throwable) {
|
||||
// Producer failed to produce the previous value with exception
|
||||
throw ProducerFailureException(cause = e)
|
||||
}
|
||||
// This actually won't suspend (there are tests ensuring this point),
|
||||
// since the value is already completed.
|
||||
// Just to be error-tolerating and re-throwing exceptions.
|
||||
// (Also because `Deferred.getCompleted()` is not stable yet (coroutines 1.6))
|
||||
return awaitValueSafe(state.value)
|
||||
}
|
||||
}
|
||||
|
||||
@ -146,45 +156,59 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
||||
is ChannelState.Producing<T, V> -> {
|
||||
// still producing value
|
||||
|
||||
state.deferred.await() // just wait for value, but does not return it.
|
||||
// Wait for value and throw exception caused by the producer if there is one.
|
||||
awaitValueSafe(state.deferred) // this may or may not suspend.
|
||||
|
||||
// The value will be completed in ProducerState.Consuming state,
|
||||
// but you cannot thread-safely assume current state is Consuming.
|
||||
// Now deferred is complete, and we will be in the Consuming state, but we can't use the value here.
|
||||
// We must ensure only one thread gets the value, and state should then be Consumed
|
||||
|
||||
// Here we will loop again, to atomically switch to Consumed state.
|
||||
// So we loop again and do this in the Consuming state.
|
||||
}
|
||||
|
||||
is ChannelState.Finished -> {
|
||||
state.exception?.let { err ->
|
||||
throw ProducerFailureException(cause = err)
|
||||
}
|
||||
// see public API docs for behavior
|
||||
return null
|
||||
}
|
||||
|
||||
else -> throw IllegalProducerStateException(state)
|
||||
else ->
|
||||
// internal error
|
||||
throw IllegalChannelStateException(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend inline fun awaitValueSafe(deferred: Deferred<V>) = try {
|
||||
deferred.await()
|
||||
} catch (e: Throwable) {
|
||||
// Producer failed to produce the previous value with exception
|
||||
val producerFailureException = ProducerFailureException(cause = e)
|
||||
setStateFinished(
|
||||
this.state.value,
|
||||
"OnDemandChannel is closed because producer failed to produce value, see cause",
|
||||
producerFailureException
|
||||
)
|
||||
throw producerFailureException
|
||||
}
|
||||
|
||||
override fun expectMore(ticket: T): Boolean {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ChannelState.JustInitialized -> {
|
||||
// start producer atomically
|
||||
val ready = ChannelState.ProducerReady { Producer(ticket) }
|
||||
if (compareAndSetState(state, ready)) {
|
||||
ready.startProducerIfNotYet()
|
||||
}
|
||||
compareAndSetState(state, ready)
|
||||
// loop again
|
||||
}
|
||||
|
||||
is ChannelState.ProducerReady -> {
|
||||
setStateProducing(state)
|
||||
if (setStateProducing(state)) {
|
||||
return true
|
||||
}
|
||||
// lost race, try again
|
||||
}
|
||||
|
||||
is ChannelState.Producing -> return true // ok
|
||||
|
||||
is ChannelState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
|
||||
is ChannelState.Producing,
|
||||
is ChannelState.Consuming -> throw IllegalChannelStateException(state) // a value is already ready
|
||||
|
||||
is ChannelState.Consumed -> {
|
||||
if (compareAndSetState(state, ChannelState.ProducerReady { state.producer })) {
|
||||
@ -200,7 +224,12 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
||||
}
|
||||
}
|
||||
|
||||
override fun finish() {
|
||||
finishImpl(null)
|
||||
override fun close() {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ChannelState.Finished -> return
|
||||
else -> if (setStateFinished(state, "OnDemandChannel is closed normally", null)) return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -21,36 +21,27 @@ import kotlin.coroutines.cancellation.CancellationException
|
||||
|
||||
/**
|
||||
* 按需供给的 [SendChannel].
|
||||
*
|
||||
* @param T 令牌类型.
|
||||
* @param V 值类型.
|
||||
*/
|
||||
public interface OnDemandSendChannel<T, V> {
|
||||
/**
|
||||
* 挂起协程, 直到 [OnDemandReceiveChannel] [期望接收][OnDemandReceiveChannel.receiveOrNull]一个 [V], 届时将 [value] 传递给 [OnDemandReceiveChannel.receiveOrNull], 成为其返回值.
|
||||
* 挂起协程, 直到 [OnDemandReceiveChannel] [期望接收][OnDemandReceiveChannel.receiveOrNull]一个 [V],
|
||||
* 届时将 [value] 传递给 [OnDemandReceiveChannel.receiveOrNull], 成为其返回值.
|
||||
*
|
||||
* 若在调用 [emit] 时已经有 [OnDemandReceiveChannel.receiveOrNull] 正在等待, 则该协程会立即[恢复][Continuation.resumeWith], [emit] 不会挂起.
|
||||
*
|
||||
* 若 [OnDemandReceiveChannel] 已经[完结][OnDemandReceiveChannel.finish], [OnDemandSendChannel.emit] 会抛出 [IllegalProducerStateException].
|
||||
* 若 [OnDemandReceiveChannel] 已经[完结][OnDemandReceiveChannel.close], [OnDemandSendChannel.emit] 会抛出 [IllegalChannelStateException].
|
||||
*
|
||||
* @see OnDemandReceiveChannel.receiveOrNull
|
||||
*
|
||||
* @param value 需要传递给 [OnDemandReceiveChannel.receiveOrNull] 的值
|
||||
* @return 下一个 ticket [T].
|
||||
*
|
||||
* @throws CancellationException 当此协程被取消时抛出
|
||||
*/
|
||||
public suspend fun emit(value: V): T
|
||||
|
||||
/**
|
||||
* 标记此 [OnDemandSendChannel] 在生产 [V] 的过程中出现异常.
|
||||
*
|
||||
* 这也会终止此 [OnDemandSendChannel], 若 [OnDemandReceiveChannel] 正在期待一个值, 则当它调用 [OnDemandReceiveChannel.receiveOrNull] 时, 它将得到一个 [ProducerFailureException].
|
||||
*
|
||||
* 在 [finishExceptionally] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] 或 [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
|
||||
*/
|
||||
public fun finishExceptionally(exception: Throwable)
|
||||
|
||||
/**
|
||||
* 标记此 [OnDemandSendChannel] 已经没有更多 [V] 可生产.
|
||||
*
|
||||
* 这会终止此 [OnDemandSendChannel], 若 [OnDemandReceiveChannel] 正在期待一个值, 则当它调用 [OnDemandReceiveChannel.receiveOrNull] 时, 它将得到一个 [ProducerFailureException].
|
||||
*
|
||||
* 在 [finish] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] 或 [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
|
||||
*/
|
||||
public fun finish()
|
||||
}
|
||||
|
||||
|
||||
@ -60,36 +51,42 @@ public interface OnDemandSendChannel<T, V> {
|
||||
* 与 [ReceiveChannel] 不同, [OnDemandReceiveChannel] 只有在调用 [expectMore] 后才会让[生产者][OnDemandSendChannel] 开始生产下一个 [V].
|
||||
*/
|
||||
public interface OnDemandReceiveChannel<T, V> {
|
||||
/**
|
||||
* 当此 [OnDemandReceiveChannel] 已经关闭, 即不再期望更多值时返回 `true`,
|
||||
* 无论是调用了 [close] (主动关闭) 还是 [OnDemandSendChannel] 没有更多值了 (被动关闭).
|
||||
*/
|
||||
public val isClosed: Boolean
|
||||
|
||||
/**
|
||||
* 尝试从 [OnDemandSendChannel] [接收][OnDemandSendChannel.emit]一个 [V].
|
||||
* 当且仅当在 [OnDemandSendChannel] 已经[正常结束][OnDemandSendChannel.finish] 时返回 `null`.
|
||||
* 当且仅当在 [OnDemandSendChannel] 已经正常结束时返回 `null`.
|
||||
*
|
||||
* 若目前已有 [V], 此函数立即返回该 [V], 不会挂起.
|
||||
* 否则, 此函数将会挂起直到 [OnDemandSendChannel.emit].
|
||||
*
|
||||
* 当此函数被多个协程 (线程) 同时调用时, 只有一个协程会获得 [V], 其他协程将会挂起.
|
||||
*
|
||||
* 若在等待过程中 [OnDemandSendChannel] [异常结束][OnDemandSendChannel.finishExceptionally],
|
||||
* 若在等待过程中 [OnDemandSendChannel] 异常结束,
|
||||
* 本函数会立即恢复并抛出 [ProducerFailureException], 其 `cause` 为令 [OnDemandSendChannel] 的异常.
|
||||
*
|
||||
* 此挂起函数可被取消.
|
||||
* 如果在此函数挂起时当前协程的 [Job] 被取消或完结, 此函数会立即恢复并抛出 [CancellationException]. 此行为与 [Deferred.await] 相同.
|
||||
*
|
||||
* @throws ProducerFailureException 当 [OnDemandSendChannel.finishExceptionally] 时抛出.
|
||||
* @throws ProducerFailureException 当 [OnDemandSendChannel] 产生了一个异常时抛出.
|
||||
* @throws CancellationException 当协程被取消时抛出
|
||||
* @throws IllegalProducerStateException 当状态异常, 如未调用 [expectMore] 时抛出
|
||||
* @throws IllegalChannelStateException 当状态异常, 如未调用 [expectMore] 时抛出
|
||||
*/
|
||||
@Throws(ProducerFailureException::class, CancellationException::class)
|
||||
public suspend fun receiveOrNull(): V?
|
||||
|
||||
/**
|
||||
* 期待 [OnDemandSendChannel] 再生产一个 [V].
|
||||
* 期望生产后必须在之后调用 [receiveOrNull] 或 [finish] 来消耗生产的 [V].
|
||||
* 期望生产后必须在之后调用 [receiveOrNull] 或 [close] 来消耗生产的 [V].
|
||||
* 不可连续重复调用 [expectMore].
|
||||
*
|
||||
* 在成功发起期待后返回 `true`; 在 [OnDemandSendChannel] 已经[完结][OnDemandSendChannel.finish] 时返回 `false`.
|
||||
*
|
||||
* @throws IllegalProducerStateException 当 [expectMore] 被调用后, 没有调用 [receiveOrNull] 就又调用了 [expectMore] 时抛出
|
||||
* @throws IllegalChannelStateException 当 [expectMore] 被调用后, 没有调用 [receiveOrNull] 就又调用了 [expectMore] 时抛出
|
||||
*/
|
||||
public fun expectMore(ticket: T): Boolean
|
||||
|
||||
@ -98,10 +95,11 @@ public interface OnDemandReceiveChannel<T, V> {
|
||||
*
|
||||
* 如果 [OnDemandSendChannel] 仍在运行 (无论是挂起中还是正在计算下一个值), 都会正常地[取消][Job.cancel] [OnDemandSendChannel].
|
||||
*
|
||||
* 若此 [OnDemandSendChannel] 已经被关闭, 则此函数不会进行任何操作.
|
||||
*
|
||||
* 在 [finish] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] 或 [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
|
||||
* 在 [close] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] 或 [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
|
||||
*/
|
||||
public fun finish()
|
||||
public fun close()
|
||||
}
|
||||
|
||||
@Suppress("FunctionName")
|
||||
|
@ -10,6 +10,6 @@
|
||||
package net.mamoe.mirai.utils.channels
|
||||
|
||||
public class ProducerFailureException(
|
||||
override val message: String? = null,
|
||||
override val message: String? = "Producer failed to produce a value, see cause",
|
||||
override val cause: Throwable?
|
||||
) : Exception()
|
@ -10,6 +10,10 @@
|
||||
package net.mamoe.mirai.utils.channels
|
||||
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import net.mamoe.mirai.utils.AtomicBoolean
|
||||
import net.mamoe.mirai.utils.testFramework.assertCoroutineSuspends
|
||||
import net.mamoe.mirai.utils.testFramework.assertNoCoroutineSuspension
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
@ -25,7 +29,7 @@ class OnDemandChannelTest {
|
||||
fail()
|
||||
}
|
||||
assertEquals(1, job.children.toList().size)
|
||||
channel.finish()
|
||||
channel.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -38,14 +42,113 @@ class OnDemandChannelTest {
|
||||
val job = supervisor.children.single()
|
||||
assertEquals(true, job.isActive)
|
||||
|
||||
channel.finish()
|
||||
channel.close()
|
||||
|
||||
assertEquals(0, supervisor.children.toList().size)
|
||||
assertEquals(false, job.isActive)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `cancel producer job on finish`() = runTest {
|
||||
// Actually, this case won't happen, because producer coroutine will be cancelled on [finish]
|
||||
|
||||
lateinit var job: Job
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
job = currentCoroutineContext()[Job]!!
|
||||
emit(1)
|
||||
emit(1)
|
||||
emit(1)
|
||||
emit(1)
|
||||
fail()
|
||||
}
|
||||
|
||||
channel.expectMore(1)
|
||||
channel.receiveOrNull()
|
||||
assertTrue { job.isActive }
|
||||
channel.close()
|
||||
assertFalse { job.isActive }
|
||||
yield()
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Producer Coroutine — Lazy Initialization
|
||||
// Producer Coroutine — Tickets
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@Test
|
||||
fun `producer receives initial ticket`() = runTest {
|
||||
val channel = OnDemandChannel(currentCoroutineContext()) { initialTicket ->
|
||||
assertEquals(1, initialTicket)
|
||||
emit(2)
|
||||
}
|
||||
|
||||
channel.expectMore(1)
|
||||
channel.receiveOrNull()
|
||||
|
||||
channel.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `producer receives second ticket`() = runTest {
|
||||
val channel = OnDemandChannel(currentCoroutineContext()) { initialTicket ->
|
||||
assertEquals(1, initialTicket)
|
||||
assertEquals(2, emit(3))
|
||||
}
|
||||
|
||||
channel.expectMore(1)
|
||||
channel.receiveOrNull()
|
||||
channel.expectMore(2)
|
||||
|
||||
channel.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `producer receives third ticket`() = runTest {
|
||||
val channel = OnDemandChannel(currentCoroutineContext()) { initialTicket ->
|
||||
assertEquals(1, initialTicket)
|
||||
assertEquals(2, emit(4))
|
||||
assertEquals(3, emit(5))
|
||||
}
|
||||
|
||||
channel.expectMore(1)
|
||||
channel.receiveOrNull()
|
||||
channel.expectMore(2)
|
||||
channel.receiveOrNull()
|
||||
channel.expectMore(3)
|
||||
|
||||
channel.close()
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Consumer — Receive Correct Values
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@Test
|
||||
fun `receives correct first value`() = runTest {
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
emit(3)
|
||||
}
|
||||
|
||||
channel.expectMore(1)
|
||||
assertEquals(3, channel.receiveOrNull())
|
||||
channel.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `receives correct second value`() = runTest {
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
emit(3)
|
||||
emit(4)
|
||||
}
|
||||
|
||||
channel.expectMore(1)
|
||||
assertEquals(3, channel.receiveOrNull())
|
||||
channel.expectMore(2)
|
||||
assertEquals(4, channel.receiveOrNull())
|
||||
channel.close()
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// expectMore/emit/receiveOrNull
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@Test
|
||||
@ -53,11 +156,11 @@ class OnDemandChannelTest {
|
||||
val channel = OnDemandChannel<Int, Int> {
|
||||
fail()
|
||||
}
|
||||
channel.finish()
|
||||
channel.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `producer coroutine starts iff expectMore`() = runBlocking(Dispatchers.Default.limitedParallelism(1)) {
|
||||
fun `producer coroutine starts iff expectMore`() = runTest {
|
||||
var started = false
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
// (1)
|
||||
@ -67,10 +170,164 @@ class OnDemandChannelTest {
|
||||
fail()
|
||||
}
|
||||
assertFalse { started }
|
||||
channel.expectMore(1) // launches the job, but it won't execute due to single parallelism
|
||||
assertTrue { channel.expectMore(1) } // launches the job, but it won't execute due to single parallelism
|
||||
yield() // goto (1)
|
||||
// (2)
|
||||
assertTrue { started }
|
||||
channel.finish()
|
||||
channel.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `receiveOrNull does not suspend if value is ready`() = runTest {
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
emit(1)
|
||||
}
|
||||
|
||||
assertTrue { channel.expectMore(1) }
|
||||
yield() // run `emit`
|
||||
// now value is ready
|
||||
assertNoCoroutineSuspension { channel.receiveOrNull() }
|
||||
channel.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `receiveOrNull does suspend if value is not ready`() = runTest {
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
yield()
|
||||
emit(1)
|
||||
}
|
||||
|
||||
assertTrue { channel.expectMore(1) }
|
||||
assertCoroutineSuspends { channel.receiveOrNull() }
|
||||
channel.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `emit won't resume unless another expectMore`() = runTest {
|
||||
val canResume = AtomicBoolean(false)
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
emit(1)
|
||||
|
||||
if (!canResume.value) fail("Emit should not resume")
|
||||
|
||||
canResume.value = false
|
||||
}
|
||||
|
||||
channel.expectMore(1)
|
||||
channel.receiveOrNull()
|
||||
canResume.value = true
|
||||
channel.expectMore(2)
|
||||
yield() // run producer
|
||||
assertEquals(false, canResume.value)
|
||||
channel.close()
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Operation while already finished
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@Test
|
||||
fun `expectMore and receiveOrNull while already finished just after instantiation`() = runTest {
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
fail("Producer should not run")
|
||||
}
|
||||
channel.close()
|
||||
|
||||
assertFalse { channel.expectMore(1) }
|
||||
assertNull(channel.receiveOrNull())
|
||||
assertFalse { channel.expectMore(1) }
|
||||
assertFalse { channel.expectMore(1) }
|
||||
assertNull(channel.receiveOrNull())
|
||||
assertNull(channel.receiveOrNull())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `expectMore and receiveOrNull while already finished`() = runTest {
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
emit(1)
|
||||
}
|
||||
|
||||
assertTrue { channel.expectMore(1) }
|
||||
assertNotNull(channel.receiveOrNull())
|
||||
assertFalse { channel.isClosed }
|
||||
|
||||
assertTrue { channel.expectMore(1) } // `expectMore` don't know if more values are available
|
||||
yield() // go to producer
|
||||
// now we must know producer has no more value
|
||||
assertTrue { channel.isClosed }
|
||||
assertNull(channel.receiveOrNull())
|
||||
assertFalse { channel.expectMore(1) }
|
||||
assertNull(channel.receiveOrNull())
|
||||
assertFalse { (channel as CoroutineOnDemandReceiveChannel).getScope().isActive }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `emit while already finished`() {
|
||||
// Actually, this case won't happen, because producer coroutine will be cancelled on [finish]
|
||||
|
||||
`cancel producer job on finish`()
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `producer exception closes channel then receiveOrNull throws`() = runTest {
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
throw NoSuchElementException("Oops")
|
||||
}
|
||||
|
||||
assertTrue { channel.expectMore(1) }
|
||||
assertFalse { channel.isClosed }
|
||||
assertIs<ChannelState.Producing<*, *>>(channel.state)
|
||||
assertFailsWith<ProducerFailureException> {
|
||||
println(channel.receiveOrNull())
|
||||
}.also {
|
||||
assertIs<NoSuchElementException>(it.cause)
|
||||
}
|
||||
assertTrue { channel.isClosed }
|
||||
|
||||
// The exception looks like this, though I don't know why there are two causes.
|
||||
|
||||
//net.mamoe.mirai.utils.channels.ProducerFailureException: Producer failed to produce a value, see cause
|
||||
// at net.mamoe.mirai.utils.channels.CoroutineOnDemandReceiveChannel.receiveOrNull(OnDemandChannelImpl.kt:164)
|
||||
// at net.mamoe.mirai.utils.channels.CoroutineOnDemandReceiveChannel$receiveOrNull$1.invokeSuspend(OnDemandChannelImpl.kt)
|
||||
// at kotlin.coroutines.jvm.internal.BaseContinuationImpl.resumeWith(ContinuationImpl.kt:33)
|
||||
// at kotlinx.coroutines.test.TestBuildersKt.runTest$default(Unknown Source)
|
||||
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest.producer exception(OnDemandChannelTest.kt:273)
|
||||
// at worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)
|
||||
//Caused by: java.util.NoSuchElementException: Oops
|
||||
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest$producer exception$1$channel$1.invokeSuspend(OnDemandChannelTest.kt:275)
|
||||
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest$producer exception$1$channel$1.invoke(OnDemandChannelTest.kt)
|
||||
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest$producer exception$1$channel$1.invoke(OnDemandChannelTest.kt)
|
||||
// at net.mamoe.mirai.utils.channels.CoroutineOnDemandReceiveChannel$Producer$1.invokeSuspend(OnDemandChannelImpl.kt:46)
|
||||
// (Coroutine boundary)
|
||||
// at net.mamoe.mirai.utils.channels.CoroutineOnDemandReceiveChannel.receiveOrNull(OnDemandChannelImpl.kt:162)
|
||||
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest$producer exception$1.invokeSuspend(OnDemandChannelTest.kt:280)
|
||||
// at kotlinx.coroutines.test.TestBuildersKt__TestBuildersKt$runTestCoroutine$2.invokeSuspend(TestBuilders.kt:212)
|
||||
//Caused by: java.util.NoSuchElementException: Oops
|
||||
// ...
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `producer exception closes channel then receiveOrNull throws in Producing state`() = runTest {
|
||||
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
|
||||
throw NoSuchElementException("Oops")
|
||||
}
|
||||
|
||||
assertTrue { channel.expectMore(1) }
|
||||
yield() // fail the channel first
|
||||
assertIs<ChannelState.Consuming<*, *>>(channel.state)
|
||||
assertFalse { channel.isClosed } // channel won't close until receiveOrNull
|
||||
|
||||
assertFailsWith<ProducerFailureException> {
|
||||
println(channel.receiveOrNull())
|
||||
}.also {
|
||||
assertIs<NoSuchElementException>(it.cause)
|
||||
}
|
||||
assertTrue { channel.isClosed }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private val <T, V> OnDemandReceiveChannel<T, V>.state
|
||||
get() = (this as CoroutineOnDemandReceiveChannel<T, V>).getState()
|
@ -34,14 +34,23 @@ suspend inline fun <R> assertNoCoroutineSuspension(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes [block], and asserts there happens at least one coroutine suspension in [block].
|
||||
*
|
||||
* When the first coroutine suspension happens, [onSuspend] will be called.
|
||||
*/
|
||||
@OptIn(ExperimentalStdlibApi::class)
|
||||
suspend inline fun <R> assertCoroutineSuspends(
|
||||
noinline onSuspend: (suspend () -> Unit)? = null,
|
||||
crossinline block: suspend () -> R,
|
||||
): R {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
|
||||
return withContext(Dispatchers.Default.limitedParallelism(1)) {
|
||||
val dispatcher = currentCoroutineContext()[CoroutineDispatcher] ?: Dispatchers.Main.limitedParallelism(1)
|
||||
return withContext(dispatcher.limitedParallelism(1)) {
|
||||
val job = launch(start = CoroutineStart.UNDISPATCHED) {
|
||||
yield() // goto block
|
||||
onSuspend?.invoke()
|
||||
}
|
||||
val ret = block()
|
||||
kotlin.test.assertTrue("Expected coroutine suspension") { job.isCompleted }
|
||||
|
@ -22,7 +22,6 @@ import net.mamoe.mirai.utils.channels.ProducerFailureException
|
||||
import net.mamoe.mirai.utils.debug
|
||||
import net.mamoe.mirai.utils.verbose
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.coroutines.cancellation.CancellationException
|
||||
|
||||
|
||||
/**
|
||||
@ -45,20 +44,7 @@ internal class AuthControl(
|
||||
logger.subLogger("AuthControl/UserDecisions").asUtilsLogger()
|
||||
) { _ ->
|
||||
val sessionImpl = SafeBotAuthSession(this)
|
||||
|
||||
try {
|
||||
logger.verbose { "[AuthControl/auth] Authorization started" }
|
||||
|
||||
authorization.authorize(sessionImpl, botAuthInfo)
|
||||
|
||||
logger.verbose { "[AuthControl/auth] Authorization exited" }
|
||||
finish()
|
||||
} catch (e: CancellationException) {
|
||||
logger.verbose { "[AuthControl/auth] Authorization cancelled" }
|
||||
} catch (e: Throwable) {
|
||||
logger.verbose { "[AuthControl/auth] Authorization failed: $e" }
|
||||
finishExceptionally(e)
|
||||
}
|
||||
authorization.authorize(sessionImpl, botAuthInfo) // OnDemandChannel handles exceptions for us
|
||||
}
|
||||
|
||||
fun start() {
|
||||
@ -86,6 +72,6 @@ internal class AuthControl(
|
||||
|
||||
fun actComplete() {
|
||||
logger.verbose { "[AuthControl/resume] Fire auth completed" }
|
||||
userDecisions.finish()
|
||||
userDecisions.close()
|
||||
}
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ package net.mamoe.mirai.internal.network.auth
|
||||
import net.mamoe.mirai.auth.BotAuthResult
|
||||
import net.mamoe.mirai.internal.network.components.SsoProcessorImpl
|
||||
import net.mamoe.mirai.utils.SecretsProtection
|
||||
import net.mamoe.mirai.utils.channels.IllegalProducerStateException
|
||||
import net.mamoe.mirai.utils.channels.IllegalChannelStateException
|
||||
import net.mamoe.mirai.utils.channels.OnDemandSendChannel
|
||||
|
||||
/**
|
||||
@ -40,7 +40,7 @@ internal class SafeBotAuthSession(
|
||||
private inline fun <R> runWrapInternalException(block: () -> R): R {
|
||||
try {
|
||||
return block()
|
||||
} catch (e: IllegalProducerStateException) {
|
||||
} catch (e: IllegalChannelStateException) {
|
||||
if (e.lastStateWasSucceed) {
|
||||
throw IllegalStateException(
|
||||
"This login session has already completed. Please return the BotAuthResult you get from 'authBy*()' immediately",
|
||||
|
Loading…
Reference in New Issue
Block a user