mirror of
https://github.com/mamoe/mirai.git
synced 2025-02-26 20:20:14 +08:00
Redesign auth
This commit is contained in:
parent
5abcd7fb26
commit
992f9289ce
@ -102,8 +102,19 @@ public interface BotAuthInfo {
|
||||
|
||||
@NotStableForInheritance
|
||||
public interface BotAuthSession {
|
||||
/**
|
||||
* @throws LoginFailedException
|
||||
*/
|
||||
public suspend fun authByPassword(password: String): BotAuthResult
|
||||
|
||||
/**
|
||||
* @throws LoginFailedException
|
||||
*/
|
||||
public suspend fun authByPassword(passwordMd5: ByteArray): BotAuthResult
|
||||
|
||||
/**
|
||||
* @throws LoginFailedException
|
||||
*/
|
||||
public suspend fun authByQRCode(): BotAuthResult
|
||||
}
|
||||
|
||||
|
@ -9,157 +9,108 @@
|
||||
|
||||
package net.mamoe.mirai.internal.network.auth
|
||||
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import net.mamoe.mirai.auth.BotAuthInfo
|
||||
import net.mamoe.mirai.auth.BotAuthResult
|
||||
import net.mamoe.mirai.auth.BotAuthorization
|
||||
import net.mamoe.mirai.internal.network.components.SsoProcessorImpl
|
||||
import net.mamoe.mirai.internal.utils.subLogger
|
||||
import net.mamoe.mirai.utils.*
|
||||
import kotlin.coroutines.Continuation
|
||||
import kotlin.coroutines.resume
|
||||
import kotlin.coroutines.suspendCoroutine
|
||||
import kotlin.jvm.Volatile
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.coroutines.cancellation.CancellationException
|
||||
|
||||
|
||||
/**
|
||||
* Event sequence:
|
||||
*
|
||||
* 1. Starts a user coroutine [BotAuthorization.authorize].
|
||||
* 2. User coroutine
|
||||
*/
|
||||
internal class AuthControl(
|
||||
private val botAuthInfo: BotAuthInfo,
|
||||
private val authorization: BotAuthorization,
|
||||
private val logger: MiraiLogger,
|
||||
scope: CoroutineScope,
|
||||
parentCoroutineContext: CoroutineContext,
|
||||
) {
|
||||
internal val exceptionCollector = ExceptionCollector()
|
||||
|
||||
@Volatile
|
||||
private var authorizationContinuation: Continuation<Unit>? = null
|
||||
private val userDecisions: OnDemandConsumer<Throwable?, SsoProcessorImpl.AuthMethod> =
|
||||
CoroutineOnDemandValueScope(parentCoroutineContext, logger.subLogger("AuthControl/UserDecisions")) {
|
||||
/**
|
||||
* Implements [BotAuthSessionInternal] from API, to be called by the user, to receive user's decisions.
|
||||
*/
|
||||
val sessionImpl = object : BotAuthSessionInternal() {
|
||||
private val authResultImpl = object : BotAuthResult {}
|
||||
|
||||
@Volatile
|
||||
private var authRspFuture = initCompletableDeferred()
|
||||
|
||||
@Volatile
|
||||
private var isCompleted = false
|
||||
|
||||
private val rsp = object : BotAuthResult {}
|
||||
|
||||
@Suppress("RemoveExplicitTypeArguments")
|
||||
@OptIn(TestOnly::class)
|
||||
private val authComponent = object : SsoProcessorImpl.SsoProcessorAuthComponent() {
|
||||
override val botAuthResult: BotAuthResult get() = rsp
|
||||
|
||||
override suspend fun emit(method: SsoProcessorImpl.AuthMethod) {
|
||||
logger.verbose { "[AuthControl/emit] Trying emit $method" }
|
||||
|
||||
if (isCompleted) {
|
||||
val msg = "[AuthControl/emit] Failed to emit $method because control completed"
|
||||
|
||||
error(msg.also { logger.verbose(it) })
|
||||
}
|
||||
suspendCoroutine<Unit> { next ->
|
||||
val rspTarget = authRspFuture
|
||||
if (!rspTarget.complete(method)) {
|
||||
val msg = "[AuthControl/emit] Failed to emit $method because auth response completed"
|
||||
|
||||
error(msg.also { logger.verbose(it) })
|
||||
override suspend fun authByPassword(passwordMd5: SecretsProtection.EscapedByteBuffer): BotAuthResult {
|
||||
runWrapInternalException {
|
||||
emit(SsoProcessorImpl.AuthMethod.Pwd(passwordMd5))
|
||||
}?.let { throw it }
|
||||
return authResultImpl
|
||||
}
|
||||
|
||||
override suspend fun authByQRCode(): BotAuthResult {
|
||||
runWrapInternalException {
|
||||
emit(SsoProcessorImpl.AuthMethod.QRCode)
|
||||
}?.let { throw it }
|
||||
return authResultImpl
|
||||
}
|
||||
|
||||
private inline fun <R> runWrapInternalException(block: () -> R): R {
|
||||
try {
|
||||
return block()
|
||||
} catch (e: IllegalProducerStateException) {
|
||||
if (e.lastStateWasSucceed) {
|
||||
throw IllegalStateException(
|
||||
"This login session has already completed. Please return the BotAuthResult you get from 'authBy*()' immediately",
|
||||
e
|
||||
)
|
||||
} else {
|
||||
throw e // internal bug
|
||||
}
|
||||
}
|
||||
}
|
||||
authorizationContinuation = next
|
||||
logger.verbose { "[AuthControl/emit] Emitted $method to $rspTarget" }
|
||||
}
|
||||
logger.verbose { "[AuthControl/emit] Authorization resumed after $method" }
|
||||
}
|
||||
|
||||
override suspend fun authByPassword(passwordMd5: SecretsProtection.EscapedByteBuffer): BotAuthResult {
|
||||
emit(SsoProcessorImpl.AuthMethod.Pwd(passwordMd5))
|
||||
return rsp
|
||||
}
|
||||
|
||||
override suspend fun authByPassword(password: String): BotAuthResult {
|
||||
return authByPassword(password.md5())
|
||||
}
|
||||
|
||||
override suspend fun authByPassword(passwordMd5: ByteArray): BotAuthResult {
|
||||
return authByPassword(SecretsProtection.EscapedByteBuffer(passwordMd5))
|
||||
}
|
||||
|
||||
override suspend fun authByQRCode(): BotAuthResult {
|
||||
emit(SsoProcessorImpl.AuthMethod.QRCode)
|
||||
return rsp
|
||||
}
|
||||
}
|
||||
|
||||
init {
|
||||
// start users' BotAuthorization.authorize
|
||||
scope.launch {
|
||||
try {
|
||||
logger.verbose { "[AuthControl/auth] Authorization started" }
|
||||
|
||||
authorization.authorize(authComponent, botAuthInfo)
|
||||
authorization.authorize(sessionImpl, botAuthInfo)
|
||||
|
||||
logger.verbose { "[AuthControl/auth] Authorization exited" }
|
||||
|
||||
isCompleted = true
|
||||
authRspFuture.complete(SsoProcessorImpl.AuthMethod.NotAvailable)
|
||||
|
||||
finish()
|
||||
} catch (e: CancellationException) {
|
||||
logger.verbose { "[AuthControl/auth] Authorization cancelled" }
|
||||
} catch (e: Throwable) {
|
||||
logger.verbose({ "[AuthControl/auth] Authorization failed" }, e)
|
||||
|
||||
isCompleted = true
|
||||
authRspFuture.complete(SsoProcessorImpl.AuthMethod.Error(e))
|
||||
logger.verbose { "[AuthControl/auth] Authorization failed: $e" }
|
||||
finishExceptionally(e)
|
||||
}
|
||||
}
|
||||
|
||||
init {
|
||||
userDecisions.expectMore(null)
|
||||
}
|
||||
|
||||
private fun onSpinWait() {}
|
||||
// Does not throw
|
||||
suspend fun acquireAuth(): SsoProcessorImpl.AuthMethod {
|
||||
val authTarget = authRspFuture
|
||||
logger.verbose { "[AuthControl/acquire] Acquiring auth method with $authTarget" }
|
||||
val rsp = authTarget.await()
|
||||
logger.debug { "[AuthControl/acquire] Authorization responded: $authTarget, $rsp" }
|
||||
logger.verbose { "[AuthControl/acquire] Acquiring auth method" }
|
||||
|
||||
while (authorizationContinuation == null && !isCompleted) {
|
||||
onSpinWait()
|
||||
val rsp = try {
|
||||
userDecisions.receiveOrNull() ?: SsoProcessorImpl.AuthMethod.NotAvailable
|
||||
} catch (e: ProducerFailureException) {
|
||||
SsoProcessorImpl.AuthMethod.Error(e)
|
||||
}
|
||||
logger.verbose { "[AuthControl/acquire] authorizationContinuation setup: $authorizationContinuation, $isCompleted" }
|
||||
|
||||
logger.debug { "[AuthControl/acquire] Authorization responded: $rsp" }
|
||||
return rsp
|
||||
}
|
||||
|
||||
fun actFailed(cause: Throwable) {
|
||||
fun actMethodFailed(cause: Throwable) {
|
||||
logger.verbose { "[AuthControl/resume] Fire auth failed with cause: $cause" }
|
||||
|
||||
authRspFuture = initCompletableDeferred()
|
||||
authorizationContinuation!!.let { cont ->
|
||||
authorizationContinuation = null
|
||||
cont.resumeWith(Result.failure(cause))
|
||||
}
|
||||
}
|
||||
|
||||
@TestOnly // same as act failed
|
||||
fun actResume() {
|
||||
logger.verbose { "[AuthControl/resume] Fire auth resume" }
|
||||
|
||||
authRspFuture = initCompletableDeferred()
|
||||
authorizationContinuation!!.let { cont ->
|
||||
authorizationContinuation = null
|
||||
cont.resume(Unit)
|
||||
}
|
||||
userDecisions.expectMore(cause)
|
||||
}
|
||||
|
||||
fun actComplete() {
|
||||
logger.verbose { "[AuthControl/resume] Fire auth completed" }
|
||||
|
||||
isCompleted = true
|
||||
authRspFuture = CompletableDeferred(SsoProcessorImpl.AuthMethod.NotAvailable)
|
||||
authorizationContinuation!!.let { cont ->
|
||||
authorizationContinuation = null
|
||||
cont.resume(Unit)
|
||||
}
|
||||
}
|
||||
|
||||
private fun initCompletableDeferred(): CompletableDeferred<SsoProcessorImpl.AuthMethod> {
|
||||
return CompletableDeferred<SsoProcessorImpl.AuthMethod>().also { df ->
|
||||
df.invokeOnCompletion {
|
||||
logger.debug { "[AuthControl/cd] $df completed with $it" }
|
||||
}
|
||||
}
|
||||
userDecisions.finish()
|
||||
}
|
||||
}
|
||||
|
@ -14,10 +14,20 @@ import net.mamoe.mirai.auth.BotAuthResult
|
||||
import net.mamoe.mirai.auth.BotAuthSession
|
||||
import net.mamoe.mirai.auth.BotAuthorization
|
||||
import net.mamoe.mirai.utils.SecretsProtection
|
||||
import net.mamoe.mirai.utils.md5
|
||||
|
||||
|
||||
// With SecretsProtection support
|
||||
internal abstract class BotAuthSessionInternal : BotAuthSession {
|
||||
|
||||
final override suspend fun authByPassword(password: String): BotAuthResult {
|
||||
return authByPassword(password.md5())
|
||||
}
|
||||
|
||||
final override suspend fun authByPassword(passwordMd5: ByteArray): BotAuthResult {
|
||||
return authByPassword(SecretsProtection.EscapedByteBuffer(passwordMd5))
|
||||
}
|
||||
|
||||
abstract suspend fun authByPassword(passwordMd5: SecretsProtection.EscapedByteBuffer): BotAuthResult
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,196 @@
|
||||
/*
|
||||
* Copyright 2019-2023 Mamoe Technologies and contributors.
|
||||
*
|
||||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
|
||||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
|
||||
*
|
||||
* https://github.com/mamoe/mirai/blob/dev/LICENSE
|
||||
*/
|
||||
|
||||
package net.mamoe.mirai.internal.network.auth
|
||||
|
||||
import kotlinx.atomicfu.AtomicRef
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.atomicfu.loop
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.job
|
||||
import kotlinx.coroutines.launch
|
||||
import net.mamoe.mirai.utils.MiraiLogger
|
||||
import net.mamoe.mirai.utils.childScope
|
||||
import net.mamoe.mirai.utils.debug
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.coroutines.cancellation.CancellationException
|
||||
|
||||
|
||||
internal class IllegalProducerStateException(
|
||||
private val state: ProducerState<*, *>,
|
||||
message: String? = state.toString(),
|
||||
cause: Throwable? = null,
|
||||
) : IllegalStateException(message, cause) {
|
||||
val lastStateWasSucceed get() = (state is ProducerState.Finished) && state.isSuccess
|
||||
}
|
||||
|
||||
internal class CoroutineOnDemandValueScope<T, V>(
|
||||
parentCoroutineContext: CoroutineContext,
|
||||
private val logger: MiraiLogger,
|
||||
private val producerCoroutine: suspend OnDemandProducerScope<T, V>.() -> Unit,
|
||||
) : OnDemandConsumer<T, V> {
|
||||
private val coroutineScope = parentCoroutineContext.childScope("CoroutineOnDemandValueScope")
|
||||
|
||||
private val state: AtomicRef<ProducerState<T, V>> = atomic(ProducerState.JustInitialized())
|
||||
|
||||
|
||||
inner class Producer : OnDemandProducerScope<T, V> {
|
||||
init {
|
||||
coroutineScope.launch {
|
||||
try {
|
||||
producerCoroutine()
|
||||
} catch (_: CancellationException) {
|
||||
// ignored
|
||||
} catch (e: Exception) {
|
||||
finishExceptionally(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun emit(value: V): T {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
|
||||
is ProducerState.Producing -> {
|
||||
val deferred = state.deferred
|
||||
val consumingState = ProducerState.Consuming(
|
||||
state.producer,
|
||||
state.deferred,
|
||||
coroutineScope.coroutineContext
|
||||
)
|
||||
if (compareAndSetState(state, consumingState)) {
|
||||
deferred.complete(value) // produce a value
|
||||
return consumingState.producerLatch.acquire() // wait for producer to consume the previous value.
|
||||
}
|
||||
}
|
||||
|
||||
else -> throw IllegalProducerStateException(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun finishExceptionally(exception: Throwable) {
|
||||
finishImpl(exception)
|
||||
}
|
||||
|
||||
override fun finish() {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
|
||||
else -> {
|
||||
if (compareAndSetState(state, ProducerState.Finished(state, null))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun finishImpl(exception: Throwable?) {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ProducerState.Finished -> throw state.createAlreadyFinishedException(exception)
|
||||
else -> {
|
||||
if (compareAndSetState(state, ProducerState.Finished(state, exception))) {
|
||||
val cancellationException = kotlinx.coroutines.CancellationException("Finished", exception)
|
||||
coroutineScope.cancel(cancellationException)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun compareAndSetState(state: ProducerState<T, V>, newState: ProducerState<T, V>): Boolean {
|
||||
return this.state.compareAndSet(state, newState).also {
|
||||
logger.debug { "CAS: $state -> $newState: $it" }
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun receiveOrNull(): V? {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ProducerState.Producing -> {
|
||||
// still producing value
|
||||
|
||||
state.deferred.await() // just wait for value, but does not return it.
|
||||
|
||||
// The value will be completed in ProducerState.Consuming state,
|
||||
// but you cannot thread-safely assume current state is Consuming.
|
||||
|
||||
// Here we will loop again, to atomically switch to Consumed state.
|
||||
}
|
||||
|
||||
is ProducerState.Consuming -> {
|
||||
// value is ready, switch state to ProducerReady
|
||||
|
||||
if (compareAndSetState(
|
||||
state,
|
||||
ProducerState.Consumed(state.producer, state.producerLatch)
|
||||
)
|
||||
) {
|
||||
return try {
|
||||
state.value.await() // won't suspend, since value is already completed
|
||||
} catch (e: Exception) {
|
||||
throw ProducerFailureException(cause = e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
is ProducerState.Finished -> return null
|
||||
else -> throw IllegalProducerStateException(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun expectMore(ticket: T): Boolean {
|
||||
state.loop { state ->
|
||||
when (state) {
|
||||
is ProducerState.JustInitialized -> {
|
||||
compareAndSetState(state, ProducerState.CreatingProducer { Producer() })
|
||||
// loop again
|
||||
}
|
||||
|
||||
is ProducerState.CreatingProducer -> {
|
||||
compareAndSetState(state, ProducerState.ProducerReady(state.producer))
|
||||
// loop again
|
||||
}
|
||||
|
||||
is ProducerState.ProducerReady -> {
|
||||
val deferred = CompletableDeferred<V>(coroutineScope.coroutineContext.job)
|
||||
if (!compareAndSetState(state, ProducerState.Producing(state.producer, deferred))) {
|
||||
deferred.cancel() // avoid leak
|
||||
}
|
||||
// loop again
|
||||
}
|
||||
|
||||
is ProducerState.Producing -> return true // ok
|
||||
|
||||
is ProducerState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
|
||||
|
||||
is ProducerState.Consumed -> {
|
||||
if (compareAndSetState(state, ProducerState.ProducerReady(state.producer))) {
|
||||
// wake up producer async.
|
||||
state.producerLatch.resumeWith(Result.success(ticket))
|
||||
// loop again to switch state atomically to Producing.
|
||||
// Do not do switch state directly here — async producer may race with you!
|
||||
}
|
||||
}
|
||||
|
||||
is ProducerState.Finished -> return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun finish() {
|
||||
finishImpl(null)
|
||||
}
|
||||
}
|
60
mirai-core/src/commonMain/kotlin/network/auth/Latch.kt
Normal file
60
mirai-core/src/commonMain/kotlin/network/auth/Latch.kt
Normal file
@ -0,0 +1,60 @@
|
||||
/*
|
||||
* Copyright 2019-2023 Mamoe Technologies and contributors.
|
||||
*
|
||||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
|
||||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
|
||||
*
|
||||
* https://github.com/mamoe/mirai/blob/dev/LICENSE
|
||||
*/
|
||||
|
||||
package net.mamoe.mirai.internal.network.auth
|
||||
|
||||
import kotlinx.atomicfu.locks.reentrantLock
|
||||
import kotlinx.atomicfu.locks.withLock
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.completeWith
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.jvm.Volatile
|
||||
|
||||
|
||||
internal interface Latch<T> {
|
||||
/**
|
||||
* Suspends and waits to acquire the latch.
|
||||
* @throws Throwable if [resumeWith] is called with [Result.Failure]
|
||||
*/
|
||||
suspend fun acquire(): T
|
||||
|
||||
/**
|
||||
* Release the latch, resuming the coroutines waiting for the latch.
|
||||
*
|
||||
* This function will return immediately unless a client is calling [acquire] concurrently.
|
||||
*/
|
||||
fun resumeWith(result: Result<T>)
|
||||
}
|
||||
|
||||
|
||||
internal fun <T> Latch(parentCoroutineContext: CoroutineContext): Latch<T> = LatchImpl(parentCoroutineContext)
|
||||
|
||||
private class LatchImpl<T>(
|
||||
parentCoroutineContext: CoroutineContext
|
||||
) : Latch<T> {
|
||||
@Volatile
|
||||
private var deferred: CompletableDeferred<T>? = CompletableDeferred(parentCoroutineContext[Job])
|
||||
|
||||
private val lock = reentrantLock()
|
||||
|
||||
override suspend fun acquire(): T = lock.withLock {
|
||||
val deferred = this.deferred!!
|
||||
return deferred.await().also {
|
||||
this.deferred = null
|
||||
}
|
||||
}
|
||||
|
||||
override fun resumeWith(result: Result<T>): Unit = lock.withLock {
|
||||
val deferred = this.deferred ?: CompletableDeferred<T>().also { this.deferred = it }
|
||||
deferred.completeWith(result)
|
||||
}
|
||||
|
||||
override fun toString(): String = "LatchImpl($deferred)"
|
||||
}
|
@ -0,0 +1,86 @@
|
||||
/*
|
||||
* Copyright 2019-2023 Mamoe Technologies and contributors.
|
||||
*
|
||||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
|
||||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
|
||||
*
|
||||
* https://github.com/mamoe/mirai/blob/dev/LICENSE
|
||||
*/
|
||||
|
||||
package net.mamoe.mirai.internal.network.auth
|
||||
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.channels.ReceiveChannel
|
||||
import kotlin.coroutines.Continuation
|
||||
import kotlin.coroutines.cancellation.CancellationException
|
||||
|
||||
|
||||
/**
|
||||
* 按需供给的值制造器.
|
||||
*/
|
||||
internal interface OnDemandProducerScope<T, V> {
|
||||
/**
|
||||
* 挂起协程, 直到 [OnDemandConsumer] 期望接收一个 [V], 届时将 [value] 传递给 [OnDemandConsumer.receiveOrNull], 成为其返回值.
|
||||
*
|
||||
* 若在调用 [emit] 时已经有 [OnDemandConsumer] 正在等待, 则该 [OnDemandConsumer] 协程会立即[恢复][Continuation.resumeWith].
|
||||
*
|
||||
* 若 [OnDemandConsumer] 已经[完结][OnDemandConsumer.finish], [OnDemandProducerScope.emit] 会抛出 [IllegalProducerStateException].
|
||||
*/
|
||||
suspend fun emit(value: V): T
|
||||
|
||||
/**
|
||||
* 标记此 [OnDemandProducerScope] 在生产 [V] 的过程中出现错误.
|
||||
*
|
||||
* 这也会终止此 [OnDemandProducerScope], 随后 [OnDemandConsumer.receiveOrNull] 将会抛出 [ProducerFailureException].
|
||||
*/
|
||||
fun finishExceptionally(exception: Throwable)
|
||||
|
||||
/**
|
||||
* 标记此 [OnDemandProducerScope] 已经没有更多 [V] 可生产.
|
||||
*
|
||||
* 随后 [OnDemandConsumer.receiveOrNull] 将会抛出 [IllegalStateException].
|
||||
*/
|
||||
fun finish()
|
||||
}
|
||||
|
||||
/**
|
||||
* 按需消费者.
|
||||
*
|
||||
* 与 [ReceiveChannel] 不同, [OnDemandConsumer] 只有在调用 [expectMore] 后才会期待[生产者][OnDemandProducerScope] 生产下一个 [V].
|
||||
*/
|
||||
internal interface OnDemandConsumer<T, V> {
|
||||
/**
|
||||
* 挂起协程并等待从 [OnDemandProducerScope] [接收][OnDemandProducerScope.emit]一个 [V].
|
||||
*
|
||||
* 当此函数被多个线程 (协程) 同时调用时, 只有一个线程挂起并获得 [V], 其他线程将会
|
||||
*
|
||||
* @throws ProducerFailureException 当 [OnDemandProducerScope.finishExceptionally] 时抛出.
|
||||
* @throws CancellationException 当协程被取消时抛出
|
||||
* @throws IllegalProducerStateException 当状态异常, 如未调用 [expectMore] 时抛出
|
||||
*/
|
||||
@Throws(ProducerFailureException::class, CancellationException::class)
|
||||
suspend fun receiveOrNull(): V?
|
||||
|
||||
/**
|
||||
* 期待 [OnDemandProducerScope] 再生产一个 [V]. 期望生产后必须在之后调用 [receiveOrNull] 或 [finish] 来消耗生产的 [V].
|
||||
*
|
||||
* 在成功发起期待后返回 `true`; 在 [OnDemandProducerScope] 已经[完结][OnDemandProducerScope.finish] 时返回 `false`.
|
||||
*
|
||||
* @throws IllegalProducerStateException 当 [expectMore] 被调用后, 没有调用 [receiveOrNull] 就又调用了 [expectMore] 时抛出
|
||||
*/
|
||||
fun expectMore(ticket: T): Boolean
|
||||
|
||||
/**
|
||||
* 标记此 [OnDemandConsumer] 已经完结.
|
||||
*
|
||||
* 如果 [OnDemandProducerScope] 仍在运行, 将会 (正常地) 取消 [OnDemandProducerScope].
|
||||
*
|
||||
* 随后 [OnDemandProducerScope.emit] 将会抛出 [IllegalStateException].
|
||||
*/
|
||||
fun finish()
|
||||
}
|
||||
|
||||
internal class ProducerFailureException(
|
||||
override val message: String? = null,
|
||||
override val cause: Throwable?
|
||||
) : Exception()
|
176
mirai-core/src/commonMain/kotlin/network/auth/ProducerState.kt
Normal file
176
mirai-core/src/commonMain/kotlin/network/auth/ProducerState.kt
Normal file
@ -0,0 +1,176 @@
|
||||
/*
|
||||
* Copyright 2019-2023 Mamoe Technologies and contributors.
|
||||
*
|
||||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
|
||||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
|
||||
*
|
||||
* https://github.com/mamoe/mirai/blob/dev/LICENSE
|
||||
*/
|
||||
|
||||
package net.mamoe.mirai.internal.network.auth
|
||||
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.Deferred
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
/**
|
||||
* Producer states.
|
||||
*/
|
||||
internal sealed interface ProducerState<T, V> {
|
||||
/*
|
||||
* 可变更状态的函数: [emit], [receiveOrNull], [expectMore], [finish], [finishExceptionally]
|
||||
*
|
||||
* [emit] 和 [receiveOrNull] 为 suspend 函数, 在图中 "(suspend)" 表示挂起它们的协程, "(resume)" 表示恢复它们的协程.
|
||||
*
|
||||
* "A ~~~~~~> B" 表示在切换为状态 A 后, 会挂起或恢复协程 B.
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
* JustInitialized
|
||||
* |
|
||||
* | 调用 [expectMore]
|
||||
* |
|
||||
* V
|
||||
* CreatingProducer
|
||||
* |
|
||||
* |
|
||||
* |
|
||||
* V
|
||||
* ProducerReady (从此用户协程作为 producer 在后台运行)
|
||||
* |
|
||||
* |
|
||||
* | <--------------------------------------------------
|
||||
* | \
|
||||
* V |
|
||||
* Producing ([expectMore] 结束) |
|
||||
* | \ |
|
||||
* 调用 | \ |
|
||||
* [receiveOrNull] | \ 调用 [emit] |
|
||||
* / \ |
|
||||
* / \ |
|
||||
* / \ |
|
||||
* | \ |
|
||||
* | \ |
|
||||
* | |------------- |
|
||||
* | | \ |
|
||||
* | | | |
|
||||
* | \ | |
|
||||
* | \ | |
|
||||
* | \ | |
|
||||
* | | | |
|
||||
* V (resume) V | |
|
||||
* ([receiveOrNull] suspend) <~~~~~~~~~~~~ Consuming | |
|
||||
* | / | |
|
||||
* | / | |
|
||||
* | /---------------/ | |
|
||||
* | / 调用 [receiveOrNull] | |
|
||||
* | / | |
|
||||
* |/ | |
|
||||
* | | |
|
||||
* | | |
|
||||
* V | |
|
||||
* ([receiveOrNull] 结束) Consumed | |
|
||||
* | | |
|
||||
* | 调用 [expectMore] | |
|
||||
* | | |
|
||||
* V (resume) V |
|
||||
* ProducerReady ~~~~~~~~~~~~~~~~> ([emit] suspend) |
|
||||
* | | |
|
||||
* | | |
|
||||
* | V |
|
||||
* | ([emit] 结束) |
|
||||
* | |
|
||||
* |------------------------------------------------------------+
|
||||
* (返回顶部 Producing)
|
||||
*
|
||||
*
|
||||
*
|
||||
* 在任意状态调用 [finish] 以及 [finishExceptionally], 可将状态转移到最终状态 [Finished].
|
||||
*
|
||||
* 在一个状态中调用图中未说明的函数会抛出 [IllegalProducerStateException].
|
||||
*/
|
||||
|
||||
/**
|
||||
* Override this function to produce good debug information
|
||||
*/
|
||||
abstract override fun toString(): String
|
||||
|
||||
class JustInitialized<T, V> : ProducerState<T, V> {
|
||||
override fun toString(): String = "JustInitialized"
|
||||
}
|
||||
|
||||
sealed interface HasProducer<T, V> : ProducerState<T, V> {
|
||||
val producer: OnDemandProducerScope<T, V>
|
||||
}
|
||||
|
||||
// This is need — to ensure [launchProducer] is called exactly once.
|
||||
class CreatingProducer<T, V>(
|
||||
launchProducer: () -> OnDemandProducerScope<T, V>
|
||||
) : HasProducer<T, V> {
|
||||
override val producer: OnDemandProducerScope<T, V> by lazy(launchProducer)
|
||||
override fun toString(): String = "CreatingProducer"
|
||||
}
|
||||
|
||||
class ProducerReady<T, V>(
|
||||
override val producer: OnDemandProducerScope<T, V>,
|
||||
) : HasProducer<T, V> {
|
||||
override fun toString(): String = "ProducerReady"
|
||||
}
|
||||
|
||||
class Producing<T, V>(
|
||||
override val producer: OnDemandProducerScope<T, V>,
|
||||
val deferred: CompletableDeferred<V>,
|
||||
) : HasProducer<T, V> {
|
||||
override fun toString(): String = "Producing(deferred.completed=${deferred.isCompleted})"
|
||||
}
|
||||
|
||||
class Consuming<T, V>(
|
||||
override val producer: OnDemandProducerScope<T, V>,
|
||||
val value: Deferred<V>,
|
||||
parentCoroutineContext: CoroutineContext,
|
||||
) : HasProducer<T, V> {
|
||||
val producerLatch = Latch<T>(parentCoroutineContext)
|
||||
|
||||
override fun toString(): String {
|
||||
val completed =
|
||||
value.runCatching { getCompleted().toString() }.getOrNull() // getCompleted() is experimental
|
||||
return "Consuming(value=$completed)"
|
||||
}
|
||||
}
|
||||
|
||||
class Consumed<T, V>(
|
||||
override val producer: OnDemandProducerScope<T, V>,
|
||||
val producerLatch: Latch<T>
|
||||
) : HasProducer<T, V> {
|
||||
override fun toString(): String = "Consumed($producerLatch)"
|
||||
}
|
||||
|
||||
class Finished<T, V>(
|
||||
val previousState: ProducerState<T, V>,
|
||||
val exception: Throwable?,
|
||||
) : ProducerState<T, V> {
|
||||
val isSuccess get() = exception == null
|
||||
|
||||
fun createAlreadyFinishedException(cause: Throwable?): IllegalProducerStateException {
|
||||
val exception = exception
|
||||
return if (exception == null) {
|
||||
IllegalProducerStateException(
|
||||
this,
|
||||
"Producer has already finished normally, but attempting to finish with the cause $cause. Previous state was: $previousState",
|
||||
cause = cause
|
||||
)
|
||||
} else {
|
||||
IllegalProducerStateException(
|
||||
this,
|
||||
"Producer has already finished with the suppressed exception, but attempting to finish with the cause $cause. Previous state was: $previousState",
|
||||
cause = cause
|
||||
).apply {
|
||||
addSuppressed(exception)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String = "Finished($previousState, $exception)"
|
||||
}
|
||||
}
|
@ -17,7 +17,6 @@ import net.mamoe.mirai.internal.network.QQAndroidClient
|
||||
import net.mamoe.mirai.internal.network.QRCodeLoginData
|
||||
import net.mamoe.mirai.internal.network.WLoginSigInfo
|
||||
import net.mamoe.mirai.internal.network.auth.AuthControl
|
||||
import net.mamoe.mirai.internal.network.auth.BotAuthSessionInternal
|
||||
import net.mamoe.mirai.internal.network.auth.BotAuthorizationWithSecretsProtection
|
||||
import net.mamoe.mirai.internal.network.component.ComponentKey
|
||||
import net.mamoe.mirai.internal.network.handler.NetworkHandler
|
||||
@ -165,7 +164,7 @@ internal class SsoProcessorImpl(
|
||||
botAuthInfo,
|
||||
ssoContext.bot.account.authorization,
|
||||
ssoContext.bot.network.logger,
|
||||
ssoContext.bot,
|
||||
ssoContext.bot.coroutineContext, // do not use network context because network may restart whilst auth control should keep alive
|
||||
)
|
||||
}
|
||||
|
||||
@ -218,8 +217,6 @@ internal class SsoProcessorImpl(
|
||||
ssoContext.bot.components[EcdhInitialPublicKeyUpdater].refreshInitialPublicKeyAndApplyEcdh()
|
||||
|
||||
when (val authw = authControl0.acquireAuth().also { nextAuthMethod = it }) {
|
||||
is AuthMethod.DirectError -> throw authw.exception
|
||||
|
||||
is AuthMethod.Error -> {
|
||||
authControl = null
|
||||
authControl0.exceptionCollector.collectThrow(authw.exception)
|
||||
@ -247,20 +244,13 @@ internal class SsoProcessorImpl(
|
||||
authControl = null
|
||||
} catch (exception: Throwable) {
|
||||
if (exception is SelectorRequireReconnectException) {
|
||||
|
||||
if (nextAuthMethod is AuthMethod.DirectError) { // @TestOnly
|
||||
authControl0.actResume()
|
||||
}
|
||||
|
||||
throw exception
|
||||
}
|
||||
|
||||
ssoContext.bot.network.logger.warning({ "Failed with auth method: $nextAuthMethod" }, exception)
|
||||
|
||||
if (nextAuthMethod is AuthMethod.DirectError) { // @TestOnly
|
||||
authControl0.actResume()
|
||||
} else if (nextAuthMethod !is AuthMethod.Error && nextAuthMethod != null) {
|
||||
authControl0.actFailed(exception)
|
||||
if (nextAuthMethod !is AuthMethod.Error && nextAuthMethod != null) {
|
||||
authControl0.actMethodFailed(exception)
|
||||
}
|
||||
|
||||
if (exception is NetworkException) {
|
||||
@ -280,51 +270,24 @@ internal class SsoProcessorImpl(
|
||||
|
||||
|
||||
sealed class AuthMethod {
|
||||
object NotAvailable : AuthMethod()
|
||||
object NotAvailable : AuthMethod() {
|
||||
override fun toString(): String = "NotAvailable"
|
||||
}
|
||||
|
||||
object QRCode : AuthMethod() {
|
||||
override fun toString(): String {
|
||||
return "QRCode"
|
||||
}
|
||||
override fun toString(): String = "QRCode"
|
||||
}
|
||||
|
||||
class Pwd(val passwordMd5: SecretsProtection.EscapedByteBuffer) : AuthMethod() {
|
||||
override fun toString(): String {
|
||||
return "Password@${hashCode()}"
|
||||
}
|
||||
override fun toString(): String = "Password@${hashCode()}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Exception in [BotAuthorization]
|
||||
*/
|
||||
class Error(val exception: Throwable) : AuthMethod() {
|
||||
override fun toString(): String {
|
||||
return "Error[$exception]@${hashCode()}"
|
||||
}
|
||||
override fun toString(): String = "Error[$exception]@${hashCode()}"
|
||||
}
|
||||
|
||||
/**
|
||||
* For mocking a login method throw a exception
|
||||
*/
|
||||
@TestOnly
|
||||
class DirectError(val exception: Throwable) : AuthMethod() {
|
||||
override fun toString(): String {
|
||||
return "DirectError[$exception]@${hashCode()}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@TestOnly
|
||||
internal abstract class SsoProcessorAuthComponent : BotAuthSessionInternal() {
|
||||
abstract suspend fun emit(method: AuthMethod)
|
||||
|
||||
suspend fun emitDirectError(error: Throwable) {
|
||||
emit(AuthMethod.DirectError(error))
|
||||
}
|
||||
|
||||
|
||||
abstract val botAuthResult: BotAuthResult
|
||||
}
|
||||
|
||||
private var authControl: AuthControl? = null
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2019-2022 Mamoe Technologies and contributors.
|
||||
* Copyright 2019-2023 Mamoe Technologies and contributors.
|
||||
*
|
||||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
|
||||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
|
||||
@ -115,5 +115,10 @@ internal object MiraiCoreServices {
|
||||
"net.mamoe.mirai.message.data.OfflineAudio.Factory",
|
||||
"net.mamoe.mirai.internal.message.data.OfflineAudioFactoryImpl"
|
||||
) { net.mamoe.mirai.internal.message.data.OfflineAudioFactoryImpl() }
|
||||
|
||||
Services.register(
|
||||
"net.mamoe.mirai.auth.DefaultBotAuthorizationFactory",
|
||||
"net.mamoe.mirai.internal.network.auth.DefaultBotAuthorizationFactoryImpl"
|
||||
) { net.mamoe.mirai.internal.network.auth.DefaultBotAuthorizationFactoryImpl() }
|
||||
}
|
||||
}
|
@ -15,19 +15,21 @@ import net.mamoe.mirai.auth.BotAuthInfo
|
||||
import net.mamoe.mirai.auth.BotAuthResult
|
||||
import net.mamoe.mirai.auth.BotAuthSession
|
||||
import net.mamoe.mirai.auth.BotAuthorization
|
||||
import net.mamoe.mirai.internal.network.auth.AuthControl
|
||||
import net.mamoe.mirai.internal.network.components.SsoProcessorContext
|
||||
import net.mamoe.mirai.internal.network.components.SsoProcessorImpl
|
||||
import net.mamoe.mirai.internal.network.framework.AbstractCommonNHTest
|
||||
import net.mamoe.mirai.network.CustomLoginFailedException
|
||||
import net.mamoe.mirai.utils.BotConfiguration
|
||||
import net.mamoe.mirai.utils.DeviceInfo
|
||||
import net.mamoe.mirai.utils.EMPTY_BYTE_ARRAY
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertFails
|
||||
import kotlin.test.assertFailsWith
|
||||
import kotlin.test.fail
|
||||
|
||||
internal class BotAuthControlTest : AbstractCommonNHTest() {
|
||||
val botAuthInfo = object : BotAuthInfo {
|
||||
private val botAuthInfo = object : BotAuthInfo {
|
||||
override val id: Long
|
||||
get() = bot.id
|
||||
override val deviceInfo: DeviceInfo
|
||||
@ -36,7 +38,7 @@ internal class BotAuthControlTest : AbstractCommonNHTest() {
|
||||
get() = bot.configuration
|
||||
}
|
||||
|
||||
private suspend fun SsoProcessorImpl.AuthControl.assertRequire(exceptedType: KClass<*>) {
|
||||
private suspend fun AuthControl.assertRequire(exceptedType: KClass<*>) {
|
||||
println("Requiring auth method")
|
||||
val nextAuth = acquireAuth()
|
||||
println("Got $nextAuth")
|
||||
@ -52,11 +54,11 @@ internal class BotAuthControlTest : AbstractCommonNHTest() {
|
||||
@Test
|
||||
fun `auth test`() = runTest {
|
||||
|
||||
val control = SsoProcessorImpl.AuthControl(botAuthInfo, object : BotAuthorization {
|
||||
val control = AuthControl(botAuthInfo, object : BotAuthorization {
|
||||
override suspend fun authorize(session: BotAuthSession, info: BotAuthInfo): BotAuthResult {
|
||||
return session.authByPassword(EMPTY_BYTE_ARRAY)
|
||||
}
|
||||
}, bot.logger, backgroundScope)
|
||||
}, bot.logger, backgroundScope.coroutineContext)
|
||||
|
||||
control.assertRequire(SsoProcessorImpl.AuthMethod.Pwd::class)
|
||||
control.actComplete()
|
||||
@ -66,35 +68,37 @@ internal class BotAuthControlTest : AbstractCommonNHTest() {
|
||||
|
||||
@Test
|
||||
fun `test auth failed and reselect`() = runTest {
|
||||
class MyLoginFailedException : CustomLoginFailedException(killBot = false)
|
||||
|
||||
val control = SsoProcessorImpl.AuthControl(botAuthInfo, object : BotAuthorization {
|
||||
val control = AuthControl(botAuthInfo, object : BotAuthorization {
|
||||
override suspend fun authorize(session: BotAuthSession, info: BotAuthInfo): BotAuthResult {
|
||||
assertFails { session.authByPassword(EMPTY_BYTE_ARRAY); println("!") }
|
||||
assertFailsWith<MyLoginFailedException> { session.authByPassword(EMPTY_BYTE_ARRAY); println("!") }
|
||||
println("114514")
|
||||
return session.authByPassword(EMPTY_BYTE_ARRAY)
|
||||
}
|
||||
}, bot.logger, backgroundScope)
|
||||
}, bot.logger, backgroundScope.coroutineContext)
|
||||
|
||||
control.assertRequire(SsoProcessorImpl.AuthMethod.Pwd::class)
|
||||
control.actFailed(Throwable())
|
||||
control.actMethodFailed(MyLoginFailedException())
|
||||
|
||||
control.assertRequire(SsoProcessorImpl.AuthMethod.Pwd::class)
|
||||
control.actComplete()
|
||||
|
||||
control.assertRequire(SsoProcessorImpl.AuthMethod.NotAvailable::class)
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `failed when login complete`() = runTest {
|
||||
|
||||
val control = SsoProcessorImpl.AuthControl(botAuthInfo, object : BotAuthorization {
|
||||
val control = AuthControl(botAuthInfo, object : BotAuthorization {
|
||||
override suspend fun authorize(session: BotAuthSession, info: BotAuthInfo): BotAuthResult {
|
||||
val rsp = session.authByPassword(EMPTY_BYTE_ARRAY)
|
||||
assertFails { session.authByPassword(EMPTY_BYTE_ARRAY) }
|
||||
assertFails { session.authByPassword(EMPTY_BYTE_ARRAY) }
|
||||
assertFails { session.authByPassword(EMPTY_BYTE_ARRAY) }
|
||||
assertFailsWith<IllegalStateException> { session.authByPassword(EMPTY_BYTE_ARRAY) }
|
||||
assertFailsWith<IllegalStateException> { session.authByPassword(EMPTY_BYTE_ARRAY) }
|
||||
assertFailsWith<IllegalStateException> { session.authByPassword(EMPTY_BYTE_ARRAY) }
|
||||
return rsp
|
||||
}
|
||||
}, bot.logger, backgroundScope)
|
||||
}, bot.logger, backgroundScope.coroutineContext)
|
||||
|
||||
control.assertRequire(SsoProcessorImpl.AuthMethod.Pwd::class)
|
||||
control.actComplete()
|
||||
|
@ -130,14 +130,6 @@ internal abstract class AbstractRealNetworkHandlerTest<H : NetworkHandler> : Abs
|
||||
return rsp
|
||||
}
|
||||
|
||||
override suspend fun authByPassword(password: String): BotAuthResult {
|
||||
return authByPassword(password.md5())
|
||||
}
|
||||
|
||||
override suspend fun authByPassword(passwordMd5: ByteArray): BotAuthResult {
|
||||
return authByPassword(SecretsProtection.EscapedByteBuffer(passwordMd5))
|
||||
}
|
||||
|
||||
override suspend fun authByQRCode(): BotAuthResult {
|
||||
return rsp
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user