From 9cce64bc4a552a8f62b2d88dccd428beab55b2f9 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Mon, 27 Jan 2020 23:00:25 +0800
Subject: [PATCH] JceStruct serialization

---
 .../qqandroid/io/serialization/JceEncoder.kt  | 75 +++++++++++--------
 .../JceEncoderTest.kt                         | 23 +++++-
 2 files changed, 62 insertions(+), 36 deletions(-)

diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/JceEncoder.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/JceEncoder.kt
index e74e34355..0de297c38 100644
--- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/JceEncoder.kt
+++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/JceEncoder.kt
@@ -11,8 +11,7 @@ import kotlinx.serialization.internal.*
 import kotlinx.serialization.modules.EmptyModule
 import kotlinx.serialization.modules.SerialModule
 import net.mamoe.mirai.qqandroid.io.JceEncodeException
-import net.mamoe.mirai.qqandroid.io.JceOutput
-import net.mamoe.mirai.utils.io.toUHexString
+import net.mamoe.mirai.qqandroid.io.JceStruct
 import kotlin.reflect.KClass
 
 enum class JceCharset(val kotlinCharset: Charset) {
@@ -62,17 +61,16 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
         defaultStringCharset: JceCharset,
         private val count: Int,
         private val tag: JceDesc,
-        private val parentEncoder: JceEncoder,
-        private val stream: ByteArrayOutputStream = ByteArrayOutputStream()
-    ) : JceEncoder(defaultStringCharset, stream) {
+        private val parentEncoder: JceEncoder
+    ) : JceEncoder(defaultStringCharset, ByteArrayOutputStream()) {
         override fun SerialDescriptor.getTag(index: Int): JceDesc {
             return JceDesc(0, getCharset(index))
         }
 
         override fun endEncode(desc: SerialDescriptor) {
-            parentEncoder.writeHead(JceOutput.LIST, this.tag.id)
+            parentEncoder.writeHead(LIST, this.tag.id)
             parentEncoder.encodeTaggedInt(JceDesc.STUB_FOR_PRIMITIVE_NUMBERS_GBK, count)
-            parentEncoder.output.write(stream.toByteArray())
+            parentEncoder.output.write(this.output.toByteArray())
         }
     }
 
@@ -83,28 +81,35 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
         private val stream: ByteArrayOutputStream = ByteArrayOutputStream()
     ) : JceEncoder(defaultStringCharset, stream) {
         override fun endEncode(desc: SerialDescriptor) {
-            parentEncoder.writeHead(JceOutput.STRUCT_BEGIN, this.tag.id)
+            parentEncoder.writeHead(STRUCT_BEGIN, this.tag.id)
             parentEncoder.output.write(stream.toByteArray())
-            parentEncoder.writeHead(JceOutput.STRUCT_END, 0)
+            parentEncoder.writeHead(STRUCT_END, 0)
         }
     }
 
     private inner class JceMapWriter(
         defaultStringCharset: JceCharset,
-        private val count: Int,
-        private val tag: JceDesc,
-        private val parentEncoder: JceEncoder
-    ) : JceEncoder(defaultStringCharset, ByteArrayOutputStream()) {
+        output: ByteArrayOutputStream
+    ) : JceEncoder(defaultStringCharset, output) {
         override fun SerialDescriptor.getTag(index: Int): JceDesc {
             return if (index % 2 == 0) JceDesc(0, getCharset(index))
             else JceDesc(1, getCharset(index))
         }
 
+        /*
         override fun endEncode(desc: SerialDescriptor) {
-            parentEncoder.writeHead(JceOutput.MAP, this.tag.id)
+            parentEncoder.writeHead(MAP, this.tag.id)
             parentEncoder.encodeTaggedInt(JceDesc.STUB_FOR_PRIMITIVE_NUMBERS_GBK, count)
             println(this.output.toByteArray().toUHexString())
             parentEncoder.output.write(this.output.toByteArray())
+        }*/
+
+        override fun beginCollection(desc: SerialDescriptor, collectionSize: Int, vararg typeParams: KSerializer<*>): CompositeEncoder {
+            return this
+        }
+
+        override fun beginStructure(desc: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeEncoder {
+            return this
         }
     }
 
@@ -136,31 +141,27 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
         override fun beginStructure(desc: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeEncoder = when (desc.kind) {
             StructureKind.LIST -> this
             StructureKind.MAP -> this
-            StructureKind.CLASS, UnionKind.OBJECT -> {
-                val currentTag = currentTagOrNull
-                if (currentTag == null) {
-                    this
-                } else {
-                    JceStructWriter(defaultStringCharset, currentTag, this)
-                }
-            }
-            is PolymorphicKind -> error("unsupported: PolymorphicKind")
+            StructureKind.CLASS, UnionKind.OBJECT -> this
+            is PolymorphicKind -> this
             else -> throw SerializationException("Primitives are not supported at top-level")
         }
 
         @Suppress("UNCHECKED_CAST", "NAME_SHADOWING")
         override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) = when (serializer.descriptor) {
-            // encode maps as collection of map entries, not merged collection of key-values
             is MapLikeDescriptor -> {
+                println("hello")
                 val entries = (value as Map<*, *>).entries
                 val serializer = (serializer as MapLikeSerializer<Any?, Any?, T, *>)
                 val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
-                HashSetSerializer(mapEntrySerial).serialize(JceMapWriter(charset, entries.size, popTag(), this), entries)
+
+                this.writeHead(MAP, currentTag.id)
+                this.encodeTaggedInt(JceDesc.STUB_FOR_PRIMITIVE_NUMBERS_GBK, entries.count())
+                HashSetSerializer(mapEntrySerial).serialize(JceMapWriter(charset, this.output), entries)
             }
             ByteArraySerializer.descriptor -> encodeTaggedByteArray(popTag(), value as ByteArray)
             is PrimitiveArrayDescriptor -> {
                 if (value is ByteArray) {
-                    this.encodeTaggedByteArray(currentTag, value)
+                    this.encodeTaggedByteArray( popTag(), value)
                 } else{
                     serializer.serialize(
                         ListWriter(charset, when(value){
@@ -171,24 +172,34 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
                             is DoubleArray -> value.size
                             is CharArray -> value.size
                             else -> error("unknown array type: ${value.getClassName()}")
-                        }, currentTag, this),
+                        },  popTag(), this),
                         value
                     )
                 }
             }
             is ArrayClassDesc-> {
                 serializer.serialize(
-                    ListWriter(charset, (value as Array<*>).size, currentTag, this),
+                    ListWriter(charset, (value as Array<*>).size,  popTag(), this),
                     value
                 )
             }
             is ListLikeDescriptor -> {
                 serializer.serialize(
-                    ListWriter(charset, (value as Collection<*>).size, currentTag, this),
+                    ListWriter(charset, (value as Collection<*>).size,  popTag(), this),
                     value
                 )
             }
