From 6e06406a3abad05da939c4aff325f94094fd8127 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Mon, 26 Apr 2021 21:30:51 +0800
Subject: [PATCH] Support network selector retry

---
 .../network/handler/NetworkHandlerSupport.kt  | 12 ++++--
 ...AbstractKeepAliveNetworkHandlerSelector.kt | 38 +++++++++++++++----
 .../kotlin/network/framework/testUtils.kt     |  2 +-
 .../KeepAliveNetworkHandlerSelectorTest.kt    | 29 ++++++++++++--
 4 files changed, 65 insertions(+), 16 deletions(-)

diff --git a/mirai-core/src/commonMain/kotlin/network/handler/NetworkHandlerSupport.kt b/mirai-core/src/commonMain/kotlin/network/handler/NetworkHandlerSupport.kt
index 082367ded..3b77d5ceb 100644
--- a/mirai-core/src/commonMain/kotlin/network/handler/NetworkHandlerSupport.kt
+++ b/mirai-core/src/commonMain/kotlin/network/handler/NetworkHandlerSupport.kt
@@ -188,16 +188,20 @@ internal abstract class NetworkHandlerSupport(
     ) : CancellationException("State is switched from $old to $new")
 
 
+    /**
+     * Attempts to change state. Returns null if new state has same [class][KClass] as current.
+     */
     protected inline fun <reified S : BaseStateImpl> setState(noinline new: () -> S): S? = setState(S::class, new)
 
     /**
-     * Calculate [new state][new] and set it as the current, returning the new state, or `null` if state has concurrently been set to CLOSED.
+     * Calculate [new state][new] and set it as the current, returning the new state,
+     * or `null` if state has concurrently been set to CLOSED, or has same [class][KClass] as current.
      *
      * You may need to call [BaseStateImpl.resumeConnection] to activate the new state, as states are lazy.
      */
