Avoid user injection

This commit is contained in:
Karlatemp 2021-11-09 18:20:56 +08:00
parent fa364b4b45
commit e1ca6dd6c9
No known key found for this signature in database
GPG Key ID: 21FBDDF664FF06F8
6 changed files with 33 additions and 18 deletions

View File

@ -243,19 +243,25 @@ internal suspend fun <C : User> SendMessageHandler<out C>.sendMessageImpl(
preSendEventConstructor: (C, Message) -> MessagePreSendEvent,
postSendEventConstructor: (C, MessageChain, Throwable?, MessageReceipt<C>?) -> MessagePostSendEvent<C>,
): MessageReceipt<C> {
require(!message.isContentEmpty()) { "message is empty" }
val isMiraiInternal = if (message is MessageChain) {
message.anyIsInstance<MiraiInternalMessageFlag>()
} else false
val chain = contact.broadcastMessagePreSendEvent(message, preSendEventConstructor)
require(isMiraiInternal || !message.isContentEmpty()) { "message is empty" }
val chain = contact.broadcastMessagePreSendEvent(message, isMiraiInternal, preSendEventConstructor)
val result = this
.runCatching { sendMessage(message, chain, SendMessageStep.FIRST) }
.runCatching { sendMessage(message, chain, isMiraiInternal, SendMessageStep.FIRST) }
if (result.isSuccess) {
// logMessageSent(result.getOrNull()?.source?.plus(chain) ?: chain) // log with source
contact.logMessageSent(chain)
}
postSendEventConstructor(contact, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
if (!isMiraiInternal) {
postSendEventConstructor(contact, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
}
return result.getOrThrow()
}

View File

@ -154,21 +154,25 @@ internal class GroupImpl constructor(
}
override suspend fun sendMessage(message: Message): MessageReceipt<Group> {
require(!message.isContentEmpty()) { "message is empty" }
val isMiraiInternal = if (message is MessageChain) {
message.anyIsInstance<MiraiInternalMessageFlag>()
} else false
require(isMiraiInternal || !message.isContentEmpty()) { "message is empty" }
check(!isBotMuted) { throw BotIsBeingMutedException(this) }
val chain = broadcastMessagePreSendEvent(message, ::GroupMessagePreSendEvent)
val chain = broadcastMessagePreSendEvent(message, isMiraiInternal, ::GroupMessagePreSendEvent)
val result = GroupSendMessageHandler(this)
.runCatching { sendMessage(message, chain, SendMessageStep.FIRST) }
.runCatching { sendMessage(message, chain, isMiraiInternal, SendMessageStep.FIRST) }
if (result.isSuccess) {
// logMessageSent(result.getOrNull()?.source?.plus(chain) ?: chain) // log with source
logMessageSent(chain)
}
GroupMessagePostSendEvent(this, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
if (!isMiraiInternal) {
GroupMessagePostSendEvent(this, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
}
return result.getOrThrow()
}

View File

@ -24,8 +24,10 @@ import net.mamoe.mirai.message.data.toMessageChain
*/
internal suspend fun <C : Contact> C.broadcastMessagePreSendEvent(
message: Message,
isMiraiInternal: Boolean,
eventConstructor: (C, Message) -> MessagePreSendEvent,
): MessageChain {
if (isMiraiInternal) return message.toMessageChain()
return kotlin.runCatching {
eventConstructor(this, message).broadcast()
}.onSuccess {

View File

@ -121,6 +121,7 @@ internal abstract class SendMessageHandler<C : Contact> {
originalMessage: Message,
transformedMessage: MessageChain,
finalMessage: MessageChain,
isMiraiInternal: Boolean,
step: SendMessageStep,
): MessageReceipt<C> {
bot.components[MessageSvcSyncer].joinSync()
@ -140,10 +141,10 @@ internal abstract class SendMessageHandler<C : Contact> {
if (resp is MessageSvcPbSendMsg.Response.MessageTooLarge) {
return when (step) {
SendMessageStep.FIRST -> {
sendMessageImpl(originalMessage, transformedMessage, SendMessageStep.LONG_MESSAGE)
sendMessageImpl(originalMessage, transformedMessage, isMiraiInternal, SendMessageStep.LONG_MESSAGE)
}
SendMessageStep.LONG_MESSAGE -> {
sendMessageImpl(originalMessage, transformedMessage, SendMessageStep.FRAGMENTED)
sendMessageImpl(originalMessage, transformedMessage, isMiraiInternal, SendMessageStep.FRAGMENTED)
}
else -> {
@ -312,6 +313,7 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.transformSpecialMessage
internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
originalMessage: Message,
transformedMessage: Message,
isMiraiInternal: Boolean,
step: SendMessageStep,
): MessageReceipt<C> = sendMessageImpl(
originalMessage,
@ -320,6 +322,7 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
preConversionTransformedMessage(transformedMessage)
)
),
isMiraiInternal,
step
)
@ -329,16 +332,19 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
private suspend fun <C : Contact> SendMessageHandler<C>.sendMessageImpl(
originalMessage: Message,
transformedMessage: MessageChain,
isMiraiInternal: Boolean,
step: SendMessageStep,
): MessageReceipt<C> { // Result cannot be in interface.
transformedMessage.verifySendingValid()
if (!isMiraiInternal && step == SendMessageStep.FIRST) {
transformedMessage.verifySendingValid()
}
val chain = transformedMessage.convertToLongMessageIfNeeded(step)
chain.findIsInstance<QuoteReply>()?.source?.ensureSequenceIdAvailable()
postTransformActions(chain)
return sendMessagePacket(originalMessage, transformedMessage, chain, step)
return sendMessagePacket(originalMessage, transformedMessage, chain, isMiraiInternal, step)
}
internal sealed class UserSendMessageHandler<C : AbstractUser>(

View File

@ -38,9 +38,6 @@ internal fun Message.verifySendingValid() {
fun fail(msg: String): Nothing = throw IllegalArgumentException(msg)
when (this) {
is MessageChain -> {
if (contains(MiraiInternalMessageFlag)) {
return
}
this.forEach { it.verifySendingValid() }
}
is FileMessage -> fail("Sending FileMessage is not in support")

View File

@ -73,7 +73,7 @@ internal class MessageReceiptTest : AbstractTestWithMiraiImpl() {
listOf()
}
}
val result = handler.sendMessage(message, message, SendMessageStep.FIRST)
val result = handler.sendMessage(message, message, false, SendMessageStep.FIRST)
assertIs<ForwardMessage>(result.source.originalMessage[ForwardMessage])
assertEquals(message, result.source.originalMessage)