Remove ProtoBufWithNullableSupport.kt, use Kotlin's ProtoBuf instead, fix potential serialization problems

This commit is contained in:
Him188 2021-01-23 01:29:34 +08:00
parent 9b38b9a54d
commit 76459aca34
2 changed files with 5 additions and 290 deletions

View File

@ -1,287 +0,0 @@
/*
* Copyright 2019-2020 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/master/LICENSE
*/
@file:Suppress("DEPRECATION_ERROR")
package net.mamoe.mirai.internal.utils.io.serialization
import kotlinx.serialization.*
import kotlinx.serialization.builtins.ByteArraySerializer
import kotlinx.serialization.builtins.MapEntrySerializer
import kotlinx.serialization.builtins.SetSerializer
import kotlinx.serialization.descriptors.PolymorphicKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encoding.CompositeEncoder
import kotlinx.serialization.internal.MapLikeSerializer
import kotlinx.serialization.internal.TaggedEncoder
import kotlinx.serialization.modules.EmptySerializersModule
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.protobuf.ProtoBuf
import kotlinx.serialization.protobuf.ProtoIntegerType
import kotlinx.serialization.protobuf.ProtoType
import net.mamoe.mirai.internal.utils.io.serialization.ProtoBufWithNullableSupport.Varint.encodeVarint
import net.mamoe.mirai.internal.utils.io.serialization.tars.TarsId
import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer
import java.nio.ByteOrder
internal typealias ProtoDesc = Pair<Int, ProtoIntegerType>
internal fun getSerialId(desc: SerialDescriptor, index: Int): Int? = desc.findAnnotation<TarsId>(index)?.id
internal fun extractParameters(desc: SerialDescriptor, index: Int, zeroBasedDefault: Boolean = false): ProtoDesc {
val idx = getSerialId(desc, index) ?: (if (zeroBasedDefault) index else index + 1)
val format = desc.findAnnotation<ProtoType>(index)?.type
?: ProtoIntegerType.DEFAULT
return idx to format
}
/**
* 带有 null (optional) support Protocol buffers 序列化器.
* 所有的为 null 的属性都将不会被序列化. 以此实现可选属性.
*
* 代码复制自 kotlinx.serialization. 修改部分已进行标注 (详见 "MIRAI MODIFY START")
*/
@OptIn(InternalSerializationApi::class)
internal class ProtoBufWithNullableSupport(override val serializersModule: SerializersModule = EmptySerializersModule) :
SerialFormat, BinaryFormat {
internal open inner class ProtobufWriter(private val encoder: ProtobufEncoder) : TaggedEncoder<ProtoDesc>() {
override val serializersModule
get() = this@ProtoBufWithNullableSupport.serializersModule
@Suppress("OverridingDeprecatedMember")
override fun beginStructure(
descriptor: SerialDescriptor,
): CompositeEncoder = when (descriptor.kind) {
StructureKind.LIST -> RepeatedWriter(encoder, currentTag)
StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> ObjectWriter(currentTagOrNull, encoder)
StructureKind.MAP -> MapRepeatedWriter(currentTagOrNull, encoder)
else -> throw SerializationException("Primitives are not supported at top-level")
}
override fun encodeTaggedInt(tag: ProtoDesc, value: Int) = encoder.writeInt(value, tag.first, tag.second)
override fun encodeTaggedByte(tag: ProtoDesc, value: Byte) = encoder.writeInt(value.toInt(), tag.first, tag.second)
override fun encodeTaggedShort(tag: ProtoDesc, value: Short) = encoder.writeInt(value.toInt(), tag.first, tag.second)
override fun encodeTaggedLong(tag: ProtoDesc, value: Long) = encoder.writeLong(value, tag.first, tag.second)
override fun encodeTaggedFloat(tag: ProtoDesc, value: Float) = encoder.writeFloat(value, tag.first)
override fun encodeTaggedDouble(tag: ProtoDesc, value: Double) = encoder.writeDouble(value, tag.first)
override fun encodeTaggedBoolean(tag: ProtoDesc, value: Boolean) = encoder.writeInt(if (value) 1 else 0, tag.first, ProtoIntegerType.DEFAULT)
override fun encodeTaggedChar(tag: ProtoDesc, value: Char) = encoder.writeInt(value.toInt(), tag.first, tag.second)
override fun encodeTaggedString(tag: ProtoDesc, value: String) = encoder.writeString(value, tag.first)
override fun encodeTaggedEnum(
tag: ProtoDesc,
enumDescriptor: SerialDescriptor,
ordinal: Int
) = encoder.writeInt(
extractParameters(enumDescriptor, ordinal, zeroBasedDefault = true).first,
tag.first,
ProtoIntegerType.DEFAULT
)
override fun SerialDescriptor.getTag(index: Int) = this.getProtoDesc(index)
// MIRAI MODIFY START
override fun encodeTaggedNull(tag: ProtoDesc) {
}
override fun <T : Any> encodeNullableSerializableValue(serializer: SerializationStrategy<T>, value: T?) {
if (value == null) {
encodeTaggedNull(popTag())
} else encodeSerializableValue(serializer, value)
}
// MIRAI MODIFY END
@Suppress("UNCHECKED_CAST", "NAME_SHADOWING")
override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) = when {
// encode maps as collection of map entries, not merged collection of key-values
serializer.descriptor.kind == StructureKind.MAP -> {
val serializer = (serializer as MapLikeSerializer<Any?, Any?, T, *>)
val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
SetSerializer(mapEntrySerial).serialize(this, (value as Map<*, *>).entries)
}
serializer.descriptor == ByteArraySerializer().descriptor -> encoder.writeBytes(
value as ByteArray,
popTag().first
)
else -> serializer.serialize(this, value)
}
}
internal open inner class ObjectWriter(
val parentTag: ProtoDesc?, private val parentEncoder: ProtobufEncoder,
private val stream: ByteArrayOutputStream = ByteArrayOutputStream()
) : ProtobufWriter(
ProtobufEncoder(
stream
)
) {
override fun endEncode(descriptor: SerialDescriptor) {
if (parentTag != null) {
parentEncoder.writeBytes(stream.toByteArray(), parentTag.first)
} else {
parentEncoder.out.write(stream.toByteArray())
}
}
}
internal inner class MapRepeatedWriter(parentTag: ProtoDesc?, parentEncoder: ProtobufEncoder) : ObjectWriter(parentTag, parentEncoder) {
override fun SerialDescriptor.getTag(index: Int): ProtoDesc =
if (index % 2 == 0) 1 to (parentTag?.second ?: ProtoIntegerType.DEFAULT)
else 2 to (parentTag?.second ?: ProtoIntegerType.DEFAULT)
}
internal inner class RepeatedWriter(encoder: ProtobufEncoder, private val curTag: ProtoDesc) :
ProtobufWriter(encoder) {
override fun SerialDescriptor.getTag(index: Int) = curTag
}
internal class ProtobufEncoder(val out: ByteArrayOutputStream) {
fun writeBytes(bytes: ByteArray, tag: Int) {
val header = encode32((tag shl 3) or SIZE_DELIMITED)
val len = encode32(bytes.size)
out.write(header)
out.write(len)
out.write(bytes)
}
fun writeInt(value: Int, tag: Int, format: ProtoIntegerType) {
val wireType = if (format == ProtoIntegerType.FIXED) i32 else VARINT
val header = encode32((tag shl 3) or wireType)
val content = encode32(value, format)
out.write(header)
out.write(content)
}
fun writeLong(value: Long, tag: Int, format: ProtoIntegerType) {
val wireType = if (format == ProtoIntegerType.FIXED) i64 else VARINT
val header = encode32((tag shl 3) or wireType)
val content = encode64(value, format)
out.write(header)
out.write(content)
}
@OptIn(ExperimentalStdlibApi::class)
fun writeString(value: String, tag: Int) {
val bytes = value.encodeToByteArray()
writeBytes(bytes, tag)
}
fun writeDouble(value: Double, tag: Int) {
val header = encode32((tag shl 3) or i64)
val content = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putDouble(value).array()
out.write(header)
out.write(content)
}
fun writeFloat(value: Float, tag: Int) {
val header = encode32((tag shl 3) or i32)
val content = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(value).array()
out.write(header)
out.write(content)
}
private fun encode32(number: Int, format: ProtoIntegerType = ProtoIntegerType.DEFAULT): ByteArray =
when (format) {
ProtoIntegerType.FIXED -> ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(number).array()
ProtoIntegerType.DEFAULT -> encodeVarint(number.toLong())
ProtoIntegerType.SIGNED -> encodeVarint(((number shl 1) xor (number shr 31)))
}
private fun encode64(number: Long, format: ProtoIntegerType = ProtoIntegerType.DEFAULT): ByteArray =
when (format) {
ProtoIntegerType.FIXED -> ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(number).array()
ProtoIntegerType.DEFAULT -> encodeVarint(number)
ProtoIntegerType.SIGNED -> encodeVarint((number shl 1) xor (number shr 63))
}
}
/**
* Source for all varint operations:
* https://github.com/addthis/stream-lib/blob/master/src/main/java/com/clearspring/analytics/util/Varint.java
*/
@Suppress("unused")
internal object Varint {
internal fun encodeVarint(inp: Int): ByteArray {
var value = inp
val byteArrayList = ByteArray(10)
var i = 0
while (value and 0xFFFFFF80.toInt() != 0) {
byteArrayList[i++] = ((value and 0x7F) or 0x80).toByte()
value = value ushr 7
}
byteArrayList[i] = (value and 0x7F).toByte()
val out = ByteArray(i + 1)
while (i >= 0) {
out[i] = byteArrayList[i]
i--
}
return out
}
internal fun encodeVarint(inp: Long): ByteArray {
var value = inp
val byteArrayList = ByteArray(10)
var i = 0
while (value and 0x7FL.inv() != 0L) {
byteArrayList[i++] = ((value and 0x7F) or 0x80).toByte()
value = value ushr 7
}
byteArrayList[i] = (value and 0x7F).toByte()
val out = ByteArray(i + 1)
while (i >= 0) {
out[i] = byteArrayList[i]
i--
}
return out
}
}
companion object : BinaryFormat {
override val serializersModule: SerializersModule
get() = plain.serializersModule
private fun SerialDescriptor.getProtoDesc(index: Int): ProtoDesc {
return extractParameters(this, index)
}
internal const val VARINT = 0
internal const val i64 = 1
internal const val SIZE_DELIMITED = 2
internal const val i32 = 5
private val plain = ProtoBufWithNullableSupport()
override fun <T> encodeToByteArray(serializer: SerializationStrategy<T>, value: T): ByteArray {
return plain.encodeToByteArray(serializer, value)
}
override fun <T> decodeFromByteArray(deserializer: DeserializationStrategy<T>, bytes: ByteArray): T {
return plain.decodeFromByteArray(deserializer, bytes)
}
}
override fun <T> encodeToByteArray(serializer: SerializationStrategy<T>, value: T): ByteArray {
val encoder = ByteArrayOutputStream()
val dumper = ProtobufWriter(ProtobufEncoder(encoder))
dumper.encodeSerializableValue(serializer, value)
return encoder.toByteArray()
}
override fun <T> decodeFromByteArray(deserializer: DeserializationStrategy<T>, bytes: ByteArray): T {
return ProtoBuf.decodeFromByteArray(deserializer, bytes)
}
}

