From f0702a7ea92e5710714ecd36e7bc2a1431f0fae3 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Wed, 20 Nov 2019 19:15:11 +0800
Subject: [PATCH] Add more utilities

---
 .../net.mamoe.mirai/utils/io/InputUtils.kt    | 32 ++++++++++++++++---
 .../net.mamoe.mirai/utils/io/OutputUtils.kt   | 13 +++++---
 2 files changed, 36 insertions(+), 9 deletions(-)

diff --git a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/io/InputUtils.kt b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/io/InputUtils.kt
index 79d1d6c46..8d36a8697 100644
--- a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/io/InputUtils.kt
+++ b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/io/InputUtils.kt
@@ -15,7 +15,7 @@ fun ByteReadPacket.readIoBuffer(
     n: Int = remaining.toInt()//not that safe but adequate
 ): IoBuffer = IoBuffer.Pool.borrow().also { this.readFully(it, n) }
 
-fun ByteReadPacket.readIoBuffer(n: Number) = this.readIoBuffer(n.toInt())
+fun ByteReadPacket.readIoBuffer(n: Short) = this.readIoBuffer(n.toInt())
 
 fun Input.readIP(): String = buildString(4 + 3) {
     repeat(4) {
@@ -25,6 +25,8 @@ fun Input.readIP(): String = buildString(4 + 3) {
     }
 }
 
+fun Input.readPacket(length: Int): ByteReadPacket = this.readBytes(length).toReadPacket()
+
 fun Input.readUVarIntLVString(): String = String(this.readUVarIntByteArray())
 
 fun Input.readUShortLVString(): String = String(this.readUShortLVByteArray())
@@ -62,7 +64,14 @@ fun Input.readTLVMap(expectingEOF: Boolean = false, tagSize: Int = 1): MutableMa
                     ", duplicating value=${this.readUShortLVByteArray()}" +
                     ", remaining=" + if (expectingEOF) this.readBytes().toUHexString() else "[Not expecting EOF]"
         }
-        map[type.toUInt()] = this.readUShortLVByteArray()
+        try {
+            map[type.toUInt()] = this.readUShortLVByteArray()
+        } catch (e: RuntimeException) { // BufferUnderflowException
+            if (expectingEOF) {
+                return map
+            }
+            throw e
+        }
     }
     return map
 }
@@ -103,11 +112,24 @@ fun Input.readFlatTUVarIntMap(expectingEOF: Boolean = false, tagSize: Int = 1):
     return map
 }
 
-fun Map<UInt, ByteArray>.printTLVMap(name: String) =
-    debugPrintln("TLVMap $name= " + this.mapValues { (_, value) -> value.toUHexString() }.mapKeys { it.key.toInt().toUShort().toUHexString() })
+fun Map<UInt, ByteArray>.printTLVMap(name: String = "", keyLength: Int = 1) =
+    debugPrintln("TLVMap $name= " + this.mapValues { (_, value) -> value.toUHexString() }.mapKeys {
+        when (keyLength) {
+            1 -> it.key.toInt().toUByte().toUHexString()
+            2 -> it.key.toInt().toUShort().toUHexString()
+            4 -> it.key.toInt().toUInt().toUHexString()
+            else -> illegalArgument("Expecting 1, 2 or 4 for keyLength")
+        }
+    })
+
+@Suppress("NOTHING_TO_INLINE")
+internal inline fun unsupported(): Nothing = error("Unsupported")
+
+@Suppress("NOTHING_TO_INLINE")
+internal inline fun illegalArgument(message: String? = null): Nothing = error(message ?: "Illegal argument passed")
 
 @JvmName("printTLVStringMap")
-fun Map<UInt, String>.printTLVMap(name: String) =
+fun Map<UInt, String>.printTLVMap(name: String = "") =
     debugPrintln("TLVMap $name= " + this.mapKeys { it.key.toInt().toUShort().toUHexString() })
 
 fun Input.readString(length: Int): String = String(this.readBytes(length))
diff --git a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/io/OutputUtils.kt b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/io/OutputUtils.kt
index a64a99bdc..e82837d66 100644
--- a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/io/OutputUtils.kt
+++ b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/io/OutputUtils.kt
@@ -4,6 +4,8 @@ package net.mamoe.mirai.utils.io
 
 import kotlinx.io.core.*
 import kotlinx.io.pool.useInstance
+import kotlinx.serialization.SerializationStrategy
+import kotlinx.serialization.protobuf.ProtoBuf
 import net.mamoe.mirai.contact.GroupId
 import net.mamoe.mirai.contact.GroupInternalId
 import net.mamoe.mirai.network.protocol.tim.TIMProtocol
@@ -57,6 +59,9 @@ fun BytePacketBuilder.writeHex(uHex: String) {
     }
 }
 
+fun <T> BytePacketBuilder.writeProto(serializer: SerializationStrategy<T>, obj: T) = writeFully(ProtoBuf.dump(serializer, obj))
+
+
 fun BytePacketBuilder.writeTLV(tag: UByte, values: UByteArray) {
     writeUByte(tag)
     writeUVarInt(values.size.toUInt())
@@ -104,17 +109,17 @@ fun BytePacketBuilder.writeTByteArray(tag: UByte, value: UByteArray) {
 /**
  * 会使用 [ByteArrayPool] 缓存
  */
-fun BytePacketBuilder.encryptAndWrite(key: ByteArray, encoder: BytePacketBuilder.() -> Unit) =
+inline fun BytePacketBuilder.encryptAndWrite(key: ByteArray, encoder: BytePacketBuilder.() -> Unit) =
     BytePacketBuilder().apply(encoder).build().encryptBy(key) { decrypted -> writeFully(decrypted) }
 
-fun BytePacketBuilder.encryptAndWrite(key: IoBuffer, encoder: BytePacketBuilder.() -> Unit) = ByteArrayPool.useInstance {
+inline fun BytePacketBuilder.encryptAndWrite(key: IoBuffer, encoder: BytePacketBuilder.() -> Unit) = ByteArrayPool.useInstance {
     key.readFully(it, 0, key.readRemaining)
     encryptAndWrite(it, encoder)
 }
 
-fun BytePacketBuilder.encryptAndWrite(key: DecrypterByteArray, encoder: BytePacketBuilder.() -> Unit) = encryptAndWrite(key.value, encoder)
+inline fun BytePacketBuilder.encryptAndWrite(key: DecrypterByteArray, encoder: BytePacketBuilder.() -> Unit) = encryptAndWrite(key.value, encoder)
 
-fun BytePacketBuilder.encryptAndWrite(keyHex: String, encoder: BytePacketBuilder.() -> Unit) = encryptAndWrite(keyHex.hexToBytes(), encoder)
+inline fun BytePacketBuilder.encryptAndWrite(keyHex: String, encoder: BytePacketBuilder.() -> Unit) = encryptAndWrite(keyHex.hexToBytes(), encoder)
 
 fun BytePacketBuilder.writeTLV0006(qq: UInt, password: String, loginTime: Int, loginIP: String, privateKey: PrivateKey) {
     val firstMD5 = md5(password)