Jce null support

This commit is contained in:
Him188 2020-01-30 17:18:49 +08:00
parent b9e9c00748
commit 3a11adf1c0
5 changed files with 116 additions and 128 deletions

View File

@ -1,43 +1,6 @@
package net.mamoe.mirai.qqandroid.io
import kotlinx.io.core.BytePacketBuilder
import kotlinx.io.core.Input
import kotlinx.io.core.readBytes
import kotlinx.io.core.writeFully
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerializationStrategy
import net.mamoe.mirai.qqandroid.io.serialization.ProtoBufWithNullableSupport
import net.mamoe.mirai.utils.io.toUHexString
/**
* 仅有标示作用
*/
interface ProtoBuf
fun <T : ProtoBuf> BytePacketBuilder.writeProtoBuf(serializer: SerializationStrategy<T>, v: T) {
this.writeFully(v.toByteArray(serializer).also {
println("发送 protobuf: ${it.toUHexString()}")
})
}
/**
* dump
*/
fun <T : ProtoBuf> T.toByteArray(serializer: SerializationStrategy<T>): ByteArray {
return ProtoBufWithNullableSupport.dump(serializer, this)
}
/**
* load
*/
fun <T : ProtoBuf> ByteArray.loadAs(deserializer: DeserializationStrategy<T>): T {
return ProtoBufWithNullableSupport.load(deserializer, this)
}
/**
* load
*/
fun <T : ProtoBuf> Input.readRemainingAsProtoBuf(serializer: DeserializationStrategy<T>): T {
return ProtoBufWithNullableSupport.load(serializer, this.readBytes())
}

View File

