获取重连间隔时间时传入总重连次数

This commit is contained in:
John Smith 2023-11-26 16:21:50 +08:00
parent b2c26b5f16
commit 5525bf5419
2 changed files with 13 additions and 7 deletions

View File

@ -105,7 +105,7 @@ class WebSocketClientBase:
self._need_init_room = True
self._handler: Optional[handlers.HandlerInterface] = None
"""消息处理器"""
self._get_reconnect_interval: Callable[[int], float] = DEFAULT_RECONNECT_POLICY
self._get_reconnect_interval: Callable[[int, int], float] = DEFAULT_RECONNECT_POLICY
"""重连间隔时间增长策略"""
# 在调用init_room后初始化的字段
@ -144,11 +144,11 @@ class WebSocketClientBase:
"""
self._handler = handler
def set_reconnect_policy(self, get_reconnect_interval: Callable[[int], float]):
def set_reconnect_policy(self, get_reconnect_interval: Callable[[int, int], float]):
"""
设置重连间隔时间增长策略
:param get_reconnect_interval: 一个可调用对象输入重试次数返回间隔时间
:param get_reconnect_interval: 一个可调用对象输入重试次数 (retry_count, total_retry_count)返回间隔时间
"""
self._get_reconnect_interval = get_reconnect_interval
@ -258,7 +258,9 @@ class WebSocketClientBase:
"""
网络协程负责连接服务器接收消息解包
"""
# retry_count在连接成功后会重置为0total_retry_count不会
retry_count = 0
total_retry_count = 0
while True:
try:
await self._on_before_ws_connect(retry_count)
@ -292,8 +294,12 @@ class WebSocketClientBase:
# 准备重连
retry_count += 1
logger.warning('room=%d is reconnecting, retry_count=%d', self.room_id, retry_count)
await asyncio.sleep(self._get_reconnect_interval(retry_count))
total_retry_count += 1
logger.warning(
'room=%d is reconnecting, retry_count=%d, total_retry_count=%d',
self.room_id, retry_count, total_retry_count
)
await asyncio.sleep(self._get_reconnect_interval(retry_count, total_retry_count))
async def _on_before_ws_connect(self, retry_count):
"""

View File

@ -5,13 +5,13 @@ USER_AGENT = (
def make_constant_retry_policy(interval: float):
def get_interval(_retry_count: int):
def get_interval(_retry_count: int, _total_retry_count: int):
return interval
return get_interval
def make_linear_retry_policy(start_interval: float, interval_step: float, max_interval: float):
def get_interval(retry_count: int):
def get_interval(retry_count: int, _total_retry_count: int):
return min(
start_interval + (retry_count - 1) * interval_step,
max_interval