diff --git a/backend/mirai-console/src/command/descriptor/CommandArgumentContext.kt b/backend/mirai-console/src/command/descriptor/CommandArgumentContext.kt index e8c6a4113..e9828c1f9 100644 --- a/backend/mirai-console/src/command/descriptor/CommandArgumentContext.kt +++ b/backend/mirai-console/src/command/descriptor/CommandArgumentContext.kt @@ -23,6 +23,9 @@ import net.mamoe.mirai.contact.* import net.mamoe.mirai.message.data.Image import net.mamoe.mirai.message.data.MessageContent import net.mamoe.mirai.message.data.PlainText +import java.util.* +import kotlin.collections.ArrayList +import kotlin.collections.HashSet import kotlin.contracts.InvocationKind.EXACTLY_ONCE import kotlin.contracts.contract import kotlin.internal.LowPriorityInOverloadResolution @@ -75,34 +78,53 @@ public interface CommandArgumentContext { public val EMPTY: CommandArgumentContext = EmptyCommandArgumentContext } + private object EnumCommandArgumentContext : CommandArgumentContext { + private val cache = WeakHashMap, CommandValueArgumentParser<*>>() + private val enumKlass = Enum::class + override fun get(kClass: KClass): CommandValueArgumentParser? { + return if (kClass.isSubclassOf(enumKlass)) { + val jclass = kClass.java.asSubclass(Enum::class.java) + @Suppress("UNCHECKED_CAST") + (cache[jclass] ?: kotlin.run { + EnumValueArgumentParser(jclass).also { cache[jclass] = it } + }) as CommandValueArgumentParser + } else null + } + + override fun toList(): List> = emptyList() + } + /** * 内建的默认 [CommandValueArgumentParser] */ - public object Builtins : CommandArgumentContext by (buildCommandArgumentContext { - Int::class with IntValueArgumentParser - Byte::class with ByteValueArgumentParser - Short::class with ShortValueArgumentParser - Boolean::class with BooleanValueArgumentParser - String::class with StringValueArgumentParser - Long::class with LongValueArgumentParser - Double::class with DoubleValueArgumentParser - Float::class with FloatValueArgumentParser + public object Builtins : CommandArgumentContext by listOf( + EnumCommandArgumentContext, + buildCommandArgumentContext { + Int::class with IntValueArgumentParser + Byte::class with ByteValueArgumentParser + Short::class with ShortValueArgumentParser + Boolean::class with BooleanValueArgumentParser + String::class with StringValueArgumentParser + Long::class with LongValueArgumentParser + Double::class with DoubleValueArgumentParser + Float::class with FloatValueArgumentParser - Image::class with ImageValueArgumentParser - PlainText::class with PlainTextValueArgumentParser + Image::class with ImageValueArgumentParser + PlainText::class with PlainTextValueArgumentParser - Contact::class with ExistingContactValueArgumentParser - User::class with ExistingUserValueArgumentParser - Member::class with ExistingMemberValueArgumentParser - Group::class with ExistingGroupValueArgumentParser - Friend::class with ExistingFriendValueArgumentParser - Bot::class with ExistingBotValueArgumentParser + Contact::class with ExistingContactValueArgumentParser + User::class with ExistingUserValueArgumentParser + Member::class with ExistingMemberValueArgumentParser + Group::class with ExistingGroupValueArgumentParser + Friend::class with ExistingFriendValueArgumentParser + Bot::class with ExistingBotValueArgumentParser - PermissionId::class with PermissionIdValueArgumentParser - PermitteeId::class with PermitteeIdValueArgumentParser + PermissionId::class with PermissionIdValueArgumentParser + PermitteeId::class with PermitteeIdValueArgumentParser - MessageContent::class with RawContentValueArgumentParser - }) + MessageContent::class with RawContentValueArgumentParser + }, + ).fold(EmptyCommandArgumentContext, CommandArgumentContext::plus) } /** @@ -127,7 +149,7 @@ public object EmptyCommandArgumentContext : CommandArgumentContext by SimpleComm * 合并两个 [buildCommandArgumentContext], [replacer] 将会替换 [this] 中重复的 parser. */ public operator fun CommandArgumentContext.plus(replacer: CommandArgumentContext): CommandArgumentContext { - if (replacer == EmptyCommandArgumentContext) return this + if (replacer === EmptyCommandArgumentContext) return this if (this == EmptyCommandArgumentContext) return replacer return object : CommandArgumentContext { override fun get(kClass: KClass): CommandValueArgumentParser? = @@ -142,7 +164,7 @@ public operator fun CommandArgumentContext.plus(replacer: CommandArgumentContext */ public operator fun CommandArgumentContext.plus(replacer: List>): CommandArgumentContext { if (replacer.isEmpty()) return this - if (this == EmptyCommandArgumentContext) return SimpleCommandArgumentContext(replacer) + if (this === EmptyCommandArgumentContext) return SimpleCommandArgumentContext(replacer) return object : CommandArgumentContext { @Suppress("UNCHECKED_CAST") override fun get(kClass: KClass): CommandValueArgumentParser? = diff --git a/backend/mirai-console/src/command/descriptor/CommandArgumentParserBuiltins.kt b/backend/mirai-console/src/command/descriptor/CommandArgumentParserBuiltins.kt index 16bc5dff0..dc4e0e29c 100644 --- a/backend/mirai-console/src/command/descriptor/CommandArgumentParserBuiltins.kt +++ b/backend/mirai-console/src/command/descriptor/CommandArgumentParserBuiltins.kt @@ -364,6 +364,102 @@ public object RawContentValueArgumentParser : CommandValueArgumentParser>( + private val type: Class, +) : InternalCommandValueArgumentParserExtensions() { + // 此 Exception 仅用于中断 enum 搜索, 不需要使用堆栈信息 + private object NoEnumException : RuntimeException() + + + init { + check(Enum::class.java.isAssignableFrom(type)) { + "$type not a enum class" + } + } + + private fun Sequence.hasDuplicates(): Boolean = iterator().hasDuplicates() + private fun Iterator.hasDuplicates(): Boolean { + val observed = HashSet() + for (elem in this) { + if (!observed.add(elem)) + return true + } + return false + } + + @Suppress("NOTHING_TO_INLINE") + private inline fun noConstant(): Nothing { + throw NoEnumException + } + + private val delegate: (String) -> T = kotlin.run { + val enums = type.enumConstants.asSequence() + // step 1: 分析是否能够忽略大小写 + if (enums.map { it.name.toLowerCase() }.hasDuplicates()) { + ({ java.lang.Enum.valueOf(type, it) }) + } else { // step 2: 分析是否能使用小驼峰命名 + val lowerCaseEnumDirection = enums.map { it.name.toLowerCase() to it }.toList().toMap() + + val camelCase = enums.mapNotNull { elm -> + val name = elm.name.split('_') + if (name.size == 1) { // No splitter + null + } else { + buildString { + val iterator = name.iterator() + append(iterator.next().toLowerCase()) + for (v in iterator) { + if (v.isEmpty()) continue + append(v[0].toUpperCase()) + append(v.substring(1, v.length).toLowerCase()) + } + } to elm + } + } + + val camelCaseDirection = if (( + enums.map { it.name.toLowerCase() } + camelCase.map { it.first.toLowerCase() } + ).hasDuplicates() + ) { // 确认驼峰命名与源没有冲突 + emptyMap() + } else { + camelCase.toList().toMap() + } + + ({ + camelCaseDirection[it] + ?: lowerCaseEnumDirection[it.toLowerCase()] + ?: noConstant() + }) + } + } + + override fun parse(raw: String, sender: CommandSender): T { + return try { + delegate(raw) + } catch (e: Throwable) { + illegalArgument("无法解析 $raw 为 ${type.simpleName}") + } + } +} + internal abstract class InternalCommandValueArgumentParserExtensions : AbstractCommandValueArgumentParser() { private fun String.parseToLongOrFail(): Long = toLongOrNull() ?: illegalArgument("无法解析 $this 为整数") diff --git a/backend/mirai-console/test/command/TestCommand.kt b/backend/mirai-console/test/command/TestCommand.kt index c2c060379..f00898787 100644 --- a/backend/mirai-console/test/command/TestCommand.kt +++ b/backend/mirai-console/test/command/TestCommand.kt @@ -64,9 +64,38 @@ object TestSimpleCommand : RawCommand(owner, "testSimple", "tsS") { } } +@Suppress("EnumEntryName") +object TestEnumArgCommand : CompositeCommand(owner, "testenum") { + enum class TestEnum { + V1, V2, V3 + } + enum class TestCase { + A, a + } + enum class TestCamelCase { + A, B, A_B + } + + @SubCommand("tcc") + fun CommandSender.testCamelCase(enum: TestCamelCase) { + Testing.ok(enum) + } + + @SubCommand("tc") + fun CommandSender.testCase(enum: TestCase) { + Testing.ok(enum) + } + + @SubCommand + fun CommandSender.e1(enum: TestEnum) { + Testing.ok(enum) + } +} + internal val sender by lazy { ConsoleCommandSender } internal object TestUnitCommandOwner : CommandOwner by ConsoleCommandOwner + internal val owner by lazy { TestUnitCommandOwner } @@ -137,6 +166,75 @@ internal class TestCommand { assertEquals(2, result.size) } + @Test + fun `test enum argument`() = runBlocking { + TestEnumArgCommand.withRegistration { + + assertEquals(TestEnumArgCommand.TestEnum.V1, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("e1"), PlainText("V1"))) + }) + assertEquals(TestEnumArgCommand.TestEnum.V2, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("e1"), PlainText("V2"))) + }) + assertEquals(TestEnumArgCommand.TestEnum.V3, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("e1"), PlainText("V3"))) + }) + withTesting { + assertFailure(TestEnumArgCommand.execute(sender, PlainText("e1"), PlainText("ENUM_NOT_FOUND"))) + Testing.ok(Unit) + } + assertEquals(TestEnumArgCommand.TestEnum.V1, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("e1"), PlainText("v1"))) + }) + assertEquals(TestEnumArgCommand.TestEnum.V2, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("e1"), PlainText("v2"))) + }) + assertEquals(TestEnumArgCommand.TestEnum.V3, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("e1"), PlainText("v3"))) + }) + + + assertEquals(TestEnumArgCommand.TestCase.A, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tc"), PlainText("A"))) + }) + assertEquals(TestEnumArgCommand.TestCase.a, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tc"), PlainText("a"))) + }) + withTesting { + assertFailure(TestEnumArgCommand.execute(sender, PlainText("tc"), PlainText("ENUM_NOT_FOUND"))) + Testing.ok(Unit) + } + + + assertEquals(TestEnumArgCommand.TestCamelCase.A, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tcc"), PlainText("A"))) + }) + assertEquals(TestEnumArgCommand.TestCamelCase.A, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tcc"), PlainText("a"))) + }) + assertEquals(TestEnumArgCommand.TestCamelCase.B, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tcc"), PlainText("B"))) + }) + assertEquals(TestEnumArgCommand.TestCamelCase.B, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tcc"), PlainText("b"))) + }) + assertEquals(TestEnumArgCommand.TestCamelCase.A_B, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tcc"), PlainText("A_B"))) + }) + assertEquals(TestEnumArgCommand.TestCamelCase.A_B, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tcc"), PlainText("a_b"))) + }) + assertEquals(TestEnumArgCommand.TestCamelCase.A_B, withTesting { + assertSuccess(TestEnumArgCommand.execute(sender, PlainText("tcc"), PlainText("aB"))) + }) + withTesting { + assertFailure(TestEnumArgCommand.execute(sender, PlainText("tc"), PlainText("ENUM_NOT_FOUND"))) + Testing.ok(Unit) + } + + } + } + @Test fun testSimpleArgsSplitting() = runBlocking { TestSimpleCommand.withRegistration { @@ -362,3 +460,10 @@ internal fun assertSuccess(result: CommandExecuteResult) { throw result.exception ?: AssertionError(result.toString()) } } + +@OptIn(ExperimentalCommandDescriptors::class) +internal fun assertFailure(result: CommandExecuteResult) { + if (!result.isFailure()) { + throw AssertionError("$result not a failure") + } +}