Redesign auth

This commit is contained in:
Him188 2023-03-12 17:09:40 +00:00
parent 5abcd7fb26
commit 992f9289ce
No known key found for this signature in database
GPG Key ID: BA439CDDCF652375
11 changed files with 637 additions and 183 deletions

View File

@ -102,8 +102,19 @@ public interface BotAuthInfo {
@NotStableForInheritance
public interface BotAuthSession {
/**
* @throws LoginFailedException
*/
public suspend fun authByPassword(password: String): BotAuthResult
/**
* @throws LoginFailedException
*/
public suspend fun authByPassword(passwordMd5: ByteArray): BotAuthResult
/**
* @throws LoginFailedException
*/
public suspend fun authByQRCode(): BotAuthResult
}

View File

@ -9,157 +9,108 @@
package net.mamoe.mirai.internal.network.auth
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import net.mamoe.mirai.auth.BotAuthInfo
import net.mamoe.mirai.auth.BotAuthResult
import net.mamoe.mirai.auth.BotAuthorization
import net.mamoe.mirai.internal.network.components.SsoProcessorImpl
import net.mamoe.mirai.internal.utils.subLogger
import net.mamoe.mirai.utils.*
import kotlin.coroutines.Continuation
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
import kotlin.jvm.Volatile
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.cancellation.CancellationException
/**
* Event sequence:
*
* 1. Starts a user coroutine [BotAuthorization.authorize].
* 2. User coroutine
*/
internal class AuthControl(
private val botAuthInfo: BotAuthInfo,
private val authorization: BotAuthorization,
private val logger: MiraiLogger,
scope: CoroutineScope,
parentCoroutineContext: CoroutineContext,
) {
internal val exceptionCollector = ExceptionCollector()
@Volatile
private var authorizationContinuation: Continuation<Unit>? = null
private val userDecisions: OnDemandConsumer<Throwable?, SsoProcessorImpl.AuthMethod> =
CoroutineOnDemandValueScope(parentCoroutineContext, logger.subLogger("AuthControl/UserDecisions")) {
/**
* Implements [BotAuthSessionInternal] from API, to be called by the user, to receive user's decisions.
*/
val sessionImpl = object : BotAuthSessionInternal() {
private val authResultImpl = object : BotAuthResult {}
@Volatile
private var authRspFuture = initCompletableDeferred()
@Volatile
private var isCompleted = false
private val rsp = object : BotAuthResult {}
@Suppress("RemoveExplicitTypeArguments")
@OptIn(TestOnly::class)
private val authComponent = object : SsoProcessorImpl.SsoProcessorAuthComponent() {
override val botAuthResult: BotAuthResult get() = rsp
override suspend fun emit(method: SsoProcessorImpl.AuthMethod) {
logger.verbose { "[AuthControl/emit] Trying emit $method" }
if (isCompleted) {
val msg = "[AuthControl/emit] Failed to emit $method because control completed"
error(msg.also { logger.verbose(it) })
}
suspendCoroutine<Unit> { next ->
val rspTarget = authRspFuture
if (!rspTarget.complete(method)) {
val msg = "[AuthControl/emit] Failed to emit $method because auth response completed"
error(msg.also { logger.verbose(it) })
override suspend fun authByPassword(passwordMd5: SecretsProtection.EscapedByteBuffer): BotAuthResult {
runWrapInternalException {
emit(SsoProcessorImpl.AuthMethod.Pwd(passwordMd5))
}?.let { throw it }
return authResultImpl
}
override suspend fun authByQRCode(): BotAuthResult {
runWrapInternalException {
emit(SsoProcessorImpl.AuthMethod.QRCode)
}?.let { throw it }
return authResultImpl
}
private inline fun <R> runWrapInternalException(block: () -> R): R {
try {
return block()
} catch (e: IllegalProducerStateException) {
if (e.lastStateWasSucceed) {
throw IllegalStateException(
"This login session has already completed. Please return the BotAuthResult you get from 'authBy*()' immediately",
e
)
} else {
throw e // internal bug
}
}
}
authorizationContinuation = next
logger.verbose { "[AuthControl/emit] Emitted $method to $rspTarget" }
}
logger.verbose { "[AuthControl/emit] Authorization resumed after $method" }
}
override suspend fun authByPassword(passwordMd5: SecretsProtection.EscapedByteBuffer): BotAuthResult {
emit(SsoProcessorImpl.AuthMethod.Pwd(passwordMd5))
return rsp
}
override suspend fun authByPassword(password: String): BotAuthResult {
return authByPassword(password.md5())
}
override suspend fun authByPassword(passwordMd5: ByteArray): BotAuthResult {
return authByPassword(SecretsProtection.EscapedByteBuffer(passwordMd5))
}
override suspend fun authByQRCode(): BotAuthResult {
emit(SsoProcessorImpl.AuthMethod.QRCode)
return rsp
}
}
init {
// start users' BotAuthorization.authorize
scope.launch {
try {
logger.verbose { "[AuthControl/auth] Authorization started" }
authorization.authorize(authComponent, botAuthInfo)
authorization.authorize(sessionImpl, botAuthInfo)
logger.verbose { "[AuthControl/auth] Authorization exited" }
isCompleted = true
authRspFuture.complete(SsoProcessorImpl.AuthMethod.NotAvailable)
finish()
} catch (e: CancellationException) {
logger.verbose { "[AuthControl/auth] Authorization cancelled" }
} catch (e: Throwable) {
logger.verbose({ "[AuthControl/auth] Authorization failed" }, e)
isCompleted = true
authRspFuture.complete(SsoProcessorImpl.AuthMethod.Error(e))
logger.verbose { "[AuthControl/auth] Authorization failed: $e" }
finishExceptionally(e)
}
}
init {
userDecisions.expectMore(null)
}
private fun onSpinWait() {}
// Does not throw
suspend fun acquireAuth(): SsoProcessorImpl.AuthMethod {
val authTarget = authRspFuture
logger.verbose { "[AuthControl/acquire] Acquiring auth method with $authTarget" }
val rsp = authTarget.await()
logger.debug { "[AuthControl/acquire] Authorization responded: $authTarget, $rsp" }
logger.verbose { "[AuthControl/acquire] Acquiring auth method" }
while (authorizationContinuation == null && !isCompleted) {
onSpinWait()
val rsp = try {
userDecisions.receiveOrNull() ?: SsoProcessorImpl.AuthMethod.NotAvailable
} catch (e: ProducerFailureException) {
SsoProcessorImpl.AuthMethod.Error(e)
}
logger.verbose { "[AuthControl/acquire] authorizationContinuation setup: $authorizationContinuation, $isCompleted" }
logger.debug { "[AuthControl/acquire] Authorization responded: $rsp" }
return rsp
}
fun actFailed(cause: Throwable) {
fun actMethodFailed(cause: Throwable) {
logger.verbose { "[AuthControl/resume] Fire auth failed with cause: $cause" }
authRspFuture = initCompletableDeferred()
authorizationContinuation!!.let { cont ->
authorizationContinuation = null
cont.resumeWith(Result.failure(cause))
}
}
@TestOnly // same as act failed
fun actResume() {
logger.verbose { "[AuthControl/resume] Fire auth resume" }
authRspFuture = initCompletableDeferred()
authorizationContinuation!!.let { cont ->
authorizationContinuation = null
cont.resume(Unit)
}
userDecisions.expectMore(cause)
}
fun actComplete() {
logger.verbose { "[AuthControl/resume] Fire auth completed" }
isCompleted = true
authRspFuture = CompletableDeferred(SsoProcessorImpl.AuthMethod.NotAvailable)
authorizationContinuation!!.let { cont ->
authorizationContinuation = null
cont.resume(Unit)
}
}
private fun initCompletableDeferred(): CompletableDeferred<SsoProcessorImpl.AuthMethod> {
return CompletableDeferred<SsoProcessorImpl.AuthMethod>().also { df ->
df.invokeOnCompletion {
logger.debug { "[AuthControl/cd] $df completed with $it" }
}
}
userDecisions.finish()
}
}

View File

@ -14,10 +14,20 @@ import net.mamoe.mirai.auth.BotAuthResult
import net.mamoe.mirai.auth.BotAuthSession
import net.mamoe.mirai.auth.BotAuthorization
import net.mamoe.mirai.utils.SecretsProtection
import net.mamoe.mirai.utils.md5
// With SecretsProtection support
internal abstract class BotAuthSessionInternal : BotAuthSession {
final override suspend fun authByPassword(password: String): BotAuthResult {
return authByPassword(password.md5())
}
final override suspend fun authByPassword(passwordMd5: ByteArray): BotAuthResult {
return authByPassword(SecretsProtection.EscapedByteBuffer(passwordMd5))
}
abstract suspend fun authByPassword(passwordMd5: SecretsProtection.EscapedByteBuffer): BotAuthResult
}

View File

@ -0,0 +1,196 @@
/*
* Copyright 2019-2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.network.auth
import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.loop
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.cancel
import kotlinx.coroutines.job
import kotlinx.coroutines.launch
import net.mamoe.mirai.utils.MiraiLogger
import net.mamoe.mirai.utils.childScope
import net.mamoe.mirai.utils.debug
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.cancellation.CancellationException
internal class IllegalProducerStateException(
private val state: ProducerState<*, *>,
message: String? = state.toString(),
cause: Throwable? = null,
) : IllegalStateException(message, cause) {
val lastStateWasSucceed get() = (state is ProducerState.Finished) && state.isSuccess
}
internal class CoroutineOnDemandValueScope<T, V>(
parentCoroutineContext: CoroutineContext,
private val logger: MiraiLogger,
private val producerCoroutine: suspend OnDemandProducerScope<T, V>.() -> Unit,
) : OnDemandConsumer<T, V> {
private val coroutineScope = parentCoroutineContext.childScope("CoroutineOnDemandValueScope")
private val state: AtomicRef<ProducerState<T, V>> = atomic(ProducerState.JustInitialized())
inner class Producer : OnDemandProducerScope<T, V> {
init {
coroutineScope.launch {
try {
producerCoroutine()
} catch (_: CancellationException) {
// ignored
} catch (e: Exception) {
finishExceptionally(e)
}
}
}
override suspend fun emit(value: V): T {
state.loop { state ->
when (state) {
is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
is ProducerState.Producing -> {
val deferred = state.deferred
val consumingState = ProducerState.Consuming(
state.producer,
state.deferred,
coroutineScope.coroutineContext
)
if (compareAndSetState(state, consumingState)) {
deferred.complete(value) // produce a value
return consumingState.producerLatch.acquire() // wait for producer to consume the previous value.
}
}
else -> throw IllegalProducerStateException(state)
}
}
}
override fun finishExceptionally(exception: Throwable) {
finishImpl(exception)
}
override fun finish() {
state.loop { state ->
when (state) {
is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
else -> {
if (compareAndSetState(state, ProducerState.Finished(state, null))) {
return
}
}
}
}
}
}
private fun finishImpl(exception: Throwable?) {
state.loop { state ->
when (state) {
is ProducerState.Finished -> throw state.createAlreadyFinishedException(exception)
else -> {
if (compareAndSetState(state, ProducerState.Finished(state, exception))) {
val cancellationException = kotlinx.coroutines.CancellationException("Finished", exception)
coroutineScope.cancel(cancellationException)
return
}
}
}
}
}
private fun compareAndSetState(state: ProducerState<T, V>, newState: ProducerState<T, V>): Boolean {
return this.state.compareAndSet(state, newState).also {
logger.debug { "CAS: $state -> $newState: $it" }
}
}
override suspend fun receiveOrNull(): V? {
state.loop { state ->
when (state) {
is ProducerState.Producing -> {
// still producing value
state.deferred.await() // just wait for value, but does not return it.
// The value will be completed in ProducerState.Consuming state,
// but you cannot thread-safely assume current state is Consuming.
// Here we will loop again, to atomically switch to Consumed state.
}
is ProducerState.Consuming -> {
// value is ready, switch state to ProducerReady
if (compareAndSetState(
state,
ProducerState.Consumed(state.producer, state.producerLatch)
)
) {
return try {
state.value.await() // won't suspend, since value is already completed
} catch (e: Exception) {
throw ProducerFailureException(cause = e)
}
}
}
is ProducerState.Finished -> return null
else -> throw IllegalProducerStateException(state)
}
}
}
override fun expectMore(ticket: T): Boolean {
state.loop { state ->
when (state) {
is ProducerState.JustInitialized -> {
compareAndSetState(state, ProducerState.CreatingProducer { Producer() })
// loop again
}
is ProducerState.CreatingProducer -> {
compareAndSetState(state, ProducerState.ProducerReady(state.producer))
// loop again
}
is ProducerState.ProducerReady -> {
val deferred = CompletableDeferred<V>(coroutineScope.coroutineContext.job)
if (!compareAndSetState(state, ProducerState.Producing(state.producer, deferred))) {
deferred.cancel() // avoid leak
}
// loop again
}
is ProducerState.Producing -> return true // ok
is ProducerState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
is ProducerState.Consumed -> {
if (compareAndSetState(state, ProducerState.ProducerReady(state.producer))) {
// wake up producer async.
state.producerLatch.resumeWith(Result.success(ticket))
// loop again to switch state atomically to Producing.
// Do not do switch state directly here — async producer may race with you!
}
}
is ProducerState.Finished -> return false
}
}
}
override fun finish() {
finishImpl(null)
}
}

View File

@ -0,0 +1,60 @@
/*
* Copyright 2019-2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.network.auth
import kotlinx.atomicfu.locks.reentrantLock
import kotlinx.atomicfu.locks.withLock
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Job
import kotlinx.coroutines.completeWith
import kotlin.coroutines.CoroutineContext
import kotlin.jvm.Volatile
internal interface Latch<T> {
/**
* Suspends and waits to acquire the latch.
* @throws Throwable if [resumeWith] is called with [Result.Failure]
*/
suspend fun acquire(): T
/**
* Release the latch, resuming the coroutines waiting for the latch.
*
* This function will return immediately unless a client is calling [acquire] concurrently.
*/
fun resumeWith(result: Result<T>)
}
internal fun <T> Latch(parentCoroutineContext: CoroutineContext): Latch<T> = LatchImpl(parentCoroutineContext)
private class LatchImpl<T>(
parentCoroutineContext: CoroutineContext
) : Latch<T> {
@Volatile
private var deferred: CompletableDeferred<T>? = CompletableDeferred(parentCoroutineContext[Job])
private val lock = reentrantLock()
override suspend fun acquire(): T = lock.withLock {
val deferred = this.deferred!!
return deferred.await().also {
this.deferred = null
}
}
override fun resumeWith(result: Result<T>): Unit = lock.withLock {
val deferred = this.deferred ?: CompletableDeferred<T>().also { this.deferred = it }
deferred.completeWith(result)
}
override fun toString(): String = "LatchImpl($deferred)"
}

View File

@ -0,0 +1,86 @@
/*
* Copyright 2019-2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.network.auth
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ReceiveChannel
import kotlin.coroutines.Continuation
import kotlin.coroutines.cancellation.CancellationException
/**
* 按需供给的值制造器.
*/
internal interface OnDemandProducerScope<T, V> {
/**
* 挂起协程, 直到 [OnDemandConsumer] 期望接收一个 [V], 届时将 [value] 传递给 [OnDemandConsumer.receiveOrNull], 成为其返回值.
*
* 若在调用 [emit] 时已经有 [OnDemandConsumer] 正在等待, 则该 [OnDemandConsumer] 协程会立即[恢复][Continuation.resumeWith].
*
* [OnDemandConsumer] 已经[完结][OnDemandConsumer.finish], [OnDemandProducerScope.emit] 会抛出 [IllegalProducerStateException].
*/
suspend fun emit(value: V): T
/**
* 标记此 [OnDemandProducerScope] 在生产 [V] 的过程中出现错误.
*
* 这也会终止此 [OnDemandProducerScope], 随后 [OnDemandConsumer.receiveOrNull] 将会抛出 [ProducerFailureException].
*/
fun finishExceptionally(exception: Throwable)
/**
* 标记此 [OnDemandProducerScope] 已经没有更多 [V] 可生产.
*
* 随后 [OnDemandConsumer.receiveOrNull] 将会抛出 [IllegalStateException].
*/
fun finish()
}
/**
* 按需消费者.
*
* [ReceiveChannel] 不同, [OnDemandConsumer] 只有在调用 [expectMore] 后才会期待[生产者][OnDemandProducerScope] 生产下一个 [V].
*/
internal interface OnDemandConsumer<T, V> {
/**
* 挂起协程并等待从 [OnDemandProducerScope] [接收][OnDemandProducerScope.emit]一个 [V].
*
* 当此函数被多个线程 (协程) 同时调用时, 只有一个线程挂起并获得 [V], 其他线程将会
*
* @throws ProducerFailureException [OnDemandProducerScope.finishExceptionally] 时抛出.
* @throws CancellationException 当协程被取消时抛出
* @throws IllegalProducerStateException 当状态异常, 如未调用 [expectMore] 时抛出
*/
@Throws(ProducerFailureException::class, CancellationException::class)
suspend fun receiveOrNull(): V?
/**
* 期待 [OnDemandProducerScope] 再生产一个 [V]. 期望生产后必须在之后调用 [receiveOrNull] [finish] 来消耗生产的 [V].
*
* 在成功发起期待后返回 `true`; [OnDemandProducerScope] 已经[完结][OnDemandProducerScope.finish] 时返回 `false`.
*
* @throws IllegalProducerStateException [expectMore] 被调用后, 没有调用 [receiveOrNull] 就又调用了 [expectMore] 时抛出
*/
fun expectMore(ticket: T): Boolean
/**
* 标记此 [OnDemandConsumer] 已经完结.
*
* 如果 [OnDemandProducerScope] 仍在运行, 将会 (正常地) 取消 [OnDemandProducerScope].
*
* 随后 [OnDemandProducerScope.emit] 将会抛出 [IllegalStateException].
*/
fun finish()
}
internal class ProducerFailureException(
override val message: String? = null,
override val cause: Throwable?
) : Exception()

View File

@ -0,0 +1,176 @@
/*
* Copyright 2019-2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
package net.mamoe.mirai.internal.network.auth
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Deferred
import kotlin.coroutines.CoroutineContext
/**
* Producer states.
*/
internal sealed interface ProducerState<T, V> {
/*
* 可变更状态的函数: [emit], [receiveOrNull], [expectMore], [finish], [finishExceptionally]
*
* [emit] [receiveOrNull] suspend 函数, 在图中 "(suspend)" 表示挂起它们的协程, "(resume)" 表示恢复它们的协程.
*
* "A ~~~~~~> B" 表示在切换为状态 A , 会挂起或恢复协程 B.
*
*
*
*
* JustInitialized
* |
* | 调用 [expectMore]
* |
* V
* CreatingProducer
* |
* |
* |
* V
* ProducerReady (从此用户协程作为 producer 在后台运行)
* |
* |
* | <--------------------------------------------------
* | \
* V |
* Producing ([expectMore] 结束) |
* | \ |
* 调用 | \ |
* [receiveOrNull] | \ 调用 [emit] |
* / \ |
* / \ |
* / \ |
* | \ |
* | \ |
* | |------------- |
* | | \ |
* | | | |
* | \ | |
* | \ | |
* | \ | |
* | | | |
* V (resume) V | |
* ([receiveOrNull] suspend) <~~~~~~~~~~~~ Consuming | |
* | / | |
* | / | |
* | /---------------/ | |
* | / 调用 [receiveOrNull] | |
* | / | |
* |/ | |
* | | |
* | | |
* V | |
* ([receiveOrNull] 结束) Consumed | |
* | | |
* | 调用 [expectMore] | |
* | | |
* V (resume) V |
* ProducerReady ~~~~~~~~~~~~~~~~> ([emit] suspend) |
* | | |
* | | |
* | V |
* | ([emit] 结束) |
* | |
* |------------------------------------------------------------+
* (返回顶部 Producing)
*
*
*
* 在任意状态调用 [finish] 以及 [finishExceptionally], 可将状态转移到最终状态 [Finished].
*
* 在一个状态中调用图中未说明的函数会抛出 [IllegalProducerStateException].
*/
/**
* Override this function to produce good debug information
*/
abstract override fun toString(): String
class JustInitialized<T, V> : ProducerState<T, V> {
override fun toString(): String = "JustInitialized"
}
sealed interface HasProducer<T, V> : ProducerState<T, V> {
val producer: OnDemandProducerScope<T, V>
}
// This is need — to ensure [launchProducer] is called exactly once.
class CreatingProducer<T, V>(
launchProducer: () -> OnDemandProducerScope<T, V>
) : HasProducer<T, V> {
override val producer: OnDemandProducerScope<T, V> by lazy(launchProducer)
override fun toString(): String = "CreatingProducer"
}
class ProducerReady<T, V>(
override val producer: OnDemandProducerScope<T, V>,
) : HasProducer<T, V> {
override fun toString(): String = "ProducerReady"
}
class Producing<T, V>(
override val producer: OnDemandProducerScope<T, V>,
val deferred: CompletableDeferred<V>,
) : HasProducer<T, V> {
override fun toString(): String = "Producing(deferred.completed=${deferred.isCompleted})"
}
class Consuming<T, V>(
override val producer: OnDemandProducerScope<T, V>,
val value: Deferred<V>,
parentCoroutineContext: CoroutineContext,
) : HasProducer<T, V> {
val producerLatch = Latch<T>(parentCoroutineContext)
override fun toString(): String {
val completed =
value.runCatching { getCompleted().toString() }.getOrNull() // getCompleted() is experimental
return "Consuming(value=$completed)"
}
}
class Consumed<T, V>(
override val producer: OnDemandProducerScope<T, V>,
val producerLatch: Latch<T>
) : HasProducer<T, V> {
override fun toString(): String = "Consumed($producerLatch)"
}
class Finished<T, V>(
val previousState: ProducerState<T, V>,
val exception: Throwable?,
) : ProducerState<T, V> {
val isSuccess get() = exception == null
fun createAlreadyFinishedException(cause: Throwable?): IllegalProducerStateException {
val exception = exception
return if (exception == null) {
IllegalProducerStateException(
this,
"Producer has already finished normally, but attempting to finish with the cause $cause. Previous state was: $previousState",
cause = cause
)
} else {
IllegalProducerStateException(
this,
"Producer has already finished with the suppressed exception, but attempting to finish with the cause $cause. Previous state was: $previousState",
cause = cause
).apply {
addSuppressed(exception)
}
}
}
override fun toString(): String = "Finished($previousState, $exception)"
}
}

View File

@ -17,7 +17,6 @@ import net.mamoe.mirai.internal.network.QQAndroidClient
import net.mamoe.mirai.internal.network.QRCodeLoginData
import net.mamoe.mirai.internal.network.WLoginSigInfo
import net.mamoe.mirai.internal.network.auth.AuthControl
import net.mamoe.mirai.internal.network.auth.BotAuthSessionInternal
import net.mamoe.mirai.internal.network.auth.BotAuthorizationWithSecretsProtection
import net.mamoe.mirai.internal.network.component.ComponentKey
import net.mamoe.mirai.internal.network.handler.NetworkHandler
@ -165,7 +164,7 @@ internal class SsoProcessorImpl(
botAuthInfo,
ssoContext.bot.account.authorization,
ssoContext.bot.network.logger,
ssoContext.bot,
ssoContext.bot.coroutineContext, // do not use network context because network may restart whilst auth control should keep alive
)
}
@ -218,8 +217,6 @@ internal class SsoProcessorImpl(
ssoContext.bot.components[EcdhInitialPublicKeyUpdater].refreshInitialPublicKeyAndApplyEcdh()
when (val authw = authControl0.acquireAuth().also { nextAuthMethod = it }) {
is AuthMethod.DirectError -> throw authw.exception
is AuthMethod.Error -> {
authControl = null
authControl0.exceptionCollector.collectThrow(authw.exception)
@ -247,20 +244,13 @@ internal class SsoProcessorImpl(
authControl = null
} catch (exception: Throwable) {
if (exception is SelectorRequireReconnectException) {
if (nextAuthMethod is AuthMethod.DirectError) { // @TestOnly
authControl0.actResume()
}
throw exception
}
ssoContext.bot.network.logger.warning({ "Failed with auth method: $nextAuthMethod" }, exception)
if (nextAuthMethod is AuthMethod.DirectError) { // @TestOnly
authControl0.actResume()
} else if (nextAuthMethod !is AuthMethod.Error && nextAuthMethod != null) {
authControl0.actFailed(exception)
if (nextAuthMethod !is AuthMethod.Error && nextAuthMethod != null) {
authControl0.actMethodFailed(exception)
}
if (exception is NetworkException) {
@ -280,51 +270,24 @@ internal class SsoProcessorImpl(
sealed class AuthMethod {
object NotAvailable : AuthMethod()
object NotAvailable : AuthMethod() {
override fun toString(): String = "NotAvailable"
}
object QRCode : AuthMethod() {
override fun toString(): String {
return "QRCode"
}
override fun toString(): String = "QRCode"
}
class Pwd(val passwordMd5: SecretsProtection.EscapedByteBuffer) : AuthMethod() {
override fun toString(): String {
return "Password@${hashCode()}"
}
override fun toString(): String = "Password@${hashCode()}"
}
/**
* Exception in [BotAuthorization]
*/
class Error(val exception: Throwable) : AuthMethod() {
override fun toString(): String {
return "Error[$exception]@${hashCode()}"
}
override fun toString(): String = "Error[$exception]@${hashCode()}"
}
/**
* For mocking a login method throw a exception
*/
@TestOnly
class DirectError(val exception: Throwable) : AuthMethod() {
override fun toString(): String {
return "DirectError[$exception]@${hashCode()}"
}
}
}
@TestOnly
internal abstract class SsoProcessorAuthComponent : BotAuthSessionInternal() {
abstract suspend fun emit(method: AuthMethod)
suspend fun emitDirectError(error: Throwable) {
emit(AuthMethod.DirectError(error))
}
abstract val botAuthResult: BotAuthResult
}
private var authControl: AuthControl? = null

View File

@ -1,5 +1,5 @@
/*
* Copyright 2019-2022 Mamoe Technologies and contributors.
* Copyright 2019-2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
@ -115,5 +115,10 @@ internal object MiraiCoreServices {
"net.mamoe.mirai.message.data.OfflineAudio.Factory",
"net.mamoe.mirai.internal.message.data.OfflineAudioFactoryImpl"
) { net.mamoe.mirai.internal.message.data.OfflineAudioFactoryImpl() }
Services.register(
"net.mamoe.mirai.auth.DefaultBotAuthorizationFactory",
"net.mamoe.mirai.internal.network.auth.DefaultBotAuthorizationFactoryImpl"
) { net.mamoe.mirai.internal.network.auth.DefaultBotAuthorizationFactoryImpl() }
}
}

View File

@ -15,19 +15,21 @@ import net.mamoe.mirai.auth.BotAuthInfo
import net.mamoe.mirai.auth.BotAuthResult
import net.mamoe.mirai.auth.BotAuthSession
import net.mamoe.mirai.auth.BotAuthorization
import net.mamoe.mirai.internal.network.auth.AuthControl
import net.mamoe.mirai.internal.network.components.SsoProcessorContext
import net.mamoe.mirai.internal.network.components.SsoProcessorImpl
import net.mamoe.mirai.internal.network.framework.AbstractCommonNHTest
import net.mamoe.mirai.network.CustomLoginFailedException
import net.mamoe.mirai.utils.BotConfiguration
import net.mamoe.mirai.utils.DeviceInfo
import net.mamoe.mirai.utils.EMPTY_BYTE_ARRAY
import kotlin.reflect.KClass
import kotlin.test.Test
import kotlin.test.assertFails
import kotlin.test.assertFailsWith
import kotlin.test.fail
internal class BotAuthControlTest : AbstractCommonNHTest() {
val botAuthInfo = object : BotAuthInfo {
private val botAuthInfo = object : BotAuthInfo {
override val id: Long
get() = bot.id
override val deviceInfo: DeviceInfo
@ -36,7 +38,7 @@ internal class BotAuthControlTest : AbstractCommonNHTest() {
get() = bot.configuration
}
private suspend fun SsoProcessorImpl.AuthControl.assertRequire(exceptedType: KClass<*>) {
private suspend fun AuthControl.assertRequire(exceptedType: KClass<*>) {
println("Requiring auth method")
val nextAuth = acquireAuth()
println("Got $nextAuth")
@ -52,11 +54,11 @@ internal class BotAuthControlTest : AbstractCommonNHTest() {
@Test
fun `auth test`() = runTest {
val control = SsoProcessorImpl.AuthControl(botAuthInfo, object : BotAuthorization {
val control = AuthControl(botAuthInfo, object : BotAuthorization {
override suspend fun authorize(session: BotAuthSession, info: BotAuthInfo): BotAuthResult {
return session.authByPassword(EMPTY_BYTE_ARRAY)
}
}, bot.logger, backgroundScope)
}, bot.logger, backgroundScope.coroutineContext)
control.assertRequire(SsoProcessorImpl.AuthMethod.Pwd::class)
control.actComplete()
@ -66,35 +68,37 @@ internal class BotAuthControlTest : AbstractCommonNHTest() {
@Test
fun `test auth failed and reselect`() = runTest {
class MyLoginFailedException : CustomLoginFailedException(killBot = false)
val control = SsoProcessorImpl.AuthControl(botAuthInfo, object : BotAuthorization {
val control = AuthControl(botAuthInfo, object : BotAuthorization {
override suspend fun authorize(session: BotAuthSession, info: BotAuthInfo): BotAuthResult {
assertFails { session.authByPassword(EMPTY_BYTE_ARRAY); println("!") }
assertFailsWith<MyLoginFailedException> { session.authByPassword(EMPTY_BYTE_ARRAY); println("!") }
println("114514")
return session.authByPassword(EMPTY_BYTE_ARRAY)
}
}, bot.logger, backgroundScope)
}, bot.logger, backgroundScope.coroutineContext)
control.assertRequire(SsoProcessorImpl.AuthMethod.Pwd::class)
control.actFailed(Throwable())
control.actMethodFailed(MyLoginFailedException())
control.assertRequire(SsoProcessorImpl.AuthMethod.Pwd::class)
control.actComplete()
control.assertRequire(SsoProcessorImpl.AuthMethod.NotAvailable::class)
}
@Test
fun `failed when login complete`() = runTest {
val control = SsoProcessorImpl.AuthControl(botAuthInfo, object : BotAuthorization {
val control = AuthControl(botAuthInfo, object : BotAuthorization {
override suspend fun authorize(session: BotAuthSession, info: BotAuthInfo): BotAuthResult {
val rsp = session.authByPassword(EMPTY_BYTE_ARRAY)
assertFails { session.authByPassword(EMPTY_BYTE_ARRAY) }
assertFails { session.authByPassword(EMPTY_BYTE_ARRAY) }
assertFails { session.authByPassword(EMPTY_BYTE_ARRAY) }
assertFailsWith<IllegalStateException> { session.authByPassword(EMPTY_BYTE_ARRAY) }
assertFailsWith<IllegalStateException> { session.authByPassword(EMPTY_BYTE_ARRAY) }
assertFailsWith<IllegalStateException> { session.authByPassword(EMPTY_BYTE_ARRAY) }
return rsp
}
}, bot.logger, backgroundScope)
}, bot.logger, backgroundScope.coroutineContext)
control.assertRequire(SsoProcessorImpl.AuthMethod.Pwd::class)
control.actComplete()

View File

@ -130,14 +130,6 @@ internal abstract class AbstractRealNetworkHandlerTest<H : NetworkHandler> : Abs
return rsp
}
override suspend fun authByPassword(password: String): BotAuthResult {
return authByPassword(password.md5())
}
override suspend fun authByPassword(passwordMd5: ByteArray): BotAuthResult {
return authByPassword(SecretsProtection.EscapedByteBuffer(passwordMd5))
}
override suspend fun authByQRCode(): BotAuthResult {
return rsp
}