Add ConstructorCallCodegen

This commit is contained in:
Him188 2021-08-25 13:53:26 +08:00
parent f31f525343
commit c4939a7446
5 changed files with 646 additions and 0 deletions

View File

@ -0,0 +1,105 @@
/*
* Copyright 2019-2021 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.utils.codegen
import kotlinx.serialization.Serializable
import net.mamoe.mirai.utils.cast
import kotlin.reflect.KParameter
import kotlin.reflect.KProperty1
import kotlin.reflect.KType
import kotlin.reflect.full.declaredMemberProperties
import kotlin.reflect.full.hasAnnotation
import kotlin.reflect.full.primaryConstructor
import kotlin.reflect.full.valueParameters
import kotlin.reflect.typeOf
object ConstructorCallCodegenFacade {
/**
* Analyze [value] and give its correspondent [ValueDesc].
*/
fun analyze(value: Any?, type: KType): ValueDesc {
if (value == null) return PlainValueDesc("null", null)
val clazz = value::class
if (clazz.isData || clazz.hasAnnotation<Serializable>()) {
val clazz1 = value::class
val primaryConstructor =
clazz1.primaryConstructor ?: error("$value does not have a primary constructor.")
val properties = clazz1.declaredMemberProperties
val map = mutableMapOf<KParameter, ValueDesc>()
for (valueParameter in primaryConstructor.valueParameters) {
val prop = properties.find { it.name == valueParameter.name }
?: error("Could not find corresponding property for parameter ${valueParameter.name}")
prop.cast<KProperty1<Any, Any?>>()
map[valueParameter] = analyze(prop.get(value), prop.returnType)
}
return ClassValueDesc(value, map)
}
ArrayValueDesc.createOrNull(value, type)?.let { return it }
if (value is Collection<*>) {
return CollectionValueDesc(value, arrayType = type, elementType = type.arguments.first().type!!)
} else if (value is Map<*, *>) {
return MapValueDesc(
value.cast(),
value.cast(),
type,
type.arguments.first().type!!,
type.arguments[1].type!!
)
}
return when (value) {
is CharSequence -> {
PlainValueDesc('"' + value.toString() + '"', value)
}
is Char -> {
PlainValueDesc("'$value'", value)
}
else -> PlainValueDesc(value.toString(), value)
}
}
/**
* Generate source code to construct the value represented by [desc].
*/
fun generate(desc: ValueDesc, context: CodegenContext = CodegenContext()): String {
if (context.configuration.removeDefaultValues) {
val def = AnalyzeDefaultValuesMappingVisitor()
desc.accept(def)
desc.accept(RemoveDefaultValuesVisitor(def.mappings))
}
ValueCodegen(context).generate(desc)
return context.getResult()
}
fun analyzeAndGenerate(value: Any?, type: KType, context: CodegenContext = CodegenContext()): String {
return generate(analyze(value, type), context)
}
}
@OptIn(ExperimentalStdlibApi::class)
inline fun <reified T> ConstructorCallCodegenFacade.analyze(value: T): ValueDesc {
return analyze(value, typeOf<T>())
}
@OptIn(ExperimentalStdlibApi::class)
inline fun <reified T> ConstructorCallCodegenFacade.analyzeAndGenerate(
value: T,
context: CodegenContext = CodegenContext()
): String {
return analyzeAndGenerate(value, typeOf<T>(), context)
}

View File

