diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/JceEncoder.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/JceEncoder.kt new file mode 100644 index 000000000..e74e34355 --- /dev/null +++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/JceEncoder.kt @@ -0,0 +1,349 @@ +package net.mamoe.mirai.qqandroid.io.serialization + +import kotlinx.io.ByteArrayOutputStream +import kotlinx.io.ByteBuffer +import kotlinx.io.ByteOrder +import kotlinx.io.charsets.Charset +import kotlinx.io.core.ExperimentalIoApi +import kotlinx.io.core.toByteArray +import kotlinx.serialization.* +import kotlinx.serialization.internal.* +import kotlinx.serialization.modules.EmptyModule +import kotlinx.serialization.modules.SerialModule +import net.mamoe.mirai.qqandroid.io.JceEncodeException +import net.mamoe.mirai.qqandroid.io.JceOutput +import net.mamoe.mirai.utils.io.toUHexString +import kotlin.reflect.KClass + +enum class JceCharset(val kotlinCharset: Charset) { + GBK(Charset.forName("GBK")), + UTF8(Charset.forName("UTF8")) +} + +@Target(AnnotationTarget.CLASS, AnnotationTarget.FIELD) +@Retention(AnnotationRetention.RUNTIME) +annotation class SerialCharset(val charset: JceCharset) + +internal object JceType { + const val BYTE: Int = 0 + const val DOUBLE: Int = 5 + const val FLOAT: Int = 4 + const val INT: Int = 2 + const val JCE_MAX_STRING_LENGTH = 104857600 + const val LIST: Int = 9 + const val LONG: Int = 3 + const val MAP: Int = 8 + const val SHORT: Int = 1 + const val SIMPLE_LIST: Int = 13 + const val STRING1: Int = 6 + const val STRING4: Int = 7 + const val STRUCT_BEGIN: Int = 10 + const val STRUCT_END: Int = 11 + const val ZERO_TYPE: Int = 12 + + private fun Any?.getClassName(): KClass = if (this == null) Unit::class else this::class +} + + +internal fun getSerialId(desc: SerialDescriptor, index: Int): Int? = desc.findAnnotation(index)?.id + +internal data class JceDesc( + val id: Int, + val charset: JceCharset +) { + companion object { + val STUB_FOR_PRIMITIVE_NUMBERS_GBK = JceDesc(0, JceCharset.GBK) + } +} + +class Jce private constructor(private val charset: JceCharset, context: SerialModule = EmptyModule) : AbstractSerialFormat(context), BinaryFormat { + + private inner class ListWriter( + defaultStringCharset: JceCharset, + private val count: Int, + private val tag: JceDesc, + private val parentEncoder: JceEncoder, + private val stream: ByteArrayOutputStream = ByteArrayOutputStream() + ) : JceEncoder(defaultStringCharset, stream) { + override fun SerialDescriptor.getTag(index: Int): JceDesc { + return JceDesc(0, getCharset(index)) + } + + override fun endEncode(desc: SerialDescriptor) { + parentEncoder.writeHead(JceOutput.LIST, this.tag.id) + parentEncoder.encodeTaggedInt(JceDesc.STUB_FOR_PRIMITIVE_NUMBERS_GBK, count) + parentEncoder.output.write(stream.toByteArray()) + } + } + + private inner class JceStructWriter( + defaultStringCharset: JceCharset, + private val tag: JceDesc, + private val parentEncoder: JceEncoder, + private val stream: ByteArrayOutputStream = ByteArrayOutputStream() + ) : JceEncoder(defaultStringCharset, stream) { + override fun endEncode(desc: SerialDescriptor) { + parentEncoder.writeHead(JceOutput.STRUCT_BEGIN, this.tag.id) + parentEncoder.output.write(stream.toByteArray()) + parentEncoder.writeHead(JceOutput.STRUCT_END, 0) + } + } + + private inner class JceMapWriter( + defaultStringCharset: JceCharset, + private val count: Int, + private val tag: JceDesc, + private val parentEncoder: JceEncoder + ) : JceEncoder(defaultStringCharset, ByteArrayOutputStream()) { + override fun SerialDescriptor.getTag(index: Int): JceDesc { + return if (index % 2 == 0) JceDesc(0, getCharset(index)) + else JceDesc(1, getCharset(index)) + } + + override fun endEncode(desc: SerialDescriptor) { + parentEncoder.writeHead(JceOutput.MAP, this.tag.id) + parentEncoder.encodeTaggedInt(JceDesc.STUB_FOR_PRIMITIVE_NUMBERS_GBK, count) + println(this.output.toByteArray().toUHexString()) + parentEncoder.output.write(this.output.toByteArray()) + } + } + + /** + * From: com.qq.taf.jce.JceOutputStream + */ + @Suppress("unused", "MemberVisibilityCanBePrivate") + @UseExperimental(ExperimentalIoApi::class) + private open inner class JceEncoder( + /** + * 标注在 class 上的 charset + */ + private val defaultStringCharset: JceCharset, + internal val output: ByteArrayOutputStream + ) : TaggedEncoder() { + override val context get() = this@Jce.context + + protected fun SerialDescriptor.getCharset(index: Int): JceCharset { + return findAnnotation(index)?.charset ?: defaultStringCharset + } + + override fun SerialDescriptor.getTag(index: Int): JceDesc { + return JceDesc(getSerialId(this, index) ?: error("cannot find @SerialId"), getCharset(index)) + } + + /** + * 序列化最开始的时候的 + */ + override fun beginStructure(desc: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeEncoder = when (desc.kind) { + StructureKind.LIST -> this + StructureKind.MAP -> this + StructureKind.CLASS, UnionKind.OBJECT -> { + val currentTag = currentTagOrNull + if (currentTag == null) { + this + } else { + JceStructWriter(defaultStringCharset, currentTag, this) + } + } + is PolymorphicKind -> error("unsupported: PolymorphicKind") + else -> throw SerializationException("Primitives are not supported at top-level") + } + + @Suppress("UNCHECKED_CAST", "NAME_SHADOWING") + override fun encodeSerializableValue(serializer: SerializationStrategy, value: T) = when (serializer.descriptor) { + // encode maps as collection of map entries, not merged collection of key-values + is MapLikeDescriptor -> { + val entries = (value as Map<*, *>).entries + val serializer = (serializer as MapLikeSerializer) + val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer) + HashSetSerializer(mapEntrySerial).serialize(JceMapWriter(charset, entries.size, popTag(), this), entries) + } + ByteArraySerializer.descriptor -> encodeTaggedByteArray(popTag(), value as ByteArray) + is PrimitiveArrayDescriptor -> { + if (value is ByteArray) { + this.encodeTaggedByteArray(currentTag, value) + } else{ + serializer.serialize( + ListWriter(charset, when(value){ + is ShortArray -> value.size + is IntArray -> value.size + is LongArray -> value.size + is FloatArray -> value.size + is DoubleArray -> value.size + is CharArray -> value.size + else -> error("unknown array type: ${value.getClassName()}") + }, currentTag, this), + value + ) + } + } + is ArrayClassDesc-> { + serializer.serialize( + ListWriter(charset, (value as Array<*>).size, currentTag, this), + value + ) + } + is ListLikeDescriptor -> { + serializer.serialize( + ListWriter(charset, (value as Collection<*>).size, currentTag, this), + value + ) + } + else -> serializer.serialize(this, value) + } + + override fun encodeTaggedByte(tag: JceDesc, value: Byte) { + if (value.toInt() == 0) { + writeHead(ZERO_TYPE, tag.id) + } else { + writeHead(BYTE, tag.id) + output.write(value.toInt()) + } + } + + override fun encodeTaggedShort(tag: JceDesc, value: Short) { + if (value in Byte.MIN_VALUE..Byte.MAX_VALUE) { + encodeTaggedByte(tag, value.toByte()) + } else { + writeHead(SHORT, tag.id) + output.write(ByteBuffer.allocate(2).order(ByteOrder.LITTLE_ENDIAN).putShort(value).array()) + } + } + + override fun encodeTaggedInt(tag: JceDesc, value: Int) { + if (value in Short.MIN_VALUE..Short.MAX_VALUE) { + encodeTaggedShort(tag, value.toShort()) + } else { + writeHead(INT, tag.id) + output.write(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(value).array()) + } + } + + override fun encodeTaggedFloat(tag: JceDesc, value: Float) { + writeHead(FLOAT, tag.id) + output.write(ByteBuffer.allocate(4).order(ByteOrder.BIG_ENDIAN).putFloat(value).array()) + } + + override fun encodeTaggedDouble(tag: JceDesc, value: Double) { + writeHead(DOUBLE, tag.id) + output.write(ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN).putDouble(value).array()) + } + + override fun encodeTaggedLong(tag: JceDesc, value: Long) { + if (value in Int.MIN_VALUE..Int.MAX_VALUE) { + encodeTaggedInt(tag, value.toInt()) + } else { + writeHead(LONG, tag.id) + output.write(ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(value).array()) + } + } + + override fun encodeTaggedBoolean(tag: JceDesc, value: Boolean) { + encodeTaggedByte(tag, if (value) 1 else 0) + } + + override fun encodeTaggedChar(tag: JceDesc, value: Char) { + encodeTaggedByte(tag, value.toByte()) + } + + override fun encodeTaggedEnum(tag: JceDesc, enumDescription: SerialDescriptor, ordinal: Int) { + TODO() + } + + override fun encodeTaggedNull(tag: JceDesc) { + } + + override fun encodeTaggedUnit(tag: JceDesc) { + encodeTaggedNull(tag) + } + + fun encodeTaggedByteArray(tag: JceDesc, bytes: ByteArray) { + writeHead(JceOutput.SIMPLE_LIST, tag.id) + writeHead(JceOutput.BYTE, 0) + encodeTaggedInt(JceDesc.STUB_FOR_PRIMITIVE_NUMBERS_GBK, bytes.size) + output.write(bytes) + } + + override fun encodeTaggedString(tag: JceDesc, value: String) { + val array = value.toByteArray(defaultStringCharset.kotlinCharset) + if (array.size > 255) { + writeHead(STRING4, tag.id) + output.write(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(array.size).array()) + output.write(array) + } else { + writeHead(STRING1, tag.id) + output.write(ByteBuffer.allocate(1).order(ByteOrder.LITTLE_ENDIAN).put(array.size.toByte()).array()) + output.write(array) + } + } + + override fun encodeTaggedValue(tag: JceDesc, value: Any) { + when (value) { + is Byte -> encodeTaggedByte(tag, value) + is Short -> encodeTaggedShort(tag, value) + is Int -> encodeTaggedInt(tag, value) + is Long -> encodeTaggedLong(tag, value) + is Float -> encodeTaggedFloat(tag, value) + is Double -> encodeTaggedDouble(tag, value) + is Boolean -> encodeTaggedBoolean(tag, value) + is String -> encodeTaggedString(tag, value) + is Unit -> encodeTaggedUnit(tag) + else -> error("unsupported type: ${value.getClassName()}") + } + } + + @PublishedApi + internal fun writeHead(type: Int, tag: Int) { + if (tag < 15) { + this.output.write((tag shl 4) or type) + return + } + if (tag < 256) { + this.output.write(type or 0xF0) + this.output.write(tag) + return + } + throw JceEncodeException("tag is too large: $tag") + } + } + + companion object { + val UTF8 = Jce(JceCharset.UTF8) + val GBK = Jce(JceCharset.GBK) + + internal const val BYTE: Int = 0 + internal const val DOUBLE: Int = 5 + internal const val FLOAT: Int = 4 + internal const val INT: Int = 2 + internal const val JCE_MAX_STRING_LENGTH = 104857600 + internal const val LIST: Int = 9 + internal const val LONG: Int = 3 + internal const val MAP: Int = 8 + internal const val SHORT: Int = 1 + internal const val SIMPLE_LIST: Int = 13 + internal const val STRING1: Int = 6 + internal const val STRING4: Int = 7 + internal const val STRUCT_BEGIN: Int = 10 + internal const val STRUCT_END: Int = 11 + internal const val ZERO_TYPE: Int = 12 + + private fun Any?.getClassName(): KClass = if (this == null) Unit::class else this::class + + internal const val VARINT = 0 + internal const val i64 = 1 + internal const val SIZE_DELIMITED = 2 + internal const val i32 = 5 + } + + override fun dump(serializer: SerializationStrategy, obj: T): ByteArray { + val encoder = ByteArrayOutputStream() + + val dumper = JceEncoder(charset, encoder) + dumper.encode(serializer, obj) + return encoder.toByteArray() + } + + override fun load(deserializer: DeserializationStrategy, bytes: ByteArray): T { + TODO() + } + +} diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/SerializationHelper.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/SerializationHelper.kt new file mode 100644 index 000000000..e75bde933 --- /dev/null +++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/SerializationHelper.kt @@ -0,0 +1,16 @@ +package net.mamoe.mirai.qqandroid.io.serialization + +import kotlinx.serialization.SerialDescriptor + +/* + * Helper for kotlinx.serialization + */ + +internal inline fun SerialDescriptor.findAnnotation(elementIndex: Int): A? { + val candidates = getElementAnnotations(elementIndex).filterIsInstance() + return when (candidates.size) { + 0 -> null + 1 -> candidates[0] + else -> throw IllegalStateException("There are duplicate annotations of type ${A::class} in the descriptor $this") + } +} diff --git a/mirai-core-qqandroid/src/jvmTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceEncoderTest.kt b/mirai-core-qqandroid/src/jvmTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceEncoderTest.kt new file mode 100644 index 000000000..a4341ff25 --- /dev/null +++ b/mirai-core-qqandroid/src/jvmTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceEncoderTest.kt @@ -0,0 +1,65 @@ +package net.mamoe.mirai.qqandroid.io.serialization + +import kotlinx.io.core.readBytes +import kotlinx.serialization.SerialId +import kotlinx.serialization.Serializable +import net.mamoe.mirai.qqandroid.io.buildJcePacket +import net.mamoe.mirai.utils.io.toUHexString +import kotlin.test.Test +import kotlin.test.assertEquals + + +class JceEncoderTest { + + @Serializable + class TestSimpleJceStruct( + @SerialId(0) val string: String = "123", + @SerialId(1) val byte: Byte = 123, + @SerialId(2) val short: Short = 123, + @SerialId(3) val int: Int = 123, + @SerialId(4) val long: Long = 123, + @SerialId(5) val float: Float = 123f, + @SerialId(6) val double: Double = 123.0 + ) + + @Test + fun testEncoder() { + assertEquals( + buildJcePacket { + writeString("123", 0) + writeByte(123, 1) + writeShort(123, 2) + writeInt(123, 3) + writeLong(123, 4) + writeFloat(123f, 5) + writeDouble(123.0, 6) + }.readBytes().toUHexString(), + Jce.GBK.dump( + TestSimpleJceStruct.serializer(), + TestSimpleJceStruct() + ).toUHexString() + ) + } + + @Test + fun testEncoder2() { + assertEquals( + buildJcePacket { + writeFully(byteArrayOf(1, 2, 3), 7) + writeCollection(listOf(1, 2, 3), 8) + writeMap(mapOf("哈哈" to "嘿嘿"), 9) + }.readBytes().toUHexString(), + Jce.GBK.dump( + TestComplexJceStruct.serializer(), + TestComplexJceStruct() + ).toUHexString() + ) + } + + @Serializable + class TestComplexJceStruct( + @SerialId(7) val byteArray: ByteArray = byteArrayOf(1, 2, 3), + @SerialId(8) val byteList: List = listOf(1, 2, 3), + @SerialId(9) val map: Map = mapOf("哈哈" to "嘿嘿") + ) +} \ No newline at end of file