This commit is contained in:
Him188 2023-01-18 11:22:45 +00:00
parent 0a7ebb7b4a
commit c793280413
No known key found for this signature in database
GPG Key ID: BA439CDDCF652375
3 changed files with 466 additions and 1 deletions

View File

@ -0,0 +1,203 @@
/*
* Copyright 2019-2022 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.message.protocol.encode
import net.mamoe.mirai.contact.Contact
import net.mamoe.mirai.contact.Group
import net.mamoe.mirai.contact.getMember
import net.mamoe.mirai.contact.nameCardOrNick
import net.mamoe.mirai.message.data.*
import net.mamoe.mirai.message.data.visitor.MessageVisitor
import net.mamoe.mirai.message.data.visitor.RecursiveMessageVisitor
import net.mamoe.mirai.message.data.visitor.accept
import net.mamoe.mirai.utils.toLongUnsigned
import kotlin.jvm.JvmStatic
import kotlin.math.log10
/**
* An object that stores these length properties.
* @see MessageLengthVerifier
*/
internal interface MessageLengthTokens {
val uiChars: Long
val uiImages: Long
val uiForwardNodes: Long
// val protocolTotal: Long
companion object {
val comparator: Comparator<MessageLengthTokens> =
compareBy<MessageLengthTokens> { it.uiChars }
.then(compareBy { it.uiImages })
.then(compareBy { it.uiForwardNodes })
// .then(compareBy { it.protocolTotal })
}
}
/**
* A [MessageVisitor] that calculates and verifies length of a message.
*
* Can be applied to any [Message] by calling[Message.accept] passing this visitor.
*
* Applying to [ForwardMessage] verifies the [ForwardMessage] itself and its nodes recursively.
*
* Use properties from [MessageLengthTokens] to retrieve calculation results.
* @since 2.14
*/
internal interface MessageLengthVerifier : MessageVisitor<Unit, Unit>, MessageLengthTokens {
val nestedVerifiers: List<MessageLengthVerifier>
fun isLengthValid(): Boolean
}
/**
* Gets an [MessageLengthVerifier] with specified configuration [lengthTokens] and [context].
*/
internal fun MessageLengthVerifier(
context: Contact?,
lengthTokens: MessageLengthLimits,
failfast: Boolean,
): MessageLengthVerifier {
return MessageLengthVerifierImpl(context, lengthTokens, failfast)
}
/**
* Specifies length limits for [MessageLengthVerifier]
* @sample net.mamoe.mirai.internal.message.protocol.encode.MessageLengthVerifierTest
*/
internal class MessageLengthLimits(
override val uiChars: Long = 5000,// 5000 chars
override val uiImages: Long = 50,
override val uiForwardNodes: Long = 200, // 200 nodes for each forward message
// override val protocolTotal: Long = 1 * 1000 * 1000, // 1 MB
) : MessageLengthTokens {
companion object {
@JvmStatic
val DEFAULT = MessageLengthLimits()
}
}
///////////////////////////////////////////////////////////////////////////
// IMPLEMENTATION
///////////////////////////////////////////////////////////////////////////
private inline operator fun MessageLengthTokens.compareTo(other: MessageLengthTokens): Int =
MessageLengthTokens.comparator.compare(this, other)
internal class MessageLengthVerifierImpl constructor(
private val context: Contact?,
private val limits: MessageLengthLimits,
private val failfast: Boolean,
) : RecursiveMessageVisitor<Unit>(), MessageLengthVerifier {
override val nestedVerifiers: MutableList<MessageLengthVerifierImpl> = mutableListOf()
private var hasInvalidNested: Boolean = false
/**
* 展示在 UI 的字符长度.
* @see MessageLengthLimits.uiChars
*/
override var uiChars: Long = 0
private set
override var uiImages: Long = 0
private set
override var uiForwardNodes: Long = 0
private set
override fun isFinished(): Boolean {
if (!failfast) return false
return !isLengthValid()
}
override fun isLengthValid(): Boolean = this <= limits && !hasInvalidNested
override fun visitPlainText(message: PlainText, data: Unit) {
uiChars += message.content.length
}
override fun visitAt(message: At, data: Unit) {
val length = message.displayInGroup()
?: message.target.numberOfDigitsInDecimal
uiChars += length + 1 // + `@`
}
private fun At.displayInGroup(): Long? {
return if (context is Group) {
context.getMember(target)?.nameCardOrNick?.length?.toLongUnsigned()
} else {
null
}
}
override fun visitAtAll(message: AtAll, data: Unit) {
uiChars += message.content.length
}
override fun visitFace(message: Face, data: Unit) {
uiChars += 4
}
override fun visitImage(message: Image, data: Unit) {
uiImages++
// protocolTotal = TypicalMessageSize.image
}
override fun visitFlashImage(message: FlashImage, data: Unit) {
visitImage(message.image, data)
}
override fun visitQuoteReply(message: QuoteReply, data: Unit) {
message.source.originalMessage.accept(this)
}
override fun visitForwardMessage(message: ForwardMessage, data: Unit) {
val nested = MessageLengthVerifierImpl(context, limits, failfast)
nestedVerifiers.add(nested)
for (node in message.nodeList) {
if (nested.isFinished()) break
nested.visitForwardMessageNode(node)
}
if (!nested.isLengthValid()) {
hasInvalidNested = true
}
}
fun visitForwardMessageNode(node: ForwardMessage.INode) {
uiForwardNodes++
node.messageChain.accept(this)
}
companion object {
val Long.numberOfDigitsInDecimal: Long
get() = if (this == 0L) 1 else 1 + log10(this.toDouble()).toLong()
}
}
//private object TypicalMessageSize {
// @Serializable
// private class Elements(
// @ProtoNumber(1) val elements: List<ImMsgBody.Elem>
// )
//
// val image: Long = kotlin.run {
// val elems = MessageProtocolFacade.encode(
// chain = Image("{01E9451B-70ED-EAE3-B37C-101F1EEBF5B5}.jpg").toMessageChain(),
// messageTarget = null,
// withGeneralFlags = false,
// isForward = false
// )
// ProtoBuf.encodeToByteArray(Elements.serializer(), Elements(elems)).size.toLongUnsigned()
// }
//}

