From 089b403a062080afd884a2c62cff7f4a7e907258 Mon Sep 17 00:00:00 2001 From: Him188 Date: Sat, 30 Apr 2022 17:23:19 +0100 Subject: [PATCH] Allow nulls in TypeSafeMap --- .../src/commonMain/kotlin/TypeSafeMap.kt | 48 +++++++++++++------ .../net/mamoe/mirai/utils/TypeSafeMapTest.kt | 15 +++++- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/mirai-core-utils/src/commonMain/kotlin/TypeSafeMap.kt b/mirai-core-utils/src/commonMain/kotlin/TypeSafeMap.kt index e14ce975b..ff519c116 100644 --- a/mirai-core-utils/src/commonMain/kotlin/TypeSafeMap.kt +++ b/mirai-core-utils/src/commonMain/kotlin/TypeSafeMap.kt @@ -34,8 +34,8 @@ public sealed interface TypeSafeMap { public operator fun get(key: TypeKey, defaultValue: S): S public operator fun contains(key: TypeKey): Boolean = get(key) != null - public fun toMapBoxed(): Map, Any?> - public fun toMap(): Map + public fun toMapBoxed(): Map, Any> + public fun toMap(): Map public companion object { public val EMPTY: TypeSafeMap = TypeSafeMapImpl(emptyMap()) @@ -59,10 +59,11 @@ public sealed interface MutableTypeSafeMap : TypeSafeMap { public fun setAll(other: TypeSafeMap) } +private val NULL: Any = Symbol("NULL")!! @PublishedApi internal open class TypeSafeMapImpl( - @PublishedApi internal open val map: Map = ConcurrentHashMap() + @PublishedApi internal open val map: Map = ConcurrentHashMap() ) : TypeSafeMap { override val size: Int get() = map.size @@ -78,21 +79,29 @@ internal open class TypeSafeMapImpl( return "TypeSafeMapImpl(map=$map)" } - override operator fun get(key: TypeKey): T = - map[key.name]?.uncheckedCast() ?: throw NoSuchElementException(key.toString()) + override operator fun get(key: TypeKey): T { + val value = map[key.name] + if (value === NULL) { + return null.uncheckedCast() + } + return value?.uncheckedCast() ?: throw NoSuchElementException(key.toString()) + } - override operator fun get(key: TypeKey, defaultValue: S): S = - map[key.name]?.uncheckedCast() ?: defaultValue + override operator fun get(key: TypeKey, defaultValue: S): S { + val value = map[key.name] + if (value === NULL) return defaultValue + return value?.uncheckedCast() ?: defaultValue + } override operator fun contains(key: TypeKey): Boolean = map.containsKey(key.name) - override fun toMapBoxed(): Map, Any?> = map.mapKeys { TypeKey(it.key) } - override fun toMap(): Map = map + override fun toMapBoxed(): Map, Any> = map.mapKeys { TypeKey(it.key) } + override fun toMap(): Map = map } @PublishedApi internal class MutableTypeSafeMapImpl( - @PublishedApi override val map: MutableMap = ConcurrentHashMap() + @PublishedApi override val map: MutableMap = ConcurrentHashMap() ) : TypeSafeMap, MutableTypeSafeMap, TypeSafeMapImpl(map) { override fun equals(other: Any?): Boolean { return other is MutableTypeSafeMapImpl && other.map == this.map @@ -107,7 +116,11 @@ internal class MutableTypeSafeMapImpl( } override operator fun set(key: TypeKey, value: T) { - map[key.name] = value + if (value == null) { + map[key.name] = NULL + } else { + map[key.name] = value + } } override fun setAll(other: TypeSafeMap) { @@ -118,18 +131,23 @@ internal class MutableTypeSafeMapImpl( } } - override fun remove(key: TypeKey): T? = map.remove(key.name)?.uncheckedCast() + override fun remove(key: TypeKey): T? { + val value = map.remove(key.name) + return if (value == NULL) { + null + } else { + value?.uncheckedCast() + } + } } public fun TypeSafeMap.toMutableTypeSafeMap(): MutableTypeSafeMap = MutableTypeSafeMap(this.toMap()) public inline fun MutableTypeSafeMap(): MutableTypeSafeMap = MutableTypeSafeMapImpl() -public inline fun MutableTypeSafeMap(map: Map): MutableTypeSafeMap = +public inline fun MutableTypeSafeMap(map: Map): MutableTypeSafeMap = MutableTypeSafeMapImpl().also { it.map.putAll(map) } public inline fun TypeSafeMap(): TypeSafeMap = TypeSafeMap.EMPTY -public inline fun TypeSafeMap(map: Map): TypeSafeMap = - MutableTypeSafeMapImpl().also { it.map.putAll(map) } public inline fun buildTypeSafeMap(block: MutableTypeSafeMap.() -> Unit): MutableTypeSafeMap { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } diff --git a/mirai-core-utils/src/commonTest/kotlin/net/mamoe/mirai/utils/TypeSafeMapTest.kt b/mirai-core-utils/src/commonTest/kotlin/net/mamoe/mirai/utils/TypeSafeMapTest.kt index e10e31394..bc8f8ddce 100644 --- a/mirai-core-utils/src/commonTest/kotlin/net/mamoe/mirai/utils/TypeSafeMapTest.kt +++ b/mirai-core-utils/src/commonTest/kotlin/net/mamoe/mirai/utils/TypeSafeMapTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2019-2021 Mamoe Technologies and contributors. + * Copyright 2019-2022 Mamoe Technologies and contributors. * * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证. * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link. @@ -17,6 +17,8 @@ import kotlin.test.assertEquals internal class TypeSafeMapTest { private val myKey = TypeKey("test") + private val myNullableKey = TypeKey("testNullable") + private val myNullableKey2 = TypeKey("testNullable2") private val myKey2 = TypeKey("test2") @Test @@ -27,6 +29,17 @@ internal class TypeSafeMapTest { assertEquals(2, map.size) assertEquals("str", map[myKey]) assertEquals("str2", map[myKey2]) + + } + + @Test + fun `test nulls`() { + val map = MutableTypeSafeMap() + map[myNullableKey] = null + map[myNullableKey2] = "str2" + assertEquals(2, map.size) + assertEquals(null, map[myNullableKey]) + assertEquals("str2", map[myNullableKey2]) } @Test