From e375d17f82467df5922951b87412d3ef2ecbbe30 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Sun, 8 Mar 2020 16:02:11 +0800
Subject: [PATCH] Support maps and nesting

---
 .../io/serialization/jce/JceDecoder.kt        | 82 +++++++++++++----
 .../qqandroid/io/serialization/jce/common.kt  |  4 +-
 .../JceInputTest.kt                           | 89 +++++++++++++++++++
 3 files changed, 155 insertions(+), 20 deletions(-)

diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
index 596e5481c..d766b36cf 100644
--- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
+++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
@@ -7,11 +7,12 @@
  * https://github.com/mamoe/mirai/blob/master/LICENSE
  */
 
+@file:Suppress("PrivatePropertyName")
+
 package net.mamoe.mirai.qqandroid.io.serialization.jce
 
 import kotlinx.serialization.*
 import kotlinx.serialization.builtins.AbstractDecoder
-import kotlinx.serialization.builtins.ByteArraySerializer
 import kotlinx.serialization.internal.TaggedDecoder
 import kotlinx.serialization.modules.SerialModule
 import net.mamoe.mirai.qqandroid.io.serialization.Jce
@@ -38,16 +39,14 @@ internal class JceDecoder(
     }
 
     private fun SerialDescriptor.getJceTagId(index: Int): Int {
-        return getElementAnnotations(index).filterIsInstance<JceId>().single().id
+        println("getTag: ${getElementName(index)}")
+        return getElementAnnotations(index).filterIsInstance<JceId>().singleOrNull()?.id
+            ?: error("missing @JceId for ${getElementName(index)} in ${this.serialName}")
     }
 
+    private val SimpleByteArrayReader: SimpleByteArrayReaderImpl = SimpleByteArrayReaderImpl()
 