View File

@ -63,7 +63,7 @@ internal class ForwardMessageProtocol : MessageProtocol() {
forward: ForwardMessage,
contact: AbstractContact
) {
check(forward.nodeList.size <= 200) {
check(forward.nodeList.size <= MAX_NODES) {
throw MessageTooLargeException(
contact, forward, forward,
"ForwardMessage allows up to 200 nodes, but found ${forward.nodeList.size}"
@ -71,4 +71,8 @@ internal class ForwardMessageProtocol : MessageProtocol() {
}
}
}
companion object {
const val MAX_NODES : Int = 200
}
}

View File

@ -0,0 +1,258 @@
/*
* Copyright 2019-2022 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.message.protocol.encode
import net.mamoe.mirai.contact.MemberPermission
import net.mamoe.mirai.internal.MockBot
import net.mamoe.mirai.internal.message.protocol.encode.MessageLengthVerifierImpl.Companion.numberOfDigitsInDecimal
import net.mamoe.mirai.internal.notice.processors.GroupExtensions
import net.mamoe.mirai.internal.test.AbstractTest
import net.mamoe.mirai.message.data.*
import net.mamoe.mirai.message.data.visitor.accept
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
/**
* @see MessageLengthVerifier
*/
internal class MessageLengthVerifierTest : AbstractTest(), GroupExtensions {
private val bot = MockBot { }
private val group = bot.addGroup(123L, 1111L).apply {
addMember(1111L, permission = MemberPermission.OWNER)
}
private companion object {
private val fiveThousandChars = PlainText("a".repeat(5000))
private val anImage =
// [mirai:image:{9D97AF44-0007-5F86-6567-C0BD3F6A5C5C}.gif, width=211, height=243, size=108292, type=GIF, isEmoji=true]
Image("{9D97AF44-0007-5F86-6567-C0BD3F6A5C5C}.gif") { // guess what it is?
width = 211
height = 243
size = 108292
type = ImageType.GIF
isEmoji = true
}
}
@Test
fun numberOfDigitsInDecimal() {
assertEquals(1, 0L.numberOfDigitsInDecimal)
assertEquals(2, 10L.numberOfDigitsInDecimal)
assertEquals(2, 11L.numberOfDigitsInDecimal)
assertEquals(4, 1000L.numberOfDigitsInDecimal)
assertEquals(4, 1001L.numberOfDigitsInDecimal)
assertEquals(4, 1999L.numberOfDigitsInDecimal)
}
@Test
fun `initial values`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = false)
assertEquals(0, verifier.uiChars)
assertEquals(0, verifier.uiImages)
assertEquals(0, verifier.uiForwardNodes)
assertTrue(verifier.isLengthValid())
}
@Test
fun `count PlainTexts`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = false)
val chain = messageChainOf(fiveThousandChars)
chain.accept(verifier)
assertEquals(5000, verifier.uiChars)
assertEquals(0, verifier.uiImages)
assertEquals(0, verifier.uiForwardNodes)
}
@Test
fun `count Images`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = false)
val chain = messageChainOf(anImage, anImage)
chain.accept(verifier)
assertEquals(0, verifier.uiChars)
assertEquals(2, verifier.uiImages)
assertEquals(0, verifier.uiForwardNodes)
}
@Test
fun `count Images and PlainTexts`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = false)
val chain = messageChainOf(fiveThousandChars, anImage)
chain.accept(verifier)
assertEquals(5000, verifier.uiChars)
assertEquals(1, verifier.uiImages)
assertEquals(0, verifier.uiForwardNodes)
}
@Test
fun failfast() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = true)
val chain = messageChainOf(fiveThousandChars, anImage, fiveThousandChars, fiveThousandChars, anImage)
chain.accept(verifier)
assertEquals(fiveThousandChars.content.length * 2L, verifier.uiChars)
assertEquals(1, verifier.uiImages)
assertEquals(0, verifier.uiForwardNodes)
assertFalse(verifier.isLengthValid())
}
@Test
fun `disable failfast`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = false)
val chain = messageChainOf(fiveThousandChars, anImage, fiveThousandChars, fiveThousandChars, anImage)
chain.accept(verifier)
assertEquals(fiveThousandChars.content.length * 3L, verifier.uiChars)
assertEquals(2, verifier.uiImages)
assertEquals(0, verifier.uiForwardNodes)
assertFalse(verifier.isLengthValid())
}
@Test
fun `limits are inclusive`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = true)
val chain = messageChainOf(fiveThousandChars, anImage)
chain.accept(verifier)
assertEquals(5000, verifier.uiChars)
assertEquals(1, verifier.uiImages)
assertEquals(0, verifier.uiForwardNodes)
assertTrue(verifier.isLengthValid())
}
@Test
fun `count recursively ForwardMessage nodes`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = false)
val chain = messageChainOf(buildForwardMessage(group) {
1111 says fiveThousandChars
1111 says anImage
1111 says fiveThousandChars
1111 says fiveThousandChars
1111 says anImage
})
chain.accept(verifier)
assertEquals(fiveThousandChars.content.length * 3L, verifier.uiChars)
assertEquals(2, verifier.uiImages)
assertEquals(5, verifier.uiForwardNodes)
assertFalse(verifier.isLengthValid())
}
@Test
fun `count deeply recursively ForwardMessage nodes`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = false)
val chain = messageChainOf(buildForwardMessage(group) {
1111 says fiveThousandChars
1111 says anImage
1111 says fiveThousandChars
1111 says fiveThousandChars
1111 says anImage
1111 says buildForwardMessage(group) {
1111 says fiveThousandChars
1111 says anImage
1111 says fiveThousandChars
1111 says fiveThousandChars
1111 says anImage
}
})
chain.accept(verifier)
assertEquals(fiveThousandChars.content.length * 3L * 2, verifier.uiChars)
assertEquals(2 * 2, verifier.uiImages)
assertEquals(6 + 5, verifier.uiForwardNodes)
assertFalse(verifier.isLengthValid())
}
@Test
fun `count deeply recursively ForwardMessage nodes failfast`() {
val limits = MessageLengthLimits(
uiChars = 5000,
uiImages = 50,
uiForwardNodes = 200,
)
val verifier = MessageLengthVerifier(null, limits, failfast = true)
val chain = messageChainOf(buildForwardMessage(group) {
1111 says fiveThousandChars
1111 says anImage
1111 says fiveThousandChars
1111 says fiveThousandChars
1111 says anImage
1111 says buildForwardMessage(group) {
1111 says fiveThousandChars
1111 says anImage
1111 says fiveThousandChars
1111 says fiveThousandChars
1111 says anImage
}
})
chain.accept(verifier)
assertEquals(fiveThousandChars.content.length * 1L, verifier.uiChars)
assertEquals(1, verifier.uiImages)
assertEquals(2, verifier.uiForwardNodes)
assertFalse(verifier.isLengthValid())
}
}