Avoid user injection

This commit is contained in:
Karlatemp 2021-11-09 18:20:56 +08:00
parent fa364b4b45
commit e1ca6dd6c9
No known key found for this signature in database
GPG Key ID: 21FBDDF664FF06F8
6 changed files with 33 additions and 18 deletions

View File

@ -243,19 +243,25 @@ internal suspend fun <C : User> SendMessageHandler<out C>.sendMessageImpl(
preSendEventConstructor: (C, Message) -> MessagePreSendEvent, preSendEventConstructor: (C, Message) -> MessagePreSendEvent,
postSendEventConstructor: (C, MessageChain, Throwable?, MessageReceipt<C>?) -> MessagePostSendEvent<C>, postSendEventConstructor: (C, MessageChain, Throwable?, MessageReceipt<C>?) -> MessagePostSendEvent<C>,
): MessageReceipt<C> { ): MessageReceipt<C> {
require(!message.isContentEmpty()) { "message is empty" } val isMiraiInternal = if (message is MessageChain) {
message.anyIsInstance<MiraiInternalMessageFlag>()
} else false
val chain = contact.broadcastMessagePreSendEvent(message, preSendEventConstructor) require(isMiraiInternal || !message.isContentEmpty()) { "message is empty" }
val chain = contact.broadcastMessagePreSendEvent(message, isMiraiInternal, preSendEventConstructor)
val result = this val result = this
.runCatching { sendMessage(message, chain, SendMessageStep.FIRST) } .runCatching { sendMessage(message, chain, isMiraiInternal, SendMessageStep.FIRST) }
if (result.isSuccess) { if (result.isSuccess) {
// logMessageSent(result.getOrNull()?.source?.plus(chain) ?: chain) // log with source // logMessageSent(result.getOrNull()?.source?.plus(chain) ?: chain) // log with source
contact.logMessageSent(chain) contact.logMessageSent(chain)
} }
postSendEventConstructor(contact, chain, result.exceptionOrNull(), result.getOrNull()).broadcast() if (!isMiraiInternal) {
postSendEventConstructor(contact, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
}
return result.getOrThrow() return result.getOrThrow()
} }

View File

