diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/Jce.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/Jce.kt index 490ff1766..f709304d3 100644 --- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/Jce.kt +++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/Jce.kt @@ -8,8 +8,6 @@ import kotlinx.serialization.modules.EmptyModule import kotlinx.serialization.modules.SerialModule import net.mamoe.mirai.qqandroid.io.JceStruct import net.mamoe.mirai.qqandroid.io.ProtoBuf -import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestDataVersion3 -import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestPacket import net.mamoe.mirai.qqandroid.network.protocol.packet.withUse import net.mamoe.mirai.utils.io.readIoBuffer import net.mamoe.mirai.utils.io.readString @@ -27,22 +25,11 @@ enum class JceCharset(val kotlinCharset: Charset) { UTF8(Charset.forName("UTF8")) } -private val JCE_STRUCT_HEAD_OF_TAG_0 = byteArrayOf(0x0A) -private val JCE_STRUCT_TAIL_OF_TAG_0 = byteArrayOf(0x0B) - -/** - * 构造 [RequestPacket] 的 [RequestPacket.sBuffer] - */ -fun jceRequestSBuffer(name: String, serializer: SerializationStrategy, jceStruct: T): ByteArray { - return RequestDataVersion3( - mapOf( - name to JCE_STRUCT_HEAD_OF_TAG_0 + jceStruct.toByteArray(serializer) + JCE_STRUCT_TAIL_OF_TAG_0 - ) - ).toByteArray(RequestDataVersion3.serializer()) -} - internal fun getSerialId(desc: SerialDescriptor, index: Int): Int? = desc.findAnnotation(index)?.id +/** + * Jce 数据结构序列化和反序列化工具, 能将 kotlinx.serialization 通用的注解标记格式的 `class` 序列化为 [ByteArray] + */ class Jce private constructor(private val charset: JceCharset, context: SerialModule = EmptyModule) : AbstractSerialFormat(context), BinaryFormat { private inner class ListWriter( diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/ProtoBufWithNullableSupport.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/ProtoBufWithNullableSupport.kt index 1c06ac7f0..76167692e 100644 --- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/ProtoBufWithNullableSupport.kt +++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/ProtoBufWithNullableSupport.kt @@ -6,16 +6,12 @@ package net.mamoe.mirai.qqandroid.io.serialization import kotlinx.io.* import kotlinx.serialization.* -import kotlinx.serialization.CompositeDecoder.Companion.READ_DONE import kotlinx.serialization.internal.* import kotlinx.serialization.modules.EmptyModule import kotlinx.serialization.modules.SerialModule +import kotlinx.serialization.protobuf.ProtoBuf import kotlinx.serialization.protobuf.ProtoNumberType import kotlinx.serialization.protobuf.ProtoType -import kotlinx.serialization.protobuf.ProtobufDecodingException -import net.mamoe.mirai.qqandroid.io.serialization.ProtoBufWithNullableSupport.Varint.decodeSignedVarintInt -import net.mamoe.mirai.qqandroid.io.serialization.ProtoBufWithNullableSupport.Varint.decodeSignedVarintLong -import net.mamoe.mirai.qqandroid.io.serialization.ProtoBufWithNullableSupport.Varint.decodeVarint import net.mamoe.mirai.qqandroid.io.serialization.ProtoBufWithNullableSupport.Varint.encodeVarint internal typealias ProtoDesc = Pair @@ -28,6 +24,12 @@ internal fun extractParameters(desc: SerialDescriptor, index: Int, zeroBasedDefa } +/** + * 带有 null (optional) support 的 Protocol buffers 序列化器. + * 所有的为 null 的属性都将不会被序列化. 以此实现可选属性. + * + * 代码复制自 kotlinx.serialization. 修改部分已进行标注 (详见 "MIRAI MODIFY START") + */ class ProtoBufWithNullableSupport(context: SerialModule = EmptyModule) : AbstractSerialFormat(context), BinaryFormat { internal open inner class ProtobufWriter(val encoder: ProtobufEncoder) : TaggedEncoder() { @@ -171,178 +173,6 @@ class ProtoBufWithNullableSupport(context: SerialModule = EmptyModule) : Abstrac } } - private open inner class ProtobufReader(val decoder: ProtobufDecoder) : TaggedDecoder() { - override val context: SerialModule - get() = this@ProtoBufWithNullableSupport.context - - private val indexByTag: MutableMap = mutableMapOf() - private fun findIndexByTag(desc: SerialDescriptor, serialId: Int, zeroBasedDefault: Boolean = false): Int = - (0 until desc.elementsCount).firstOrNull { - extractParameters( - desc, - it, - zeroBasedDefault - ).first == serialId - } ?: -1 - - override fun beginStructure(desc: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder = when (desc.kind) { - StructureKind.LIST -> RepeatedReader(decoder, currentTag) - StructureKind.CLASS, UnionKind.OBJECT, is PolymorphicKind -> - ProtobufReader(makeDelimited(decoder, currentTagOrNull)) - StructureKind.MAP -> MapEntryReader(makeDelimited(decoder, currentTagOrNull), currentTagOrNull) - else -> throw SerializationException("Primitives are not supported at top-level") - } - - override fun decodeTaggedBoolean(tag: ProtoDesc): Boolean = when (val i = decoder.nextInt(ProtoNumberType.DEFAULT)) { - 0 -> false - 1 -> true - else -> throw ProtobufDecodingException("Expected boolean value (0 or 1), found $i") - } - - override fun decodeTaggedByte(tag: ProtoDesc): Byte = decoder.nextInt(tag.second).toByte() - override fun decodeTaggedShort(tag: ProtoDesc): Short = decoder.nextInt(tag.second).toShort() - override fun decodeTaggedInt(tag: ProtoDesc): Int = decoder.nextInt(tag.second) - override fun decodeTaggedLong(tag: ProtoDesc): Long = decoder.nextLong(tag.second) - override fun decodeTaggedFloat(tag: ProtoDesc): Float = decoder.nextFloat() - override fun decodeTaggedDouble(tag: ProtoDesc): Double = decoder.nextDouble() - override fun decodeTaggedChar(tag: ProtoDesc): Char = decoder.nextInt(tag.second).toChar() - override fun decodeTaggedString(tag: ProtoDesc): String = decoder.nextString() - override fun decodeTaggedEnum(tag: ProtoDesc, enumDescription: SerialDescriptor): Int = - findIndexByTag(enumDescription, decoder.nextInt(ProtoNumberType.DEFAULT), zeroBasedDefault = true) - - @Suppress("UNCHECKED_CAST") - override fun decodeSerializableValue(deserializer: DeserializationStrategy): T = when { - // encode maps as collection of map entries, not merged collection of key-values - deserializer.descriptor is MapLikeDescriptor -> { - val serializer = (deserializer as MapLikeSerializer) - val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer) - val setOfEntries = HashSetSerializer(mapEntrySerial).deserialize(this) - setOfEntries.associateBy({ it.key }, { it.value }) as T - } - deserializer.descriptor == ByteArraySerializer.descriptor -> decoder.nextObject() as T - else -> deserializer.deserialize(this) - } - - override fun SerialDescriptor.getTag(index: Int) = this.getProtoDesc(index) - - override fun decodeElementIndex(desc: SerialDescriptor): Int { - while (true) { - if (decoder.curId == -1) // EOF - return READ_DONE - val ind = indexByTag.getOrPut(decoder.curId) { findIndexByTag(desc, decoder.curId) } - if (ind == -1) // not found - decoder.skipElement() - else return ind - } - } - } - - private inner class RepeatedReader(decoder: ProtobufDecoder, val targetTag: ProtoDesc) : ProtobufReader(decoder) { - private var ind = -1 - - override fun decodeElementIndex(desc: SerialDescriptor) = if (decoder.curId == targetTag.first) ++ind else READ_DONE - override fun SerialDescriptor.getTag(index: Int): ProtoDesc = targetTag - } - - private inner class MapEntryReader(decoder: ProtobufDecoder, val parentTag: ProtoDesc?) : ProtobufReader(decoder) { - override fun SerialDescriptor.getTag(index: Int): ProtoDesc = - if (index % 2 == 0) 1 to (parentTag?.second ?: ProtoNumberType.DEFAULT) - else 2 to (parentTag?.second ?: ProtoNumberType.DEFAULT) - } - - internal class ProtobufDecoder(val inp: ByteArrayInputStream) { - val curId - get() = curTag.first - private var curTag: Pair = -1 to -1 - - init { - readTag() - } - - private fun readTag(): Pair { - val header = decode32(eofAllowed = true) - curTag = if (header == -1) { - -1 to -1 - } else { - val wireType = header and 0b111 - val fieldId = header ushr 3 - fieldId to wireType - } - return curTag - } - - fun skipElement() { - when (curTag.second) { - VARINT -> nextInt(ProtoNumberType.DEFAULT) - i64 -> nextLong(ProtoNumberType.FIXED) - SIZE_DELIMITED -> nextObject() - i32 -> nextInt(ProtoNumberType.FIXED) - else -> throw ProtobufDecodingException("Unsupported start group or end group wire type") - } - } - - @Suppress("NOTHING_TO_INLINE") - private inline fun assertWireType(expected: Int) { - if (curTag.second != expected) throw ProtobufDecodingException("Expected wire type $expected, but found ${curTag.second}") - } - - fun nextObject(): ByteArray { - assertWireType(SIZE_DELIMITED) - val len = decode32() - check(len >= 0) - val ans = inp.readExactNBytes(len) - readTag() - return ans - } - - fun nextInt(format: ProtoNumberType): Int { - val wireType = if (format == ProtoNumberType.FIXED) i32 else VARINT - assertWireType(wireType) - val ans = decode32(format) - readTag() - return ans - } - - fun nextLong(format: ProtoNumberType): Long { - val wireType = if (format == ProtoNumberType.FIXED) i64 else VARINT - assertWireType(wireType) - val ans = decode64(format) - readTag() - return ans - } - - fun nextFloat(): Float { - assertWireType(i32) - val ans = inp.readToByteBuffer(4).order(ByteOrder.LITTLE_ENDIAN).getFloat() - readTag() - return ans - } - - fun nextDouble(): Double { - assertWireType(i64) - val ans = inp.readToByteBuffer(8).order(ByteOrder.LITTLE_ENDIAN).getDouble() - readTag() - return ans - } - - fun nextString(): String { - val bytes = this.nextObject() - return stringFromUtf8Bytes(bytes) - } - - private fun decode32(format: ProtoNumberType = ProtoNumberType.DEFAULT, eofAllowed: Boolean = false): Int = when (format) { - ProtoNumberType.DEFAULT -> decodeVarint(inp, 64, eofAllowed).toInt() - ProtoNumberType.SIGNED -> decodeSignedVarintInt(inp) - ProtoNumberType.FIXED -> inp.readToByteBuffer(4).order(ByteOrder.LITTLE_ENDIAN).getInt() - } - - private fun decode64(format: ProtoNumberType = ProtoNumberType.DEFAULT): Long = when (format) { - ProtoNumberType.DEFAULT -> decodeVarint(inp, 64) - ProtoNumberType.SIGNED -> decodeSignedVarintLong(inp) - ProtoNumberType.FIXED -> inp.readToByteBuffer(8).order(ByteOrder.LITTLE_ENDIAN).getLong() - } - } - /** * Source for all varint operations: * https://github.com/addthis/stream-lib/blob/master/src/main/java/com/clearspring/analytics/util/Varint.java @@ -381,57 +211,10 @@ class ProtoBufWithNullableSupport(context: SerialModule = EmptyModule) : Abstrac } return out } - - internal fun decodeVarint(inp: InputStream, bitLimit: Int = 32, eofOnStartAllowed: Boolean = false): Long { - var result = 0L - var shift = 0 - var b: Int - do { - if (shift >= bitLimit) { - // Out of range - throw ProtobufDecodingException("Varint too long: exceeded $bitLimit bits") - } - // Get 7 bits from next byte - b = inp.read() - if (b == -1) { - if (eofOnStartAllowed && shift == 0) return -1 - else throw IOException("Unexpected EOF") - } - result = result or (b.toLong() and 0x7FL shl shift) - shift += 7 - } while (b and 0x80 != 0) - return result - } - - internal fun decodeSignedVarintInt(inp: InputStream): Int { - val raw = decodeVarint(inp, 32).toInt() - val temp = raw shl 31 shr 31 xor raw shr 1 - // This extra step lets us deal with the largest signed values by treating - // negative results from read unsigned methods as like unsigned values. - // Must re-flip the top bit if the original read value had it set. - return temp xor (raw and (1 shl 31)) - } - - internal fun decodeSignedVarintLong(inp: InputStream): Long { - val raw = decodeVarint(inp, 64) - val temp = raw shl 63 shr 63 xor raw shr 1 - // This extra step lets us deal with the largest signed values by treating - // negative results from read unsigned methods as like unsigned values - // Must re-flip the top bit if the original read value had it set. - return temp xor (raw and (1L shl 63)) - - } } companion object : BinaryFormat { - public override val context: SerialModule get() = plain.context - - // todo: make more memory-efficient - private fun makeDelimited(decoder: ProtobufDecoder, parentTag: ProtoDesc?): ProtobufDecoder { - if (parentTag == null) return decoder - val bytes = decoder.nextObject() - return ProtobufDecoder(ByteArrayInputStream(bytes)) - } + override val context: SerialModule get() = plain.context private fun SerialDescriptor.getProtoDesc(index: Int): ProtoDesc { return extractParameters(this, index) @@ -457,9 +240,7 @@ class ProtoBufWithNullableSupport(context: SerialModule = EmptyModule) : Abstrac } override fun load(deserializer: DeserializationStrategy, bytes: ByteArray): T { - val stream = ByteArrayInputStream(bytes) - val reader = ProtobufReader(ProtobufDecoder(stream)) - return reader.decode(deserializer) + return ProtoBuf.load(deserializer, bytes) } } diff --git a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/Bot.kt b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/Bot.kt index 27d55cf58..889564387 100644 --- a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/Bot.kt +++ b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/Bot.kt @@ -92,7 +92,7 @@ abstract class Bot : CoroutineScope { abstract val network: BotNetworkHandler /** - * 登录. + * 登录, 或重新登录. * * 最终调用 [net.mamoe.mirai.network.BotNetworkHandler.login] *