View File

@ -27,6 +27,8 @@ import net.mamoe.mirai.utils.readPacketExact
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
internal typealias KtProtoBuf = kotlinx.serialization.protobuf.ProtoBuf
internal fun <T : JceStruct> ByteArray.loadWithUniPacket( internal fun <T : JceStruct> ByteArray.loadWithUniPacket(
deserializer: DeserializationStrategy<T>, deserializer: DeserializationStrategy<T>,
name: String? = null name: String? = null
@ -127,14 +129,14 @@ internal fun <T : ProtoBuf> BytePacketBuilder.writeProtoBuf(serializer: Serializ
* dump * dump
*/ */
internal fun <T : ProtoBuf> T.toByteArray(serializer: SerializationStrategy<T>): ByteArray { internal fun <T : ProtoBuf> T.toByteArray(serializer: SerializationStrategy<T>): ByteArray {
return ProtoBufWithNullableSupport.encodeToByteArray(serializer, this) return KtProtoBuf.encodeToByteArray(serializer, this)
} }
/** /**
* load * load
*/ */
internal fun <T : ProtoBuf> ByteArray.loadAs(deserializer: DeserializationStrategy<T>): T { internal fun <T : ProtoBuf> ByteArray.loadAs(deserializer: DeserializationStrategy<T>): T {
return ProtoBufWithNullableSupport.decodeFromByteArray(deserializer, this) return KtProtoBuf.decodeFromByteArray(deserializer, this)
} }
/** /**
@ -143,7 +145,7 @@ internal fun <T : ProtoBuf> ByteArray.loadAs(deserializer: DeserializationStrate
internal fun <T : ProtoBuf> ByteReadPacket.readProtoBuf( internal fun <T : ProtoBuf> ByteReadPacket.readProtoBuf(
serializer: DeserializationStrategy<T>, serializer: DeserializationStrategy<T>,
length: Int = this.remaining.toInt() length: Int = this.remaining.toInt()
): T = ProtoBufWithNullableSupport.decodeFromByteArray(serializer, this.readBytes(length)) ): T = KtProtoBuf.decodeFromByteArray(serializer, this.readBytes(length))
/** /**
* 构造 [RequestPacket] [RequestPacket.sBuffer] * 构造 [RequestPacket] [RequestPacket.sBuffer]