@ -8,57 +8,18 @@ import kotlinx.serialization.modules.EmptyModule
import kotlinx.serialization.modules.SerialModule
import net.mamoe.mirai.qqandroid.io.JceStruct
import net.mamoe.mirai.qqandroid.io.ProtoBuf
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestDataVersion2
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestDataVersion3
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestPacket
import net.mamoe.mirai.qqandroid.network.protocol.packet.withUse
import net.mamoe.mirai.utils.firstValue
import net.mamoe.mirai.utils.io.read
import net.mamoe.mirai.utils.io.readIoBuffer
import net.mamoe.mirai.utils.io.readString
import net.mamoe.mirai.utils.io.toIoBuffer
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
@PublishedApi
internal val CharsetGBK = Charset.forName("GBK")
@PublishedApi
internal val CharsetUTF8 = Charset.forName("UTF8")
fun <T> ByteArray.loadAs(deserializer: DeserializationStrategy<T>, c: JceCharset = JceCharset.UTF8): T {
return Jce.byCharSet(c).load(deserializer, this)
}
fun <T> BytePacketBuilder.writeJceStruct(serializer: SerializationStrategy<T>, struct: T, charset: JceCharset = JceCharset.GBK) {
this.writePacket(Jce.byCharSet(charset).dumpAsPacket(serializer, struct))
}
fun <T> ByteReadPacket.readRemainingAsJceStruct(serializer: DeserializationStrategy<T>, charset: JceCharset = JceCharset.UTF8): T {
return Jce.byCharSet(charset).load(serializer, this)
}
/**
* 先解析为 [RequestPacket], `UniRequest`, 再按版本解析 map, 再找出指定数据并反序列化
*/
fun <T : JceStruct> ByteReadPacket.decodeUniPacket(deserializer: DeserializationStrategy<T>, name: String? = null): T {
val request = this.readRemainingAsJceStruct(RequestPacket.serializer())
fun ByteArray.doReadInner(): T = read {
discardExact(1)
this.readRemainingAsJceStruct(deserializer)
}
return if (name == null) when (request.iVersion.toInt()) {
2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.firstValue().firstValue().doReadInner()
3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.firstValue().doReadInner()
else -> error("unsupported version ${request.iVersion}")
} else when (request.iVersion.toInt()) {
2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.getOrElse(name) { error("cannot find $name") }.firstValue().doReadInner()
3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.getOrElse(name) { error("cannot find $name") }.doReadInner()
else -> error("unsupported version ${request.iVersion}")
}
}
fun <T : JceStruct> T.toByteArray(serializer: SerializationStrategy<T>, c: JceCharset = JceCharset.GBK): ByteArray = Jce.byCharSet(c).dump(serializer, this)
enum class JceCharset(val kotlinCharset: Charset) {
GBK(Charset.forName("GBK")),
UTF8(Charset.forName("UTF8"))
@ -410,11 +371,6 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
is MapLikeDescriptor -> {
val tag = currentTagOrNull
if (tag != null && input.skipToTagOrNull(tag) { popTag() } == null && desc.isNullable) {
return NullReader(this.input)
}
if (tag != null) {
popTag()
}
@ -461,7 +417,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
@Suppress("UNCHECKED_CAST")
override fun <T : Any> decodeNullableSerializableValue(deserializer: DeserializationStrategy<T?>): T? {
//println("decodeNullableSerializableValue: ${deserializer.getClassName()}")
println("decodeNullableSerializableValue: ${deserializer.getClassName()}")
if (deserializer is NullReader) {
return null
}
@ -487,20 +443,40 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
return if (isTagOptional(tag)) input.readByteArrayOrNull(tag)?.toMutableList() as? T
else input.readByteArray(tag).toMutableList() as T
}
return super.decodeSerializableValue(deserializer)
val tag = popTag()
println(tag)
@Suppress("SENSELESS_COMPARISON") // false positive
if (input.skipToTagOrNull(tag) {
return deserializer.deserialize(JceListReader(input.readInt(0), input))
} == null) {
if (isTagOptional(tag)) {
return null
} else error("property is notnull but cannot find tag $tag")
}
error("UNREACHABLE CODE")
}
is MapLikeDescriptor -> {
// 将 mapOf(k1 to v1, k2 to v2, ...) 转换为 listOf(k1, v1, k2, v2, ...) 以便于写入.
val serializer = (deserializer as MapLikeSerializer<Any?, Any?, T, *>)
val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
val setOfEntries = HashSetSerializer(mapEntrySerial).deserialize(this)
return setOfEntries.associateBy({ it.key }, { it.value }) as T
val tag = popTag()
@Suppress("SENSELESS_COMPARISON")
if (input.skipToTagOrNull(tag) {
// 将 mapOf(k1 to v1, k2 to v2, ...) 转换为 listOf(k1, v1, k2, v2, ...) 以便于写入.
val serializer = (deserializer as MapLikeSerializer<Any?, Any?, T, *>)
val mapEntrySerial = MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
val setOfEntries = HashSetSerializer(mapEntrySerial).deserialize(JceMapReader(input.readInt(0), input))
return setOfEntries.associateBy({ it.key }, { it.value }) as T
} == null) {
if (isTagOptional(tag)) {
return null
} else error("property is notnull but cannot find tag $tag")
}
error("UNREACHABLE CODE")
}
}
val tag = currentTagOrNull ?: return deserializer.deserialize(this)
return if (this.decodeTaggedNotNullMark(tag)) {
deserializer.deserialize(this)
} else {
// popTag()
null
}
}
@ -514,7 +490,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
@UseExperimental(ExperimentalUnsignedTypes::class)
private inner class JceInput(
internal inner class JceInput(
@PublishedApi
internal val input: IoBuffer
) : Closeable {
@ -732,24 +708,6 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
else -> error("invalid type: $type")
}
internal inline fun <R> skipToTagOrNull(tag: Int, block: (JceHead) -> R): R? {
while (true) {
if (this.input.endOfInput) {
return null
}
val head = peakHead()
if (head.tag > tag) {
return null
}
readHead()
if (head.tag == tag) {
return block(head)
}
this.skipField(head.type)
}
}
}
companion object {
@ -809,6 +767,28 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
}
}
@UseExperimental(ExperimentalContracts::class)
internal inline fun <R> Jce.JceInput.skipToTagOrNull(tag: Int, block: (JceHead) -> R): R? {
contract {
callsInPlace(block, kotlin.contracts.InvocationKind.UNKNOWN)
}
while (true) {
if (this.input.endOfInput) {
return null
}
val head = peakHead()
if (head.tag > tag) {
return null
}
readHead()
if (head.tag == tag) {
return block(head)
}
this.skipField(head.type)
}
}
@UseExperimental(ExperimentalUnsignedTypes::class)
inline class JceHead(private val value: Long) {
constructor(tag: Int, type: Byte) : this(tag.toLong().shl(32) or type.toLong())

View File

@ -9,9 +9,9 @@ import net.mamoe.mirai.message.FriendMessage
import net.mamoe.mirai.message.data.MessageChain
import net.mamoe.mirai.qqandroid.QQAndroidBot
import net.mamoe.mirai.qqandroid.event.ForceOfflineEvent
import net.mamoe.mirai.qqandroid.io.readRemainingAsProtoBuf
import net.mamoe.mirai.qqandroid.io.serialization.decodeUniPacket
import net.mamoe.mirai.qqandroid.io.writeProtoBuf
import net.mamoe.mirai.qqandroid.io.serialization.readRemainingAsProtoBuf
import net.mamoe.mirai.qqandroid.io.serialization.writeProtoBuf
import net.mamoe.mirai.qqandroid.network.QQAndroidClient
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestPushForceOffline
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.RequestPushNotify

View File

@ -5,7 +5,6 @@ import net.mamoe.mirai.data.Packet
import net.mamoe.mirai.qqandroid.QQAndroidBot
import net.mamoe.mirai.qqandroid.io.serialization.toByteArray
import net.mamoe.mirai.qqandroid.io.serialization.writeJceStruct
import net.mamoe.mirai.qqandroid.io.toByteArray
import net.mamoe.mirai.qqandroid.io.writeJcePacket
import net.mamoe.mirai.qqandroid.network.QQAndroidClient
import net.mamoe.mirai.qqandroid.network.protocol.data.jce.GetFriendListReq

View File

@ -8,9 +8,9 @@ import net.mamoe.mirai.qqandroid.io.JceStruct
import net.mamoe.mirai.qqandroid.io.buildJcePacket
import net.mamoe.mirai.utils.cryptor.contentToString
import kotlin.test.Test
import kotlin.test.assertEquals
class JceDecoderTest {
@Serializable
class TestSimpleJceStruct(
@SerialId(0) val string: String = "123",
@ -37,7 +37,7 @@ class JceDecoderTest {
class TestComplexJceStruct(
@SerialId(6) val string: String = "haha",
@SerialId(7) val byteArray: ByteArray = ByteArray(500),
@SerialId(8) val byteList: List<Byte> = listOf(1, 2, 3), // error here
@SerialId(8) val byteList: List<Long> = listOf(1, 2, 3), // error here
@SerialId(9) val map: Map<String, Map<String, ByteArray>> = mapOf("哈哈" to mapOf("哈哈" to byteArrayOf(1, 2, 3))),
// @SerialId(10) val nestedJceStruct: TestSimpleJceStruct = TestSimpleJceStruct(),
@SerialId(11) val byteList2: List<List<Int>> = listOf(listOf(1, 2, 3), listOf(1, 2, 3))
@ -47,7 +47,7 @@ class JceDecoderTest {
class TestComplexNullableJceStruct(
@SerialId(6) val string: String = "haha",
@SerialId(7) val byteArray: ByteArray = ByteArray(2000),
@SerialId(8) val byteList: List<Byte>? = listOf(1, 2, 3), // error here
@SerialId(8) val byteList: List<Long>? = listOf(1, 2, 3), // error here
@SerialId(9) val map: Map<String, Map<String, ByteArray>>? = mapOf("哈哈" to mapOf("哈哈" to byteArrayOf(1, 2, 3))),
@SerialId(10) val nestedJceStruct: TestComplexJceStruct? = TestComplexJceStruct(),
@SerialId(11) val byteList2: List<List<Int>>? = listOf(listOf(1, 2, 3), listOf(1, 2, 3))
@ -78,7 +78,7 @@ class JceDecoderTest {
@Serializable
class TestNestedList(
@SerialId(7) val array: List<List<Int>> = listOf(listOf(1, 2, 3), listOf(1, 2, 3), listOf(1, 2, 3))
)
) : JceStruct
println(buildJcePacket {
writeCollection(listOf(listOf(1, 2, 3), listOf(1, 2, 3), listOf(1, 2, 3)), 7)
@ -90,7 +90,7 @@ class JceDecoderTest {
@Serializable
class TestNestedArray(
@SerialId(7) val array: Array<Array<Int>> = arrayOf(arrayOf(1, 2, 3), arrayOf(1, 2, 3), arrayOf(1, 2, 3))
)
) : JceStruct
println(buildJcePacket {
writeFully(arrayOf(arrayOf(1, 2, 3), arrayOf(1, 2, 3), arrayOf(1, 2, 3)), 7)
@ -101,24 +101,24 @@ class JceDecoderTest {
fun testSimpleMap() {
@Serializable
class TestSimpleMap(
data class TestSimpleMap(
@SerialId(7) val map: Map<String, Long> = mapOf("byteArrayOf(1)" to 2222L)
)
println(buildJcePacket {
) : JceStruct
assertEquals(buildJcePacket {
writeMap(mapOf("byteArrayOf(1)" to 2222), 7)
}.readBytes().loadAs(TestSimpleMap.serializer()).contentToString())
}.readBytes().loadAs(TestSimpleMap.serializer()).toString(), TestSimpleMap().toString())
}
@Test
fun testSimpleList() {
@Serializable
class TestSimpleList(
data class TestSimpleList(
@SerialId(7) val list: List<String> = listOf("asd", "asdasdasd")
)
println(buildJcePacket {
) : JceStruct
assertEquals(buildJcePacket {
writeCollection(listOf("asd", "asdasdasd"), 7)
}.readBytes().loadAs(TestSimpleList.serializer()).contentToString())
}.readBytes().loadAs(TestSimpleList.serializer()).toString(), TestSimpleList().toString())
}
@Test
@ -126,9 +126,55 @@ class JceDecoderTest {
@Serializable
class TestNestedMap(
@SerialId(7) val map: Map<ByteArray, Map<ByteArray, ShortArray>> = mapOf(byteArrayOf(1) to mapOf(byteArrayOf(1) to shortArrayOf(2)))
)
println(buildJcePacket {
) : JceStruct
assertEquals(buildJcePacket {
writeMap(mapOf(byteArrayOf(1) to mapOf(byteArrayOf(1) to shortArrayOf(2))), 7)
}.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString())
}.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString(), "{01=[0x0002(2)]}")
}
@Test
fun testNullableEncode() {
@Serializable
data class AllNullJce(
@SerialId(6) val string: String? = null,
@SerialId(7) val byteArray: ByteArray? = null,
@SerialId(8) val byteList: List<Long>? = null,
@SerialId(9) val map: Map<String, Map<String, ByteArray>>? = null,
@SerialId(10) val nestedJceStruct: TestComplexJceStruct? = null,
@SerialId(11) val byteList2: List<List<Int>>? = null
) : JceStruct {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as AllNullJce
if (string != other.string) return false
if (byteArray != null) {
if (other.byteArray == null) return false
if (!byteArray.contentEquals(other.byteArray)) return false
} else if (other.byteArray != null) return false
if (byteList != other.byteList) return false
if (map != other.map) return false
if (nestedJceStruct != other.nestedJceStruct) return false
if (byteList2 != other.byteList2) return false
return true
}
override fun hashCode(): Int {
var result = string?.hashCode() ?: 0
result = 31 * result + (byteArray?.contentHashCode() ?: 0)
result = 31 * result + (byteList?.hashCode() ?: 0)
result = 31 * result + (map?.hashCode() ?: 0)
result = 31 * result + (nestedJceStruct?.hashCode() ?: 0)
result = 31 * result + (byteList2?.hashCode() ?: 0)
return result
}
}
assert(AllNullJce().toByteArray(AllNullJce.serializer()).isEmpty())
assertEquals(ByteArray(0).loadAs(AllNullJce.serializer()), AllNullJce())
}
}