@ -0,0 +1,137 @@
/*
* Copyright 2019-2021 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.utils.codegen
import net.mamoe.mirai.utils.encodeToString
import net.mamoe.mirai.utils.toUHexString
class ValueCodegen(
val context: CodegenContext
) {
fun generate(desc: ValueDesc) {
when (desc) {
is PlainValueDesc -> generate(desc)
is ObjectArrayValueDesc -> generate(desc)
is PrimitiveArrayValueDesc -> generate(desc)
is CollectionValueDesc -> generate(desc)
is ClassValueDesc<*> -> generate(desc)
is MapValueDesc -> generate(desc)
}
}
fun generate(desc: PlainValueDesc) {
context.append(desc.value)
}
fun generate(desc: MapValueDesc) {
context.run {
appendLine("mutableMapOf(")
for ((key, value) in desc.elements) {
generate(key)
append(" to ")
generate(value)
appendLine(",")
}
append(")")
}
}
fun <T : Any> generate(desc: ClassValueDesc<T>) {
context.run {
appendLine("${desc.type.qualifiedName}(")
for ((param, valueDesc) in desc.properties) {
append(param.name)
append("=")
generate(valueDesc)
appendLine(",")
}
append(")")
}
}
fun generate(desc: ArrayValueDesc) {
val array = desc.value
fun impl(funcName: String, elements: List<ValueDesc>) {
context.run {
append(funcName)
append('(')
val list = elements.toList()
list.forEachIndexed { index, desc ->
generate(desc)
if (index != list.lastIndex) append(", ")
}
append(')')
}
}
return when (array) {
is Array<*> -> impl("arrayOf", desc.elements)
is IntArray -> impl("intArrayOf", desc.elements)
is ByteArray -> {
if (array.size == 0) {
context.append("net.mamoe.mirai.utils.EMPTY_BYTE_ARRAY") // let IDE to shorten references.
return
} else {
if (array.encodeToString().all { Character.isUnicodeIdentifierPart(it) || it.isWhitespace() }) {
// prefers to show readable string
context.append(
"\"${
array.encodeToString().escapeQuotation()
}\".toByteArray() /* ${array.toUHexString()} */"
)
} else {
context.append("\"${array.toUHexString()}\".hexToBytes()")
}
return
}
}
is ShortArray -> impl("shortArrayOf", desc.elements)
is CharArray -> impl("charArrayOf", desc.elements)
is LongArray -> impl("longArrayOf", desc.elements)
is FloatArray -> impl("floatArrayOf", desc.elements)
is DoubleArray -> impl("doubleArrayOf", desc.elements)
is BooleanArray -> impl("booleanArrayOf", desc.elements)
is List<*> -> impl("mutableListOf", desc.elements)
is Set<*> -> impl("mutableSetOf", desc.elements)
else -> error("$array is not an array.")
}
}
}
class CodegenContext(
val sb: StringBuilder = StringBuilder(),
val configuration: CodegenConfiguration = CodegenConfiguration()
) : Appendable by sb {
fun getResult(): String {
return sb.toString()
}
}
class CodegenConfiguration(
var removeDefaultValues: Boolean = true,
)
private fun String.escapeQuotation(): String = buildString { this@escapeQuotation.escapeQuotationTo(this) }
private fun String.escapeQuotationTo(out: StringBuilder) {
for (i in 0 until length) {
when (val ch = this[i]) {
'\\' -> out.append("\\\\")
'\n' -> out.append("\\n")
'\r' -> out.append("\\r")
'\t' -> out.append("\\t")
'\"' -> out.append("\\\"")
else -> out.append(ch)
}
}
}

View File

