From 2acca0c2e6678fe7be9ef6a374fc020503dda921 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Sun, 8 Mar 2020 16:34:38 +0800
Subject: [PATCH] Jce complex nesting support

---
 .../io/serialization/jce/JceDecoder.kt        | 54 +++++++++++--
 .../JceInputTest.kt                           | 80 +++++++++++++++++--
 2 files changed, 122 insertions(+), 12 deletions(-)

diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
index d766b36cf..3d7298d05 100644
--- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
+++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
@@ -111,6 +111,42 @@ internal class JceDecoder(
         }
     }
 
+
+    private val MapReader: MapReaderImpl = MapReaderImpl()
+
+    private inner class MapReaderImpl : AbstractDecoder() {
+        override fun decodeSequentially(): Boolean = true
+        override fun decodeElementIndex(descriptor: SerialDescriptor): Int = error("should not be reached")
+        override fun endStructure(descriptor: SerialDescriptor) {
+            this@JceDecoder.endStructure(descriptor)
+        }
+
+        private var state: Boolean = true
+        override fun beginStructure(descriptor: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder {
+            this@JceDecoder.pushTag(JceTag(if (state) 0 else 1, false))
+            state = !state
+            return this@JceDecoder.beginStructure(descriptor, *typeParams)
+        }
+
+        override fun decodeByte(): Byte = jce.useHead { jce.readJceByteValue(it) }
+        override fun decodeShort(): Short = jce.useHead { jce.readJceShortValue(it) }
+        override fun decodeInt(): Int = jce.useHead { jce.readJceIntValue(it) }
+        override fun decodeLong(): Long = jce.useHead { jce.readJceLongValue(it) }
+        override fun decodeFloat(): Float = jce.useHead { jce.readJceFloatValue(it) }
+        override fun decodeDouble(): Double = jce.useHead { jce.readJceDoubleValue(it) }
+        override fun decodeBoolean(): Boolean = jce.useHead { jce.readJceBooleanValue(it) }
+        override fun decodeChar(): Char = decodeByte().toChar()
+        override fun decodeEnum(enumDescriptor: SerialDescriptor): Int = decodeInt()
+        override fun decodeString(): String = jce.useHead { jce.readJceStringValue(it) }
+
+        override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
+            println("decodeCollectionSize in MapReader: ${descriptor.serialName}")
+            // 不读下一个 head
+            return jce.useHead { jce.readJceIntValue(it) }.also { println("listSize=$it") }
+        }
+    }
+
+
     override fun endStructure(descriptor: SerialDescriptor) {
         println("endStructure: $descriptor")
         if (currentTagOrNull?.isSimpleByteArray == true) {
@@ -123,9 +159,11 @@ internal class JceDecoder(
             while (true) {
                 val currentHead = jce.currentHeadOrNull ?: return
                 if (currentHead.type == Jce.STRUCT_END) {
+                    jce.prepareNextHead()
+                    println("current end")
                     break
                 }
-                println("skipping")
+                println("current $currentHead")
                 jce.skipField(currentHead.type)
                 jce.prepareNextHead()
             }
@@ -145,7 +183,7 @@ internal class JceDecoder(
                 println("!! MAP")
                 return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) {
                     it.checkType(Jce.MAP)
-                    ListReader
+                    MapReader
                 }
             }
             StructureKind.LIST -> {
@@ -171,9 +209,15 @@ internal class JceDecoder(
                 println("!! CLASS")
                 println("decoderTag: $currentTag")
                 println("jceHead: " + jce.currentHeadOrNull)
-                return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) {
-                    it.checkType(Jce.STRUCT_BEGIN)
-                    this@JceDecoder
+                return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) { jceHead ->
+                    jceHead.checkType(Jce.STRUCT_BEGIN)
+
+
+                    // TODO: 2020/3/8 检查是否需要 scope 化
+                    repeat(descriptor.elementsCount) {
+                        pushTag(descriptor.getTag(descriptor.elementsCount - it - 1)) // better performance
+                    }
+                    this // independent tag stack
                 }
             }
 
diff --git a/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt b/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
index 42fae90fe..c002ef2d2 100644
--- a/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
+++ b/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
@@ -42,7 +42,7 @@ internal const val ZERO_TYPE: Byte = 12
 internal class JceInputTest {
 
     @Test
-    fun testNestedJceStruct() {
+    fun testFuckingComprehensiveStruct() {
         @Serializable
         data class TestSerializableClassC(
             @JceId(5) val value3: Int = 123123
@@ -51,15 +51,77 @@ internal class JceInputTest {
         @Serializable
         data class TestSerializableClassB(
             @JceId(0) val value: Int,
-            @JceId(123) val nested2: TestSerializableClassC
+            @JceId(123) val nested2: TestSerializableClassC,
+            @JceId(5) val value5: Int
+        )
+
+        @Serializable
+        data class TestSerializableClassA(
+            @JceId(0) val map: Map<TestSerializableClassB, TestSerializableClassC>
+        )
+
+
+        val input = buildPacket {
+            writeJceHead(MAP, 0) // TestSerializableClassB
+            writeJceHead(BYTE, 0)
+            writeByte(1)
+
+            writeJceHead(STRUCT_BEGIN, 0);
+            {
+                writeJceHead(INT, 0)
+                writeInt(123)
+
+                writeJceHead(STRUCT_BEGIN, 123); // TestSerializableClassC
+                {
+                    writeJceHead(INT, 5)
+                    writeInt(123123)
+                }()
+                writeJceHead(STRUCT_END, 0)
+
+                writeJceHead(INT, 5)
+                writeInt(9)
+            }()
+            writeJceHead(STRUCT_END, 0)
+
+            writeJceHead(STRUCT_BEGIN, 1);
+            {
+                writeJceHead(INT, 5)
+                writeInt(123123)
+            }()
+            writeJceHead(STRUCT_END, 0)
+        }
+
+        assertEquals(
+            TestSerializableClassA(
+                mapOf(
+                    TestSerializableClassB(123, TestSerializableClassC(123123), 9)
+                            to TestSerializableClassC(123123)
+                )
+            ),
+            JceNew.UTF_8.load(TestSerializableClassA.serializer(), input)
+        )
+    }
+
+    @Test
+    fun testNestedJceStruct() {
+        @Serializable
+        data class TestSerializableClassC(
+            @JceId(5) val value3: Int
+        )
+
+        @Serializable
+        data class TestSerializableClassB(
+            @JceId(0) val value: Int,
+            @JceId(123) val nested2: TestSerializableClassC,
+            @JceId(5) val value5: Int
         )
 
         @Serializable
         data class TestSerializableClassA(
             @JceId(0) val value1: Int,
+            @JceId(4) val notOptional: Int,
             @JceId(1) val nestedStruct: TestSerializableClassB,
-            @JceId(2) val optional: Int = 3,
-            @JceId(4) val notOptional: Int
+            @JceId(2) val optional: Int = 3
         )
 
         val input = buildPacket {
@@ -78,8 +140,8 @@ internal class JceInputTest {
                 }()
                 writeJceHead(STRUCT_END, 0)
 
-                writeJceHead(INT, 2) // 多余
-                writeInt(123)
+                writeJceHead(INT, 5)
+                writeInt(9)
             }()
             writeJceHead(STRUCT_END, 0)
 
@@ -88,7 +150,11 @@ internal class JceInputTest {
         }
 
         assertEquals(
-            TestSerializableClassA(444, TestSerializableClassB(123, TestSerializableClassC(123123)), notOptional = 5),
+            TestSerializableClassA(
+                444,
+                5,
+                TestSerializableClassB(123, TestSerializableClassC(123123), 9)
+            ),
             JceNew.UTF_8.load(TestSerializableClassA.serializer(), input)
         )
     }