1
0
mirror of https://github.com/mamoe/mirai.git synced 2025-04-25 13:03:35 +08:00

[core] Revise OnDemandChannel, improve state abstraction:

[core] Start producer coroutine immediately on `expectMore` and yield

Improve docs for OnDemandChannel

Rename factory function `OnDemandReceiveChannel` to  `OnDemandChannel` to better cover its meaning

Create deferred lazily in Producing state

Rename ProducerState to ChannelState

fix atomicfu bug receiveOrNull

Add docs (WIP, to be rebased)

[core] OnDemandChannel: Catch Throwable in `receiveOrNull` to prevent possible failures
This commit is contained in:
Him188 2023-04-22 13:12:11 +01:00
parent 99f592d614
commit 87fbaa4fb2
6 changed files with 173 additions and 67 deletions
mirai-core-utils/src
mirai-core/src/commonMain/kotlin/network/auth

View File

@ -12,13 +12,13 @@ package net.mamoe.mirai.utils.channels
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import net.mamoe.mirai.utils.sync.Latch
import kotlinx.coroutines.Job
import kotlin.coroutines.CoroutineContext
/**
* Producer states.
*/
internal sealed interface ProducerState<T, V> {
internal sealed interface ChannelState<T, V> {
/*
* 可变更状态的函数: [emit], [receiveOrNull], [expectMore], [finish], [finishExceptionally]
*
@ -93,11 +93,11 @@ internal sealed interface ProducerState<T, V> {
*/
abstract override fun toString(): String
class JustInitialized<T, V> : ProducerState<T, V> {
class JustInitialized<T, V> : ChannelState<T, V> {
override fun toString(): String = "JustInitialized"
}
sealed interface HasProducer<T, V> : ProducerState<T, V> {
sealed interface HasProducer<T, V> : ChannelState<T, V> {
val producer: OnDemandSendChannel<T, V>
}
@ -116,8 +116,10 @@ internal sealed interface ProducerState<T, V> {
class Producing<T, V>(
override val producer: OnDemandSendChannel<T, V>,
val deferred: CompletableDeferred<V>,
parentJob: Job,
) : HasProducer<T, V> {
val deferred: CompletableDeferred<V> by lazy { CompletableDeferred<V>(parentJob) }
override fun toString(): String = "Producing(deferred.completed=${deferred.isCompleted})"
}
@ -126,7 +128,7 @@ internal sealed interface ProducerState<T, V> {
val value: Deferred<V>,
parentCoroutineContext: CoroutineContext,
) : HasProducer<T, V> {
val producerLatch: Latch<T> = Latch(parentCoroutineContext)
val producerLatch: CompletableDeferred<T> = CompletableDeferred(parentCoroutineContext[Job])
override fun toString(): String {
@OptIn(ExperimentalCoroutinesApi::class)
@ -138,15 +140,15 @@ internal sealed interface ProducerState<T, V> {
class Consumed<T, V>(
override val producer: OnDemandSendChannel<T, V>,
val producerLatch: Latch<T>
val producerLatch: CompletableDeferred<T>
) : HasProducer<T, V> {
override fun toString(): String = "Consumed($producerLatch)"
}
class Finished<T, V>(
private val previousState: ProducerState<T, V>,
private val previousState: ChannelState<T, V>,
val exception: Throwable?,
) : ProducerState<T, V> {
) : ChannelState<T, V> {
val isSuccess: Boolean get() = exception == null
fun createAlreadyFinishedException(cause: Throwable?): IllegalProducerStateException {

View File

@ -10,9 +10,9 @@
package net.mamoe.mirai.utils.channels
public class IllegalProducerStateException internal constructor(
private val state: ProducerState<*, *>,
private val state: ChannelState<*, *>,
message: String? = state.toString(),
cause: Throwable? = null,
) : IllegalStateException(message, cause) {
public val lastStateWasSucceed: Boolean get() = (state is ProducerState.Finished) && state.isSuccess
public val lastStateWasSucceed: Boolean get() = (state is ChannelState.Finished) && state.isSuccess
}

View File

@ -12,10 +12,7 @@ package net.mamoe.mirai.utils.channels
import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.loop
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.cancel
import kotlinx.coroutines.job
import kotlinx.coroutines.launch
import kotlinx.coroutines.*
import net.mamoe.mirai.utils.UtilsLogger
import net.mamoe.mirai.utils.childScope
import net.mamoe.mirai.utils.debug
@ -30,14 +27,17 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
) : OnDemandReceiveChannel<T, V> {
private val coroutineScope = parentCoroutineContext.childScope("CoroutineOnDemandReceiveChannel")
private val state: AtomicRef<ProducerState<T, V>> = atomic(ProducerState.JustInitialized())
private val state: AtomicRef<ChannelState<T, V>> = atomic(ChannelState.JustInitialized())
inner class Producer(
private val initialTicket: T,
) : OnDemandSendChannel<T, V> {
init {
coroutineScope.launch {
// `UNDISPATCHED` with `yield()`: start the coroutine immediately in current thread,
// attaching Job to the coroutineScope, then `yield` the thread back, to complete `launch`.
coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
yield()
try {
producerCoroutine(initialTicket)
} catch (_: CancellationException) {
@ -51,21 +51,21 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
override suspend fun emit(value: V): T {
state.loop { state ->
when (state) {
is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
is ProducerState.Producing -> {
is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
is ChannelState.Producing -> {
val deferred = state.deferred
val consumingState = ProducerState.Consuming(
val consumingState = ChannelState.Consuming(
state.producer,
state.deferred,
coroutineScope.coroutineContext
)
if (compareAndSetState(state, consumingState)) {
deferred.complete(value) // produce a value
return consumingState.producerLatch.acquire() // wait for producer to consume the previous value.
return consumingState.producerLatch.await() // wait for producer to consume the previous value.
}
}
is ProducerState.ProducerReady -> {
is ChannelState.ProducerReady -> {
setStateProducing(state)
}
@ -81,9 +81,9 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
override fun finish() {
state.loop { state ->
when (state) {
is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
else -> {
if (compareAndSetState(state, ProducerState.Finished(state, null))) {
if (compareAndSetState(state, ChannelState.Finished(state, null))) {
return
}
}
@ -92,20 +92,16 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
}
}
private fun setStateProducing(state: ProducerState.ProducerReady<T, V>) {
val deferred = CompletableDeferred<V>(coroutineScope.coroutineContext.job)
if (!compareAndSetState(state, ProducerState.Producing(state.producer, deferred))) {
deferred.cancel() // avoid leak
}
// loop again
private fun setStateProducing(state: ChannelState.ProducerReady<T, V>) {
compareAndSetState(state, ChannelState.Producing(state.producer, coroutineScope.coroutineContext.job))
}
private fun finishImpl(exception: Throwable?) {
state.loop { state ->
when (state) {
is ProducerState.Finished -> {} // ignore
is ChannelState.Finished -> {} // ignore
else -> {
if (compareAndSetState(state, ProducerState.Finished(state, exception))) {
if (compareAndSetState(state, ChannelState.Finished(state, exception))) {
val cancellationException = kotlinx.coroutines.CancellationException("Finished", exception)
coroutineScope.cancel(cancellationException)
return
@ -115,24 +111,31 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
}
}
private fun compareAndSetState(state: ProducerState<T, V>, newState: ProducerState<T, V>): Boolean {
private fun compareAndSetState(state: ChannelState<T, V>, newState: ChannelState<T, V>): Boolean {
return this.state.compareAndSet(state, newState).also {
logger.debug { "CAS: $state -> $newState: $it" }
}
}
override suspend fun receiveOrNull(): V? {
state.loop { state ->
when (state) {
is ProducerState.Consuming -> {
// value is ready, switch state to Consumed
// don't use `.loop`:
// java.lang.VerifyError: Bad type on operand stack
// net/mamoe/mirai/utils/channels/CoroutineOnDemandReceiveChannel.receiveOrNull(Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @103: getfield
while (true) {
when (val state = state.value) {
is ChannelState.Consuming -> {
// value is ready, now we consume the value
if (compareAndSetState(state, ChannelState.Consumed(state.producer, state.producerLatch))) {
// value is consumed, no contention, safe to retrieve
if (compareAndSetState(state, ProducerState.Consumed(state.producer, state.producerLatch))) {
return try {
// This actually won't suspend, since the value is already completed
// Just to be error-tolerating
// Just to be error-tolerating and re-throwing exceptions.
state.value.await()
} catch (e: Exception) {
} catch (e: Throwable) {
// Producer failed to produce the previous value with exception
throw ProducerFailureException(cause = e)
}
}
@ -140,7 +143,7 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
// note: actually, this case should be the first case (for code consistency) in `when`,
// but atomicfu 1.8.10 fails on this.
is ProducerState.Producing<T, V> -> {
is ChannelState.Producing<T, V> -> {
// still producing value
state.deferred.await() // just wait for value, but does not return it.
@ -151,7 +154,7 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
// Here we will loop again, to atomically switch to Consumed state.
}
is ProducerState.Finished -> {
is ChannelState.Finished -> {
state.exception?.let { err ->
throw ProducerFailureException(cause = err)
}
@ -166,32 +169,33 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
override fun expectMore(ticket: T): Boolean {
state.loop { state ->
when (state) {
is ProducerState.JustInitialized -> {
val ready = ProducerState.ProducerReady { Producer(ticket) }
is ChannelState.JustInitialized -> {
// start producer atomically
val ready = ChannelState.ProducerReady { Producer(ticket) }
if (compareAndSetState(state, ready)) {
ready.startProducerIfNotYet()
}
// loop again
}
is ProducerState.ProducerReady -> {
is ChannelState.ProducerReady -> {
setStateProducing(state)
}
is ProducerState.Producing -> return true // ok
is ChannelState.Producing -> return true // ok
is ProducerState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
is ChannelState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
is ProducerState.Consumed -> {
if (compareAndSetState(state, ProducerState.ProducerReady { state.producer })) {
is ChannelState.Consumed -> {
if (compareAndSetState(state, ChannelState.ProducerReady { state.producer })) {
// wake up producer async.
state.producerLatch.resumeWith(Result.success(ticket))
state.producerLatch.complete(ticket)
// loop again to switch state atomically to Producing.
// Do not do switch state directly here — async producer may race with you!
}
}
is ProducerState.Finished -> return false
is ChannelState.Finished -> return false
}
}
}

View File

@ -9,8 +9,10 @@
package net.mamoe.mirai.utils.channels
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import net.mamoe.mirai.utils.UtilsLogger
import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineContext
@ -18,43 +20,60 @@ import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.cancellation.CancellationException
/**
* 按需供给的 [Channel].
* 按需供给的 [SendChannel].
*/
public interface OnDemandSendChannel<T, V> {
/**
* 挂起协程, 直到 [OnDemandReceiveChannel] 期望接收一个 [V], 届时将 [value] 传递给 [OnDemandReceiveChannel.receiveOrNull], 成为其返回值.
* 挂起协程, 直到 [OnDemandReceiveChannel] [期望接收][OnDemandReceiveChannel.receiveOrNull]一个 [V], 届时将 [value] 传递给 [OnDemandReceiveChannel.receiveOrNull], 成为其返回值.
*
* 若在调用 [emit] 时已经有 [OnDemandReceiveChannel] 正在等待, 则该 [OnDemandReceiveChannel] 协程会立即[恢复][Continuation.resumeWith].
* 若在调用 [emit] 时已经有 [OnDemandReceiveChannel.receiveOrNull] 正在等待, 则该协程会立即[恢复][Continuation.resumeWith], [emit] 不会挂起.
*
* [OnDemandReceiveChannel] 已经[完结][OnDemandReceiveChannel.finish], [OnDemandSendChannel.emit] 会抛出 [IllegalProducerStateException].
*
* @see OnDemandReceiveChannel.receiveOrNull
*/
public suspend fun emit(value: V): T
/**
* 标记此 [OnDemandSendChannel] 在生产 [V] 的过程中出现错误.
* 标记此 [OnDemandSendChannel] 在生产 [V] 的过程中出现异常.
*
* 这也会终止此 [OnDemandSendChannel], 随后 [OnDemandReceiveChannel.receiveOrNull] 将会抛出 [ProducerFailureException].
* 这也会终止此 [OnDemandSendChannel], [OnDemandReceiveChannel] 正在期待一个值, 则当它调用 [OnDemandReceiveChannel.receiveOrNull] , 它将得到一个 [ProducerFailureException].
*
* [finishExceptionally] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
*/
public fun finishExceptionally(exception: Throwable)
/**
* 标记此 [OnDemandSendChannel] 已经没有更多 [V] 可生产.
*
* 随后 [OnDemandReceiveChannel.receiveOrNull] 将会抛出 [IllegalStateException].
* 这会终止此 [OnDemandSendChannel], [OnDemandReceiveChannel] 正在期待一个值, 则当它调用 [OnDemandReceiveChannel.receiveOrNull] , 它将得到一个 [ProducerFailureException].
*
* [finish] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
*/
public fun finish()
}
/**
* 按需消费者.
* 线程安全的按需接收通道.
*
* [ReceiveChannel] 不同, [OnDemandReceiveChannel] 只有在调用 [expectMore] 后才会让[生产者][OnDemandSendChannel] 开始生产下一个 [V].
*/
public interface OnDemandReceiveChannel<T, V> {
/**
* 挂起协程并等待从 [OnDemandSendChannel] [接收][OnDemandSendChannel.emit]一个 [V].
* 尝试从 [OnDemandSendChannel] [接收][OnDemandSendChannel.emit]一个 [V].
* 当且仅当在 [OnDemandSendChannel] 已经[正常结束][OnDemandSendChannel.finish] 时返回 `null`.
*
* 当此函数被多个线程 (协程) 同时调用时, 只有一个线程挂起并获得 [V], 其他线程将会
* 若目前已有 [V], 此函数立即返回该 [V], 不会挂起.
* 否则, 此函数将会挂起直到 [OnDemandSendChannel.emit].
*
* 当此函数被多个协程 (线程) 同时调用时, 只有一个协程会获得 [V], 其他协程将会挂起.
*
* 若在等待过程中 [OnDemandSendChannel] [异常结束][OnDemandSendChannel.finishExceptionally],
* 本函数会立即恢复并抛出 [ProducerFailureException], `cause` 为令 [OnDemandSendChannel] 的异常.
*
* 此挂起函数可被取消.
* 如果在此函数挂起时当前协程的 [Job] 被取消或完结, 此函数会立即恢复并抛出 [CancellationException]. 此行为与 [Deferred.await] 相同.
*
* @throws ProducerFailureException [OnDemandSendChannel.finishExceptionally] 时抛出.
* @throws CancellationException 当协程被取消时抛出
@ -64,7 +83,9 @@ public interface OnDemandReceiveChannel<T, V> {
public suspend fun receiveOrNull(): V?
/**
* 期待 [OnDemandSendChannel] 再生产一个 [V]. 期望生产后必须在之后调用 [receiveOrNull] [finish] 来消耗生产的 [V].
* 期待 [OnDemandSendChannel] 再生产一个 [V].
* 期望生产后必须在之后调用 [receiveOrNull] [finish] 来消耗生产的 [V].
* 不可连续重复调用 [expectMore].
*
* 在成功发起期待后返回 `true`; [OnDemandSendChannel] 已经[完结][OnDemandSendChannel.finish] 时返回 `false`.
*
@ -73,16 +94,18 @@ public interface OnDemandReceiveChannel<T, V> {
public fun expectMore(ticket: T): Boolean
/**
* 标记此 [OnDemandReceiveChannel] 已经完结.
* 标记此 [OnDemandSendChannel] 已经不再需要更多的值.
*
* 如果 [OnDemandSendChannel] 仍在运行, 将会 (正常地) 取消 [OnDemandSendChannel].
* 如果 [OnDemandSendChannel] 仍在运行 (无论是挂起中还是正在计算下一个值), 都会正常地[取消][Job.cancel] [OnDemandSendChannel].
*
* 随后 [OnDemandSendChannel.emit] 将会抛出 [IllegalStateException].
*
* [finish] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
*/
public fun finish()
}
public fun <T, V> OnDemandReceiveChannel(
@Suppress("FunctionName")
public fun <T, V> OnDemandChannel(
parentCoroutineContext: CoroutineContext = EmptyCoroutineContext,
logger: UtilsLogger = UtilsLogger.noop(),
producerCoroutine: suspend OnDemandSendChannel<T, V>.(initialTicket: T) -> Unit,

View File

@ -0,0 +1,76 @@
/*
* Copyright 2019-2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.utils.channels
import kotlinx.coroutines.*
import kotlin.test.*
class OnDemandChannelTest {
///////////////////////////////////////////////////////////////////////////
// CoroutineScope lifecycle
///////////////////////////////////////////////////////////////////////////
@Test
fun attachScopeJob() {
val job = SupervisorJob()
val channel = OnDemandChannel<Int, Int>(job) {
fail()
}
assertEquals(1, job.children.toList().size)
channel.finish()
}
@Test
fun finishAfterInstantiation() {
val supervisor = SupervisorJob()
val channel = OnDemandChannel<Int, Int>(supervisor) {
fail("ran")
}
assertEquals(1, supervisor.children.toList().size)
val job = supervisor.children.single()
assertEquals(true, job.isActive)
channel.finish()
assertEquals(0, supervisor.children.toList().size)
assertEquals(false, job.isActive)
}
///////////////////////////////////////////////////////////////////////////
// Producer Coroutine — Lazy Initialization
///////////////////////////////////////////////////////////////////////////
@Test
fun `producer coroutine won't start until expectMore`() {
val channel = OnDemandChannel<Int, Int> {
fail()
}
channel.finish()
}
@Test
fun `producer coroutine starts iff expectMore`() = runBlocking(Dispatchers.Default.limitedParallelism(1)) {
var started = false
val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
// (1)
assertEquals(false, started)
started = true
yield() // goto (2)
fail()
}
assertFalse { started }
channel.expectMore(1) // launches the job, but it won't execute due to single parallelism
yield() // goto (1)
// (2)
assertTrue { started }
channel.finish()
}
}

View File

@ -16,6 +16,7 @@ import net.mamoe.mirai.internal.utils.asUtilsLogger
import net.mamoe.mirai.internal.utils.subLogger
import net.mamoe.mirai.utils.ExceptionCollector
import net.mamoe.mirai.utils.MiraiLogger
import net.mamoe.mirai.utils.channels.OnDemandChannel
import net.mamoe.mirai.utils.channels.OnDemandReceiveChannel
import net.mamoe.mirai.utils.channels.ProducerFailureException
import net.mamoe.mirai.utils.debug
@ -39,7 +40,7 @@ internal class AuthControl(
internal val exceptionCollector = ExceptionCollector()
private val userDecisions: OnDemandReceiveChannel<Throwable?, SsoProcessorImpl.AuthMethod> =
OnDemandReceiveChannel(
OnDemandChannel(
parentCoroutineContext,
logger.subLogger("AuthControl/UserDecisions").asUtilsLogger()
) { _ ->