@ -0,0 +1,152 @@
/*
* Copyright 2019-2021 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.utils.codegen
import kotlin.reflect.KClass
import kotlin.reflect.KParameter
import kotlin.reflect.KType
import kotlin.reflect.full.createType
import kotlin.reflect.typeOf
sealed interface ValueDesc {
val origin: Any?
fun accept(visitor: ValueDescVisitor)
}
sealed interface ArrayValueDesc : ValueDesc {
val value: Any
val arrayType: KType
val elementType: KType
val elements: MutableList<ValueDesc>
companion object {
@OptIn(ExperimentalStdlibApi::class)
fun createOrNull(array: Any, type: KType): ArrayValueDesc? {
if (array is Array<*>) return ObjectArrayValueDesc(array, arrayType = type)
return when (array) {
is IntArray -> PrimitiveArrayValueDesc(array, arrayType = type, elementType = typeOf<Int>())
is ByteArray -> PrimitiveArrayValueDesc(array, arrayType = type, elementType = typeOf<Byte>())
is ShortArray -> PrimitiveArrayValueDesc(array, arrayType = type, elementType = typeOf<Short>())
is CharArray -> PrimitiveArrayValueDesc(array, arrayType = type, elementType = typeOf<Char>())
is LongArray -> PrimitiveArrayValueDesc(array, arrayType = type, elementType = typeOf<Long>())
is FloatArray -> PrimitiveArrayValueDesc(array, arrayType = type, elementType = typeOf<Float>())
is DoubleArray -> PrimitiveArrayValueDesc(array, arrayType = type, elementType = typeOf<Double>())
is BooleanArray -> PrimitiveArrayValueDesc(array, arrayType = type, elementType = typeOf<Boolean>())
else -> return null
}
}
}
}
class ObjectArrayValueDesc(
override var value: Array<*>,
override val origin: Array<*> = value,
override val arrayType: KType,
override val elementType: KType = arrayType.arguments.first().type ?: Any::class.createType()
) : ArrayValueDesc {
override val elements: MutableList<ValueDesc> by lazy {
value.mapTo(mutableListOf()) {
ConstructorCallCodegenFacade.analyze(it, elementType)
}
}
override fun accept(visitor: ValueDescVisitor) {
visitor.visitObjectArray(this)
}
}
class CollectionValueDesc(
override var value: Collection<*>,
override val origin: Collection<*> = value,
override val arrayType: KType,
override val elementType: KType = arrayType.arguments.first().type ?: Any::class.createType()
) : ArrayValueDesc {
override val elements: MutableList<ValueDesc> by lazy {
value.mapTo(mutableListOf()) {
ConstructorCallCodegenFacade.analyze(it, elementType)
}
}
override fun accept(visitor: ValueDescVisitor) {
visitor.visitCollection(this)
}
}
class MapValueDesc(
var value: Map<Any?, Any?>,
override val origin: Map<Any?, Any?> = value,
val mapType: KType,
val keyType: KType = mapType.arguments.first().type ?: Any::class.createType(),
val valueType: KType = mapType.arguments[1].type ?: Any::class.createType(),
) : ValueDesc {
val elements: MutableMap<ValueDesc, ValueDesc> by lazy {
value.map {
ConstructorCallCodegenFacade.analyze(it.key, keyType) to ConstructorCallCodegenFacade.analyze(
it.value,
valueType
)
}.toMap(mutableMapOf())
}
override fun accept(visitor: ValueDescVisitor) {
visitor.visitMap(this)
}
}
class PrimitiveArrayValueDesc(
override var value: Any,
override val origin: Any = value,
override val arrayType: KType,
override val elementType: KType
) : ArrayValueDesc {
override val elements: MutableList<ValueDesc> by lazy {
when (val value = value) {
is IntArray -> value.mapTo(mutableListOf()) { ConstructorCallCodegenFacade.analyze(it, elementType) }
is ByteArray -> value.mapTo(mutableListOf()) { ConstructorCallCodegenFacade.analyze(it, elementType) }
is ShortArray -> value.mapTo(mutableListOf()) { ConstructorCallCodegenFacade.analyze(it, elementType) }
is CharArray -> value.mapTo(mutableListOf()) { ConstructorCallCodegenFacade.analyze(it, elementType) }
is LongArray -> value.mapTo(mutableListOf()) { ConstructorCallCodegenFacade.analyze(it, elementType) }
is FloatArray -> value.mapTo(mutableListOf()) { ConstructorCallCodegenFacade.analyze(it, elementType) }
is DoubleArray -> value.mapTo(mutableListOf()) { ConstructorCallCodegenFacade.analyze(it, elementType) }
is BooleanArray -> value.mapTo(mutableListOf()) { ConstructorCallCodegenFacade.analyze(it, elementType) }
else -> error("$value is not an array.")
}
}
override fun accept(visitor: ValueDescVisitor) {
visitor.visitPrimitiveArray(this)
}
}
class PlainValueDesc(
var value: String,
override val origin: Any?
) : ValueDesc {
init {
require(value.isNotBlank())
}
override fun accept(visitor: ValueDescVisitor) {
visitor.visitPlain(this)
}
}
class ClassValueDesc<T : Any>(
override val origin: T,
val properties: MutableMap<KParameter, ValueDesc>,
) : ValueDesc {
val type: KClass<out T> by lazy { origin::class }
override fun accept(visitor: ValueDescVisitor) {
visitor.visitClass(this)
}
}

View File

@ -0,0 +1,140 @@
/*
* Copyright 2019-2021 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.utils.codegen
import net.mamoe.mirai.utils.cast
import kotlin.reflect.KClass
import kotlin.reflect.KParameter
import kotlin.reflect.KProperty
import kotlin.reflect.KProperty1
import kotlin.reflect.full.memberProperties
import kotlin.reflect.full.primaryConstructor
interface ValueDescVisitor {
fun visitValue(desc: ValueDesc) {}
fun visitPlain(desc: PlainValueDesc) {
visitValue(desc)
}
fun visitArray(desc: ArrayValueDesc) {
visitValue(desc)
for (element in desc.elements) {
element.accept(this)
}
}
fun visitObjectArray(desc: ObjectArrayValueDesc) {
visitArray(desc)
}
fun visitCollection(desc: CollectionValueDesc) {
visitArray(desc)
}
fun visitMap(desc: MapValueDesc) {
visitValue(desc)
for ((key, value) in desc.elements.entries) {
key.accept(this)
value.accept(this)
}
}
fun visitPrimitiveArray(desc: PrimitiveArrayValueDesc) {
visitArray(desc)
}
fun <T : Any> visitClass(desc: ClassValueDesc<T>) {
visitValue(desc)
desc.properties.forEach { (_, u) ->
u.accept(this)
}
}
}
class DefaultValuesMapping(
val forClass: KClass<*>,
val mapping: MutableMap<String, Any?> = mutableMapOf()
) {
operator fun get(property: KProperty<*>): Any? = mapping[property.name]
}
class AnalyzeDefaultValuesMappingVisitor : ValueDescVisitor {
val mappings: MutableList<DefaultValuesMapping> = mutableListOf()
override fun <T : Any> visitClass(desc: ClassValueDesc<T>) {
super.visitClass(desc)
if (mappings.any { it.forClass == desc.type }) return
val defaultInstance =
createInstanceWithMostDefaultValues(desc.type, desc.properties.mapValues { it.value.origin })
val optionalParameters = desc.type.primaryConstructor!!.parameters.filter { it.isOptional }
mappings.add(
DefaultValuesMapping(
desc.type,
optionalParameters.associateTo(mutableMapOf()) { param ->
val value = findCorrespondingProperty(desc, param).get(defaultInstance)
param.name!! to value
}
)
)
}
private fun <T : Any> findCorrespondingProperty(
desc: ClassValueDesc<T>,
param: KParameter
) = desc.type.memberProperties.single { it.name == param.name }.cast<KProperty1<Any, Any>>()
private fun <T : Any> createInstanceWithMostDefaultValues(clazz: KClass<T>, arguments: Map<KParameter, Any?>): T {
val primaryConstructor = clazz.primaryConstructor ?: error("Type $clazz does not have primary constructor.")
return primaryConstructor.callBy(arguments.filter { !it.key.isOptional })
}
}
class RemoveDefaultValuesVisitor(
private val mappings: MutableList<DefaultValuesMapping>,
) : ValueDescVisitor {
override fun <T : Any> visitClass(desc: ClassValueDesc<T>) {
super.visitClass(desc)
val mapping = mappings.find { it.forClass == desc.type }?.mapping ?: return
// remove properties who have the same values as their default values, this would significantly reduce code size.
mapping.forEach { (name, defaultValue) ->
if (desc.properties.entries.removeIf {
it.key.name == name && equals(it.value.origin, defaultValue)
}
) {
return@forEach // by removing one property, there will not by any other matches
}
}
}
fun equals(a: Any?, b: Any?): Boolean {
return when {
a === b -> true
a == b -> true
a is Array<*> && b is Array<*> -> a.contentEquals(b)
a is IntArray && b is IntArray -> a.contentEquals(b)
a is ByteArray && b is ByteArray -> a.contentEquals(b)
a is ShortArray && b is ShortArray -> a.contentEquals(b)
a is LongArray && b is LongArray -> a.contentEquals(b)
a is CharArray && b is CharArray -> a.contentEquals(b)
a is FloatArray && b is FloatArray -> a.contentEquals(b)
a is DoubleArray && b is DoubleArray -> a.contentEquals(b)
a is BooleanArray && b is BooleanArray -> a.contentEquals(b)
else -> false
}
}
}

View File

@ -0,0 +1,112 @@
/*
* Copyright 2019-2021 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.utils.codegen.test
import net.mamoe.mirai.internal.utils.codegen.ConstructorCallCodegenFacade
import net.mamoe.mirai.internal.utils.codegen.analyzeAndGenerate
import kotlin.test.Test
import kotlin.test.assertEquals
class ConstructorCallCodegenTest {
@Test
fun `test plain`() {
assertEquals(
"\"test\"",
ConstructorCallCodegenFacade.analyzeAndGenerate("test")
)
assertEquals(
"1",
ConstructorCallCodegenFacade.analyzeAndGenerate(1)
)
assertEquals(
"1.0",
ConstructorCallCodegenFacade.analyzeAndGenerate(1.0)
)
}
@Test
fun `test array`() {
assertEquals(
"arrayOf(1, 2)",
ConstructorCallCodegenFacade.analyzeAndGenerate(arrayOf(1, 2))
)
assertEquals(
"arrayOf(5.0)",
ConstructorCallCodegenFacade.analyzeAndGenerate(arrayOf(5.0))
)
assertEquals(
"arrayOf(\"1\")",
ConstructorCallCodegenFacade.analyzeAndGenerate(arrayOf("1"))
)
assertEquals(
"arrayOf(arrayOf(1))",
ConstructorCallCodegenFacade.analyzeAndGenerate(arrayOf(arrayOf(1)))
)
}
data class TestClass(
val value: String
)
data class TestClass2(
val value: Any
)
@Test
fun `test class`() {
assertEquals(
"""
${TestClass::class.qualifiedName!!}(
value="test",
)
""".trimIndent(),
ConstructorCallCodegenFacade.analyzeAndGenerate(TestClass("test"))
)
assertEquals(
"""
${TestClass2::class.qualifiedName!!}(
value="test",
)
""".trimIndent(),
ConstructorCallCodegenFacade.analyzeAndGenerate(TestClass2("test"))
)
assertEquals(
"""
${TestClass2::class.qualifiedName!!}(
value=1,
)
""".trimIndent(),
ConstructorCallCodegenFacade.analyzeAndGenerate(TestClass2(1))
)
}
data class TestNesting(
val nested: Nested
) {
data class Nested(
val value: String
)
}
@Test
fun `test nesting`() {
assertEquals(
"""
net.mamoe.mirai.internal.utils.codegen.test.ConstructorCallCodegenTest.TestNesting(
nested=net.mamoe.mirai.internal.utils.codegen.test.ConstructorCallCodegenTest.TestNesting.Nested(
value="test",
),
)
""".trimIndent(),
ConstructorCallCodegenFacade.analyzeAndGenerate(TestNesting(TestNesting.Nested("test")))
)
}
}