Fix support for Any and support primitive and reference arrays, close #1801

This commit is contained in:
Him188 2022-04-24 15:05:44 +01:00
parent 1d60cf57b8
commit 45defb32a2
3 changed files with 143 additions and 6 deletions

View File

@ -12,8 +12,7 @@
package net.mamoe.mirai.console.internal.data
import kotlinx.serialization.KSerializer
import kotlinx.serialization.builtins.ArraySerializer
import kotlinx.serialization.builtins.nullable
import kotlinx.serialization.builtins.*
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
@ -28,17 +27,32 @@ import kotlin.reflect.KClass
import kotlin.reflect.KType
/**
* Copied from kotlinx.serialization, modifications are marked with "/* mamoe modify */"
* Copyright 2017-2020 JetBrains s.r.o.
*/
@Suppress("UNCHECKED_CAST")
internal fun SerializersModule.serializerMirai(type: KType): KSerializer<Any?> {
fun serializerByKTypeImpl(type: KType): KSerializer<*> {
val rootClass = type.classifierAsKClass()
// In Kotlin 1.6.20, `typeOf<Array<Long>>?.classifier` surprisingly gives kotlin.LongArray
// https://youtrack.jetbrains.com/issue/KT-52170/
if (type.arguments.size == 1) { // can be typeOf<Array<...>>, so cannot be typeOf<IntArray>
val result: KSerializer<Any?>? = when (rootClass) {
ByteArray::class -> ArraySerializer(Byte.serializer()).cast()
ShortArray::class -> ArraySerializer(Short.serializer()).cast()
IntArray::class -> ArraySerializer(Int.serializer()).cast()
LongArray::class -> ArraySerializer(Long.serializer()).cast()
FloatArray::class -> ArraySerializer(Float.serializer()).cast()
DoubleArray::class -> ArraySerializer(Double.serializer()).cast()
CharArray::class -> ArraySerializer(Char.serializer()).cast()
BooleanArray::class -> ArraySerializer(Boolean.serializer()).cast()
else -> null
}
if (result != null) return result
}
this.serializerOrNull(type)?.let { return it } // Kotlin builtin and user-defined
MessageSerializers.serializersModule.serializerOrNull(type)?.let { return it } // Mirai Messages
if (type.classifier == Any::class) return if (type.isMarkedNullable) YamlNullableDynamicSerializer else YamlDynamicSerializer as KSerializer<Any?>
val typeArguments = type.arguments
.map { requireNotNull(it.type) { "Star projections in type arguments are not allowed, but had $type" } }
@ -47,6 +61,18 @@ internal fun SerializersModule.serializerMirai(type: KType): KSerializer<Any?> {
else -> {
val serializers = typeArguments.map(::serializerMirai)
when (rootClass) {
Collection::class, List::class, MutableList::class, ArrayList::class -> ListSerializer(serializers[0])
HashSet::class -> SetSerializer(serializers[0])
Set::class, MutableSet::class, LinkedHashSet::class -> SetSerializer(serializers[0])
HashMap::class -> MapSerializer(serializers[0], serializers[1])
Map::class, MutableMap::class, LinkedHashMap::class -> MapSerializer(
serializers[0],
serializers[1]
)
Map.Entry::class -> MapEntrySerializer(serializers[0], serializers[1])
Pair::class -> PairSerializer(serializers[0], serializers[1])
Triple::class -> TripleSerializer(serializers[0], serializers[1], serializers[2])
Any::class -> if (type.isMarkedNullable) YamlNullableDynamicSerializer else YamlDynamicSerializer
else -> {
if (rootClass.java.isArray) {

View File

@ -111,7 +111,15 @@ internal fun PluginData.valueFromKTypeImpl(type: KType): SerializerAwareValue<*>
}
}
private fun KClass<*>.isReferencingSamePlatformClass(other: KClass<*>): Boolean {
return this.qualifiedName == other.qualifiedName // not using .java for
}
internal fun KClass<*>.createInstanceSmart(): Any {
when {
isReferencingSamePlatformClass(Array::class) -> return emptyArray<Any?>()
}
return when (this) {
Byte::class -> 0.toByte()
Short::class -> 0.toShort()
@ -145,6 +153,15 @@ internal fun KClass<*>.createInstanceSmart(): Any {
ConcurrentMap::class,
-> ConcurrentHashMap<Any?, Any?>()
ByteArray::class -> byteArrayOf()
BooleanArray::class -> booleanArrayOf()
ShortArray::class -> shortArrayOf()
IntArray::class -> intArrayOf()
LongArray::class -> longArrayOf()
FloatArray::class -> floatArrayOf()
DoubleArray::class -> doubleArrayOf()
CharArray::class -> charArrayOf()
else -> createInstanceOrNull()
?: error("Cannot create instance or find a initial value for ${this.qualifiedNameOrTip}")
}

View File

@ -222,4 +222,98 @@ internal class PluginDataTest : AbstractConsoleInstanceTest() {
storage.load(mockPlugin, data)
assertEquals(serialized, storage.getPluginDataFileInternal(mockPlugin, data).readText())
}
class DefaultValueForArray : AutoSavePluginData("save") {
val byteArray: ByteArray by value()
val booleanArray: BooleanArray by value()
var shortArray: ShortArray by value()
val intArray: IntArray by value()
val longArray: LongArray by value()
val floatArray: FloatArray by value()
val doubleArray: DoubleArray by value()
val charArray: CharArray by value()
var stringArray: Array<String> by value()
var longObjectArray: Array<Long> by value()
}
@Test
fun `default value for array`() {
val instance = DefaultValueForArray()
assertEquals(
"""
byteArray: []
booleanArray: []
shortArray: []
intArray: []
longArray: []
floatArray: []
doubleArray: []
charArray: []
stringArray: []
longObjectArray: []
""".trimIndent(),
serializePluginData(instance)
)
instance.shortArray = shortArrayOf(1)
instance.stringArray = arrayOf("1234")
println(instance.findBackingFieldValueNode(instance::longObjectArray))
instance.longObjectArray = arrayOf(1234)
assertEquals(
"""
byteArray: []
booleanArray: []
shortArray:
- 1
intArray: []
longArray: []
floatArray: []
doubleArray: []
charArray: []
stringArray:
- 1234
longObjectArray:
- 1234
""".trimIndent(),
serializePluginData(instance)
)
serializeAndRereadPluginData(instance)
}
class DefaultValueForCollections : AutoSavePluginData("save") {
val map: Map<String, String> by value()
val mapAny: Map<String, Any> by value()
val hashMapAny: HashMap<String, Any> by value()
val linkedHashMapAny: LinkedHashMap<String, Any> by value()
val list: List<String> by value()
val listAny: List<Any> by value()
val set: Set<String> by value()
val setAny: Set<Any> by value()
}
@Test
fun `default value for collections`() {
val instance = DefaultValueForCollections()
assertEquals(
"""
map: {}
mapAny: {}
hashMapAny: {}
linkedHashMapAny: {}
list: []
listAny: []
set: []
setAny: []
""".trimIndent(),
serializePluginData(instance)
)
serializeAndRereadPluginData(instance)
}
}