[core] OnDemandChannel: Ensure scope is closed on exceptional cases:

[core] OnDemandChannel: Ensure channel close and emit shares same behavior

[core] OnDemandChannel: Ensure channel is closed properly when producer throws an exception

[core] OnDemandChannel: Ensure coroutine scope cancelled when producer coroutine throws exception,

also added more tests
This commit is contained in:
Him188 2023-04-23 12:54:46 +01:00
parent 87fbaa4fb2
commit 0e625d3368
9 changed files with 411 additions and 131 deletions

View File

@ -101,28 +101,27 @@ internal sealed interface ChannelState<T, V> {
val producer: OnDemandSendChannel<T, V>
}
// Producer is not running until `expectMore`. `emit` and `receiveOrNull` not allowed.
class ProducerReady<T, V>(
launchProducer: () -> OnDemandSendChannel<T, V>,
) : HasProducer<T, V> {
// Lazily start the producer job since it's on-demand
override val producer: OnDemandSendChannel<T, V> by lazy(launchProducer) // `lazy` is synchronized
fun startProducerIfNotYet() {
producer
}
override fun toString(): String = "ProducerReady"
}
// Producer is running. `emit` and `receiveOrNull` both allowed.
class Producing<T, V>(
override val producer: OnDemandSendChannel<T, V>,
parentJob: Job,
) : HasProducer<T, V> {
val deferred: CompletableDeferred<V> by lazy { CompletableDeferred<V>(parentJob) }
override fun toString(): String = "Producing(deferred.completed=${deferred.isCompleted})"
}
// Producer is suspended because it called `emit`. Expecting `receiveOrNull`.
class Consuming<T, V>(
override val producer: OnDemandSendChannel<T, V>,
val value: Deferred<V>,
@ -138,6 +137,7 @@ internal sealed interface ChannelState<T, V> {
}
}
// Producer is suspended. `expectMore` will resume producer with a ticket.
class Consumed<T, V>(
override val producer: OnDemandSendChannel<T, V>,
val producerLatch: CompletableDeferred<T>
@ -151,7 +151,7 @@ internal sealed interface ChannelState<T, V> {
) : ChannelState<T, V> {
val isSuccess: Boolean get() = exception == null
fun createAlreadyFinishedException(cause: Throwable?): IllegalProducerStateException {
fun createAlreadyFinishedException(cause: Throwable?): IllegalChannelStateException {
val exception = exception
val causeMessage = if (cause == null) {
""
@ -159,13 +159,13 @@ internal sealed interface ChannelState<T, V> {
", but attempting to finish with the cause $cause"
}
return if (exception == null) {
IllegalProducerStateException(
IllegalChannelStateException(
this,
"Producer has already finished normally$causeMessage. Previous state was: $previousState",
cause = cause
)
} else {
IllegalProducerStateException(
IllegalChannelStateException(
this,
"Producer has already finished with the suppressed exception$causeMessage. Previous state was: $previousState",
cause = cause

View File

@ -9,7 +9,8 @@
package net.mamoe.mirai.utils.channels
public class IllegalProducerStateException internal constructor(
// An internal error exception
public class IllegalChannelStateException internal constructor(
private val state: ChannelState<*, *>,
message: String? = state.toString(),
cause: Throwable? = null,

View File

@ -13,11 +13,11 @@ import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.loop
import kotlinx.coroutines.*
import net.mamoe.mirai.utils.TestOnly
import net.mamoe.mirai.utils.UtilsLogger
import net.mamoe.mirai.utils.childScope
import net.mamoe.mirai.utils.debug
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.cancellation.CancellationException
internal class CoroutineOnDemandReceiveChannel<T, V>(
@ -27,8 +27,14 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
) : OnDemandReceiveChannel<T, V> {
private val coroutineScope = parentCoroutineContext.childScope("CoroutineOnDemandReceiveChannel")
@TestOnly
internal fun getScope() = coroutineScope
private val state: AtomicRef<ChannelState<T, V>> = atomic(ChannelState.JustInitialized())
@TestOnly
internal fun getState() = state.value
inner class Producer(
private val initialTicket: T,
@ -38,20 +44,33 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
// attaching Job to the coroutineScope, then `yield` the thread back, to complete `launch`.
coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
yield()
try {
producerCoroutine(initialTicket)
} catch (_: CancellationException) {
// ignored
} catch (e: Exception) {
finishExceptionally(e)
} catch (e: Throwable) {
// close exceptionally
val r = emitImpl(Result.failure(e))
check(r == null) // assertion
return@launch
}
close()
}
}
override suspend fun emit(value: V): T {
override suspend fun emit(value: V): T = emitImpl(Result.success(value))!!
private suspend inline fun emitImpl(value: Result<V>): T? {
state.loop { state ->
when (state) {
is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
is ChannelState.Finished -> {
if (value.isFailure) {
return null
} else {
throw state.createAlreadyFinishedException(null)
}
}
is ChannelState.Producing -> {
val deferred = state.deferred
val consumingState = ChannelState.Consuming(
@ -60,55 +79,46 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
coroutineScope.coroutineContext
)
if (compareAndSetState(state, consumingState)) {
deferred.complete(value) // produce a value
deferred.completeWith(value) // produce a value
return consumingState.producerLatch.await() // wait for producer to consume the previous value.
}
// failed race, try again
}
is ChannelState.ProducerReady -> {
// This implies another coroutine is running `expectMore`,
// and we are a bit faster than it!
setStateProducing(state)
}
else -> throw IllegalProducerStateException(state)
}
}
}
else -> throw IllegalChannelStateException(
state,
if (value.isFailure)
"Producer threw an exception (see cause), so completing with the exception, but current state is not Producing"
else "Producer is emitting an value, but current state is not Producing",
value.exceptionOrNull()
)
override fun finishExceptionally(exception: Throwable) {
finishImpl(exception)
}
override fun finish() {
state.loop { state ->
when (state) {
is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
else -> {
if (compareAndSetState(state, ChannelState.Finished(state, null))) {
return
}
}
}
}
}
}
private fun setStateProducing(state: ChannelState.ProducerReady<T, V>) {
compareAndSetState(state, ChannelState.Producing(state.producer, coroutineScope.coroutineContext.job))
private fun setStateProducing(state: ChannelState.ProducerReady<T, V>): Boolean {
return compareAndSetState(state, ChannelState.Producing(state.producer, coroutineScope.coroutineContext.job))
}
private fun finishImpl(exception: Throwable?) {
state.loop { state ->
when (state) {
is ChannelState.Finished -> {} // ignore
else -> {
if (compareAndSetState(state, ChannelState.Finished(state, exception))) {
val cancellationException = kotlinx.coroutines.CancellationException("Finished", exception)
coroutineScope.cancel(cancellationException)
return
}
}
}
private fun setStateFinished(
currState: ChannelState<T, V>,
message: String,
exception: ProducerFailureException?
): Boolean {
if (compareAndSetState(currState, ChannelState.Finished(currState, exception))) {
val cancellationException = CancellationException(message, exception)
coroutineScope.cancel(cancellationException)
return true
}
return false
}
private fun compareAndSetState(state: ChannelState<T, V>, newState: ChannelState<T, V>): Boolean {
@ -117,27 +127,27 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
}
}
override val isClosed: Boolean
get() = state.value is ChannelState.Finished
override suspend fun receiveOrNull(): V? {
// don't use `.loop`:
// don't use atomicfu `.loop`:
// java.lang.VerifyError: Bad type on operand stack
// net/mamoe/mirai/utils/channels/CoroutineOnDemandReceiveChannel.receiveOrNull(Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @103: getfield
while (true) {
when (val state = state.value) {
is ChannelState.Consuming -> {
// value is ready, now we consume the value
// value is ready, now we try to consume the value
if (compareAndSetState(state, ChannelState.Consumed(state.producer, state.producerLatch))) {
// value is consumed, no contention, safe to retrieve
// value is now reserved for us, no contention is possible, safe to retrieve
return try {
// This actually won't suspend, since the value is already completed
// Just to be error-tolerating and re-throwing exceptions.
state.value.await()
} catch (e: Throwable) {
// Producer failed to produce the previous value with exception
throw ProducerFailureException(cause = e)
}
// This actually won't suspend (there are tests ensuring this point),
// since the value is already completed.
// Just to be error-tolerating and re-throwing exceptions.
// (Also because `Deferred.getCompleted()` is not stable yet (coroutines 1.6))
return awaitValueSafe(state.value)
}
}
@ -146,45 +156,59 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
is ChannelState.Producing<T, V> -> {
// still producing value
state.deferred.await() // just wait for value, but does not return it.
// Wait for value and throw exception caused by the producer if there is one.
awaitValueSafe(state.deferred) // this may or may not suspend.
// The value will be completed in ProducerState.Consuming state,
// but you cannot thread-safely assume current state is Consuming.
// Now deferred is complete, and we will be in the Consuming state, but we can't use the value here.
// We must ensure only one thread gets the value, and state should then be Consumed
// Here we will loop again, to atomically switch to Consumed state.
// So we loop again and do this in the Consuming state.
}
is ChannelState.Finished -> {
state.exception?.let { err ->
throw ProducerFailureException(cause = err)
}
// see public API docs for behavior
return null
}
else -> throw IllegalProducerStateException(state)
else ->
// internal error
throw IllegalChannelStateException(state)
}
}
}
private suspend inline fun awaitValueSafe(deferred: Deferred<V>) = try {
deferred.await()
} catch (e: Throwable) {
// Producer failed to produce the previous value with exception
val producerFailureException = ProducerFailureException(cause = e)
setStateFinished(
this.state.value,
"OnDemandChannel is closed because producer failed to produce value, see cause",
producerFailureException
)
throw producerFailureException
}
override fun expectMore(ticket: T): Boolean {
state.loop { state ->
when (state) {
is ChannelState.JustInitialized -> {
// start producer atomically
val ready = ChannelState.ProducerReady { Producer(ticket) }
if (compareAndSetState(state, ready)) {
ready.startProducerIfNotYet()
}
compareAndSetState(state, ready)
// loop again
}
is ChannelState.ProducerReady -> {
setStateProducing(state)
if (setStateProducing(state)) {
return true
}
// lost race, try again
}
is ChannelState.Producing -> return true // ok
is ChannelState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
is ChannelState.Producing,
is ChannelState.Consuming -> throw IllegalChannelStateException(state) // a value is already ready
is ChannelState.Consumed -> {
if (compareAndSetState(state, ChannelState.ProducerReady { state.producer })) {
@ -200,7 +224,12 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
}
}
override fun finish() {
finishImpl(null)
override fun close() {
state.loop { state ->
when (state) {
is ChannelState.Finished -> return
else -> if (setStateFinished(state, "OnDemandChannel is closed normally", null)) return
}
}
}
}

View File

@ -21,36 +21,27 @@ import kotlin.coroutines.cancellation.CancellationException
/**
* 按需供给的 [SendChannel].
*
* @param T 令牌类型.
* @param V 值类型.
*/
public interface OnDemandSendChannel<T, V> {
/**
* 挂起协程, 直到 [OnDemandReceiveChannel] [期望接收][OnDemandReceiveChannel.receiveOrNull]一个 [V], 届时将 [value] 传递给 [OnDemandReceiveChannel.receiveOrNull], 成为其返回值.
* 挂起协程, 直到 [OnDemandReceiveChannel] [期望接收][OnDemandReceiveChannel.receiveOrNull]一个 [V],
* 届时将 [value] 传递给 [OnDemandReceiveChannel.receiveOrNull], 成为其返回值.
*
* 若在调用 [emit] 时已经有 [OnDemandReceiveChannel.receiveOrNull] 正在等待, 则该协程会立即[恢复][Continuation.resumeWith], [emit] 不会挂起.
*
* [OnDemandReceiveChannel] 已经[完结][OnDemandReceiveChannel.finish], [OnDemandSendChannel.emit] 会抛出 [IllegalProducerStateException].
* [OnDemandReceiveChannel] 已经[完结][OnDemandReceiveChannel.close], [OnDemandSendChannel.emit] 会抛出 [IllegalChannelStateException].
*
* @see OnDemandReceiveChannel.receiveOrNull
*
* @param value 需要传递给 [OnDemandReceiveChannel.receiveOrNull] 的值
* @return 下一个 ticket [T].
*
* @throws CancellationException 当此协程被取消时抛出
*/
public suspend fun emit(value: V): T
/**
* 标记此 [OnDemandSendChannel] 在生产 [V] 的过程中出现异常.
*
* 这也会终止此 [OnDemandSendChannel], [OnDemandReceiveChannel] 正在期待一个值, 则当它调用 [OnDemandReceiveChannel.receiveOrNull] , 它将得到一个 [ProducerFailureException].
*
* [finishExceptionally] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
*/
public fun finishExceptionally(exception: Throwable)
/**
* 标记此 [OnDemandSendChannel] 已经没有更多 [V] 可生产.
*
* 这会终止此 [OnDemandSendChannel], [OnDemandReceiveChannel] 正在期待一个值, 则当它调用 [OnDemandReceiveChannel.receiveOrNull] , 它将得到一个 [ProducerFailureException].
*
* [finish] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
*/
public fun finish()
}
@ -60,36 +51,42 @@ public interface OnDemandSendChannel<T, V> {
* [ReceiveChannel] 不同, [OnDemandReceiveChannel] 只有在调用 [expectMore] 后才会让[生产者][OnDemandSendChannel] 开始生产下一个 [V].
*/
public interface OnDemandReceiveChannel<T, V> {
/**
* 当此 [OnDemandReceiveChannel] 已经关闭, 即不再期望更多值时返回 `true`,
* 无论是调用了 [close] (主动关闭) 还是 [OnDemandSendChannel] 没有更多值了 (被动关闭).
*/
public val isClosed: Boolean
/**
* 尝试从 [OnDemandSendChannel] [接收][OnDemandSendChannel.emit]一个 [V].
* 当且仅当在 [OnDemandSendChannel] 已经[正常结束][OnDemandSendChannel.finish] 时返回 `null`.
* 当且仅当在 [OnDemandSendChannel] 已经正常结束时返回 `null`.
*
* 若目前已有 [V], 此函数立即返回该 [V], 不会挂起.
* 否则, 此函数将会挂起直到 [OnDemandSendChannel.emit].
*
* 当此函数被多个协程 (线程) 同时调用时, 只有一个协程会获得 [V], 其他协程将会挂起.
*
* 若在等待过程中 [OnDemandSendChannel] [异常结束][OnDemandSendChannel.finishExceptionally],
* 若在等待过程中 [OnDemandSendChannel] 异常结束,
* 本函数会立即恢复并抛出 [ProducerFailureException], `cause` 为令 [OnDemandSendChannel] 的异常.
*
* 此挂起函数可被取消.
* 如果在此函数挂起时当前协程的 [Job] 被取消或完结, 此函数会立即恢复并抛出 [CancellationException]. 此行为与 [Deferred.await] 相同.
*
* @throws ProducerFailureException [OnDemandSendChannel.finishExceptionally] 时抛出.
* @throws ProducerFailureException [OnDemandSendChannel] 产生了一个异常时抛出.
* @throws CancellationException 当协程被取消时抛出
* @throws IllegalProducerStateException 当状态异常, 如未调用 [expectMore] 时抛出
* @throws IllegalChannelStateException 当状态异常, 如未调用 [expectMore] 时抛出
*/
@Throws(ProducerFailureException::class, CancellationException::class)
public suspend fun receiveOrNull(): V?
/**
* 期待 [OnDemandSendChannel] 再生产一个 [V].
* 期望生产后必须在之后调用 [receiveOrNull] [finish] 来消耗生产的 [V].
* 期望生产后必须在之后调用 [receiveOrNull] [close] 来消耗生产的 [V].
* 不可连续重复调用 [expectMore].
*
* 在成功发起期待后返回 `true`; [OnDemandSendChannel] 已经[完结][OnDemandSendChannel.finish] 时返回 `false`.
*
* @throws IllegalProducerStateException [expectMore] 被调用后, 没有调用 [receiveOrNull] 就又调用了 [expectMore] 时抛出
* @throws IllegalChannelStateException [expectMore] 被调用后, 没有调用 [receiveOrNull] 就又调用了 [expectMore] 时抛出
*/
public fun expectMore(ticket: T): Boolean
@ -98,10 +95,11 @@ public interface OnDemandReceiveChannel<T, V> {
*
* 如果 [OnDemandSendChannel] 仍在运行 (无论是挂起中还是正在计算下一个值), 都会正常地[取消][Job.cancel] [OnDemandSendChannel].
*
* 若此 [OnDemandSendChannel] 已经被关闭, 则此函数不会进行任何操作.
*
* [finish] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
* [close] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
*/
public fun finish()
public fun close()
}
@Suppress("FunctionName")

View File

@ -10,6 +10,6 @@
package net.mamoe.mirai.utils.channels
public class ProducerFailureException(
override val message: String? = null,
override val message: String? = "Producer failed to produce a value, see cause",
override val cause: Throwable?
) : Exception()

View File

@ -10,6 +10,10 @@
package net.mamoe.mirai.utils.channels
import kotlinx.coroutines.*
import kotlinx.coroutines.test.runTest
import net.mamoe.mirai.utils.AtomicBoolean
import net.mamoe.mirai.utils.testFramework.assertCoroutineSuspends
import net.mamoe.mirai.utils.testFramework.assertNoCoroutineSuspension
import kotlin.test.*
@ -25,7 +29,7 @@ class OnDemandChannelTest {
fail()
}
assertEquals(1, job.children.toList().size)
channel.finish()
channel.close()
}
@Test
@ -38,14 +42,113 @@ class OnDemandChannelTest {
val job = supervisor.children.single()
assertEquals(true, job.isActive)
channel.finish()
channel.close()
assertEquals(0, supervisor.children.toList().size)
assertEquals(false, job.isActive)
}
@Test
fun `cancel producer job on finish`() = runTest {
// Actually, this case won't happen, because producer coroutine will be cancelled on [finish]
lateinit var job: Job
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
job = currentCoroutineContext()[Job]!!
emit(1)
emit(1)
emit(1)
emit(1)
fail()
}
channel.expectMore(1)
channel.receiveOrNull()
assertTrue { job.isActive }
channel.close()
assertFalse { job.isActive }
yield()
}
///////////////////////////////////////////////////////////////////////////
// Producer Coroutine — Lazy Initialization
// Producer Coroutine — Tickets
///////////////////////////////////////////////////////////////////////////
@Test
fun `producer receives initial ticket`() = runTest {
val channel = OnDemandChannel(currentCoroutineContext()) { initialTicket ->
assertEquals(1, initialTicket)
emit(2)
}
channel.expectMore(1)
channel.receiveOrNull()
channel.close()
}
@Test
fun `producer receives second ticket`() = runTest {
val channel = OnDemandChannel(currentCoroutineContext()) { initialTicket ->
assertEquals(1, initialTicket)
assertEquals(2, emit(3))
}
channel.expectMore(1)
channel.receiveOrNull()
channel.expectMore(2)
channel.close()
}
@Test
fun `producer receives third ticket`() = runTest {
val channel = OnDemandChannel(currentCoroutineContext()) { initialTicket ->
assertEquals(1, initialTicket)
assertEquals(2, emit(4))
assertEquals(3, emit(5))
}
channel.expectMore(1)
channel.receiveOrNull()
channel.expectMore(2)
channel.receiveOrNull()
channel.expectMore(3)
channel.close()
}
///////////////////////////////////////////////////////////////////////////
// Consumer — Receive Correct Values
///////////////////////////////////////////////////////////////////////////
@Test
fun `receives correct first value`() = runTest {
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
emit(3)
}
channel.expectMore(1)
assertEquals(3, channel.receiveOrNull())
channel.close()
}
@Test
fun `receives correct second value`() = runTest {
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
emit(3)
emit(4)
}
channel.expectMore(1)
assertEquals(3, channel.receiveOrNull())
channel.expectMore(2)
assertEquals(4, channel.receiveOrNull())
channel.close()
}
///////////////////////////////////////////////////////////////////////////
// expectMore/emit/receiveOrNull
///////////////////////////////////////////////////////////////////////////
@Test
@ -53,11 +156,11 @@ class OnDemandChannelTest {
val channel = OnDemandChannel<Int, Int> {
fail()
}
channel.finish()
channel.close()
}
@Test
fun `producer coroutine starts iff expectMore`() = runBlocking(Dispatchers.Default.limitedParallelism(1)) {
fun `producer coroutine starts iff expectMore`() = runTest {
var started = false
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
// (1)
@ -67,10 +170,164 @@ class OnDemandChannelTest {
fail()
}
assertFalse { started }
channel.expectMore(1) // launches the job, but it won't execute due to single parallelism
assertTrue { channel.expectMore(1) } // launches the job, but it won't execute due to single parallelism
yield() // goto (1)
// (2)
assertTrue { started }
channel.finish()
channel.close()
}
@Test
fun `receiveOrNull does not suspend if value is ready`() = runTest {
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
emit(1)
}
assertTrue { channel.expectMore(1) }
yield() // run `emit`
// now value is ready
assertNoCoroutineSuspension { channel.receiveOrNull() }
channel.close()
}
@Test
fun `receiveOrNull does suspend if value is not ready`() = runTest {
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
yield()
emit(1)
}
assertTrue { channel.expectMore(1) }
assertCoroutineSuspends { channel.receiveOrNull() }
channel.close()
}
@Test
fun `emit won't resume unless another expectMore`() = runTest {
val canResume = AtomicBoolean(false)
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
emit(1)
if (!canResume.value) fail("Emit should not resume")
canResume.value = false
}
channel.expectMore(1)
channel.receiveOrNull()
canResume.value = true
channel.expectMore(2)
yield() // run producer
assertEquals(false, canResume.value)
channel.close()
}
///////////////////////////////////////////////////////////////////////////
// Operation while already finished
///////////////////////////////////////////////////////////////////////////
@Test
fun `expectMore and receiveOrNull while already finished just after instantiation`() = runTest {
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
fail("Producer should not run")
}
channel.close()
assertFalse { channel.expectMore(1) }
assertNull(channel.receiveOrNull())
assertFalse { channel.expectMore(1) }
assertFalse { channel.expectMore(1) }
assertNull(channel.receiveOrNull())
assertNull(channel.receiveOrNull())
}
@Test
fun `expectMore and receiveOrNull while already finished`() = runTest {
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
emit(1)
}
assertTrue { channel.expectMore(1) }
assertNotNull(channel.receiveOrNull())
assertFalse { channel.isClosed }
assertTrue { channel.expectMore(1) } // `expectMore` don't know if more values are available
yield() // go to producer
// now we must know producer has no more value
assertTrue { channel.isClosed }
assertNull(channel.receiveOrNull())
assertFalse { channel.expectMore(1) }
assertNull(channel.receiveOrNull())
assertFalse { (channel as CoroutineOnDemandReceiveChannel).getScope().isActive }
}
@Test
fun `emit while already finished`() {
// Actually, this case won't happen, because producer coroutine will be cancelled on [finish]
`cancel producer job on finish`()
}
@Test
fun `producer exception closes channel then receiveOrNull throws`() = runTest {
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
throw NoSuchElementException("Oops")
}
assertTrue { channel.expectMore(1) }
assertFalse { channel.isClosed }
assertIs<ChannelState.Producing<*, *>>(channel.state)
assertFailsWith<ProducerFailureException> {
println(channel.receiveOrNull())
}.also {
assertIs<NoSuchElementException>(it.cause)
}
assertTrue { channel.isClosed }
// The exception looks like this, though I don't know why there are two causes.
//net.mamoe.mirai.utils.channels.ProducerFailureException: Producer failed to produce a value, see cause
// at net.mamoe.mirai.utils.channels.CoroutineOnDemandReceiveChannel.receiveOrNull(OnDemandChannelImpl.kt:164)
// at net.mamoe.mirai.utils.channels.CoroutineOnDemandReceiveChannel$receiveOrNull$1.invokeSuspend(OnDemandChannelImpl.kt)
// at kotlin.coroutines.jvm.internal.BaseContinuationImpl.resumeWith(ContinuationImpl.kt:33)
// at kotlinx.coroutines.test.TestBuildersKt.runTest$default(Unknown Source)
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest.producer exception(OnDemandChannelTest.kt:273)
// at worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)
//Caused by: java.util.NoSuchElementException: Oops
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest$producer exception$1$channel$1.invokeSuspend(OnDemandChannelTest.kt:275)
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest$producer exception$1$channel$1.invoke(OnDemandChannelTest.kt)
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest$producer exception$1$channel$1.invoke(OnDemandChannelTest.kt)
// at net.mamoe.mirai.utils.channels.CoroutineOnDemandReceiveChannel$Producer$1.invokeSuspend(OnDemandChannelImpl.kt:46)
// (Coroutine boundary)
// at net.mamoe.mirai.utils.channels.CoroutineOnDemandReceiveChannel.receiveOrNull(OnDemandChannelImpl.kt:162)
// at net.mamoe.mirai.utils.channels.OnDemandChannelTest$producer exception$1.invokeSuspend(OnDemandChannelTest.kt:280)
// at kotlinx.coroutines.test.TestBuildersKt__TestBuildersKt$runTestCoroutine$2.invokeSuspend(TestBuilders.kt:212)
//Caused by: java.util.NoSuchElementException: Oops
// ...
}
@Test
fun `producer exception closes channel then receiveOrNull throws in Producing state`() = runTest {
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
throw NoSuchElementException("Oops")
}
assertTrue { channel.expectMore(1) }
yield() // fail the channel first
assertIs<ChannelState.Consuming<*, *>>(channel.state)
assertFalse { channel.isClosed } // channel won't close until receiveOrNull
assertFailsWith<ProducerFailureException> {
println(channel.receiveOrNull())
}.also {
assertIs<NoSuchElementException>(it.cause)
}
assertTrue { channel.isClosed }
}
}
private val <T, V> OnDemandReceiveChannel<T, V>.state
get() = (this as CoroutineOnDemandReceiveChannel<T, V>).getState()

View File

@ -34,14 +34,23 @@ suspend inline fun <R> assertNoCoroutineSuspension(
}
}
/**
* Executes [block], and asserts there happens at least one coroutine suspension in [block].
*
* When the first coroutine suspension happens, [onSuspend] will be called.
*/
@OptIn(ExperimentalStdlibApi::class)
suspend inline fun <R> assertCoroutineSuspends(
noinline onSuspend: (suspend () -> Unit)? = null,
crossinline block: suspend () -> R,
): R {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return withContext(Dispatchers.Default.limitedParallelism(1)) {
val dispatcher = currentCoroutineContext()[CoroutineDispatcher] ?: Dispatchers.Main.limitedParallelism(1)
return withContext(dispatcher.limitedParallelism(1)) {
val job = launch(start = CoroutineStart.UNDISPATCHED) {
yield() // goto block
onSuspend?.invoke()
}
val ret = block()
kotlin.test.assertTrue("Expected coroutine suspension") { job.isCompleted }

View File

@ -22,7 +22,6 @@ import net.mamoe.mirai.utils.channels.ProducerFailureException
import net.mamoe.mirai.utils.debug
import net.mamoe.mirai.utils.verbose
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.cancellation.CancellationException
/**
@ -45,20 +44,7 @@ internal class AuthControl(
logger.subLogger("AuthControl/UserDecisions").asUtilsLogger()
) { _ ->
val sessionImpl = SafeBotAuthSession(this)
try {
logger.verbose { "[AuthControl/auth] Authorization started" }
authorization.authorize(sessionImpl, botAuthInfo)
logger.verbose { "[AuthControl/auth] Authorization exited" }
finish()
} catch (e: CancellationException) {
logger.verbose { "[AuthControl/auth] Authorization cancelled" }
} catch (e: Throwable) {
logger.verbose { "[AuthControl/auth] Authorization failed: $e" }
finishExceptionally(e)
}
authorization.authorize(sessionImpl, botAuthInfo) // OnDemandChannel handles exceptions for us
}
fun start() {
@ -86,6 +72,6 @@ internal class AuthControl(
fun actComplete() {
logger.verbose { "[AuthControl/resume] Fire auth completed" }
userDecisions.finish()
userDecisions.close()
}
}

View File

@ -12,7 +12,7 @@ package net.mamoe.mirai.internal.network.auth
import net.mamoe.mirai.auth.BotAuthResult
import net.mamoe.mirai.internal.network.components.SsoProcessorImpl
import net.mamoe.mirai.utils.SecretsProtection
import net.mamoe.mirai.utils.channels.IllegalProducerStateException
import net.mamoe.mirai.utils.channels.IllegalChannelStateException
import net.mamoe.mirai.utils.channels.OnDemandSendChannel
/**
@ -40,7 +40,7 @@ internal class SafeBotAuthSession(
private inline fun <R> runWrapInternalException(block: () -> R): R {
try {
return block()
} catch (e: IllegalProducerStateException) {
} catch (e: IllegalChannelStateException) {
if (e.lastStateWasSucceed) {
throw IllegalStateException(
"This login session has already completed. Please return the BotAuthResult you get from 'authBy*()' immediately",