Ensure for all MessageChain subclasses, equals, hashCode give consistent results.

This commit is contained in:
Him188 2022-05-30 16:19:09 +01:00
parent 7088835402
commit 732e61e37d
5 changed files with 57 additions and 38 deletions

View File

@ -13,18 +13,18 @@ import net.mamoe.mirai.message.data.visitor.MessageVisitor
import net.mamoe.mirai.message.data.visitor.RecursiveMessageVisitor
import net.mamoe.mirai.message.data.visitor.accept
import net.mamoe.mirai.utils.MiraiInternalApi
import net.mamoe.mirai.utils.isSameType
/**
* One after one, hierarchically.
* @since 2.12
*/
@MiraiInternalApi
@Suppress("EXPOSED_SUPER_CLASS")
public class CombinedMessage @MessageChainConstructor constructor(
@MiraiInternalApi public val element: Message,
@MiraiInternalApi public val tail: Message,
@MiraiInternalApi public override val hasConstrainSingle: Boolean
) : MessageChainImpl, List<SingleMessage> {
) : AbstractMessageChain(), List<SingleMessage> {
override fun <D, R> accept(visitor: MessageVisitor<D, R>, data: D): R {
return visitor.visitCombinedMessage(this, data)
}
@ -165,25 +165,6 @@ public class CombinedMessage @MessageChainConstructor constructor(
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (!isSameType(this, other)) return false
if (element != other.element) return false
if (tail != other.tail) return false
if (hasConstrainSingle != other.hasConstrainSingle) return false
return true
}
override fun hashCode(): Int {
var result = element.hashCode()
result = 31 * result + tail.hashCode()
result = 31 * result + hasConstrainSingle.hashCode()
return result
}
///////////////////////////////////////////////////////////////////////////
// slow operations
///////////////////////////////////////////////////////////////////////////

View File

@ -164,7 +164,7 @@ public interface Message {
* - ...
*
* @see toString 得到包含 mirai 消息元素代码的, 易读的字符串
* @see contentEquals
* @see chainEquals
* @see Message.content Kotlin 扩展
*/
public fun contentToString(): String
@ -175,8 +175,8 @@ public interface Message {
* [strict] `true` , 还会额外判断每个消息元素的类型, 顺序和属性. [Image] 会判断 [Image.imageId]
*
* **有关 [strict]:** 每个 [Image] [contentToString] 都是 `"[图片]"`,
* [strict] `false` [contentEquals] 会得到 `true`,
* 而为 `true` 时由于 [Image.imageId] 会被比较, 两张不同的图片的 [contentEquals] 会是 `false`.
* [strict] `false` [chainEquals] 会得到 `true`,
* 而为 `true` 时由于 [Image.imageId] 会被比较, 两张不同的图片的 [chainEquals] 会是 `false`.
*
* @param ignoreCase `true` 时忽略大小写
*/

View File

@ -363,7 +363,10 @@ public fun emptyMessageChain(): MessageChain = EmptyMessageChain
replaceWith = ReplaceWith("emptyMessageChain()", "net.mamoe.mirai.message.data.emptyMessageChain")
)
@DeprecatedSinceMirai(warningSince = "2.12")
public object EmptyMessageChain : MessageChain, List<SingleMessage> by emptyList(), MessageChainImpl {
@Suppress("EXPOSED_SUPER_CLASS")
public object EmptyMessageChain : MessageChain, List<SingleMessage> by emptyList(),
AbstractMessageChain(), DirectSizeAccess, DirectToStringAccess {
override val size: Int get() = 0
override fun toString(): String = ""
@ -378,9 +381,6 @@ public object EmptyMessageChain : MessageChain, List<SingleMessage> by emptyList
override fun appendMiraiCodeTo(builder: StringBuilder) {
}
override fun equals(other: Any?): Boolean = other === this
override fun hashCode(): Int = 1
override fun iterator(): Iterator<SingleMessage> = EmptyMessageChainIterator
@Suppress("DeprecatedCallableAddReplaceWith")

View File

@ -18,6 +18,8 @@ import net.mamoe.mirai.message.data.Image.Key.IMAGE_ID_REGEX
import net.mamoe.mirai.message.data.Image.Key.IMAGE_RESOURCE_ID_REGEX_1
import net.mamoe.mirai.message.data.Image.Key.IMAGE_RESOURCE_ID_REGEX_2
import net.mamoe.mirai.message.data.visitor.MessageVisitor
import net.mamoe.mirai.message.data.visitor.RecursiveMessageVisitor
import net.mamoe.mirai.message.data.visitor.acceptChildren
import net.mamoe.mirai.utils.MiraiInternalApi
import net.mamoe.mirai.utils.asImmutable
import net.mamoe.mirai.utils.castOrNull
@ -60,20 +62,59 @@ internal fun Message.contentEqualsStrictImpl(another: Message, ignoreCase: Boole
}
internal sealed interface MessageChainImpl : MessageChain {
internal sealed class AbstractMessageChain : MessageChain {
/**
* 去重算法 v1 - 2.12:
* 在连接时若只有 0-1 方包含 [ConstrainSingle], 则使用 [CombinedMessage] 优化性能. 否则使用旧版复杂去重算法构造 [LinearMessageChainImpl].
*/
@MiraiInternalApi
val hasConstrainSingle: Boolean
abstract val hasConstrainSingle: Boolean
override fun hashCode(): Int {
var result = 1
acceptChildren(object : RecursiveMessageVisitor<Unit>() {
// override fun visitMessageChain(messageChain: MessageChain, data: Unit) {
// result = 31 * result + messageChain.hashCode()
// // do not call children
// }
// ensure `messageChainOf(messageChainOf(AtAll))` and `messageChainOf(AtAll)` get same hash code.
override fun visitSingleMessage(message: SingleMessage, data: Unit) {
result = 31 * result + message.hashCode()
super.visitSingleMessage(message, data)
}
})
return result
}
override fun equals(other: Any?): Boolean {
if (other === null) return false
if (other !is MessageChain) return false
return chainEquals(this, other)
}
private companion object {
private fun chainEquals(a: MessageChain, b: MessageChain): Boolean {
if (a.size != b.size) return false // Averagely faster even if we may end up counting size.
val itr1 = a.iterator()
val itr2 = b.iterator()
for (singleMessage in itr1) {
if (!itr2.hasNext()) return false
val n = itr2.next()
if (singleMessage != n) return false
}
return true
}
}
}
internal val Message.hasConstrainSingle: Boolean
get() {
if (this is SingleMessage) return this is ConstrainSingle
// now `this` is MessageChain
return this.castOrNull<MessageChainImpl>()?.hasConstrainSingle ?: true // for external type, assume they do
return this.castOrNull<AbstractMessageChain>()?.hasConstrainSingle ?: true // for external type, assume they do
}
/**
@ -137,7 +178,7 @@ internal class LinearMessageChainImpl @MessageChainConstructor private construct
@JvmField
internal val delegate: List<SingleMessage>,
override val hasConstrainSingle: Boolean
) : Message, MessageChain, List<SingleMessage> by delegate, MessageChainImpl,
) : Message, MessageChain, List<SingleMessage> by delegate, AbstractMessageChain(),
DirectSizeAccess, DirectToStringAccess {
override val size: Int get() = delegate.size
override fun iterator(): Iterator<SingleMessage> = delegate.iterator()
@ -148,9 +189,6 @@ internal class LinearMessageChainImpl @MessageChainConstructor private construct
private val contentToStringTemp: String by lazy { this.delegate.joinToString("") { it.contentToString() } }
override fun contentToString(): String = contentToStringTemp
override fun hashCode(): Int = delegate.hashCode()
override fun equals(other: Any?): Boolean = other is LinearMessageChainImpl && other.delegate == this.delegate
override fun <D> acceptChildren(visitor: MessageVisitor<D, *>, data: D) {
for (singleMessage in delegate) {
singleMessage.accept(visitor, data)

View File

@ -16,10 +16,10 @@ internal class MessageChainImplTest {
@OptIn(MessageChainConstructor::class)
@Test
fun allInternalImplementationsOfMessageChainAreMessageChainImpl() {
assertIs<MessageChainImpl>(CombinedMessage(AtAll, AtAll, false))
assertIs<MessageChainImpl>(emptyMessageChain())
assertIs<AbstractMessageChain>(CombinedMessage(AtAll, AtAll, false))
assertIs<AbstractMessageChain>(emptyMessageChain())
val linear = LinearMessageChainImpl.create(listOf(AtAll), true)
assertIs<LinearMessageChainImpl>(linear)
assertIs<MessageChainImpl>(linear)
assertIs<AbstractMessageChain>(linear)
}
}