[core] Fix polymorphic serialization

This commit is contained in:
Karlatemp 2023-01-03 13:19:20 +08:00
parent 1acf51e308
commit 4d9f6e88f9
No known key found for this signature in database
GPG Key ID: BA173CA2B9956C59
12 changed files with 75 additions and 25 deletions

View File

@ -84,7 +84,7 @@ public sealed interface OfflineAnnouncement : Announcement {
}
internal object Serializer : KSerializer<OfflineAnnouncement> by OfflineAnnouncementImpl.serializer().map(
resultantDescriptor = OfflineAnnouncementImpl.serializer().descriptor.copy(SERIAL_NAME),
resultantDescriptor = OfflineAnnouncementImpl.serializer().descriptor,
deserialize = { it },
serialize = { it.safeCast<OfflineAnnouncementImpl>() ?: create(content, parameters).cast() }
)

View File

@ -38,6 +38,12 @@ public open class MessageSourceSerializerImpl(serialName: String) :
Mirai.constructMessageSource(botId, kind, fromId, targetId, ids, time, internalIds, originalMessage)
}
) {
@MiraiInternalApi
public companion object {
public fun serialDataSerializer(): KSerializer<*> = SerialData.serializer()
}
@SerialName(MessageSource.SERIAL_NAME)
@Serializable
internal class SerialData(

View File

@ -25,7 +25,6 @@ import net.mamoe.mirai.message.code.CodableMessage
import net.mamoe.mirai.message.data.visitor.MessageVisitor
import net.mamoe.mirai.utils.MiraiInternalApi
import net.mamoe.mirai.utils.NotStableForInheritance
import net.mamoe.mirai.utils.copy
import net.mamoe.mirai.utils.map
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName
@ -111,9 +110,9 @@ public expect interface FileMessage : MessageContent, ConstrainSingle, CodableMe
}
@MiraiInternalApi
internal open class FallbackFileMessageSerializer constructor(serialName: String) :
internal open class FallbackFileMessageSerializer :
KSerializer<FileMessage> by Delegate.serializer().map(
Delegate.serializer().descriptor.copy(serialName),
Delegate.serializer().descriptor,
serialize = { Delegate(id, internalId, name, size) },
deserialize = { Mirai.createFileMessage(id, internalId, name, size) },
) {

View File

@ -102,7 +102,7 @@ public enum class RockPaperScissors(
}
internal object Serializer : KSerializer<RockPaperScissors> by Surrogate.serializer().map(
resultantDescriptor = Surrogate.serializer().descriptor.copy(SERIAL_NAME),
resultantDescriptor = Surrogate.serializer().descriptor,
deserialize = { valueOf(it.name) },
serialize = { Surrogate(name) },
) {

View File

@ -62,7 +62,7 @@ public interface UnsupportedMessage : MessageContent {
}
public object Serializer : KSerializer<UnsupportedMessage> by Surrogate.serializer().map(
resultantDescriptor = Surrogate.serializer().descriptor.copy(SERIAL_NAME),
resultantDescriptor = Surrogate.serializer().descriptor,
deserialize = { Mirai.createUnsupportedMessage(struct.hexToBytes()) },
serialize = { Surrogate(struct.toUHexString("")) }
) {

View File

@ -11,6 +11,7 @@ package net.mamoe.mirai.utils
import io.ktor.utils.io.core.*
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient
import kotlinx.serialization.builtins.serializer
@ -292,7 +293,7 @@ internal object DeviceInfoManager {
)
private object DeviceInfoVersionSerializer : KSerializer<DeviceInfo.Version> by SerialData.serializer().map(
resultantDescriptor = SerialData.serializer().descriptor.copy("Version"),
resultantDescriptor = SerialData.serializer().descriptor,
deserialize = {
DeviceInfo.Version(incremental, release, codename, sdk)
},
@ -300,6 +301,7 @@ internal object DeviceInfoManager {
SerialData(incremental, release, codename, sdk)
}
) {
@SerialName("Version")
@Serializable
private class SerialData(
val incremental: ByteArray = "5891938".toByteArray(),

View File

@ -120,5 +120,5 @@ public actual interface FileMessage : MessageContent, ConstrainSingle, CodableMe
}
public actual object Serializer :
KSerializer<FileMessage> by FallbackFileMessageSerializer(SERIAL_NAME) // not polymorphic
KSerializer<FileMessage> by FallbackFileMessageSerializer()
}

View File

@ -104,5 +104,5 @@ public actual interface FileMessage : MessageContent, ConstrainSingle, CodableMe
}
public actual object Serializer :
KSerializer<FileMessage> by FallbackFileMessageSerializer(SERIAL_NAME) // not polymorphic
KSerializer<FileMessage> by FallbackFileMessageSerializer()
}

View File

@ -163,7 +163,7 @@ internal class OnlineAudioImpl(
}
object Serializer : KSerializer<OnlineAudioImpl> by Surrogate.serializer().map(
resultantDescriptor = Surrogate.serializer().descriptor.copy(OnlineAudio.SERIAL_NAME),
resultantDescriptor = Surrogate.serializer().descriptor,
deserialize = {
OnlineAudioImpl(
filename = filename,
@ -251,7 +251,7 @@ internal class OfflineAudioImpl(
}
object Serializer : KSerializer<OfflineAudioImpl> by Surrogate.serializer().map(
resultantDescriptor = Surrogate.serializer().descriptor.copy(OfflineAudio.SERIAL_NAME),
resultantDescriptor = Surrogate.serializer().descriptor,
deserialize = {
OfflineAudioImpl(
filename = filename,

View File

@ -10,6 +10,7 @@
package net.mamoe.mirai.internal.message.protocol.impl
import net.mamoe.mirai.contact.AnonymousMember
import net.mamoe.mirai.internal.message.MessageSourceSerializerImpl
import net.mamoe.mirai.internal.message.protocol.MessageProtocol
import net.mamoe.mirai.internal.message.protocol.ProcessorCollector
import net.mamoe.mirai.internal.message.protocol.decode.MessageDecoder
@ -24,7 +25,6 @@ import net.mamoe.mirai.internal.message.protocol.serialization.MessageSerializer
import net.mamoe.mirai.internal.message.source.*
import net.mamoe.mirai.internal.network.protocol.data.proto.ImMsgBody
import net.mamoe.mirai.message.data.*
import net.mamoe.mirai.utils.copy
import net.mamoe.mirai.utils.map
internal class QuoteReplyProtocol : MessageProtocol(PRIORITY_METADATA) {
@ -100,7 +100,7 @@ internal class QuoteReplyProtocol : MessageProtocol(PRIORITY_METADATA) {
MessageSerializer(
MessageSource::class,
OfflineMessageSourceImplData.serializer().map(
OfflineMessageSourceImplData.serializer().descriptor.copy(MessageSource.SERIAL_NAME),
MessageSourceSerializerImpl.serialDataSerializer().descriptor,
{ it },
{
OfflineMessageSourceImplData(

View File

@ -564,6 +564,11 @@ internal class MarketFaceProtocolTest : AbstractMessageProtocolTest() {
override val message: Dice
) : PolymorphicWrapper
@Serializable
data class StaticWrapperRockPaperScissors(
override val message: RockPaperScissors
) : PolymorphicWrapper
private fun <M : MarketFace> testPolymorphicInMarketFace(
data: M,
expectedSerialName: String,
@ -591,6 +596,19 @@ internal class MarketFaceProtocolTest : AbstractMessageProtocolTest() {
)
})
private fun testStaticRockPaperScissors(
data: RockPaperScissors,
expectedInstance: RockPaperScissors = data,
) = listOf(dynamicTest("testStaticRockPaperScissors") {
testPolymorphicIn(
polySerializer = StaticWrapperRockPaperScissors.serializer(),
polyConstructor = ::StaticWrapperRockPaperScissors,
data = data,
expectedSerialName = null,
expectedInstance = expectedInstance,
)
})
@TestFactory
fun `test serialization for MarketFaceImpl`(): DynamicTestsResult {
val data = MarketFaceImpl(
@ -617,6 +635,22 @@ internal class MarketFaceProtocolTest : AbstractMessageProtocolTest() {
)
}
@TestFactory
fun `test serialization for RockPaperScissors`(): DynamicTestsResult {
val data = RockPaperScissors.PAPER
val serialName = RockPaperScissors.SERIAL_NAME
return runDynamicTests(
testPolymorphicInMarketFace(data, serialName),
testPolymorphicInMessageContent(data, serialName),
testPolymorphicInSingleMessage(data, serialName),
testInsideMessageChain(data, serialName),
testContextual(data, serialName),
testContextual(data, serialName, targetType = MarketFace::class),
testStaticRockPaperScissors(data),
)
}
@TestFactory
fun `test serialization for Dice`(): DynamicTestsResult {
val data = Dice(1)

View File

@ -9,8 +9,8 @@
package net.mamoe.mirai.internal.message.protocol.impl
import kotlinx.serialization.Polymorphic
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import net.mamoe.mirai.internal.message.protocol.MessageProtocol
import net.mamoe.mirai.internal.message.source.OfflineMessageSourceImplData
import net.mamoe.mirai.internal.message.toMessageChainOnline
@ -19,7 +19,6 @@ import net.mamoe.mirai.internal.testFramework.TestFactory
import net.mamoe.mirai.internal.testFramework.dynamicTest
import net.mamoe.mirai.internal.testFramework.runDynamicTests
import net.mamoe.mirai.internal.utils.runCoroutineInPlace
import net.mamoe.mirai.message.MessageSerializers
import net.mamoe.mirai.message.data.*
import net.mamoe.mirai.message.data.MessageSource.Key.quote
import net.mamoe.mirai.utils.EMPTY_BYTE_ARRAY
@ -488,18 +487,13 @@ internal class QuoteReplyProtocolTest : AbstractMessageProtocolTest() {
///////////////////////////////////////////////////////////////////////////
// TODO: 2022/7/20 MessageSource 在 MessageMetadata 的 scope 多态序列化后会输出 'type' = 'MessageSource', 这是期望的行为.
// 但是在反序列化时会错误 unknown field 'type'
override val format: Json
get() = Json {
prettyPrint = true
serializersModule = MessageSerializers.serializersModule
ignoreUnknownKeys = true
}
@Serializable
data class PolymorphicWrapperMessageSource(
override val message: @Polymorphic MessageSource
) : PolymorphicWrapper
@Serializable
data class StaticWrapperMessageSource(
override val message: MessageSource
) : PolymorphicWrapper
@ -516,6 +510,19 @@ internal class QuoteReplyProtocolTest : AbstractMessageProtocolTest() {
)
})
private fun <M : MessageSource> testStaticInMessageSource(
data: M,
expectedInstance: M = data,
) = listOf(dynamicTest("testStaticInMessageSource") {
testPolymorphicIn(
polySerializer = StaticWrapperMessageSource.serializer(),
polyConstructor = ::StaticWrapperMessageSource,
data = data,
expectedInstance = expectedInstance,
expectedSerialName = null,
)
})
@TestFactory
fun `test serialization for OfflineMessageSource`(): DynamicTestsResult {
val data = MessageSourceBuilder()
@ -533,6 +540,7 @@ internal class QuoteReplyProtocolTest : AbstractMessageProtocolTest() {
testPolymorphicInSingleMessage(data, serialName),
testInsideMessageChain(data, serialName),
testContextual(data, serialName),
testStaticInMessageSource(data),
)
}
@ -548,6 +556,7 @@ internal class QuoteReplyProtocolTest : AbstractMessageProtocolTest() {
testPolymorphicInSingleMessage(data, serialName, expectedInstance = expected),
testInsideMessageChain(data, serialName, expectedInstance = expected),
testContextual(data, serialName, expectedInstance = expected),
testStaticInMessageSource(data, expectedInstance = expected),
)
}
}