-    companion object {
-        private val ByteArraySerializer: KSerializer<ByteArray> = ByteArraySerializer()
-    }
-
-    // TODO: 2020/3/6 can be object
-    private inner class SimpleByteArrayReader : AbstractDecoder() {
+    private inner class SimpleByteArrayReaderImpl : AbstractDecoder() {
         override fun decodeSequentially(): Boolean = true
 
         override fun endStructure(descriptor: SerialDescriptor) {
@@ -80,8 +79,9 @@ internal class JceDecoder(
         }
     }
 
-    // TODO: 2020/3/6 can be object
-    private inner class ListReader : AbstractDecoder() {
+    private val ListReader: ListReaderImpl = ListReaderImpl()
+
+    private inner class ListReaderImpl : AbstractDecoder() {
         override fun decodeSequentially(): Boolean = true
         override fun decodeElementIndex(descriptor: SerialDescriptor): Int = error("should not be reached")
         override fun endStructure(descriptor: SerialDescriptor) {
@@ -113,33 +113,73 @@ internal class JceDecoder(
 
     override fun endStructure(descriptor: SerialDescriptor) {
         println("endStructure: $descriptor")
-        if (descriptor == ByteArraySerializer.descriptor) {
-            jce.prepareNextHead() // list 里面没读 head
-        } else jce.prepareNextHead() // TODO ?? 测试这里
-        super.endStructure(descriptor)
+        if (currentTagOrNull?.isSimpleByteArray == true) {
+            jce.prepareNextHead() // read to next head
+        }
+        if (descriptor.kind == StructureKind.CLASS) {
+            if (currentTagOrNull == null) {
+                return
+            }
+            while (true) {
+                val currentHead = jce.currentHeadOrNull ?: return
+                if (currentHead.type == Jce.STRUCT_END) {
+                    break
+                }
+                println("skipping")
+                jce.skipField(currentHead.type)
+                jce.prepareNextHead()
+            }
+            // pushTag(JceTag(0, true))
+            // skip STRUCT_END
+            // popTag()
+        }
     }
 
     override fun beginStructure(descriptor: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder {
+        println()
         println("beginStructure: ${descriptor.serialName}")
         return when (descriptor.kind) {
+            is PrimitiveKind -> this@JceDecoder
+
             StructureKind.MAP -> {
-                error("map")
+                println("!! MAP")
+                return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) {
+                    it.checkType(Jce.MAP)
+                    ListReader
+                }
             }
             StructureKind.LIST -> {
                 println("!! ByteArray")
                 println("decoderTag: $currentTagOrNull")
                 println("jceHead: " + jce.currentHeadOrNull)
-                return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) {
+                return jce.skipToHeadAndUseIfPossibleOrFail(currentTag.id) {
                     println("listHead: $it")
                     when (it.type) {
-                        Jce.SIMPLE_LIST -> SimpleByteArrayReader().also { jce.prepareNextHead() } // 无用的元素类型
-                        Jce.LIST -> ListReader()
+                        Jce.SIMPLE_LIST -> {
+                            currentTag.isSimpleByteArray = true
+                            jce.prepareNextHead() // 无用的元素类型
+                            SimpleByteArrayReader
+                        }
+                        Jce.LIST -> ListReader
                         else -> error("type mismatch. Expected SIMPLE_LIST or LIST, got ${it.type} instead")
                     }
                 }
             }
+            StructureKind.CLASS -> {
+                val currentTag = currentTagOrNull ?: return this@JceDecoder
 
-            else -> this@JceDecoder
+                println("!! CLASS")
+                println("decoderTag: $currentTag")
+                println("jceHead: " + jce.currentHeadOrNull)
+                return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) {
+                    it.checkType(Jce.STRUCT_BEGIN)
+                    this@JceDecoder
+                }
+            }
+
+            StructureKind.OBJECT -> error("unsupported StructureKind.OBJECT: ${descriptor.serialName}")
+            is UnionKind -> error("unsupported UnionKind: ${descriptor.serialName}")
+            is PolymorphicKind -> error("unsupported PolymorphicKind: ${descriptor.serialName}")
         }
     }
 
@@ -154,6 +194,10 @@ internal class JceDecoder(
     override fun decodeSequentially(): Boolean = false
     override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
         val jceHead = jce.currentHeadOrNull ?: return CompositeDecoder.READ_DONE
+        if (jceHead.type == Jce.STRUCT_END) {
+            return CompositeDecoder.READ_DONE
+        }
+
         repeat(descriptor.elementsCount) {
             val tag = descriptor.getJceTagId(it)
             if (tag == jceHead.tag) {
diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/common.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/common.kt
index 84c7a7021..6b015794d 100644
--- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/common.kt
+++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/common.kt
@@ -30,7 +30,9 @@ annotation class JceId(val id: Int)
 internal data class JceTag(
     val id: Int,
     val isNullable: Boolean
-)
+){
+    internal var isSimpleByteArray: Boolean = false
+}
 
 fun JceHead.checkType(type: Byte) {
     check(this.type == type) {"type mismatch. Expected $type, actual ${this.type}"}
diff --git a/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt b/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
index 65d7dd5be..42fae90fe 100644
--- a/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
+++ b/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
@@ -41,6 +41,57 @@ internal const val ZERO_TYPE: Byte = 12
 @Suppress("INVISIBLE_MEMBER") // bug
 internal class JceInputTest {
 
+    @Test
+    fun testNestedJceStruct() {
+        @Serializable
+        data class TestSerializableClassC(
+            @JceId(5) val value3: Int = 123123
+        )
+
+        @Serializable
+        data class TestSerializableClassB(
+            @JceId(0) val value: Int,
+            @JceId(123) val nested2: TestSerializableClassC
+        )
+
+        @Serializable
+        data class TestSerializableClassA(
+            @JceId(0) val value1: Int,
+            @JceId(1) val nestedStruct: TestSerializableClassB,
+            @JceId(2) val optional: Int = 3,
+            @JceId(4) val notOptional: Int
+        )
+
+        val input = buildPacket {
+            writeJceHead(INT, 0)
+            writeInt(444)
+
+            writeJceHead(STRUCT_BEGIN, 1); // TestSerializableClassB
+            {
+                writeJceHead(INT, 0)
+                writeInt(123)
+
+                writeJceHead(STRUCT_BEGIN, 123); // TestSerializableClassC
+                {
+                    writeJceHead(INT, 5)
+                    writeInt(123123)
+                }()
+                writeJceHead(STRUCT_END, 0)
+
+                writeJceHead(INT, 2) // 多余
+                writeInt(123)
+            }()
+            writeJceHead(STRUCT_END, 0)
+
+            writeJceHead(INT, 4)
+            writeInt(5)
+        }
+
+        assertEquals(
+            TestSerializableClassA(444, TestSerializableClassB(123, TestSerializableClassC(123123)), notOptional = 5),
+            JceNew.UTF_8.load(TestSerializableClassA.serializer(), input)
+        )
+    }
 
     @Test
     fun testNestedList() {
@@ -80,6 +131,44 @@ internal class JceInputTest {
         assertEquals(TestSerializableClassA(), JceNew.UTF_8.load(TestSerializableClassA.serializer(), input))
     }
 
+    @Test
+    fun testMap() {
+        @Serializable
+        data class TestSerializableClassA(
+            @JceId(0) val byteArray: Map<Int, Int>
+        )
+
+        val input = buildPacket {
+            writeJceHead(MAP, 0)
+
+            mapOf(1 to 2, 33 to 44).let {
+                writeJceHead(BYTE, 0)
+                writeByte(it.size.toByte())
+
+                it.forEach { (key, value) ->
+                    writeJceHead(INT, 0)
+                    writeInt(key)
+
+                    writeJceHead(INT, 1)
+                    writeInt(value)
+                }
+            }
+
+            writeJceHead(SIMPLE_LIST, 3)
+            writeJceHead(BYTE, 0)
+
+            byteArrayOf(1, 2, 3, 4).let {
+                writeJceHead(BYTE, 0)
+                writeByte(it.size.toByte())
+                writeFully(it)
+            }
+        }
+
+        assertEquals(
+            TestSerializableClassA(mapOf(1 to 2, 33 to 44)),
+            JceNew.UTF_8.load(TestSerializableClassA.serializer(), input)
+        )
+    }
 
     @Test
     fun testSimpleByteArray() {