From c7932804135ed61479779bdff89101348139ab06 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Wed, 18 Jan 2023 11:22:45 +0000
Subject: [PATCH] v

---
 .../protocol/encode/MessageLengthVerifier.kt  | 203 ++++++++++++++
 .../protocol/impl/ForwardMessageProtocol.kt   |   6 +-
 .../encode/MessageLengthVerifierTest.kt       | 258 ++++++++++++++++++
 3 files changed, 466 insertions(+), 1 deletion(-)
 create mode 100644 mirai-core/src/commonMain/kotlin/message/protocol/encode/MessageLengthVerifier.kt
 create mode 100644 mirai-core/src/commonTest/kotlin/message/protocol/encode/MessageLengthVerifierTest.kt

diff --git a/mirai-core/src/commonMain/kotlin/message/protocol/encode/MessageLengthVerifier.kt b/mirai-core/src/commonMain/kotlin/message/protocol/encode/MessageLengthVerifier.kt
new file mode 100644
index 000000000..ed441e637
--- /dev/null
+++ b/mirai-core/src/commonMain/kotlin/message/protocol/encode/MessageLengthVerifier.kt
@@ -0,0 +1,203 @@
+/*
+ * Copyright 2019-2022 Mamoe Technologies and contributors.
+ *
+ * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
+ * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
+ *
+ * https://github.com/mamoe/mirai/blob/dev/LICENSE
+ */
+
+package net.mamoe.mirai.internal.message.protocol.encode
+
+import net.mamoe.mirai.contact.Contact
+import net.mamoe.mirai.contact.Group
+import net.mamoe.mirai.contact.getMember
+import net.mamoe.mirai.contact.nameCardOrNick
+import net.mamoe.mirai.message.data.*
+import net.mamoe.mirai.message.data.visitor.MessageVisitor
+import net.mamoe.mirai.message.data.visitor.RecursiveMessageVisitor
+import net.mamoe.mirai.message.data.visitor.accept
+import net.mamoe.mirai.utils.toLongUnsigned
+import kotlin.jvm.JvmStatic
+import kotlin.math.log10
+
+
+/**
+ * An object that stores these length properties.
+ * @see MessageLengthVerifier
+ */
+internal interface MessageLengthTokens {
+    val uiChars: Long
+    val uiImages: Long
+    val uiForwardNodes: Long
+//    val protocolTotal: Long
+
+    companion object {
+        val comparator: Comparator<MessageLengthTokens> =
+            compareBy<MessageLengthTokens> { it.uiChars }
+                .then(compareBy { it.uiImages })
+                .then(compareBy { it.uiForwardNodes })
+//                .then(compareBy { it.protocolTotal })
+    }
+}
+
+/**
+ * A [MessageVisitor] that calculates and verifies length of a message.
+ *
+ * Can be applied to any [Message] by calling[Message.accept] passing this visitor.
+ *
+ * Applying to [ForwardMessage] verifies the [ForwardMessage] itself and its nodes recursively.
+ *
+ * Use properties from [MessageLengthTokens] to retrieve calculation results.
+ * @since 2.14
+ */
+internal interface MessageLengthVerifier : MessageVisitor<Unit, Unit>, MessageLengthTokens {
+    val nestedVerifiers: List<MessageLengthVerifier>
+
+    fun isLengthValid(): Boolean
+}
+
+/**
+ * Gets an [MessageLengthVerifier] with specified configuration [lengthTokens] and [context].
+ */
+internal fun MessageLengthVerifier(
+    context: Contact?,
+    lengthTokens: MessageLengthLimits,
+    failfast: Boolean,
+): MessageLengthVerifier {
+    return MessageLengthVerifierImpl(context, lengthTokens, failfast)
+}
+
+/**
+ * Specifies length limits for [MessageLengthVerifier]
+ * @sample net.mamoe.mirai.internal.message.protocol.encode.MessageLengthVerifierTest
+ */
+internal class MessageLengthLimits(
+    override val uiChars: Long = 5000,// 5000 chars
+    override val uiImages: Long = 50,
+    override val uiForwardNodes: Long = 200, // 200 nodes for each forward message
+
+//    override val protocolTotal: Long = 1 * 1000 * 1000, // 1 MB
+) : MessageLengthTokens {
+    companion object {
+        @JvmStatic
+        val DEFAULT = MessageLengthLimits()
+    }
+}
+
+///////////////////////////////////////////////////////////////////////////
+// IMPLEMENTATION
+///////////////////////////////////////////////////////////////////////////
+
+private inline operator fun MessageLengthTokens.compareTo(other: MessageLengthTokens): Int =
+    MessageLengthTokens.comparator.compare(this, other)
+
+internal class MessageLengthVerifierImpl constructor(
+    private val context: Contact?,
+    private val limits: MessageLengthLimits,
+    private val failfast: Boolean,
+) : RecursiveMessageVisitor<Unit>(), MessageLengthVerifier {
+    override val nestedVerifiers: MutableList<MessageLengthVerifierImpl> = mutableListOf()
+    private var hasInvalidNested: Boolean = false
+
+    /**
+     * 展示在 UI 的字符长度.
+     * @see MessageLengthLimits.uiChars
+     */
+    override var uiChars: Long = 0
+        private set
+
+    override var uiImages: Long = 0
+        private set
+
+    override var uiForwardNodes: Long = 0
+        private set
+
+    override fun isFinished(): Boolean {
+        if (!failfast) return false
+        return !isLengthValid()
+    }
+
+    override fun isLengthValid(): Boolean = this <= limits && !hasInvalidNested
+
+    override fun visitPlainText(message: PlainText, data: Unit) {
+        uiChars += message.content.length
+    }
+
+    override fun visitAt(message: At, data: Unit) {
+        val length = message.displayInGroup()
+            ?: message.target.numberOfDigitsInDecimal
+        uiChars += length + 1 // + `@`
+    }
+
+    private fun At.displayInGroup(): Long? {
+        return if (context is Group) {
+            context.getMember(target)?.nameCardOrNick?.length?.toLongUnsigned()
+        } else {
+            null
+        }
+    }
+
+    override fun visitAtAll(message: AtAll, data: Unit) {
+        uiChars += message.content.length
+    }
+
+    override fun visitFace(message: Face, data: Unit) {
+        uiChars += 4
+    }
+
+    override fun visitImage(message: Image, data: Unit) {
+        uiImages++
+//        protocolTotal = TypicalMessageSize.image
+    }
+
+    override fun visitFlashImage(message: FlashImage, data: Unit) {
+        visitImage(message.image, data)
+    }
+
+    override fun visitQuoteReply(message: QuoteReply, data: Unit) {
+        message.source.originalMessage.accept(this)
+    }
+
+    override fun visitForwardMessage(message: ForwardMessage, data: Unit) {
+        val nested = MessageLengthVerifierImpl(context, limits, failfast)
+        nestedVerifiers.add(nested)
+        
+        for (node in message.nodeList) {
+            if (nested.isFinished()) break
+            nested.visitForwardMessageNode(node)
+        }
+        
+        if (!nested.isLengthValid()) {
+            hasInvalidNested = true
+        }
+    }
+
+    fun visitForwardMessageNode(node: ForwardMessage.INode) {
+        uiForwardNodes++
+        node.messageChain.accept(this)
+    }
+
+    companion object {
+        val Long.numberOfDigitsInDecimal: Long
+            get() = if (this == 0L) 1 else 1 + log10(this.toDouble()).toLong()
+    }
+}
+
+
+//private object TypicalMessageSize {
+//    @Serializable
+//    private class Elements(
+//        @ProtoNumber(1) val elements: List<ImMsgBody.Elem>
+//    )
+//
+//    val image: Long = kotlin.run {
+//        val elems = MessageProtocolFacade.encode(
+//            chain = Image("{01E9451B-70ED-EAE3-B37C-101F1EEBF5B5}.jpg").toMessageChain(),
+//            messageTarget = null,
+//            withGeneralFlags = false,
+//            isForward = false
+//        )
+//        ProtoBuf.encodeToByteArray(Elements.serializer(), Elements(elems)).size.toLongUnsigned()
+//    }
+//}
\ No newline at end of file
diff --git a/mirai-core/src/commonMain/kotlin/message/protocol/impl/ForwardMessageProtocol.kt b/mirai-core/src/commonMain/kotlin/message/protocol/impl/ForwardMessageProtocol.kt
index ad4c1ec2e..9943c7ad6 100644
--- a/mirai-core/src/commonMain/kotlin/message/protocol/impl/ForwardMessageProtocol.kt
+++ b/mirai-core/src/commonMain/kotlin/message/protocol/impl/ForwardMessageProtocol.kt
@@ -63,7 +63,7 @@ internal class ForwardMessageProtocol : MessageProtocol() {
             forward: ForwardMessage,
             contact: AbstractContact
         ) {
-            check(forward.nodeList.size <= 200) {
+            check(forward.nodeList.size <= MAX_NODES) {
                 throw MessageTooLargeException(
                     contact, forward, forward,
                     "ForwardMessage allows up to 200 nodes, but found ${forward.nodeList.size}"
@@ -71,4 +71,8 @@ internal class ForwardMessageProtocol : MessageProtocol() {
             }
         }
     }
+    
+    companion object {
+        const val MAX_NODES : Int = 200
+    }
 }
\ No newline at end of file
diff --git a/mirai-core/src/commonTest/kotlin/message/protocol/encode/MessageLengthVerifierTest.kt b/mirai-core/src/commonTest/kotlin/message/protocol/encode/MessageLengthVerifierTest.kt
new file mode 100644
index 000000000..4846196e1
--- /dev/null
+++ b/mirai-core/src/commonTest/kotlin/message/protocol/encode/MessageLengthVerifierTest.kt
@@ -0,0 +1,258 @@
+/*
+ * Copyright 2019-2022 Mamoe Technologies and contributors.
+ *
+ * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
+ * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
+ *
+ * https://github.com/mamoe/mirai/blob/dev/LICENSE
+ */
+
+package net.mamoe.mirai.internal.message.protocol.encode
+
+import net.mamoe.mirai.contact.MemberPermission
+import net.mamoe.mirai.internal.MockBot
+import net.mamoe.mirai.internal.message.protocol.encode.MessageLengthVerifierImpl.Companion.numberOfDigitsInDecimal
+import net.mamoe.mirai.internal.notice.processors.GroupExtensions
+import net.mamoe.mirai.internal.test.AbstractTest
+import net.mamoe.mirai.message.data.*
+import net.mamoe.mirai.message.data.visitor.accept
+import kotlin.test.Test
+import kotlin.test.assertEquals
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
+
+/**
+ * @see MessageLengthVerifier
+ */
+internal class MessageLengthVerifierTest : AbstractTest(), GroupExtensions {
+    private val bot = MockBot { }
+    private val group = bot.addGroup(123L, 1111L).apply {
+        addMember(1111L, permission = MemberPermission.OWNER)
+    }
+
+    private companion object {
+        private val fiveThousandChars = PlainText("a".repeat(5000))
+        private val anImage =
+            // [mirai:image:{9D97AF44-0007-5F86-6567-C0BD3F6A5C5C}.gif, width=211, height=243, size=108292, type=GIF, isEmoji=true]
+            Image("{9D97AF44-0007-5F86-6567-C0BD3F6A5C5C}.gif") { // guess what it is?
+                width = 211
+                height = 243
+                size = 108292
+                type = ImageType.GIF
+                isEmoji = true
+            }
+    }
+
+    @Test
+    fun numberOfDigitsInDecimal() {
+        assertEquals(1, 0L.numberOfDigitsInDecimal)
+        assertEquals(2, 10L.numberOfDigitsInDecimal)
+        assertEquals(2, 11L.numberOfDigitsInDecimal)
+        assertEquals(4, 1000L.numberOfDigitsInDecimal)
+        assertEquals(4, 1001L.numberOfDigitsInDecimal)
+        assertEquals(4, 1999L.numberOfDigitsInDecimal)
+    }
+
+    @Test
+    fun `initial values`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = false)
+
+        assertEquals(0, verifier.uiChars)
+        assertEquals(0, verifier.uiImages)
+        assertEquals(0, verifier.uiForwardNodes)
+        assertTrue(verifier.isLengthValid())
+    }
+
+    @Test
+    fun `count PlainTexts`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = false)
+
+        val chain = messageChainOf(fiveThousandChars)
+        chain.accept(verifier)
+        assertEquals(5000, verifier.uiChars)
+        assertEquals(0, verifier.uiImages)
+        assertEquals(0, verifier.uiForwardNodes)
+    }
+
+    @Test
+    fun `count Images`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = false)
+
+        val chain = messageChainOf(anImage, anImage)
+        chain.accept(verifier)
+        assertEquals(0, verifier.uiChars)
+        assertEquals(2, verifier.uiImages)
+        assertEquals(0, verifier.uiForwardNodes)
+    }
+
+    @Test
+    fun `count Images and PlainTexts`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = false)
+
+        val chain = messageChainOf(fiveThousandChars, anImage)
+        chain.accept(verifier)
+        assertEquals(5000, verifier.uiChars)
+        assertEquals(1, verifier.uiImages)
+        assertEquals(0, verifier.uiForwardNodes)
+    }
+
+    @Test
+    fun failfast() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = true)
+
+        val chain = messageChainOf(fiveThousandChars, anImage, fiveThousandChars, fiveThousandChars, anImage)
+        chain.accept(verifier)
+        assertEquals(fiveThousandChars.content.length * 2L, verifier.uiChars)
+        assertEquals(1, verifier.uiImages)
+        assertEquals(0, verifier.uiForwardNodes)
+        assertFalse(verifier.isLengthValid())
+    }
+
+    @Test
+    fun `disable failfast`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = false)
+
+        val chain = messageChainOf(fiveThousandChars, anImage, fiveThousandChars, fiveThousandChars, anImage)
+        chain.accept(verifier)
+        assertEquals(fiveThousandChars.content.length * 3L, verifier.uiChars)
+        assertEquals(2, verifier.uiImages)
+        assertEquals(0, verifier.uiForwardNodes)
+        assertFalse(verifier.isLengthValid())
+    }
+
+    @Test
+    fun `limits are inclusive`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = true)
+
+        val chain = messageChainOf(fiveThousandChars, anImage)
+        chain.accept(verifier)
+        assertEquals(5000, verifier.uiChars)
+        assertEquals(1, verifier.uiImages)
+        assertEquals(0, verifier.uiForwardNodes)
+        assertTrue(verifier.isLengthValid())
+    }
+
+
+    @Test
+    fun `count recursively ForwardMessage nodes`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = false)
+
+        val chain = messageChainOf(buildForwardMessage(group) {
+            1111 says fiveThousandChars
+            1111 says anImage
+            1111 says fiveThousandChars
+            1111 says fiveThousandChars
+            1111 says anImage
+        })
+
+        chain.accept(verifier)
+        assertEquals(fiveThousandChars.content.length * 3L, verifier.uiChars)
+        assertEquals(2, verifier.uiImages)
+        assertEquals(5, verifier.uiForwardNodes)
+        assertFalse(verifier.isLengthValid())
+    }
+
+    @Test
+    fun `count deeply recursively ForwardMessage nodes`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = false)
+
+        val chain = messageChainOf(buildForwardMessage(group) {
+            1111 says fiveThousandChars
+            1111 says anImage
+            1111 says fiveThousandChars
+            1111 says fiveThousandChars
+            1111 says anImage
+
+            1111 says buildForwardMessage(group) {
+                1111 says fiveThousandChars
+                1111 says anImage
+                1111 says fiveThousandChars
+                1111 says fiveThousandChars
+                1111 says anImage
+            }
+        })
+
+        chain.accept(verifier)
+        assertEquals(fiveThousandChars.content.length * 3L * 2, verifier.uiChars)
+        assertEquals(2 * 2, verifier.uiImages)
+        assertEquals(6 + 5, verifier.uiForwardNodes)
+        assertFalse(verifier.isLengthValid())
+    }
+
+    @Test
+    fun `count deeply recursively ForwardMessage nodes failfast`() {
+        val limits = MessageLengthLimits(
+            uiChars = 5000,
+            uiImages = 50,
+            uiForwardNodes = 200,
+        )
+        val verifier = MessageLengthVerifier(null, limits, failfast = true)
+
+        val chain = messageChainOf(buildForwardMessage(group) {
+            1111 says fiveThousandChars
+            1111 says anImage
+            1111 says fiveThousandChars
+            1111 says fiveThousandChars
+            1111 says anImage
+
+            1111 says buildForwardMessage(group) {
+                1111 says fiveThousandChars
+                1111 says anImage
+                1111 says fiveThousandChars
+                1111 says fiveThousandChars
+                1111 says anImage
+            }
+        })
+
+        chain.accept(verifier)
+        assertEquals(fiveThousandChars.content.length * 1L, verifier.uiChars)
+        assertEquals(1, verifier.uiImages)
+        assertEquals(2, verifier.uiForwardNodes)
+        assertFalse(verifier.isLengthValid())
+    }
+}
\ No newline at end of file