Fix EventChannel.filter and .filterIsInstance when chained, fix #953

This commit is contained in:
Him188 2021-02-03 11:41:05 +08:00
parent c261b8b00e
commit 0bbea5706f
2 changed files with 51 additions and 12 deletions

View File

@ -28,6 +28,7 @@ import net.mamoe.mirai.internal.event.ListenerRegistry
import net.mamoe.mirai.internal.event.registerEventHandler import net.mamoe.mirai.internal.event.registerEventHandler
import net.mamoe.mirai.utils.MiraiExperimentalApi import net.mamoe.mirai.utils.MiraiExperimentalApi
import net.mamoe.mirai.utils.MiraiLogger import net.mamoe.mirai.utils.MiraiLogger
import net.mamoe.mirai.utils.cast
import java.util.function.Consumer import java.util.function.Consumer
import kotlin.coroutines.CoroutineContext import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext import kotlin.coroutines.EmptyCoroutineContext
@ -129,11 +130,12 @@ public open class EventChannel<out BaseEvent : Event> @JvmOverloads internal con
*/ */
@JvmSynthetic @JvmSynthetic
public fun filter(filter: suspend (event: BaseEvent) -> Boolean): EventChannel<BaseEvent> { public fun filter(filter: suspend (event: BaseEvent) -> Boolean): EventChannel<BaseEvent> {
val parent = this
return object : EventChannel<BaseEvent>(baseEventClass, defaultCoroutineContext) { return object : EventChannel<BaseEvent>(baseEventClass, defaultCoroutineContext) {
private inline val innerThis get() = this private inline val innerThis get() = this
override fun <E : Event> (suspend (E) -> ListeningStatus).intercepted(): suspend (E) -> ListeningStatus { override fun <E : Event> (suspend (E) -> ListeningStatus).intercepted(): suspend (E) -> ListeningStatus {
return { ev -> val thisIntercepted: suspend (E) -> ListeningStatus = { ev ->
val filterResult = try { val filterResult = try {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
baseEventClass.isInstance(ev) && filter(ev as BaseEvent) baseEventClass.isInstance(ev) && filter(ev as BaseEvent)
@ -141,9 +143,10 @@ public open class EventChannel<out BaseEvent : Event> @JvmOverloads internal con
if (e is ExceptionInEventChannelFilterException) throw e // wrapped by another filter if (e is ExceptionInEventChannelFilterException) throw e // wrapped by another filter
throw ExceptionInEventChannelFilterException(ev, innerThis, cause = e) throw ExceptionInEventChannelFilterException(ev, innerThis, cause = e)
} }
if (filterResult) this.invoke(ev) if (filterResult) this@intercepted.invoke(ev)
else ListeningStatus.LISTENING else ListeningStatus.LISTENING
} }
return parent.run { thisIntercepted.intercepted() }
} }
} }
} }
@ -203,16 +206,7 @@ public open class EventChannel<out BaseEvent : Event> @JvmOverloads internal con
* @see filter 获取更多信息 * @see filter 获取更多信息
*/ */
public fun <E : Event> filterIsInstance(kClass: KClass<out E>): EventChannel<E> { public fun <E : Event> filterIsInstance(kClass: KClass<out E>): EventChannel<E> {
return object : EventChannel<E>(kClass, defaultCoroutineContext) { return filter { kClass.isInstance(it) }.cast()
private inline val innerThis get() = this
override fun <E1 : Event> (suspend (E1) -> ListeningStatus).intercepted(): suspend (E1) -> ListeningStatus {
return { ev ->
if (kClass.isInstance(ev)) this.invoke(ev)
else ListeningStatus.LISTENING
}
}
}
} }
/** /**

View File

@ -9,17 +9,62 @@
package net.mamoe.mirai.event package net.mamoe.mirai.event
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import net.mamoe.mirai.event.events.FriendEvent import net.mamoe.mirai.event.events.FriendEvent
import net.mamoe.mirai.event.events.GroupEvent import net.mamoe.mirai.event.events.GroupEvent
import net.mamoe.mirai.event.events.GroupMessageEvent import net.mamoe.mirai.event.events.GroupMessageEvent
import net.mamoe.mirai.event.events.MessageEvent import net.mamoe.mirai.event.events.MessageEvent
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import java.lang.IllegalStateException
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
import kotlin.coroutines.suspendCoroutine
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
internal class EventChannelTest { internal class EventChannelTest {
suspend fun suspendCall() { suspend fun suspendCall() {
} }
data class TE(
val x: Int
) : AbstractEvent()
@Test
fun testFilter() {
runBlocking {
val received = suspendCoroutine<Int> { cont ->
GlobalEventChannel
.filterIsInstance<TE>()
.filter {
true
}
.filter {
it.x == 2
}
.filter {
true
}
.subscribeOnce<TE> {
cont.resume(it.x)
}
launch {
println("Broadcast 1")
TE(1).broadcast()
println("Broadcast 2")
TE(2).broadcast()
println("Broadcast done")
}
}
assertEquals(2, received)
}
}
@Suppress("UNUSED_VARIABLE") @Suppress("UNUSED_VARIABLE")
@Test @Test
fun testVariance() { fun testVariance() {