-    protected fun <S : BaseStateImpl> setState(newType: KClass<S>, new: () -> S): S? = synchronized(this) {
-        if (_state::class == newType) return@synchronized null // already set to expected state by another thread.
-        if (_state.correspondingState == NetworkHandler.State.CLOSED) return null // error("Cannot change state while it has already been CLOSED.")
+    protected fun <S : BaseStateImpl> setState(newType: KClass<S>?, new: () -> S): S? = synchronized(this) {
+        if (newType != null && _state::class == newType) return@synchronized null // already set to expected state by another thread. Avoid replications.
+        if (_state.correspondingState == NetworkHandler.State.CLOSED) return null // CLOSED is final.
 
         val stateObserver = context.getOrNull(StateObserver)
 
diff --git a/mirai-core/src/commonMain/kotlin/network/handler/selector/AbstractKeepAliveNetworkHandlerSelector.kt b/mirai-core/src/commonMain/kotlin/network/handler/selector/AbstractKeepAliveNetworkHandlerSelector.kt
index 25d22842b..ef21380fa 100644
--- a/mirai-core/src/commonMain/kotlin/network/handler/selector/AbstractKeepAliveNetworkHandlerSelector.kt
+++ b/mirai-core/src/commonMain/kotlin/network/handler/selector/AbstractKeepAliveNetworkHandlerSelector.kt
@@ -13,6 +13,8 @@ import kotlinx.atomicfu.atomic
 import kotlinx.coroutines.yield
 import net.mamoe.mirai.internal.network.handler.NetworkHandler
 import net.mamoe.mirai.internal.network.handler.NetworkHandlerFactory
+import net.mamoe.mirai.utils.systemProp
+import net.mamoe.mirai.utils.toLongUnsigned
 import org.jetbrains.annotations.TestOnly
 
 /**
@@ -26,7 +28,14 @@ import org.jetbrains.annotations.TestOnly
  * and new connections are created only when calling [getResumedInstance] if the old connection was dead.
  */
 // may be replaced with a better name.
-internal abstract class AbstractKeepAliveNetworkHandlerSelector<H : NetworkHandler> : NetworkHandlerSelector<H> {
+internal abstract class AbstractKeepAliveNetworkHandlerSelector<H : NetworkHandler>(
+    private val maxAttempts: Int = DEFAULT_MAX_ATTEMPTS
+) : NetworkHandlerSelector<H> {
+
+    init {
+        require(maxAttempts >= 1) { "maxAttempts must >= 1" }
+    }
+
     private val current = atomic<H?>(null)
 
     @TestOnly
@@ -38,19 +47,23 @@ internal abstract class AbstractKeepAliveNetworkHandlerSelector<H : NetworkHandl
 
     final override fun getResumedInstance(): H? = current.value
 
-    final override tailrec suspend fun awaitResumeInstance(): H { // TODO: 2021/4/18 max 5 retry
+    final override suspend fun awaitResumeInstance(): H = awaitResumeInstanceImpl(0)
+
+    private tailrec suspend fun awaitResumeInstanceImpl(attempted: Int): H {
+        if (attempted >= maxAttempts) error("Failed to resume instance. Maximum attempts reached.")
         yield()
         val current = getResumedInstance()
         return if (current != null) {
-            when (current.state) {
+            when (val thisState = current.state) {
                 NetworkHandler.State.CLOSED -> {
                     this.current.compareAndSet(current, null) // invalidate the instance and try again.
-                    awaitResumeInstance() // will create new instance.
+                    awaitResumeInstanceImpl(attempted + 1) // will create new instance.
                 }
                 NetworkHandler.State.CONNECTING,
                 NetworkHandler.State.INITIALIZED -> {
-                    current.resumeConnection()
-                    return awaitResumeInstance()
+                    current.resumeConnection() // once finished, it should has been LOADING or OK
+                    check(current.state != thisState) { "State is still $thisState after successful resumeConnection." }
+                    return awaitResumeInstanceImpl(attempted) // does not count for an attempt.
                 }
                 NetworkHandler.State.LOADING -> {
                     return current
@@ -61,8 +74,17 @@ internal abstract class AbstractKeepAliveNetworkHandlerSelector<H : NetworkHandl
                 }
             }
         } else {
-            this.current.compareAndSet(current, createInstance())
-            awaitResumeInstance()
+            synchronized(this) { // avoid concurrent `createInstance()`
+                if (getResumedInstance() == null) this.current.compareAndSet(null, createInstance())
+            }
+            awaitResumeInstanceImpl(attempted) // directly retry, does not count for attempts.
         }
     }
+
+    companion object {
+        @JvmField
+        var DEFAULT_MAX_ATTEMPTS =
+            systemProp("mirai.network.handler.selector.max.attempts", 3)
+                .coerceIn(1..Int.MAX_VALUE.toLongUnsigned()).toInt()
+    }
 }
\ No newline at end of file
diff --git a/mirai-core/src/commonTest/kotlin/network/framework/testUtils.kt b/mirai-core/src/commonTest/kotlin/network/framework/testUtils.kt
index 72014becc..af95bbb58 100644
--- a/mirai-core/src/commonTest/kotlin/network/framework/testUtils.kt
+++ b/mirai-core/src/commonTest/kotlin/network/framework/testUtils.kt
@@ -69,7 +69,7 @@ internal open class TestNetworkHandler(
     }
 
     fun setState(correspondingState: NetworkHandler.State) {
-        setState { TestState(correspondingState) }
+        setState(null) { TestState(correspondingState) }
     }
 
     private val initialState = TestState(NetworkHandler.State.INITIALIZED)
diff --git a/mirai-core/src/commonTest/kotlin/network/handler/KeepAliveNetworkHandlerSelectorTest.kt b/mirai-core/src/commonTest/kotlin/network/handler/KeepAliveNetworkHandlerSelectorTest.kt
index 7c4aa6aac..c0b3aea70 100644
--- a/mirai-core/src/commonTest/kotlin/network/handler/KeepAliveNetworkHandlerSelectorTest.kt
+++ b/mirai-core/src/commonTest/kotlin/network/handler/KeepAliveNetworkHandlerSelectorTest.kt
@@ -17,9 +17,21 @@ import java.util.concurrent.atomic.AtomicInteger
 import kotlin.test.*
 import kotlin.time.seconds
 
-private class TestSelector(val createInstance0: () -> NetworkHandler) :
-    AbstractKeepAliveNetworkHandlerSelector<NetworkHandler>() {
-    val createInstanceCount = AtomicInteger(0)
+private class TestSelector :
+    AbstractKeepAliveNetworkHandlerSelector<NetworkHandler> {
+
+    val createInstance0: () -> NetworkHandler
+
+    constructor(createInstance0: () -> NetworkHandler) : super() {
+        this.createInstance0 = createInstance0
+    }
+
+    constructor(maxAttempts: Int, createInstance0: () -> NetworkHandler) : super(maxAttempts) {
+        this.createInstance0 = createInstance0
+    }
+
+    val createInstanceCount: AtomicInteger = AtomicInteger(0)
+
     override fun createInstance(): NetworkHandler {
         createInstanceCount.incrementAndGet()
         return this.createInstance0()
@@ -60,4 +72,15 @@ internal class KeepAliveNetworkHandlerSelectorTest : AbstractMockNetworkHandlerT
         runBlockingUnit(timeout = 3.seconds) { selector.awaitResumeInstance() }
         assertEquals(1, selector.createInstanceCount.get())
     }
+
+    @Test
+    fun `limited attempts`() = runBlockingUnit {
+        val selector = TestSelector(3) {
+            createNetworkHandler().apply { setState(State.CLOSED) }
+        }
+        assertFailsWith<IllegalStateException> {
+            selector.awaitResumeInstance()
+        }
+        assertEquals(3, selector.createInstanceCount.get())
+    }
 }
\ No newline at end of file