mirror of
https://github.com/mamoe/mirai.git
synced 2025-01-23 14:20:24 +08:00
Jce null support
This commit is contained in:
parent
b9e9c00748
commit
3a11adf1c0
@ -1,43 +1,6 @@
|
||||
package net.mamoe.mirai.qqandroid.io
|
||||
|
||||
import kotlinx.io.core.BytePacketBuilder
|
||||
import kotlinx.io.core.Input
|
||||
import kotlinx.io.core.readBytes
|
||||
import kotlinx.io.core.writeFully
|
||||
import kotlinx.serialization.DeserializationStrategy
|
||||
import kotlinx.serialization.SerializationStrategy
|
||||
import net.mamoe.mirai.qqandroid.io.serialization.ProtoBufWithNullableSupport
|
||||
import net.mamoe.mirai.utils.io.toUHexString
|
||||
|
||||
/**
|
||||
* 仅有标示作用
|
||||
*/
|
||||
interface ProtoBuf
|
||||
|
||||
fun <T : ProtoBuf> BytePacketBuilder.writeProtoBuf(serializer: SerializationStrategy<T>, v: T) {
|
||||
|
||||
this.writeFully(v.toByteArray(serializer).also {
|
||||
println("发送 protobuf: ${it.toUHexString()}")
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* dump
|
||||
*/
|
||||
fun <T : ProtoBuf> T.toByteArray(serializer: SerializationStrategy<T>): ByteArray {
|
||||
return ProtoBufWithNullableSupport.dump(serializer, this)
|
||||
}
|
||||
|
||||
/**
|
||||
* load
|
||||
*/
|
||||
fun <T : ProtoBuf> ByteArray.loadAs(deserializer: DeserializationStrategy<T>): T {
|
||||
return ProtoBufWithNullableSupport.load(deserializer, this)
|
||||
}
|
||||
|
||||
/**
|
||||
* load
|
||||
*/
|
||||
fun <T : ProtoBuf> Input.readRemainingAsProtoBuf(serializer: DeserializationStrategy<T>): T {
|
||||
return ProtoBufWithNullableSupport.load(serializer, this.readBytes())
|
||||
}
|
@ -8,57 +8,18 @@ import kotlinx.serialization.modules.EmptyModule
|
||||
import kotlinx.serialization.modules.SerialModule
|
||||
import net.mamoe.mirai.qqandroid.io.JceStruct
|
||||
import net.mamoe.mirai.qqandroid.io.ProtoBuf
|
||||
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestDataVersion2
|
||||
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestDataVersion3
|
||||
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestPacket
|
||||
import net.mamoe.mirai.qqandroid.network.protocol.packet.withUse
|
||||
import net.mamoe.mirai.utils.firstValue
|
||||
import net.mamoe.mirai.utils.io.read
|
||||
import net.mamoe.mirai.utils.io.readIoBuffer
|
||||
import net.mamoe.mirai.utils.io.readString
|
||||
import net.mamoe.mirai.utils.io.toIoBuffer
|
||||
import kotlin.contracts.ExperimentalContracts
|
||||
import kotlin.contracts.contract
|
||||
|
||||
@PublishedApi
|
||||
internal val CharsetGBK = Charset.forName("GBK")
|
||||
@PublishedApi
|
||||
internal val CharsetUTF8 = Charset.forName("UTF8")
|
||||
|
||||
fun <T> ByteArray.loadAs(deserializer: DeserializationStrategy<T>, c: JceCharset = JceCharset.UTF8): T {
|
||||
return Jce.byCharSet(c).load(deserializer, this)
|
||||
}
|
||||
|
||||
fun <T> BytePacketBuilder.writeJceStruct(serializer: SerializationStrategy<T>, struct: T, charset: JceCharset = JceCharset.GBK) {
|
||||
this.writePacket(Jce.byCharSet(charset).dumpAsPacket(serializer, struct))
|
||||
}
|
||||
|
||||
fun <T> ByteReadPacket.readRemainingAsJceStruct(serializer: DeserializationStrategy<T>, charset: JceCharset = JceCharset.UTF8): T {
|
||||
return Jce.byCharSet(charset).load(serializer, this)
|
||||
}
|
||||
|
||||
/**
|
||||
* 先解析为 [RequestPacket], 即 `UniRequest`, 再按版本解析 map, 再找出指定数据并反序列化
|
||||
*/
|
||||
fun <T : JceStruct> ByteReadPacket.decodeUniPacket(deserializer: DeserializationStrategy<T>, name: String? = null): T {
|
||||
val request = this.readRemainingAsJceStruct(RequestPacket.serializer())
|
||||
|
||||
fun ByteArray.doReadInner(): T = read {
|
||||
discardExact(1)
|
||||
this.readRemainingAsJceStruct(deserializer)
|
||||
}
|
||||
|
||||
return if (name == null) when (request.iVersion.toInt()) {
|
||||
2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.firstValue().firstValue().doReadInner()
|
||||
3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.firstValue().doReadInner()
|
||||
else -> error("unsupported version ${request.iVersion}")
|
||||
} else when (request.iVersion.toInt()) {
|
||||
2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.getOrElse(name) { error("cannot find $name") }.firstValue().doReadInner()
|
||||
3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.getOrElse(name) { error("cannot find $name") }.doReadInner()
|
||||
else -> error("unsupported version ${request.iVersion}")
|
||||
}
|
||||
}
|
||||
|
||||
fun <T : JceStruct> T.toByteArray(serializer: SerializationStrategy<T>, c: JceCharset = JceCharset.GBK): ByteArray = Jce.byCharSet(c).dump(serializer, this)
|
||||
|
||||
enum class JceCharset(val kotlinCharset: Charset) {
|
||||
GBK(Charset.forName("GBK")),
|
||||
UTF8(Charset.forName("UTF8"))
|
||||
@ -410,11 +371,6 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
|
||||
|
||||
is MapLikeDescriptor -> {
|
||||
val tag = currentTagOrNull
|
||||
|
||||
if (tag != null && input.skipToTagOrNull(tag) { popTag() } == null && desc.isNullable) {
|
||||
return NullReader(this.input)
|
||||
}
|
||||
|
||||
if (tag != null) {
|
||||
popTag()
|
||||
}
|
||||
@ -461,7 +417,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <T : Any> decodeNullableSerializableValue(deserializer: DeserializationStrategy<T?>): T? {
|
||||
//println("decodeNullableSerializableValue: ${deserializer.getClassName()}")
|
||||
println("decodeNullableSerializableValue: ${deserializer.getClassName()}")
|
||||
if (deserializer is NullReader) {
|
||||
return null
|
||||
}
|
||||
@ -487,20 +443,40 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
|
||||
return if (isTagOptional(tag)) input.readByteArrayOrNull(tag)?.toMutableList() as? T
|
||||
else input.readByteArray(tag).toMutableList() as T
|
||||
}
|
||||
return super.decodeSerializableValue(deserializer)
|
||||
val tag = popTag()
|
||||
println(tag)
|
||||
@Suppress("SENSELESS_COMPARISON") // false positive
|
||||
if (input.skipToTagOrNull(tag) {
|
||||
return deserializer.deserialize(JceListReader(input.readInt(0), input))
|
||||
} == null) {
|
||||
if (isTagOptional(tag)) {
|
||||
return null
|
||||
} else error("property is notnull but cannot find tag $tag")
|
||||
}
|
||||
error("UNREACHABLE CODE")
|
||||
}
|
||||
is MapLikeDescriptor -> {
|
||||
// 将 mapOf(k1 to v1, k2 to v2, ...) 转换为 listOf(k1, v1, k2, v2, ...) 以便于写入.
|
||||
val serializer = (deserializer as MapLikeSerializer<Any?, Any?, T, *>)
|
||||
val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
|
||||
val setOfEntries = HashSetSerializer(mapEntrySerial).deserialize(this)
|
||||
return setOfEntries.associateBy({ it.key }, { it.value }) as T
|
||||
val tag = popTag()
|
||||
@Suppress("SENSELESS_COMPARISON")
|
||||
if (input.skipToTagOrNull(tag) {
|
||||
// 将 mapOf(k1 to v1, k2 to v2, ...) 转换为 listOf(k1, v1, k2, v2, ...) 以便于写入.
|
||||
val serializer = (deserializer as MapLikeSerializer<Any?, Any?, T, *>)
|
||||
val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
|
||||
val setOfEntries = HashSetSerializer(mapEntrySerial).deserialize(JceMapReader(input.readInt(0), input))
|
||||
return setOfEntries.associateBy({ it.key }, { it.value }) as T
|
||||
} == null) {
|
||||
if (isTagOptional(tag)) {
|
||||
return null
|
||||
} else error("property is notnull but cannot find tag $tag")
|
||||
}
|
||||
error("UNREACHABLE CODE")
|
||||
}
|
||||
}
|
||||
val tag = currentTagOrNull ?: return deserializer.deserialize(this)
|
||||
return if (this.decodeTaggedNotNullMark(tag)) {
|
||||
deserializer.deserialize(this)
|
||||
} else {
|
||||
// popTag()
|
||||
null
|
||||
}
|
||||
}
|
||||
@ -514,7 +490,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
|
||||
|
||||
|
||||
@UseExperimental(ExperimentalUnsignedTypes::class)
|
||||
private inner class JceInput(
|
||||
internal inner class JceInput(
|
||||
@PublishedApi
|
||||
internal val input: IoBuffer
|
||||
) : Closeable {
|
||||
@ -732,24 +708,6 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
|
||||
else -> error("invalid type: $type")
|
||||
}
|
||||
|
||||
internal inline fun <R> skipToTagOrNull(tag: Int, block: (JceHead) -> R): R? {
|
||||
while (true) {
|
||||
if (this.input.endOfInput) {
|
||||
return null
|
||||
}
|
||||
|
||||
val head = peakHead()
|
||||
if (head.tag > tag) {
|
||||
return null
|
||||
}
|
||||
readHead()
|
||||
if (head.tag == tag) {
|
||||
return block(head)
|
||||
}
|
||||
this.skipField(head.type)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
companion object {
|
||||
@ -809,6 +767,28 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
|
||||
}
|
||||
}
|
||||
|
||||
@UseExperimental(ExperimentalContracts::class)
|
||||
internal inline fun <R> Jce.JceInput.skipToTagOrNull(tag: Int, block: (JceHead) -> R): R? {
|
||||
contract {
|
||||
callsInPlace(block, kotlin.contracts.InvocationKind.UNKNOWN)
|
||||
}
|
||||
while (true) {
|
||||
if (this.input.endOfInput) {
|
||||
return null
|
||||
}
|
||||
|
||||
val head = peakHead()
|
||||
if (head.tag > tag) {
|
||||
return null
|
||||
}
|
||||
readHead()
|
||||
if (head.tag == tag) {
|
||||
return block(head)
|
||||
}
|
||||
this.skipField(head.type)
|
||||
}
|
||||
}
|
||||
|
||||
@UseExperimental(ExperimentalUnsignedTypes::class)
|
||||
inline class JceHead(private val value: Long) {
|
||||
constructor(tag: Int, type: Byte) : this(tag.toLong().shl(32) or type.toLong())
|
||||
|
@ -9,9 +9,9 @@ import net.mamoe.mirai.message.FriendMessage
|
||||
import net.mamoe.mirai.message.data.MessageChain
|
||||
import net.mamoe.mirai.qqandroid.QQAndroidBot
|
||||
import net.mamoe.mirai.qqandroid.event.ForceOfflineEvent
|
||||
import net.mamoe.mirai.qqandroid.io.readRemainingAsProtoBuf
|
||||
import net.mamoe.mirai.qqandroid.io.serialization.decodeUniPacket
|
||||
import net.mamoe.mirai.qqandroid.io.writeProtoBuf
|
||||
import net.mamoe.mirai.qqandroid.io.serialization.readRemainingAsProtoBuf
|
||||
import net.mamoe.mirai.qqandroid.io.serialization.writeProtoBuf
|
||||
import net.mamoe.mirai.qqandroid.network.QQAndroidClient
|
||||
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestPushForceOffline
|
||||
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestPushNotify
|
||||
|
@ -5,7 +5,6 @@ import net.mamoe.mirai.data.Packet
|
||||
import net.mamoe.mirai.qqandroid.QQAndroidBot
|
||||
import net.mamoe.mirai.qqandroid.io.serialization.toByteArray
|
||||
import net.mamoe.mirai.qqandroid.io.serialization.writeJceStruct
|
||||
import net.mamoe.mirai.qqandroid.io.toByteArray
|
||||
import net.mamoe.mirai.qqandroid.io.writeJcePacket
|
||||
import net.mamoe.mirai.qqandroid.network.QQAndroidClient
|
||||
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.GetFriendListReq
|
||||
|
@ -8,9 +8,9 @@ import net.mamoe.mirai.qqandroid.io.JceStruct
|
||||
import net.mamoe.mirai.qqandroid.io.buildJcePacket
|
||||
import net.mamoe.mirai.utils.cryptor.contentToString
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class JceDecoderTest {
|
||||
|
||||
@Serializable
|
||||
class TestSimpleJceStruct(
|
||||
@SerialId(0) val string: String = "123",
|
||||
@ -37,7 +37,7 @@ class JceDecoderTest {
|
||||
class TestComplexJceStruct(
|
||||
@SerialId(6) val string: String = "haha",
|
||||
@SerialId(7) val byteArray: ByteArray = ByteArray(500),
|
||||
@SerialId(8) val byteList: List<Byte> = listOf(1, 2, 3), // error here
|
||||
@SerialId(8) val byteList: List<Long> = listOf(1, 2, 3), // error here
|
||||
@SerialId(9) val map: Map<String, Map<String, ByteArray>> = mapOf("哈哈" to mapOf("哈哈" to byteArrayOf(1, 2, 3))),
|
||||
// @SerialId(10) val nestedJceStruct: TestSimpleJceStruct = TestSimpleJceStruct(),
|
||||
@SerialId(11) val byteList2: List<List<Int>> = listOf(listOf(1, 2, 3), listOf(1, 2, 3))
|
||||
@ -47,7 +47,7 @@ class JceDecoderTest {
|
||||
class TestComplexNullableJceStruct(
|
||||
@SerialId(6) val string: String = "haha",
|
||||
@SerialId(7) val byteArray: ByteArray = ByteArray(2000),
|
||||
@SerialId(8) val byteList: List<Byte>? = listOf(1, 2, 3), // error here
|
||||
@SerialId(8) val byteList: List<Long>? = listOf(1, 2, 3), // error here
|
||||
@SerialId(9) val map: Map<String, Map<String, ByteArray>>? = mapOf("哈哈" to mapOf("哈哈" to byteArrayOf(1, 2, 3))),
|
||||
@SerialId(10) val nestedJceStruct: TestComplexJceStruct? = TestComplexJceStruct(),
|
||||
@SerialId(11) val byteList2: List<List<Int>>? = listOf(listOf(1, 2, 3), listOf(1, 2, 3))
|
||||
@ -78,7 +78,7 @@ class JceDecoderTest {
|
||||
@Serializable
|
||||
class TestNestedList(
|
||||
@SerialId(7) val array: List<List<Int>> = listOf(listOf(1, 2, 3), listOf(1, 2, 3), listOf(1, 2, 3))
|
||||
)
|
||||
) : JceStruct
|
||||
|
||||
println(buildJcePacket {
|
||||
writeCollection(listOf(listOf(1, 2, 3), listOf(1, 2, 3), listOf(1, 2, 3)), 7)
|
||||
@ -90,7 +90,7 @@ class JceDecoderTest {
|
||||
@Serializable
|
||||
class TestNestedArray(
|
||||
@SerialId(7) val array: Array<Array<Int>> = arrayOf(arrayOf(1, 2, 3), arrayOf(1, 2, 3), arrayOf(1, 2, 3))
|
||||
)
|
||||
) : JceStruct
|
||||
|
||||
println(buildJcePacket {
|
||||
writeFully(arrayOf(arrayOf(1, 2, 3), arrayOf(1, 2, 3), arrayOf(1, 2, 3)), 7)
|
||||
@ -101,24 +101,24 @@ class JceDecoderTest {
|
||||
fun testSimpleMap() {
|
||||
|
||||
@Serializable
|
||||
class TestSimpleMap(
|
||||
data class TestSimpleMap(
|
||||
@SerialId(7) val map: Map<String, Long> = mapOf("byteArrayOf(1)" to 2222L)
|
||||
)
|
||||
println(buildJcePacket {
|
||||
) : JceStruct
|
||||
assertEquals(buildJcePacket {
|
||||
writeMap(mapOf("byteArrayOf(1)" to 2222), 7)
|
||||
}.readBytes().loadAs(TestSimpleMap.serializer()).contentToString())
|
||||
}.readBytes().loadAs(TestSimpleMap.serializer()).toString(), TestSimpleMap().toString())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSimpleList() {
|
||||
|
||||
@Serializable
|
||||
class TestSimpleList(
|
||||
data class TestSimpleList(
|
||||
@SerialId(7) val list: List<String> = listOf("asd", "asdasdasd")
|
||||
)
|
||||
println(buildJcePacket {
|
||||
) : JceStruct
|
||||
assertEquals(buildJcePacket {
|
||||
writeCollection(listOf("asd", "asdasdasd"), 7)
|
||||
}.readBytes().loadAs(TestSimpleList.serializer()).contentToString())
|
||||
}.readBytes().loadAs(TestSimpleList.serializer()).toString(), TestSimpleList().toString())
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -126,9 +126,55 @@ class JceDecoderTest {
|
||||
@Serializable
|
||||
class TestNestedMap(
|
||||
@SerialId(7) val map: Map<ByteArray, Map<ByteArray, ShortArray>> = mapOf(byteArrayOf(1) to mapOf(byteArrayOf(1) to shortArrayOf(2)))
|
||||
)
|
||||
println(buildJcePacket {
|
||||
) : JceStruct
|
||||
assertEquals(buildJcePacket {
|
||||
writeMap(mapOf(byteArrayOf(1) to mapOf(byteArrayOf(1) to shortArrayOf(2))), 7)
|
||||
}.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString())
|
||||
}.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString(), "{01=[0x0002(2)]}")
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun testNullableEncode() {
|
||||
@Serializable
|
||||
data class AllNullJce(
|
||||
@SerialId(6) val string: String? = null,
|
||||
@SerialId(7) val byteArray: ByteArray? = null,
|
||||
@SerialId(8) val byteList: List<Long>? = null,
|
||||
@SerialId(9) val map: Map<String, Map<String, ByteArray>>? = null,
|
||||
@SerialId(10) val nestedJceStruct: TestComplexJceStruct? = null,
|
||||
@SerialId(11) val byteList2: List<List<Int>>? = null
|
||||
) : JceStruct {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (javaClass != other?.javaClass) return false
|
||||
|
||||
other as AllNullJce
|
||||
|
||||
if (string != other.string) return false
|
||||
if (byteArray != null) {
|
||||
if (other.byteArray == null) return false
|
||||
if (!byteArray.contentEquals(other.byteArray)) return false
|
||||
} else if (other.byteArray != null) return false
|
||||
if (byteList != other.byteList) return false
|
||||
if (map != other.map) return false
|
||||
if (nestedJceStruct != other.nestedJceStruct) return false
|
||||
if (byteList2 != other.byteList2) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = string?.hashCode() ?: 0
|
||||
result = 31 * result + (byteArray?.contentHashCode() ?: 0)
|
||||
result = 31 * result + (byteList?.hashCode() ?: 0)
|
||||
result = 31 * result + (map?.hashCode() ?: 0)
|
||||
result = 31 * result + (nestedJceStruct?.hashCode() ?: 0)
|
||||
result = 31 * result + (byteList2?.hashCode() ?: 0)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
assert(AllNullJce().toByteArray(AllNullJce.serializer()).isEmpty())
|
||||
assertEquals(ByteArray(0).loadAs(AllNullJce.serializer()), AllNullJce())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user