From a6196d8580923d569b05a4dc27b9b7993c87a4c3 Mon Sep 17 00:00:00 2001 From: Him188 Date: Mon, 23 Mar 2020 21:49:42 +0800 Subject: [PATCH] Enhanced message selection --- .../kotlin/net.mamoe.mirai/event/select.kt | 440 ++++++++++++++---- .../event/subscribeMessages.kt | 24 +- 2 files changed, 374 insertions(+), 90 deletions(-) diff --git a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/select.kt b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/select.kt index e1c20e5f0..7b2e1f3fc 100644 --- a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/select.kt +++ b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/select.kt @@ -12,6 +12,7 @@ package net.mamoe.mirai.event import kotlinx.coroutines.* import net.mamoe.mirai.message.MessagePacket import net.mamoe.mirai.message.data.Message +import net.mamoe.mirai.message.data.PlainText import net.mamoe.mirai.message.isContextIdenticalWith import net.mamoe.mirai.message.nextMessage import net.mamoe.mirai.utils.MiraiExperimentalAPI @@ -41,6 +42,10 @@ import kotlin.jvm.JvmSynthetic * reply(message) * true // 继续循环 * } + * timeout(3000) { + * // on + * true + * } * } // 等待直到 `false` * * reply("复读模式结束") @@ -57,43 +62,7 @@ import kotlin.jvm.JvmSynthetic suspend inline fun > T.whileSelectMessages( timeoutMillis: Long = -1, crossinline selectBuilder: @MessageDsl MessageSelectBuilder.() -> Unit -) = withTimeoutOrCoroutineScope(timeoutMillis) { - var deferred: CompletableDeferred? = CompletableDeferred() - - // ensure sequential invoking - val listeners: MutableList Boolean, MessageListener>> = mutableListOf() - - MessageSelectBuilder(SELECT_MESSAGE_STUB) { filter: T.(String) -> Boolean, listener: MessageListener -> - listeners += filter to listener - }.apply(selectBuilder) - - // ensure atomic completing - subscribeAlways(concurrency = Listener.ConcurrencyKind.LOCKED) { event -> - if (!this.isContextIdenticalWith(this@whileSelectMessages)) - return@subscribeAlways - - listeners.forEach { (filter, listener) -> - if (deferred?.isCompleted != false || !isActive) - return@subscribeAlways - - val toString = event.message.toString() - if (filter.invoke(event, toString)) { - listener.invoke(event, toString).let { value -> - if (value !== SELECT_MESSAGE_STUB) { - deferred?.complete(value as Boolean) - return@subscribeAlways // accept the first value only - } - } - } - } - } - - while (deferred?.await() == true) { - deferred = CompletableDeferred() - } - deferred = null - coroutineContext[Job]!!.cancelChildren() -} +) = whileSelectMessagesImpl(timeoutMillis, selectBuilder) /** * [selectMessages] 的 [Unit] 返回值捷径 (由于 Kotlin 无法推断 [Unit] 类型) @@ -104,8 +73,8 @@ suspend inline fun > T.whileSelectMessages( @JvmName("selectMessages1") suspend inline fun > T.selectMessagesUnit( timeoutMillis: Long = -1, - crossinline selectBuilder: @MessageDsl MessageSelectBuilder.() -> Unit -) = selectMessages(timeoutMillis, selectBuilder) + crossinline selectBuilder: @MessageDsl MessageSelectBuilderUnit.() -> Unit +) = selectMessagesImpl(timeoutMillis, true, selectBuilder) /** @@ -136,57 +105,37 @@ suspend inline fun , R> T.selectMessages( timeoutMillis: Long = -1, @BuilderInference crossinline selectBuilder: @MessageDsl MessageSelectBuilder.() -> Unit -): R = withTimeoutOrCoroutineScope(timeoutMillis) { - val deferred = CompletableDeferred() - - // ensure sequential invoking - val listeners: MutableList Boolean, MessageListener>> = mutableListOf() - - MessageSelectBuilder(SELECT_MESSAGE_STUB) { filter: T.(String) -> Boolean, listener: MessageListener -> - listeners += filter to listener - }.apply(selectBuilder) - - subscribeAlways { event -> - if (!this.isContextIdenticalWith(this@selectMessages)) - return@subscribeAlways - - listeners.forEach { (filter, listener) -> - if (deferred.isCompleted || !isActive) - return@subscribeAlways - - val toString = event.message.toString() - if (filter.invoke(event, toString)) { - val value = listener.invoke(event, toString) - if (value !== SELECT_MESSAGE_STUB) { - @Suppress("UNCHECKED_CAST") - deferred.complete(value as R) - return@subscribeAlways - } - } - } - } - - deferred.await().also { coroutineContext[Job]!!.cancelChildren() } -} +): R = selectMessagesImpl(timeoutMillis, false) { selectBuilder.invoke(this as MessageSelectBuilder) } +/** + * [selectMessages] 时的 DSL 构建器. + * + * 它是特殊化的消息监听 ([subscribeMessages]) DSL, 屏蔽了一些 `reply` DSL 以确保作用域安全性 + * + * @see MessageSelectBuilderUnit 查看上层 API + */ @SinceMirai("0.29.0") -open class MessageSelectBuilder, R> @PublishedApi internal constructor( +abstract class MessageSelectBuilder, R> @PublishedApi internal constructor( + ownerMessagePacket: M, stub: Any?, subscriber: (M.(String) -> Boolean, MessageListener) -> Unit -) : MessageSubscribersBuilder(stub, subscriber) { - /** - * 无任何触发条件. - */ - @MessageDsl - fun default(onEvent: MessageListener): Unit = subscriber({ true }, onEvent) - - @Deprecated("Use `default` instead", level = DeprecationLevel.HIDDEN) - override fun always(onEvent: MessageListener) { - super.always(onEvent) - } +) : MessageSelectBuilderUnit(ownerMessagePacket, stub, subscriber) { // 这些函数无法获取返回值. 必须屏蔽. + @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) + override fun mapping( + mapper: M.(String) -> N?, + onEvent: @MessageDsl suspend M.(N) -> R + ) = error("prohibited") + + @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) + override infix fun MessageSelectionTimeoutChecker.reply(block: suspend () -> Any?): Nothing = error("prohibited") + + @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) + override infix fun MessageSelectionTimeoutChecker.quoteReply(block: suspend () -> Any?): Nothing = + error("prohibited") + @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) override fun String.containsReply(reply: String): Nothing = error("prohibited") @@ -222,7 +171,8 @@ open class MessageSelectBuilder, R> @PublishedApi intern override fun ListeningFilter.reply(message: Message) = error("prohibited") @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) - override fun ListeningFilter.reply(replier: suspend M.(String) -> Any?) = error("prohibited") + override fun ListeningFilter.reply(replier: suspend M.(String) -> Any?) = + error("prohibited") @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN) override fun ListeningFilter.quoteReply(toReply: String) = error("prohibited") @@ -234,6 +184,175 @@ open class MessageSelectBuilder, R> @PublishedApi intern override fun ListeningFilter.quoteReply(replier: suspend M.(String) -> Any?) = error("prohibited") } +/** + * [selectMessagesUnit] 或 [selectMessages] 时的 DSL 构建器. + * + * 它是特殊化的消息监听 ([subscribeMessages]) DSL, 没有屏蔽 `reply` DSL 以确保作用域安全性 + * + * @see MessageSubscribersBuilder 查看上层 API + */ +@SinceMirai("0.29.0") +abstract class MessageSelectBuilderUnit, R> @PublishedApi internal constructor( + private val ownerMessagePacket: M, + stub: Any?, + subscriber: (M.(String) -> Boolean, MessageListener) -> Unit +) : MessageSubscribersBuilder(stub, subscriber) { + /** + * 当其他条件都不满足时的默认处理. + */ + @MessageDsl + abstract fun default(onEvent: MessageListener) // 需要后置默认监听器 + + @Deprecated("Use `default` instead", level = DeprecationLevel.HIDDEN) + override fun always(onEvent: MessageListener) { + super.always(onEvent) + } + + /** + * 限制本次 select 的最长等待时间, 当超时后抛出 [TimeoutCancellationException] + */ + @Suppress("NOTHING_TO_INLINE") + @MessageDsl + fun timeoutException( + timeoutMillis: Long, + exception: () -> Throwable = { throw MessageSelectionTimeoutException() } + ) { + require(timeoutMillis > 0) { "timeoutMillis must be positive" } + obtainCurrentCoroutineScope().launch { + delay(timeoutMillis) + val deferred = obtainCurrentDeferred() ?: return@launch + if (deferred.isActive) { + deferred.completeExceptionally(exception()) + } + } + } + + /** + * 限制本次 select 的最长等待时间, 当超时后执行 [block] 以完成 select + */ + @MessageDsl + fun timeout(timeoutMillis: Long, block: suspend () -> R) { + require(timeoutMillis > 0) { "timeoutMillis must be positive" } + obtainCurrentCoroutineScope().launch { + delay(timeoutMillis) + val deferred = obtainCurrentDeferred() ?: return@launch + if (deferred.isActive) { + deferred.complete(block()) + } + } + } + + + /** + * 返回一个限制本次 select 的最长等待时间的 [Deferred] + * + * @see invoke + * @see reply + */ + @MessageDsl + fun timeout(timeoutMillis: Long): MessageSelectionTimeoutChecker { + require(timeoutMillis > 0) { "timeoutMillis must be positive" } + return MessageSelectionTimeoutChecker(timeoutMillis) + } + + /** + * 返回一个限制本次 select 的最长等待时间的 [Deferred] + * + * @see Deferred.invoke + */ + @Suppress("unused") + fun MessageSelectionTimeoutChecker.invoke(block: suspend () -> R) { + return timeout(this.timeoutMillis, block) + } + + /** + * 在超时后回复原消息 + * + * 当 [block] 返回值为 [Unit] 时不回复, 为 [Message] 时回复 [Message], 其他将 [toString] 后回复为 [PlainText] + * + * @see timeout + * @see quoteReply + */ + @Suppress("unused", "UNCHECKED_CAST") + open infix fun MessageSelectionTimeoutChecker.reply(block: suspend () -> Any?) { + return timeout(this.timeoutMillis) { + executeAndReply(block) + Unit as R + } + } + + /** + * 在超时后引用回复原消息 + * + * 当 [block] 返回值为 [Unit] 时不回复, 为 [Message] 时回复 [Message], 其他将 [toString] 后回复为 [PlainText] + + * @see timeout + * @see reply + */ + @Suppress("unused", "UNCHECKED_CAST") + open infix fun MessageSelectionTimeoutChecker.quoteReply(block: suspend () -> Any?) { + return timeout(this.timeoutMillis) { + executeAndQuoteReply(block) + Unit as R + } + } + + /** + * 当其他条件都不满足时回复原消息. + * + * 当 [block] 返回值为 [Unit] 时不回复, 为 [Message] 时回复 [Message], 其他将 [toString] 后回复为 [PlainText] + */ + @MessageDsl + fun defaultReply(block: suspend () -> Any?): Unit = subscriber({ true }, { + @Suppress("DSL_SCOPE_VIOLATION_WARNING") // false positive + executeAndReply(block) + }) + + + /** + * 当其他条件都不满足时引用回复原消息. + * + * 当 [block] 返回值为 [Unit] 时不回复, 为 [Message] 时回复 [Message], 其他将 [toString] 后回复为 [PlainText] + */ + @MessageDsl + fun defaultQuoteReply(block: suspend () -> Any?): Unit = subscriber({ true }, { + @Suppress("DSL_SCOPE_VIOLATION_WARNING") // false positive + executeAndQuoteReply(block) + }) + + private suspend inline fun executeAndReply(noinline block: suspend () -> Any?) { + when (val result = block()) { + Unit -> { + + } + is Message -> ownerMessagePacket.reply(result) + else -> ownerMessagePacket.reply(result.toString()) + } + } + + private suspend inline fun executeAndQuoteReply(noinline block: suspend () -> Any?) { + when (val result = block()) { + Unit -> { + + } + is Message -> ownerMessagePacket.quoteReply(result) + else -> ownerMessagePacket.quoteReply(result.toString()) + } + } + + protected abstract fun obtainCurrentCoroutineScope(): CoroutineScope + protected abstract fun obtainCurrentDeferred(): CompletableDeferred? +} + +@Suppress("NON_PUBLIC_PRIMARY_CONSTRUCTOR_OF_INLINE_CLASS") +inline class MessageSelectionTimeoutChecker internal constructor(val timeoutMillis: Long) + +class MessageSelectionTimeoutException : RuntimeException() + + +// implementations + + @JvmSynthetic @PublishedApi internal suspend inline fun withTimeoutOrCoroutineScope( @@ -250,4 +369,155 @@ internal suspend inline fun withTimeoutOrCoroutineScope( } @PublishedApi -internal val SELECT_MESSAGE_STUB = Any() \ No newline at end of file +internal val SELECT_MESSAGE_STUB = Any() + + +@PublishedApi +@BuilderInference +@OptIn(ExperimentalTypeInference::class) +internal suspend inline fun , R> T.selectMessagesImpl( + timeoutMillis: Long = -1, + isUnit: Boolean, + @BuilderInference + crossinline selectBuilder: @MessageDsl MessageSelectBuilderUnit.() -> Unit +): R = withTimeoutOrCoroutineScope(timeoutMillis) { + val deferred = CompletableDeferred() + + // ensure sequential invoking + val listeners: MutableList Boolean, MessageListener>> = mutableListOf() + val defaultListeners: MutableList> = mutableListOf() + + if (isUnit) { + object : MessageSelectBuilderUnit( + this@selectMessagesImpl, + SELECT_MESSAGE_STUB, + { filter: T.(String) -> Boolean, listener: MessageListener -> + listeners += filter to listener + }) { + override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope + override fun obtainCurrentDeferred(): CompletableDeferred? = deferred + override fun default(onEvent: MessageListener) { + defaultListeners += onEvent + } + } + } else { + object : MessageSelectBuilder( + this@selectMessagesImpl, + SELECT_MESSAGE_STUB, + { filter: T.(String) -> Boolean, listener: MessageListener -> + listeners += filter to listener + }) { + override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope + override fun obtainCurrentDeferred(): CompletableDeferred? = deferred + override fun default(onEvent: MessageListener) { + defaultListeners += onEvent + } + } + }.apply(selectBuilder) + + // we don't have any way to reduce duplication yet, + // until local functions is supported in inline functions + @Suppress("DuplicatedCode") + subscribeAlways { event -> + if (!this.isContextIdenticalWith(this@selectMessagesImpl)) + return@subscribeAlways + + val toString = event.message.toString() + listeners.forEach { (filter, listener) -> + if (deferred.isCompleted || !isActive) + return@subscribeAlways + + if (filter.invoke(event, toString)) { + // same to the one below + val value = listener.invoke(event, toString) + if (value !== SELECT_MESSAGE_STUB) { + @Suppress("UNCHECKED_CAST") + deferred.complete(value as R) + return@subscribeAlways + } else if (isUnit) { // value === stub + // unit mode: we can directly complete this selection + @Suppress("UNCHECKED_CAST") + deferred.complete(Unit as R) + } + } + } + defaultListeners.forEach { listener -> + // same to the one above + val value = listener.invoke(event, toString) + if (value !== SELECT_MESSAGE_STUB) { + @Suppress("UNCHECKED_CAST") + deferred.complete(value as R) + return@subscribeAlways + } else if (isUnit) { // value === stub + // unit mode: we can directly complete this selection + @Suppress("UNCHECKED_CAST") + deferred.complete(Unit as R) + } + } + } + + deferred.await().also { coroutineContext[Job]!!.cancelChildren() } +} + +@Suppress("unused") +@PublishedApi +internal suspend inline fun > T.whileSelectMessagesImpl( + timeoutMillis: Long = -1, + crossinline selectBuilder: @MessageDsl MessageSelectBuilder.() -> Unit +) { + withTimeoutOrCoroutineScope(timeoutMillis) { + var deferred: CompletableDeferred? = CompletableDeferred() + + // ensure sequential invoking + val listeners: MutableList Boolean, MessageListener>> = mutableListOf() + val defaltListeners: MutableList> = mutableListOf() + + object : MessageSelectBuilder( + this@whileSelectMessagesImpl, + SELECT_MESSAGE_STUB, + { filter: T.(String) -> Boolean, listener: MessageListener -> + listeners += filter to listener + }) { + override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope + override fun obtainCurrentDeferred(): CompletableDeferred? = deferred + override fun default(onEvent: MessageListener) { + defaltListeners += onEvent + } + }.apply(selectBuilder) + + // ensure atomic completing + subscribeAlways(concurrency = Listener.ConcurrencyKind.LOCKED) { event -> + if (!this.isContextIdenticalWith(this@whileSelectMessagesImpl)) + return@subscribeAlways + + val toString = event.message.toString() + listeners.forEach { (filter, listener) -> + if (deferred?.isCompleted != false || !isActive) + return@subscribeAlways + + if (filter.invoke(event, toString)) { + listener.invoke(event, toString).let { value -> + if (value !== SELECT_MESSAGE_STUB) { + deferred?.complete(value as Boolean) + return@subscribeAlways // accept the first value only + } + } + } + } + defaltListeners.forEach { listener -> + listener.invoke(event, toString).let { value -> + if (value !== SELECT_MESSAGE_STUB) { + deferred?.complete(value as Boolean) + return@subscribeAlways // accept the first value only + } + } + } + } + + while (deferred?.await() == true) { + deferred = CompletableDeferred() + } + deferred = null + coroutineContext[Job]!!.cancelChildren() + } +} \ No newline at end of file diff --git a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/subscribeMessages.kt b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/subscribeMessages.kt index 66a2befc9..b29aac1eb 100644 --- a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/subscribeMessages.kt +++ b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/subscribeMessages.kt @@ -24,6 +24,7 @@ import net.mamoe.mirai.message.FriendMessage import net.mamoe.mirai.message.GroupMessage import net.mamoe.mirai.message.MessagePacket import net.mamoe.mirai.message.data.Message +import net.mamoe.mirai.message.data.first import net.mamoe.mirai.utils.SinceMirai import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind @@ -729,18 +730,31 @@ open class MessageSubscribersBuilder, out Ret, R : RR, R } /** - * 如果消息内容包含 [M] 类型的 [Message] + * 如果消息内容包含 [N] 类型的 [Message] */ @MessageDsl - inline fun has(): ListeningFilter = - content { message.any { it is M } } + inline fun has(): ListeningFilter = + content { message.any { it is N } } /** * 如果消息内容包含 [M] 类型的 [Message], 就执行 [onEvent] */ @MessageDsl - inline fun has(noinline onEvent: MessageListener): Ret = - content({ message.any { it is N } }, onEvent) + @SinceMirai("0.30.0") + inline fun has(noinline onEvent: @MessageDsl suspend M.(N) -> R): Ret = + content({ message.any { it is N } }, { onEvent.invoke(this, message.first()) }) + + /** + * 如果 [mapper] 返回值非空, 就执行 [onEvent] + */ + @MessageDsl + @SinceMirai("0.30.0") + open fun mapping( + mapper: M.(String) -> N?, + onEvent: @MessageDsl suspend M.(N) -> R + ): Ret = always { + onEvent.invoke(this, mapper.invoke(this, message.toString()) ?: return@always stub) + } /** * 如果 [filter] 返回 `true`