From c69cb6f3de63b36ae8952ae93b3afad934005255 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Sun, 8 Mar 2020 17:44:34 +0800
Subject: [PATCH] Fix Map entry tags, remove debugging logs

---
 .../io/serialization/jce/JceDecoder.kt        | 80 +++++++------------
 .../io/serialization/jce/JceInput.kt          |  5 +-
 .../qqandroid/io/serialization/jce/common.kt  | 34 +++-----
 .../JceInputTest.kt                           | 18 +++--
 4 files changed, 60 insertions(+), 77 deletions(-)

diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
index 3d7298d05..baaec2c35 100644
--- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
+++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceDecoder.kt
@@ -15,7 +15,6 @@ import kotlinx.serialization.*
 import kotlinx.serialization.builtins.AbstractDecoder
 import kotlinx.serialization.internal.TaggedDecoder
 import kotlinx.serialization.modules.SerialModule
-import net.mamoe.mirai.qqandroid.io.serialization.Jce
 
 
 @OptIn(InternalSerializationApi::class) // 将来 kotlinx 修改后再复制过来 mirai.
@@ -30,16 +29,13 @@ internal class JceDecoder(
 
         val id = annotations.filterIsInstance<JceId>().single().id
         // ?: error("cannot find @JceId or @ProtoId for ${this.getElementName(index)} in ${this.serialName}")
-        println("getTag: ${this.getElementName(index)}=$id")
+        //println("getTag: ${this.getElementName(index)}=$id")
 
-        return JceTag(
-            id,
-            this.getElementDescriptor(index).isNullable
-        )
+        return JceTagCommon(id)
     }
 
     private fun SerialDescriptor.getJceTagId(index: Int): Int {
-        println("getTag: ${getElementName(index)}")
+        //println("getTag: ${getElementName(index)}")
         return getElementAnnotations(index).filterIsInstance<JceId>().singleOrNull()?.id
             ?: error("missing @JceId for ${getElementName(index)} in ${this.serialName}")
     }
@@ -54,11 +50,11 @@ internal class JceDecoder(
         }
 
         override fun beginStructure(descriptor: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder {
-            this@JceDecoder.pushTag(JceTag(0, false))
+            this@JceDecoder.pushTag(JceTagListElement)
             return this@JceDecoder.beginStructure(descriptor, *typeParams)
         }
 
-        override fun decodeByte(): Byte = jce.input.readByte().also { println("decodeByte: $it") }
+        override fun decodeByte(): Byte = jce.input.readByte()
         override fun decodeShort(): Short = error("illegal access")
         override fun decodeInt(): Int = error("illegal access")
         override fun decodeLong(): Long = error("illegal access")
@@ -75,7 +71,7 @@ internal class JceDecoder(
 
         override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
             // 不要读下一个 head
-            return jce.currentHead.let { jce.readJceIntValue(it) }.also { println("simpleListSize=$it") }
+            return jce.currentHead.let { jce.readJceIntValue(it) }
         }
     }
 
@@ -89,7 +85,8 @@ internal class JceDecoder(
         }
 
         override fun beginStructure(descriptor: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder {
-            this@JceDecoder.pushTag(JceTag(0, false))
+            this@JceDecoder.pushTag(JceTagListElement)
+
             return this@JceDecoder.beginStructure(descriptor, *typeParams)
         }
 
@@ -105,9 +102,9 @@ internal class JceDecoder(
         override fun decodeString(): String = jce.useHead { jce.readJceStringValue(it) }
 
         override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
-            println("decodeCollectionSize: ${descriptor.serialName}")
+            //println("decodeCollectionSize: ${descriptor.serialName}")
             // 不读下一个 head
-            return jce.useHead { jce.readJceIntValue(it) }.also { println("listSize=$it") }
+            return jce.useHead { jce.readJceIntValue(it) }
         }
     }
 
