mirror of
https://github.com/mamoe/mirai.git
synced 2025-01-18 07:52:27 +08:00
Remove ProtoBufWithNullableSupport.kt, use Kotlin's ProtoBuf instead, fix potential serialization problems
This commit is contained in:
parent
9b38b9a54d
commit
76459aca34
@ -1,287 +0,0 @@
|
||||
/*
|
||||
* Copyright 2019-2020 Mamoe Technologies and contributors.
|
||||
*
|
||||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
|
||||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
|
||||
*
|
||||
* https://github.com/mamoe/mirai/blob/master/LICENSE
|
||||
*/
|
||||
|
||||
@file:Suppress("DEPRECATION_ERROR")
|
||||
|
||||
package net.mamoe.mirai.internal.utils.io.serialization
|
||||
|
||||
import kotlinx.serialization.*
|
||||
import kotlinx.serialization.builtins.ByteArraySerializer
|
||||
import kotlinx.serialization.builtins.MapEntrySerializer
|
||||
import kotlinx.serialization.builtins.SetSerializer
|
||||
import kotlinx.serialization.descriptors.PolymorphicKind
|
||||
import kotlinx.serialization.descriptors.SerialDescriptor
|
||||
import kotlinx.serialization.descriptors.StructureKind
|
||||
import kotlinx.serialization.encoding.CompositeEncoder
|
||||
import kotlinx.serialization.internal.MapLikeSerializer
|
||||
import kotlinx.serialization.internal.TaggedEncoder
|
||||
import kotlinx.serialization.modules.EmptySerializersModule
|
||||
import kotlinx.serialization.modules.SerializersModule
|
||||
import kotlinx.serialization.protobuf.ProtoBuf
|
||||
import kotlinx.serialization.protobuf.ProtoIntegerType
|
||||
import kotlinx.serialization.protobuf.ProtoType
|
||||
import net.mamoe.mirai.internal.utils.io.serialization.ProtoBufWithNullableSupport.Varint.encodeVarint
|
||||
import net.mamoe.mirai.internal.utils.io.serialization.tars.TarsId
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
internal typealias ProtoDesc = Pair<Int, ProtoIntegerType>
|
||||
|
||||
internal fun getSerialId(desc: SerialDescriptor, index: Int): Int? = desc.findAnnotation<TarsId>(index)?.id
|
||||
|
||||
internal fun extractParameters(desc: SerialDescriptor, index: Int, zeroBasedDefault: Boolean = false): ProtoDesc {
|
||||
val idx = getSerialId(desc, index) ?: (if (zeroBasedDefault) index else index + 1)
|
||||
val format = desc.findAnnotation<ProtoType>(index)?.type
|
||||
?: ProtoIntegerType.DEFAULT
|
||||
return idx to format
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 带有 null (optional) support 的 Protocol buffers 序列化器.
|
||||
* 所有的为 null 的属性都将不会被序列化. 以此实现可选属性.
|
||||
*
|
||||
* 代码复制自 kotlinx.serialization. 修改部分已进行标注 (详见 "MIRAI MODIFY START")
|
||||
*/
|
||||
@OptIn(InternalSerializationApi::class)
|
||||
internal class ProtoBufWithNullableSupport(override val serializersModule: SerializersModule = EmptySerializersModule) :
|
||||
SerialFormat, BinaryFormat {
|
||||
|
||||
internal open inner class ProtobufWriter(private val encoder: ProtobufEncoder) : TaggedEncoder<ProtoDesc>() {
|
||||
override val serializersModule
|
||||
get() = this@ProtoBufWithNullableSupport.serializersModule
|
||||
|
||||
@Suppress("OverridingDeprecatedMember")
|
||||
override fun beginStructure(
|
||||
descriptor: SerialDescriptor,
|
||||
): CompositeEncoder = when (descriptor.kind) {
|
||||
StructureKind.LIST -> RepeatedWriter(encoder, currentTag)
|
||||
StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> ObjectWriter(currentTagOrNull, encoder)
|
||||
StructureKind.MAP -> MapRepeatedWriter(currentTagOrNull, encoder)
|
||||
else -> throw SerializationException("Primitives are not supported at top-level")
|
||||
}
|
||||
|
||||
override fun encodeTaggedInt(tag: ProtoDesc, value: Int) = encoder.writeInt(value, tag.first, tag.second)
|
||||
override fun encodeTaggedByte(tag: ProtoDesc, value: Byte) = encoder.writeInt(value.toInt(), tag.first, tag.second)
|
||||
override fun encodeTaggedShort(tag: ProtoDesc, value: Short) = encoder.writeInt(value.toInt(), tag.first, tag.second)
|
||||
override fun encodeTaggedLong(tag: ProtoDesc, value: Long) = encoder.writeLong(value, tag.first, tag.second)
|
||||
override fun encodeTaggedFloat(tag: ProtoDesc, value: Float) = encoder.writeFloat(value, tag.first)
|
||||
override fun encodeTaggedDouble(tag: ProtoDesc, value: Double) = encoder.writeDouble(value, tag.first)
|
||||
override fun encodeTaggedBoolean(tag: ProtoDesc, value: Boolean) = encoder.writeInt(if (value) 1 else 0, tag.first, ProtoIntegerType.DEFAULT)
|
||||
override fun encodeTaggedChar(tag: ProtoDesc, value: Char) = encoder.writeInt(value.toInt(), tag.first, tag.second)
|
||||
override fun encodeTaggedString(tag: ProtoDesc, value: String) = encoder.writeString(value, tag.first)
|
||||
override fun encodeTaggedEnum(
|
||||
tag: ProtoDesc,
|
||||
enumDescriptor: SerialDescriptor,
|
||||
ordinal: Int
|
||||
) = encoder.writeInt(
|
||||
extractParameters(enumDescriptor, ordinal, zeroBasedDefault = true).first,
|
||||
tag.first,
|
||||
ProtoIntegerType.DEFAULT
|
||||
)
|
||||
|
||||
override fun SerialDescriptor.getTag(index: Int) = this.getProtoDesc(index)
|
||||
|
||||
// MIRAI MODIFY START
|
||||
override fun encodeTaggedNull(tag: ProtoDesc) {
|
||||
|
||||
}
|
||||
|
||||
override fun <T : Any> encodeNullableSerializableValue(serializer: SerializationStrategy<T>, value: T?) {
|
||||
if (value == null) {
|
||||
encodeTaggedNull(popTag())
|
||||
} else encodeSerializableValue(serializer, value)
|
||||
}
|
||||
// MIRAI MODIFY END
|
||||
|
||||
@Suppress("UNCHECKED_CAST", "NAME_SHADOWING")
|
||||
override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) = when {
|
||||
// encode maps as collection of map entries, not merged collection of key-values
|
||||
serializer.descriptor.kind == StructureKind.MAP -> {
|
||||
val serializer = (serializer as MapLikeSerializer<Any?, Any?, T, *>)
|
||||
val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
|
||||
SetSerializer(mapEntrySerial).serialize(this, (value as Map<*, *>).entries)
|
||||
}
|
||||
serializer.descriptor == ByteArraySerializer().descriptor -> encoder.writeBytes(
|
||||
value as ByteArray,
|
||||
popTag().first
|
||||
)
|
||||
else -> serializer.serialize(this, value)
|
||||
}
|
||||
}
|
||||
|
||||
internal open inner class ObjectWriter(
|
||||
val parentTag: ProtoDesc?, private val parentEncoder: ProtobufEncoder,
|
||||
private val stream: ByteArrayOutputStream = ByteArrayOutputStream()
|
||||
) : ProtobufWriter(
|
||||
ProtobufEncoder(
|
||||
stream
|
||||
)
|
||||
) {
|
||||
override fun endEncode(descriptor: SerialDescriptor) {
|
||||
if (parentTag != null) {
|
||||
parentEncoder.writeBytes(stream.toByteArray(), parentTag.first)
|
||||
} else {
|
||||
parentEncoder.out.write(stream.toByteArray())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal inner class MapRepeatedWriter(parentTag: ProtoDesc?, parentEncoder: ProtobufEncoder) : ObjectWriter(parentTag, parentEncoder) {
|
||||
override fun SerialDescriptor.getTag(index: Int): ProtoDesc =
|
||||
if (index % 2 == 0) 1 to (parentTag?.second ?: ProtoIntegerType.DEFAULT)
|
||||
else 2 to (parentTag?.second ?: ProtoIntegerType.DEFAULT)
|
||||
}
|
||||
|
||||
internal inner class RepeatedWriter(encoder: ProtobufEncoder, private val curTag: ProtoDesc) :
|
||||
ProtobufWriter(encoder) {
|
||||
override fun SerialDescriptor.getTag(index: Int) = curTag
|
||||
}
|
||||
|
||||
internal class ProtobufEncoder(val out: ByteArrayOutputStream) {
|
||||
|
||||
fun writeBytes(bytes: ByteArray, tag: Int) {
|
||||
val header = encode32((tag shl 3) or SIZE_DELIMITED)
|
||||
val len = encode32(bytes.size)
|
||||
out.write(header)
|
||||
out.write(len)
|
||||
out.write(bytes)
|
||||
}
|
||||
|
||||
fun writeInt(value: Int, tag: Int, format: ProtoIntegerType) {
|
||||
val wireType = if (format == ProtoIntegerType.FIXED) i32 else VARINT
|
||||
val header = encode32((tag shl 3) or wireType)
|
||||
val content = encode32(value, format)
|
||||
out.write(header)
|
||||
out.write(content)
|
||||
}
|
||||
|
||||
fun writeLong(value: Long, tag: Int, format: ProtoIntegerType) {
|
||||
val wireType = if (format == ProtoIntegerType.FIXED) i64 else VARINT
|
||||
val header = encode32((tag shl 3) or wireType)
|
||||
val content = encode64(value, format)
|
||||
out.write(header)
|
||||
out.write(content)
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalStdlibApi::class)
|
||||
fun writeString(value: String, tag: Int) {
|
||||
val bytes = value.encodeToByteArray()
|
||||
writeBytes(bytes, tag)
|
||||
}
|
||||
|
||||
fun writeDouble(value: Double, tag: Int) {
|
||||
val header = encode32((tag shl 3) or i64)
|
||||
val content = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putDouble(value).array()
|
||||
out.write(header)
|
||||
out.write(content)
|
||||
}
|
||||
|
||||
fun writeFloat(value: Float, tag: Int) {
|
||||
val header = encode32((tag shl 3) or i32)
|
||||
val content = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(value).array()
|
||||
out.write(header)
|
||||
out.write(content)
|
||||
}
|
||||
|
||||
private fun encode32(number: Int, format: ProtoIntegerType = ProtoIntegerType.DEFAULT): ByteArray =
|
||||
when (format) {
|
||||
ProtoIntegerType.FIXED -> ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(number).array()
|
||||
ProtoIntegerType.DEFAULT -> encodeVarint(number.toLong())
|
||||
ProtoIntegerType.SIGNED -> encodeVarint(((number shl 1) xor (number shr 31)))
|
||||
}
|
||||
|
||||
|
||||
private fun encode64(number: Long, format: ProtoIntegerType = ProtoIntegerType.DEFAULT): ByteArray =
|
||||
when (format) {
|
||||
ProtoIntegerType.FIXED -> ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(number).array()
|
||||
ProtoIntegerType.DEFAULT -> encodeVarint(number)
|
||||
ProtoIntegerType.SIGNED -> encodeVarint((number shl 1) xor (number shr 63))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Source for all varint operations:
|
||||
* https://github.com/addthis/stream-lib/blob/master/src/main/java/com/clearspring/analytics/util/Varint.java
|
||||
*/
|
||||
@Suppress("unused")
|
||||
internal object Varint {
|
||||
internal fun encodeVarint(inp: Int): ByteArray {
|
||||
var value = inp
|
||||
val byteArrayList = ByteArray(10)
|
||||
var i = 0
|
||||
while (value and 0xFFFFFF80.toInt() != 0) {
|
||||
byteArrayList[i++] = ((value and 0x7F) or 0x80).toByte()
|
||||
value = value ushr 7
|
||||
}
|
||||
byteArrayList[i] = (value and 0x7F).toByte()
|
||||
val out = ByteArray(i + 1)
|
||||
while (i >= 0) {
|
||||
out[i] = byteArrayList[i]
|
||||
i--
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
internal fun encodeVarint(inp: Long): ByteArray {
|
||||
var value = inp
|
||||
val byteArrayList = ByteArray(10)
|
||||
var i = 0
|
||||
while (value and 0x7FL.inv() != 0L) {
|
||||
byteArrayList[i++] = ((value and 0x7F) or 0x80).toByte()
|
||||
value = value ushr 7
|
||||
}
|
||||
byteArrayList[i] = (value and 0x7F).toByte()
|
||||
val out = ByteArray(i + 1)
|
||||
while (i >= 0) {
|
||||
out[i] = byteArrayList[i]
|
||||
i--
|
||||
}
|
||||
return out
|
||||
}
|
||||
}
|
||||
|
||||
companion object : BinaryFormat {
|
||||
override val serializersModule: SerializersModule
|
||||
get() = plain.serializersModule
|
||||
|
||||
private fun SerialDescriptor.getProtoDesc(index: Int): ProtoDesc {
|
||||
return extractParameters(this, index)
|
||||
}
|
||||
|
||||
internal const val VARINT = 0
|
||||
internal const val i64 = 1
|
||||
internal const val SIZE_DELIMITED = 2
|
||||
internal const val i32 = 5
|
||||
|
||||
private val plain = ProtoBufWithNullableSupport()
|
||||
|
||||
override fun <T> encodeToByteArray(serializer: SerializationStrategy<T>, value: T): ByteArray {
|
||||
return plain.encodeToByteArray(serializer, value)
|
||||
}
|
||||
|
||||
override fun <T> decodeFromByteArray(deserializer: DeserializationStrategy<T>, bytes: ByteArray): T {
|
||||
return plain.decodeFromByteArray(deserializer, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
override fun <T> encodeToByteArray(serializer: SerializationStrategy<T>, value: T): ByteArray {
|
||||
val encoder = ByteArrayOutputStream()
|
||||
val dumper = ProtobufWriter(ProtobufEncoder(encoder))
|
||||
dumper.encodeSerializableValue(serializer, value)
|
||||
return encoder.toByteArray()
|
||||
}
|
||||
|
||||
override fun <T> decodeFromByteArray(deserializer: DeserializationStrategy<T>, bytes: ByteArray): T {
|
||||
return ProtoBuf.decodeFromByteArray(deserializer, bytes)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -27,6 +27,8 @@ import net.mamoe.mirai.utils.readPacketExact
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
internal typealias KtProtoBuf = kotlinx.serialization.protobuf.ProtoBuf
|
||||
|
||||
internal fun <T : JceStruct> ByteArray.loadWithUniPacket(
|
||||
deserializer: DeserializationStrategy<T>,
|
||||
name: String? = null
|
||||
@ -127,14 +129,14 @@ internal fun <T : ProtoBuf> BytePacketBuilder.writeProtoBuf(serializer: Serializ
|
||||
* dump
|
||||
*/
|
||||
internal fun <T : ProtoBuf> T.toByteArray(serializer: SerializationStrategy<T>): ByteArray {
|
||||
return ProtoBufWithNullableSupport.encodeToByteArray(serializer, this)
|
||||
return KtProtoBuf.encodeToByteArray(serializer, this)
|
||||
}
|
||||
|
||||
/**
|
||||
* load
|
||||
*/
|
||||
internal fun <T : ProtoBuf> ByteArray.loadAs(deserializer: DeserializationStrategy<T>): T {
|
||||
return ProtoBufWithNullableSupport.decodeFromByteArray(deserializer, this)
|
||||
return KtProtoBuf.decodeFromByteArray(deserializer, this)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -143,7 +145,7 @@ internal fun <T : ProtoBuf> ByteArray.loadAs(deserializer: DeserializationStrate
|
||||
internal fun <T : ProtoBuf> ByteReadPacket.readProtoBuf(
|
||||
serializer: DeserializationStrategy<T>,
|
||||
length: Int = this.remaining.toInt()
|
||||
): T = ProtoBufWithNullableSupport.decodeFromByteArray(serializer, this.readBytes(length))
|
||||
): T = KtProtoBuf.decodeFromByteArray(serializer, this.readBytes(length))
|
||||
|
||||
/**
|
||||
* 构造 [RequestPacket] 的 [RequestPacket.sBuffer]
|
||||
|
Loading…
Reference in New Issue
Block a user