@ -154,21 +154,25 @@ internal class GroupImpl constructor(
} }
override suspend fun sendMessage(message: Message): MessageReceipt<Group> { override suspend fun sendMessage(message: Message): MessageReceipt<Group> {
require(!message.isContentEmpty()) { "message is empty" } val isMiraiInternal = if (message is MessageChain) {
message.anyIsInstance<MiraiInternalMessageFlag>()
} else false
require(isMiraiInternal || !message.isContentEmpty()) { "message is empty" }
check(!isBotMuted) { throw BotIsBeingMutedException(this) } check(!isBotMuted) { throw BotIsBeingMutedException(this) }
val chain = broadcastMessagePreSendEvent(message, ::GroupMessagePreSendEvent) val chain = broadcastMessagePreSendEvent(message, isMiraiInternal, ::GroupMessagePreSendEvent)
val result = GroupSendMessageHandler(this) val result = GroupSendMessageHandler(this)
.runCatching { sendMessage(message, chain, SendMessageStep.FIRST) } .runCatching { sendMessage(message, chain, isMiraiInternal, SendMessageStep.FIRST) }
if (result.isSuccess) { if (result.isSuccess) {
// logMessageSent(result.getOrNull()?.source?.plus(chain) ?: chain) // log with source // logMessageSent(result.getOrNull()?.source?.plus(chain) ?: chain) // log with source
logMessageSent(chain) logMessageSent(chain)
} }
if (!isMiraiInternal) {
GroupMessagePostSendEvent(this, chain, result.exceptionOrNull(), result.getOrNull()).broadcast() GroupMessagePostSendEvent(this, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
}
return result.getOrThrow() return result.getOrThrow()
} }

View File

@ -24,8 +24,10 @@ import net.mamoe.mirai.message.data.toMessageChain
*/ */
internal suspend fun <C : Contact> C.broadcastMessagePreSendEvent( internal suspend fun <C : Contact> C.broadcastMessagePreSendEvent(
message: Message, message: Message,
isMiraiInternal: Boolean,
eventConstructor: (C, Message) -> MessagePreSendEvent, eventConstructor: (C, Message) -> MessagePreSendEvent,
): MessageChain { ): MessageChain {
if (isMiraiInternal) return message.toMessageChain()
return kotlin.runCatching { return kotlin.runCatching {
eventConstructor(this, message).broadcast() eventConstructor(this, message).broadcast()
}.onSuccess { }.onSuccess {

View File

@ -121,6 +121,7 @@ internal abstract class SendMessageHandler<C : Contact> {
originalMessage: Message, originalMessage: Message,
transformedMessage: MessageChain, transformedMessage: MessageChain,
finalMessage: MessageChain, finalMessage: MessageChain,
isMiraiInternal: Boolean,
step: SendMessageStep, step: SendMessageStep,
): MessageReceipt<C> { ): MessageReceipt<C> {
bot.components[MessageSvcSyncer].joinSync() bot.components[MessageSvcSyncer].joinSync()
@ -140,10 +141,10 @@ internal abstract class SendMessageHandler<C : Contact> {
if (resp is MessageSvcPbSendMsg.Response.MessageTooLarge) { if (resp is MessageSvcPbSendMsg.Response.MessageTooLarge) {
return when (step) { return when (step) {
SendMessageStep.FIRST -> { SendMessageStep.FIRST -> {
sendMessageImpl(originalMessage, transformedMessage, SendMessageStep.LONG_MESSAGE) sendMessageImpl(originalMessage, transformedMessage, isMiraiInternal, SendMessageStep.LONG_MESSAGE)
} }
SendMessageStep.LONG_MESSAGE -> { SendMessageStep.LONG_MESSAGE -> {
sendMessageImpl(originalMessage, transformedMessage, SendMessageStep.FRAGMENTED) sendMessageImpl(originalMessage, transformedMessage, isMiraiInternal, SendMessageStep.FRAGMENTED)
} }
else -> { else -> {
@ -312,6 +313,7 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.transformSpecialMessage
internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage( internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
originalMessage: Message, originalMessage: Message,
transformedMessage: Message, transformedMessage: Message,
isMiraiInternal: Boolean,
step: SendMessageStep, step: SendMessageStep,
): MessageReceipt<C> = sendMessageImpl( ): MessageReceipt<C> = sendMessageImpl(
originalMessage, originalMessage,
@ -320,6 +322,7 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
preConversionTransformedMessage(transformedMessage) preConversionTransformedMessage(transformedMessage)
) )
), ),
isMiraiInternal,
step step
) )
@ -329,16 +332,19 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
private suspend fun <C : Contact> SendMessageHandler<C>.sendMessageImpl( private suspend fun <C : Contact> SendMessageHandler<C>.sendMessageImpl(
originalMessage: Message, originalMessage: Message,
transformedMessage: MessageChain, transformedMessage: MessageChain,
isMiraiInternal: Boolean,
step: SendMessageStep, step: SendMessageStep,
): MessageReceipt<C> { // Result cannot be in interface. ): MessageReceipt<C> { // Result cannot be in interface.
transformedMessage.verifySendingValid() if (!isMiraiInternal && step == SendMessageStep.FIRST) {
transformedMessage.verifySendingValid()
}
val chain = transformedMessage.convertToLongMessageIfNeeded(step) val chain = transformedMessage.convertToLongMessageIfNeeded(step)
chain.findIsInstance<QuoteReply>()?.source?.ensureSequenceIdAvailable() chain.findIsInstance<QuoteReply>()?.source?.ensureSequenceIdAvailable()
postTransformActions(chain) postTransformActions(chain)
return sendMessagePacket(originalMessage, transformedMessage, chain, step) return sendMessagePacket(originalMessage, transformedMessage, chain, isMiraiInternal, step)
} }
internal sealed class UserSendMessageHandler<C : AbstractUser>( internal sealed class UserSendMessageHandler<C : AbstractUser>(

View File

@ -38,9 +38,6 @@ internal fun Message.verifySendingValid() {
fun fail(msg: String): Nothing = throw IllegalArgumentException(msg) fun fail(msg: String): Nothing = throw IllegalArgumentException(msg)
when (this) { when (this) {
is MessageChain -> { is MessageChain -> {
if (contains(MiraiInternalMessageFlag)) {
return
}
this.forEach { it.verifySendingValid() } this.forEach { it.verifySendingValid() }
} }
is FileMessage -> fail("Sending FileMessage is not in support") is FileMessage -> fail("Sending FileMessage is not in support")

View File

@ -73,7 +73,7 @@ internal class MessageReceiptTest : AbstractTestWithMiraiImpl() {
listOf() listOf()
} }
} }
val result = handler.sendMessage(message, message, SendMessageStep.FIRST) val result = handler.sendMessage(message, message, false, SendMessageStep.FIRST)
assertIs<ForwardMessage>(result.source.originalMessage[ForwardMessage]) assertIs<ForwardMessage>(result.source.originalMessage[ForwardMessage])
assertEquals(message, result.source.originalMessage) assertEquals(message, result.source.originalMessage)