-            else -> serializer.serialize(this, value)
+            else -> {
+                if (value is JceStruct) {
+                    if (currentTagOrNull == null) {
+                        serializer.serialize(this, value)
+                    } else {
+                        this.writeHead(STRUCT_BEGIN, currentTag.id)
+                        serializer.serialize(this, value)
+                        this.writeHead(STRUCT_END, 0)
+                    }
+                } else serializer.serialize(this, value)
+            }
         }
 
         override fun encodeTaggedByte(tag: JceDesc, value: Byte) {
@@ -257,8 +268,8 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
         }
 
         fun encodeTaggedByteArray(tag: JceDesc, bytes: ByteArray) {
-            writeHead(JceOutput.SIMPLE_LIST, tag.id)
-            writeHead(JceOutput.BYTE, 0)
+            writeHead(SIMPLE_LIST, tag.id)
+            writeHead(BYTE, 0)
             encodeTaggedInt(JceDesc.STUB_FOR_PRIMITIVE_NUMBERS_GBK, bytes.size)
             output.write(bytes)
         }
diff --git a/mirai-core-qqandroid/src/jvmTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceEncoderTest.kt b/mirai-core-qqandroid/src/jvmTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceEncoderTest.kt
index a4341ff25..7051a85b6 100644
--- a/mirai-core-qqandroid/src/jvmTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceEncoderTest.kt
+++ b/mirai-core-qqandroid/src/jvmTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceEncoderTest.kt
@@ -3,6 +3,9 @@ package net.mamoe.mirai.qqandroid.io.serialization
 import kotlinx.io.core.readBytes
 import kotlinx.serialization.SerialId
 import kotlinx.serialization.Serializable
+import net.mamoe.mirai.qqandroid.io.CharsetUTF8
+import net.mamoe.mirai.qqandroid.io.JceOutput
+import net.mamoe.mirai.qqandroid.io.JceStruct
 import net.mamoe.mirai.qqandroid.io.buildJcePacket
 import net.mamoe.mirai.utils.io.toUHexString
 import kotlin.test.Test
@@ -20,7 +23,17 @@ class JceEncoderTest {
         @SerialId(4) val long: Long = 123,
         @SerialId(5) val float: Float = 123f,
         @SerialId(6) val double: Double = 123.0
-    )
+    ) : JceStruct() {
+        override fun writeTo(builder: JceOutput) = builder.run {
+            writeString("123", 0)
+            writeByte(123, 1)
+            writeShort(123, 2)
+            writeInt(123, 3)
+            writeLong(123, 4)
+            writeFloat(123f, 5)
+            writeDouble(123.0, 6)
+        }
+    }
 
     @Test
     fun testEncoder() {
@@ -44,12 +57,13 @@ class JceEncoderTest {
     @Test
     fun testEncoder2() {
         assertEquals(
-            buildJcePacket {
+            buildJcePacket(stringCharset = CharsetUTF8) {
                 writeFully(byteArrayOf(1, 2, 3), 7)
                 writeCollection(listOf(1, 2, 3), 8)
                 writeMap(mapOf("哈哈" to "嘿嘿"), 9)
+                writeJceStruct(TestSimpleJceStruct(), 10)
             }.readBytes().toUHexString(),
-            Jce.GBK.dump(
+            Jce.UTF8.dump(
                 TestComplexJceStruct.serializer(),
                 TestComplexJceStruct()
             ).toUHexString()
@@ -60,6 +74,7 @@ class JceEncoderTest {
     class TestComplexJceStruct(
         @SerialId(7) val byteArray: ByteArray = byteArrayOf(1, 2, 3),
         @SerialId(8) val byteList: List<Byte> = listOf(1, 2, 3),
-        @SerialId(9) val map: Map<String, String> = mapOf("哈哈" to "嘿嘿")
+        @SerialId(9) val map: Map<String, String> = mapOf("哈哈" to "嘿嘿"),
+        @SerialId(10) val nestedJceStruct: TestSimpleJceStruct = TestSimpleJceStruct()
     )
 }
\ No newline at end of file