Enhanced message selection

This commit is contained in:
Him188 2020-03-23 21:49:42 +08:00
parent fccb69bb3f
commit a6196d8580
2 changed files with 374 additions and 90 deletions

View File

@ -12,6 +12,7 @@ package net.mamoe.mirai.event
import kotlinx.coroutines.*
import net.mamoe.mirai.message.MessagePacket
import net.mamoe.mirai.message.data.Message
import net.mamoe.mirai.message.data.PlainText
import net.mamoe.mirai.message.isContextIdenticalWith
import net.mamoe.mirai.message.nextMessage
import net.mamoe.mirai.utils.MiraiExperimentalAPI
@ -41,6 +42,10 @@ import kotlin.jvm.JvmSynthetic
* reply(message)
* true // 继续循环
* }
* timeout(3000) {
* // on
* true
* }
* } // 等待直到 `false`
*
* reply("复读模式结束")
@ -57,43 +62,7 @@ import kotlin.jvm.JvmSynthetic
suspend inline fun <reified T : MessagePacket<*, *>> T.whileSelectMessages(
timeoutMillis: Long = -1,
crossinline selectBuilder: @MessageDsl MessageSelectBuilder<T, Boolean>.() -> Unit
) = withTimeoutOrCoroutineScope(timeoutMillis) {
var deferred: CompletableDeferred<Boolean>? = CompletableDeferred()
// ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
MessageSelectBuilder<T, Boolean>(SELECT_MESSAGE_STUB) { filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
listeners += filter to listener
}.apply(selectBuilder)
// ensure atomic completing
subscribeAlways<T>(concurrency = Listener.ConcurrencyKind.LOCKED) { event ->
if (!this.isContextIdenticalWith(this@whileSelectMessages))
return@subscribeAlways
listeners.forEach { (filter, listener) ->
if (deferred?.isCompleted != false || !isActive)
return@subscribeAlways
val toString = event.message.toString()
if (filter.invoke(event, toString)) {
listener.invoke(event, toString).let { value ->
if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean)
return@subscribeAlways // accept the first value only
}
}
}
}
}
while (deferred?.await() == true) {
deferred = CompletableDeferred()
}
deferred = null
coroutineContext[Job]!!.cancelChildren()
}
) = whileSelectMessagesImpl(timeoutMillis, selectBuilder)
/**
* [selectMessages] [Unit] 返回值捷径 (由于 Kotlin 无法推断 [Unit] 类型)
@ -104,8 +73,8 @@ suspend inline fun <reified T : MessagePacket<*, *>> T.whileSelectMessages(
@JvmName("selectMessages1")
suspend inline fun <reified T : MessagePacket<*, *>> T.selectMessagesUnit(
timeoutMillis: Long = -1,
crossinline selectBuilder: @MessageDsl MessageSelectBuilder<T, Unit>.() -> Unit
) = selectMessages(timeoutMillis, selectBuilder)
crossinline selectBuilder: @MessageDsl MessageSelectBuilderUnit<T, Unit>.() -> Unit
) = selectMessagesImpl(timeoutMillis, true, selectBuilder)
/**
@ -136,57 +105,37 @@ suspend inline fun <reified T : MessagePacket<*, *>, R> T.selectMessages(
timeoutMillis: Long = -1,
@BuilderInference
crossinline selectBuilder: @MessageDsl MessageSelectBuilder<T, R>.() -> Unit
): R = withTimeoutOrCoroutineScope(timeoutMillis) {
val deferred = CompletableDeferred<R>()
// ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
MessageSelectBuilder<T, R>(SELECT_MESSAGE_STUB) { filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
listeners += filter to listener
}.apply(selectBuilder)
subscribeAlways<T> { event ->
if (!this.isContextIdenticalWith(this@selectMessages))
return@subscribeAlways
listeners.forEach { (filter, listener) ->
if (deferred.isCompleted || !isActive)
return@subscribeAlways
val toString = event.message.toString()
if (filter.invoke(event, toString)) {
val value = listener.invoke(event, toString)
if (value !== SELECT_MESSAGE_STUB) {
@Suppress("UNCHECKED_CAST")
deferred.complete(value as R)
return@subscribeAlways
}
}
}
}
deferred.await().also { coroutineContext[Job]!!.cancelChildren() }
}
): R = selectMessagesImpl(timeoutMillis, false) { selectBuilder.invoke(this as MessageSelectBuilder<T, R>) }
/**
* [selectMessages] 时的 DSL 构建器.
*
* 它是特殊化的消息监听 ([subscribeMessages]) DSL, 屏蔽了一些 `reply` DSL 以确保作用域安全性
*
* @see MessageSelectBuilderUnit 查看上层 API
*/
@SinceMirai("0.29.0")
open class MessageSelectBuilder<M : MessagePacket<*, *>, R> @PublishedApi internal constructor(
abstract class MessageSelectBuilder<M : MessagePacket<*, *>, R> @PublishedApi internal constructor(
ownerMessagePacket: M,
stub: Any?,
subscriber: (M.(String) -> Boolean, MessageListener<M, Any?>) -> Unit
) : MessageSubscribersBuilder<M, Unit, R, Any?>(stub, subscriber) {
/**
* 无任何触发条件.
*/
@MessageDsl
fun default(onEvent: MessageListener<M, R>): Unit = subscriber({ true }, onEvent)
@Deprecated("Use `default` instead", level = DeprecationLevel.HIDDEN)
override fun always(onEvent: MessageListener<M, Any?>) {
super.always(onEvent)
}
) : MessageSelectBuilderUnit<M, R>(ownerMessagePacket, stub, subscriber) {
// 这些函数无法获取返回值. 必须屏蔽.
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun <N : Any> mapping(
mapper: M.(String) -> N?,
onEvent: @MessageDsl suspend M.(N) -> R
) = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.reply(block: suspend () -> Any?): Nothing = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override infix fun MessageSelectionTimeoutChecker.quoteReply(block: suspend () -> Any?): Nothing =
error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun String.containsReply(reply: String): Nothing = error("prohibited")
@ -222,7 +171,8 @@ open class MessageSelectBuilder<M : MessagePacket<*, *>, R> @PublishedApi intern
override fun ListeningFilter.reply(message: Message) = error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.reply(replier: suspend M.(String) -> Any?) = error("prohibited")
override fun ListeningFilter.reply(replier: suspend M.(String) -> Any?) =
error("prohibited")
@Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
override fun ListeningFilter.quoteReply(toReply: String) = error("prohibited")
@ -234,6 +184,175 @@ open class MessageSelectBuilder<M : MessagePacket<*, *>, R> @PublishedApi intern
override fun ListeningFilter.quoteReply(replier: suspend M.(String) -> Any?) = error("prohibited")
}
/**
* [selectMessagesUnit] [selectMessages] 时的 DSL 构建器.
*
* 它是特殊化的消息监听 ([subscribeMessages]) DSL, 没有屏蔽 `reply` DSL 以确保作用域安全性
*
* @see MessageSubscribersBuilder 查看上层 API
*/
@SinceMirai("0.29.0")
abstract class MessageSelectBuilderUnit<M : MessagePacket<*, *>, R> @PublishedApi internal constructor(
private val ownerMessagePacket: M,
stub: Any?,
subscriber: (M.(String) -> Boolean, MessageListener<M, Any?>) -> Unit
) : MessageSubscribersBuilder<M, Unit, R, Any?>(stub, subscriber) {
/**
* 当其他条件都不满足时的默认处理.
*/
@MessageDsl
abstract fun default(onEvent: MessageListener<M, R>) // 需要后置默认监听器
@Deprecated("Use `default` instead", level = DeprecationLevel.HIDDEN)
override fun always(onEvent: MessageListener<M, Any?>) {
super.always(onEvent)
}
/**
* 限制本次 select 的最长等待时间, 当超时后抛出 [TimeoutCancellationException]
*/
@Suppress("NOTHING_TO_INLINE")
@MessageDsl
fun timeoutException(
timeoutMillis: Long,
exception: () -> Throwable = { throw MessageSelectionTimeoutException() }
) {
require(timeoutMillis > 0) { "timeoutMillis must be positive" }
obtainCurrentCoroutineScope().launch {
delay(timeoutMillis)
val deferred = obtainCurrentDeferred() ?: return@launch
if (deferred.isActive) {
deferred.completeExceptionally(exception())
}
}
}
/**
* 限制本次 select 的最长等待时间, 当超时后执行 [block] 以完成 select
*/
@MessageDsl
fun timeout(timeoutMillis: Long, block: suspend () -> R) {
require(timeoutMillis > 0) { "timeoutMillis must be positive" }
obtainCurrentCoroutineScope().launch {
delay(timeoutMillis)
val deferred = obtainCurrentDeferred() ?: return@launch
if (deferred.isActive) {
deferred.complete(block())
}
}
}
/**
* 返回一个限制本次 select 的最长等待时间的 [Deferred]
*
* @see invoke
* @see reply
*/
@MessageDsl
fun timeout(timeoutMillis: Long): MessageSelectionTimeoutChecker {
require(timeoutMillis > 0) { "timeoutMillis must be positive" }
return MessageSelectionTimeoutChecker(timeoutMillis)
}
/**
* 返回一个限制本次 select 的最长等待时间的 [Deferred]
*
* @see Deferred<Unit>.invoke
*/
@Suppress("unused")
fun MessageSelectionTimeoutChecker.invoke(block: suspend () -> R) {
return timeout(this.timeoutMillis, block)
}
/**
* 在超时后回复原消息
*
* [block] 返回值为 [Unit] 时不回复, [Message] 时回复 [Message], 其他将 [toString] 后回复为 [PlainText]
*
* @see timeout
* @see quoteReply
*/
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.reply(block: suspend () -> Any?) {
return timeout(this.timeoutMillis) {
executeAndReply(block)
Unit as R
}
}
/**
* 在超时后引用回复原消息
*
* [block] 返回值为 [Unit] 时不回复, [Message] 时回复 [Message], 其他将 [toString] 后回复为 [PlainText]
* @see timeout
* @see reply
*/
@Suppress("unused", "UNCHECKED_CAST")
open infix fun MessageSelectionTimeoutChecker.quoteReply(block: suspend () -> Any?) {
return timeout(this.timeoutMillis) {
executeAndQuoteReply(block)
Unit as R
}
}
/**
* 当其他条件都不满足时回复原消息.
*
* [block] 返回值为 [Unit] 时不回复, [Message] 时回复 [Message], 其他将 [toString] 后回复为 [PlainText]
*/
@MessageDsl
fun defaultReply(block: suspend () -> Any?): Unit = subscriber({ true }, {
@Suppress("DSL_SCOPE_VIOLATION_WARNING") // false positive
executeAndReply(block)
})
/**
* 当其他条件都不满足时引用回复原消息.
*
* [block] 返回值为 [Unit] 时不回复, [Message] 时回复 [Message], 其他将 [toString] 后回复为 [PlainText]
*/
@MessageDsl
fun defaultQuoteReply(block: suspend () -> Any?): Unit = subscriber({ true }, {
@Suppress("DSL_SCOPE_VIOLATION_WARNING") // false positive
executeAndQuoteReply(block)
})
private suspend inline fun executeAndReply(noinline block: suspend () -> Any?) {
when (val result = block()) {
Unit -> {
}
is Message -> ownerMessagePacket.reply(result)
else -> ownerMessagePacket.reply(result.toString())
}
}
private suspend inline fun executeAndQuoteReply(noinline block: suspend () -> Any?) {
when (val result = block()) {
Unit -> {
}
is Message -> ownerMessagePacket.quoteReply(result)
else -> ownerMessagePacket.quoteReply(result.toString())
}
}
protected abstract fun obtainCurrentCoroutineScope(): CoroutineScope
protected abstract fun obtainCurrentDeferred(): CompletableDeferred<R>?
}
@Suppress("NON_PUBLIC_PRIMARY_CONSTRUCTOR_OF_INLINE_CLASS")
inline class MessageSelectionTimeoutChecker internal constructor(val timeoutMillis: Long)
class MessageSelectionTimeoutException : RuntimeException()
// implementations
@JvmSynthetic
@PublishedApi
internal suspend inline fun <R> withTimeoutOrCoroutineScope(
@ -251,3 +370,154 @@ internal suspend inline fun <R> withTimeoutOrCoroutineScope(
@PublishedApi
internal val SELECT_MESSAGE_STUB = Any()
@PublishedApi
@BuilderInference
@OptIn(ExperimentalTypeInference::class)
internal suspend inline fun <reified T : MessagePacket<*, *>, R> T.selectMessagesImpl(
timeoutMillis: Long = -1,
isUnit: Boolean,
@BuilderInference
crossinline selectBuilder: @MessageDsl MessageSelectBuilderUnit<T, R>.() -> Unit
): R = withTimeoutOrCoroutineScope(timeoutMillis) {
val deferred = CompletableDeferred<R>()
// ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
val defaultListeners: MutableList<MessageListener<T, Any?>> = mutableListOf()
if (isUnit) {
object : MessageSelectBuilderUnit<T, R>(
this@selectMessagesImpl,
SELECT_MESSAGE_STUB,
{ filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
listeners += filter to listener
}) {
override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
override fun obtainCurrentDeferred(): CompletableDeferred<R>? = deferred
override fun default(onEvent: MessageListener<T, R>) {
defaultListeners += onEvent
}
}
} else {
object : MessageSelectBuilder<T, R>(
this@selectMessagesImpl,
SELECT_MESSAGE_STUB,
{ filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
listeners += filter to listener
}) {
override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
override fun obtainCurrentDeferred(): CompletableDeferred<R>? = deferred
override fun default(onEvent: MessageListener<T, R>) {
defaultListeners += onEvent
}
}
}.apply(selectBuilder)
// we don't have any way to reduce duplication yet,
// until local functions is supported in inline functions
@Suppress("DuplicatedCode")
subscribeAlways<T> { event ->
if (!this.isContextIdenticalWith(this@selectMessagesImpl))
return@subscribeAlways
val toString = event.message.toString()
listeners.forEach { (filter, listener) ->
if (deferred.isCompleted || !isActive)
return@subscribeAlways
if (filter.invoke(event, toString)) {
// same to the one below
val value = listener.invoke(event, toString)
if (value !== SELECT_MESSAGE_STUB) {
@Suppress("UNCHECKED_CAST")
deferred.complete(value as R)
return@subscribeAlways
} else if (isUnit) { // value === stub
// unit mode: we can directly complete this selection
@Suppress("UNCHECKED_CAST")
deferred.complete(Unit as R)
}
}
}
defaultListeners.forEach { listener ->
// same to the one above
val value = listener.invoke(event, toString)
if (value !== SELECT_MESSAGE_STUB) {
@Suppress("UNCHECKED_CAST")
deferred.complete(value as R)
return@subscribeAlways
} else if (isUnit) { // value === stub
// unit mode: we can directly complete this selection
@Suppress("UNCHECKED_CAST")
deferred.complete(Unit as R)
}
}
}
deferred.await().also { coroutineContext[Job]!!.cancelChildren() }
}
@Suppress("unused")
@PublishedApi
internal suspend inline fun <reified T : MessagePacket<*, *>> T.whileSelectMessagesImpl(
timeoutMillis: Long = -1,
crossinline selectBuilder: @MessageDsl MessageSelectBuilder<T, Boolean>.() -> Unit
) {
withTimeoutOrCoroutineScope(timeoutMillis) {
var deferred: CompletableDeferred<Boolean>? = CompletableDeferred()
// ensure sequential invoking
val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
val defaltListeners: MutableList<MessageListener<T, Any?>> = mutableListOf()
object : MessageSelectBuilder<T, Boolean>(
this@whileSelectMessagesImpl,
SELECT_MESSAGE_STUB,
{ filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
listeners += filter to listener
}) {
override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
override fun obtainCurrentDeferred(): CompletableDeferred<Boolean>? = deferred
override fun default(onEvent: MessageListener<T, Boolean>) {
defaltListeners += onEvent
}
}.apply(selectBuilder)
// ensure atomic completing
subscribeAlways<T>(concurrency = Listener.ConcurrencyKind.LOCKED) { event ->
if (!this.isContextIdenticalWith(this@whileSelectMessagesImpl))
return@subscribeAlways
val toString = event.message.toString()
listeners.forEach { (filter, listener) ->
if (deferred?.isCompleted != false || !isActive)
return@subscribeAlways
if (filter.invoke(event, toString)) {
listener.invoke(event, toString).let { value ->
if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean)
return@subscribeAlways // accept the first value only
}
}
}
}
defaltListeners.forEach { listener ->
listener.invoke(event, toString).let { value ->
if (value !== SELECT_MESSAGE_STUB) {
deferred?.complete(value as Boolean)
return@subscribeAlways // accept the first value only
}
}
}
}
while (deferred?.await() == true) {
deferred = CompletableDeferred()
}
deferred = null
coroutineContext[Job]!!.cancelChildren()
}
}

View File

@ -24,6 +24,7 @@ import net.mamoe.mirai.message.FriendMessage
import net.mamoe.mirai.message.GroupMessage
import net.mamoe.mirai.message.MessagePacket
import net.mamoe.mirai.message.data.Message
import net.mamoe.mirai.message.data.first
import net.mamoe.mirai.utils.SinceMirai
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
@ -729,18 +730,31 @@ open class MessageSubscribersBuilder<M : MessagePacket<*, *>, out Ret, R : RR, R
}
/**
* 如果消息内容包含 [M] 类型的 [Message]
* 如果消息内容包含 [N] 类型的 [Message]
*/
@MessageDsl
inline fun <reified M : Message> has(): ListeningFilter =
content { message.any { it is M } }
inline fun <reified N : Message> has(): ListeningFilter =
content { message.any { it is N } }
/**
* 如果消息内容包含 [M] 类型的 [Message], 就执行 [onEvent]
*/
@MessageDsl
inline fun <reified N : Message> has(noinline onEvent: MessageListener<M, R>): Ret =
content({ message.any { it is N } }, onEvent)
@SinceMirai("0.30.0")
inline fun <reified N : Message> has(noinline onEvent: @MessageDsl suspend M.(N) -> R): Ret =
content({ message.any { it is N } }, { onEvent.invoke(this, message.first<N>()) })
/**
* 如果 [mapper] 返回值非空, 就执行 [onEvent]
*/
@MessageDsl
@SinceMirai("0.30.0")
open fun <N : Any> mapping(
mapper: M.(String) -> N?,
onEvent: @MessageDsl suspend M.(N) -> R
): Ret = always {
onEvent.invoke(this, mapper.invoke(this, message.toString()) ?: return@always stub)
}
/**
* 如果 [filter] 返回 `true`