Support network selector retry

This commit is contained in:
Him188 2021-04-26 21:30:51 +08:00
parent 83a81961ca
commit 6e06406a3a
4 changed files with 65 additions and 16 deletions

View File

@ -188,16 +188,20 @@ internal abstract class NetworkHandlerSupport(
) : CancellationException("State is switched from $old to $new")
/**
* Attempts to change state. Returns null if new state has same [class][KClass] as current.
*/
protected inline fun <reified S : BaseStateImpl> setState(noinline new: () -> S): S? = setState(S::class, new)
/**
* Calculate [new state][new] and set it as the current, returning the new state, or `null` if state has concurrently been set to CLOSED.
* Calculate [new state][new] and set it as the current, returning the new state,
* or `null` if state has concurrently been set to CLOSED, or has same [class][KClass] as current.
*
* You may need to call [BaseStateImpl.resumeConnection] to activate the new state, as states are lazy.
*/
protected fun <S : BaseStateImpl> setState(newType: KClass<S>, new: () -> S): S? = synchronized(this) {
if (_state::class == newType) return@synchronized null // already set to expected state by another thread.
if (_state.correspondingState == NetworkHandler.State.CLOSED) return null // error("Cannot change state while it has already been CLOSED.")
protected fun <S : BaseStateImpl> setState(newType: KClass<S>?, new: () -> S): S? = synchronized(this) {
if (newType != null && _state::class == newType) return@synchronized null // already set to expected state by another thread. Avoid replications.
if (_state.correspondingState == NetworkHandler.State.CLOSED) return null // CLOSED is final.
val stateObserver = context.getOrNull(StateObserver)

View File

@ -13,6 +13,8 @@ import kotlinx.atomicfu.atomic
import kotlinx.coroutines.yield
import net.mamoe.mirai.internal.network.handler.NetworkHandler
import net.mamoe.mirai.internal.network.handler.NetworkHandlerFactory
import net.mamoe.mirai.utils.systemProp
import net.mamoe.mirai.utils.toLongUnsigned
import org.jetbrains.annotations.TestOnly
/**
@ -26,7 +28,14 @@ import org.jetbrains.annotations.TestOnly
* and new connections are created only when calling [getResumedInstance] if the old connection was dead.
*/
// may be replaced with a better name.
internal abstract class AbstractKeepAliveNetworkHandlerSelector<H : NetworkHandler> : NetworkHandlerSelector<H> {
internal abstract class AbstractKeepAliveNetworkHandlerSelector<H : NetworkHandler>(
private val maxAttempts: Int = DEFAULT_MAX_ATTEMPTS
) : NetworkHandlerSelector<H> {
init {
require(maxAttempts >= 1) { "maxAttempts must >= 1" }
}
private val current = atomic<H?>(null)
@TestOnly
@ -38,19 +47,23 @@ internal abstract class AbstractKeepAliveNetworkHandlerSelector<H : NetworkHandl
final override fun getResumedInstance(): H? = current.value
final override tailrec suspend fun awaitResumeInstance(): H { // TODO: 2021/4/18 max 5 retry
final override suspend fun awaitResumeInstance(): H = awaitResumeInstanceImpl(0)
private tailrec suspend fun awaitResumeInstanceImpl(attempted: Int): H {
if (attempted >= maxAttempts) error("Failed to resume instance. Maximum attempts reached.")
yield()
val current = getResumedInstance()
return if (current != null) {
when (current.state) {
when (val thisState = current.state) {
NetworkHandler.State.CLOSED -> {
this.current.compareAndSet(current, null) // invalidate the instance and try again.
awaitResumeInstance() // will create new instance.
awaitResumeInstanceImpl(attempted + 1) // will create new instance.
}
NetworkHandler.State.CONNECTING,
NetworkHandler.State.INITIALIZED -> {
current.resumeConnection()
return awaitResumeInstance()
current.resumeConnection() // once finished, it should has been LOADING or OK
check(current.state != thisState) { "State is still $thisState after successful resumeConnection." }
return awaitResumeInstanceImpl(attempted) // does not count for an attempt.
}
NetworkHandler.State.LOADING -> {
return current
@ -61,8 +74,17 @@ internal abstract class AbstractKeepAliveNetworkHandlerSelector<H : NetworkHandl
}
}
} else {
this.current.compareAndSet(current, createInstance())
awaitResumeInstance()
synchronized(this) { // avoid concurrent `createInstance()`
if (getResumedInstance() == null) this.current.compareAndSet(null, createInstance())
}
awaitResumeInstanceImpl(attempted) // directly retry, does not count for attempts.
}
}
companion object {
@JvmField
var DEFAULT_MAX_ATTEMPTS =
systemProp("mirai.network.handler.selector.max.attempts", 3)
.coerceIn(1..Int.MAX_VALUE.toLongUnsigned()).toInt()
}
}

View File

@ -69,7 +69,7 @@ internal open class TestNetworkHandler(
}
fun setState(correspondingState: NetworkHandler.State) {
setState { TestState(correspondingState) }
setState(null) { TestState(correspondingState) }
}
private val initialState = TestState(NetworkHandler.State.INITIALIZED)

View File

@ -17,9 +17,21 @@ import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.*
import kotlin.time.seconds
private class TestSelector(val createInstance0: () -> NetworkHandler) :
AbstractKeepAliveNetworkHandlerSelector<NetworkHandler>() {
val createInstanceCount = AtomicInteger(0)
private class TestSelector :
AbstractKeepAliveNetworkHandlerSelector<NetworkHandler> {
val createInstance0: () -> NetworkHandler
constructor(createInstance0: () -> NetworkHandler) : super() {
this.createInstance0 = createInstance0
}
constructor(maxAttempts: Int, createInstance0: () -> NetworkHandler) : super(maxAttempts) {
this.createInstance0 = createInstance0
}
val createInstanceCount: AtomicInteger = AtomicInteger(0)
override fun createInstance(): NetworkHandler {
createInstanceCount.incrementAndGet()
return this.createInstance0()
@ -60,4 +72,15 @@ internal class KeepAliveNetworkHandlerSelectorTest : AbstractMockNetworkHandlerT
runBlockingUnit(timeout = 3.seconds) { selector.awaitResumeInstance() }
assertEquals(1, selector.createInstanceCount.get())
}
@Test
fun `limited attempts`() = runBlockingUnit {
val selector = TestSelector(3) {
createNetworkHandler().apply { setState(State.CLOSED) }
}
assertFailsWith<IllegalStateException> {
selector.awaitResumeInstance()
}
assertEquals(3, selector.createInstanceCount.get())
}
}