@@ -116,14 +113,16 @@ internal class JceDecoder(
 
     private inner class MapReaderImpl : AbstractDecoder() {
         override fun decodeSequentially(): Boolean = true
-        override fun decodeElementIndex(descriptor: SerialDescriptor): Int = error("should not be reached")
+        override fun decodeElementIndex(descriptor: SerialDescriptor): Int = error("stub")
+
         override fun endStructure(descriptor: SerialDescriptor) {
             this@JceDecoder.endStructure(descriptor)
         }
 
         private var state: Boolean = true
+
         override fun beginStructure(descriptor: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder {
-            this@JceDecoder.pushTag(JceTag(if (state) 0 else 1, false))
+            this@JceDecoder.pushTag(if (jce.currentHead.tag == 0) JceTagMapEntryKey else JceTagMapEntryValue)
             state = !state
             return this@JceDecoder.beginStructure(descriptor, *typeParams)
         }
@@ -134,21 +133,22 @@ internal class JceDecoder(
         override fun decodeLong(): Long = jce.useHead { jce.readJceLongValue(it) }
         override fun decodeFloat(): Float = jce.useHead { jce.readJceFloatValue(it) }
         override fun decodeDouble(): Double = jce.useHead { jce.readJceDoubleValue(it) }
+
         override fun decodeBoolean(): Boolean = jce.useHead { jce.readJceBooleanValue(it) }
         override fun decodeChar(): Char = decodeByte().toChar()
         override fun decodeEnum(enumDescriptor: SerialDescriptor): Int = decodeInt()
         override fun decodeString(): String = jce.useHead { jce.readJceStringValue(it) }
 
         override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
-            println("decodeCollectionSize in MapReader: ${descriptor.serialName}")
+            //println("decodeCollectionSize in MapReader: ${descriptor.serialName}")
             // 不读下一个 head
-            return jce.useHead { jce.readJceIntValue(it) }.also { println("listSize=$it") }
+            return jce.useHead { jce.readJceIntValue(it) }
         }
     }
 
 
     override fun endStructure(descriptor: SerialDescriptor) {
-        println("endStructure: $descriptor")
+        //println("endStructure: $descriptor")
         if (currentTagOrNull?.isSimpleByteArray == true) {
             jce.prepareNextHead() // read to next head
         }
@@ -160,10 +160,10 @@ internal class JceDecoder(
                 val currentHead = jce.currentHeadOrNull ?: return
                 if (currentHead.type == Jce.STRUCT_END) {
                     jce.prepareNextHead()
-                    println("current end")
+                    //println("current end")
                     break
                 }
-                println("current $currentHead")
+                //println("current $currentHead")
                 jce.skipField(currentHead.type)
                 jce.prepareNextHead()
             }
@@ -174,24 +174,24 @@ internal class JceDecoder(
     }
 
     override fun beginStructure(descriptor: SerialDescriptor, vararg typeParams: KSerializer<*>): CompositeDecoder {
-        println()
-        println("beginStructure: ${descriptor.serialName}")
+        //println()
+        //println("beginStructure: ${descriptor.serialName}")
         return when (descriptor.kind) {
             is PrimitiveKind -> this@JceDecoder
 
             StructureKind.MAP -> {
-                println("!! MAP")
+                //println("!! MAP")
                 return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) {
                     it.checkType(Jce.MAP)
                     MapReader
                 }
             }
             StructureKind.LIST -> {
-                println("!! ByteArray")
-                println("decoderTag: $currentTagOrNull")
-                println("jceHead: " + jce.currentHeadOrNull)
+                //println("!! ByteArray")
+                //println("decoderTag: $currentTagOrNull")
+                //println("jceHead: " + jce.currentHeadOrNull)
                 return jce.skipToHeadAndUseIfPossibleOrFail(currentTag.id) {
-                    println("listHead: $it")
+                    //println("listHead: $it")
                     when (it.type) {
                         Jce.SIMPLE_LIST -> {
                             currentTag.isSimpleByteArray = true
@@ -204,11 +204,11 @@ internal class JceDecoder(
                 }
             }
             StructureKind.CLASS -> {
-                val currentTag = currentTagOrNull ?: return this@JceDecoder
+                currentTagOrNull ?: return this@JceDecoder // outermost
 
-                println("!! CLASS")
-                println("decoderTag: $currentTag")
-                println("jceHead: " + jce.currentHeadOrNull)
+                //println("!! CLASS")
+                //println("decoderTag: $currentTag")
+                //println("jceHead: " + jce.currentHeadOrNull)
                 return jce.skipToHeadAndUseIfPossibleOrFail(popTag().id) { jceHead ->
                     jceHead.checkType(Jce.STRUCT_BEGIN)
 
@@ -227,14 +227,6 @@ internal class JceDecoder(
         }
     }
 
-    override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
-        println(
-            "decodeSerializableValue: ${deserializer.descriptor.toString().substringBefore('(')
-                .substringAfterLast('.')}"
-        )
-        return super.decodeSerializableValue(deserializer)
-    }
-
     override fun decodeSequentially(): Boolean = false
     override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
         val jceHead = jce.currentHeadOrNull ?: return CompositeDecoder.READ_DONE
@@ -252,16 +244,6 @@ internal class JceDecoder(
         return CompositeDecoder.READ_DONE // optional support
     }
 
-    override fun decodeTaggedNull(tag: JceTag): Nothing? {
-        println("decodeTaggedNull")
-        return super.decodeTaggedNull(tag)
-    }
-
-    override fun decodeTaggedValue(tag: JceTag): Any {
-        println("decodeTaggedValue")
-        return super.decodeTaggedValue(tag)
-    }
-
     override fun decodeTaggedInt(tag: JceTag): Int =
         jce.skipToHeadAndUseIfPossibleOrFail(tag.id) { jce.readJceIntValue(it) }
 
diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceInput.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceInput.kt
index 37f968d97..fe4b7524d 100644
--- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceInput.kt
+++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/JceInput.kt
@@ -166,7 +166,7 @@ internal class JceInput(
 
     // region readers
     fun readJceIntValue(head: JceHead): Int {
-        println("readJceIntValue: $head")
+        //println("readJceIntValue: $head")
         return when (head.type) {
             Jce.ZERO_TYPE -> 0
             Jce.BYTE -> input.readByte().toInt()
@@ -197,7 +197,7 @@ internal class JceInput(
     }
 
     fun readJceByteValue(head: JceHead): Byte {
-        println("readJceByteValue: $head")
+        //println("readJceByteValue: $head")
         return when (head.type) {
             Jce.ZERO_TYPE -> 0
             Jce.BYTE -> input.readByte()
@@ -215,6 +215,7 @@ internal class JceInput(
 
     @OptIn(ExperimentalUnsignedTypes::class)
     fun readJceStringValue(head: JceHead): String {
+        //println("readJceStringValue: $head")
         return when (head.type) {
             Jce.STRING1 -> input.readString(input.readUByte().toInt(), charset = charset.kotlinCharset)
             Jce.STRING4 -> input.readString(
diff --git a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/common.kt b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/common.kt
index d172a43b7..2065a87ea 100644
--- a/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/common.kt
+++ b/mirai-core-qqandroid/src/commonMain/kotlin/net/mamoe/mirai/qqandroid/io/serialization/jce/common.kt
@@ -28,41 +28,33 @@ annotation class JceId(val id: Int)
 @PublishedApi
 internal abstract class JceTag {
     abstract val id: Int
-    abstract val isNullable: Boolean
 
     internal var isSimpleByteArray: Boolean = false
 }
 
-internal sealed class JceTagListElement(
-    override val isNullable: Boolean
-) : JceTag(){
+internal object JceTagListElement : JceTag() {
     override val id: Int get() = 0
-
-    object Nullable : JceTagListElement(true)
-    object NotNull : JceTagListElement(false)
+    override fun toString(): String {
+        return "JceTagListElement"
+    }
 }
 
-internal sealed class JceTagMapEntryKey(
-    override val isNullable: Boolean
-) : JceTag(){
+internal object JceTagMapEntryKey : JceTag() {
     override val id: Int get() = 0
-
-    object Nullable : JceTagMapEntryKey(true)
-    object NotNull : JceTagMapEntryKey(false)
+    override fun toString(): String {
+        return "JceTagMapEntryKey"
+    }
 }
 
-internal sealed class JceTagMapEntryValue(
-    override val isNullable: Boolean
-) : JceTag() {
+internal object JceTagMapEntryValue : JceTag() {
     override val id: Int get() = 1
-
-    object Nullable : JceTagMapEntryValue(true)
-    object NotNull : JceTagMapEntryValue(false)
+    override fun toString(): String {
+        return "JceTagMapEntryValue"
+    }
 }
 
 internal data class JceTagCommon(
-    override val id: Int,
-    override val isNullable: Boolean
+    override val id: Int
 ) : JceTag()
 
 fun JceHead.checkType(type: Byte) {
diff --git a/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt b/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
index 7dcba9a4e..36d421604 100644
--- a/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
+++ b/mirai-core-qqandroid/src/commonTest/kotlin/net.mamoe.mirai.qqandroid.io.serialization/JceInputTest.kt
@@ -3,6 +3,7 @@
 package net.mamoe.mirai.qqandroid.io.serialization
 
 import kotlinx.io.core.buildPacket
+import kotlinx.io.core.toByteArray
 import kotlinx.io.core.writeFully
 import kotlinx.serialization.MissingFieldException
 import kotlinx.serialization.Serializable
@@ -271,8 +272,13 @@ internal class JceInputTest {
     fun testMapStringByteArray() {
         @Serializable
         data class TestSerializableClassA(
-            @JceId(0) val byteArray: Map<String, ByteArray>
-        )
+            @JceId(0) val map: Map<String, ByteArray>
+        ) {
+            override fun toString(): String {
+                @Suppress("EXPERIMENTAL_API_USAGE")
+                return map.entries.joinToString { "${it.key}=${it.value.contentToString()}" }
+            }
+        }
 
         val input = buildPacket {
             writeJceHead(MAP, 0)
@@ -284,18 +290,20 @@ internal class JceInputTest {
                 it.forEach { (key, value) ->
                     writeJceHead(STRING1, 0)
                     writeByte(key.length.toByte())
-                    writeStringUtf8(key)
+                    writeFully(key.toByteArray())
 
                     writeJceHead(SIMPLE_LIST, 1)
                     writeJceHead(BYTE, 0)
+                    writeJceHead(INT, 0)
+                    writeInt(value.size)
                     writeFully(value)
                 }
             }
         }
 
         assertEquals(
-            TestSerializableClassA(mapOf("str1" to byteArrayOf(2, 3, 4), "str2" to byteArrayOf(2, 3, 4))),
-            Jce.UTF_8.load(TestSerializableClassA.serializer(), input)
+            TestSerializableClassA(mapOf("str1" to byteArrayOf(2, 3, 4), "str2" to byteArrayOf(2, 3, 4))).toString(),
+            Jce.UTF_8.load(TestSerializableClassA.serializer(), input).toString()
         )
     }