blivechat/utils/request.py

168 lines
5.7 KiB
Python
Raw Normal View History

2022-02-15 00:18:46 +08:00
# -*- coding: utf-8 -*-
import asyncio
import datetime
import logging
2024-12-29 23:54:18 +08:00
import os
import pickle
2023-07-29 01:22:21 +08:00
from typing import *
2022-02-15 00:18:46 +08:00
import aiohttp
import circuitbreaker
import api.open_live
import config
import utils.async_io
logger = logging.getLogger(__name__)
2022-02-15 00:18:46 +08:00
2024-12-29 23:54:18 +08:00
COOKIE_JAR_PATH = os.path.join(config.DATA_PATH, 'cookie_jar.pickle')
# 不带这堆头部有时候也能成功请求,但是带上后成功的概率更高
BILIBILI_COMMON_HEADERS = {
'Origin': 'https://www.bilibili.com',
'Referer': 'https://www.bilibili.com/',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)'
' Chrome/114.0.0.0 Safari/537.36'
}
2023-07-29 01:22:21 +08:00
http_session: Optional[aiohttp.ClientSession] = None
_COMMON_SERVER_DISCOVERY_URLS = [
'https://api1.blive.chat/api/endpoints',
'https://api2.blive.chat/api/endpoints',
]
_last_update_common_server_time: Optional[datetime.datetime] = None
_common_server_base_urls = [
'https://api1.blive.chat',
'https://api2.blive.chat',
]
_cur_common_server_base_url: Optional[str] = None
_common_server_base_url_to_circuit_breaker: Dict[str, circuitbreaker.CircuitBreaker] = {}
2023-07-29 01:22:21 +08:00
2023-07-29 12:48:57 +08:00
def init():
2024-12-29 23:54:18 +08:00
try:
cookie_jar = aiohttp.CookieJar()
cookie_jar.load(COOKIE_JAR_PATH)
except (OSError, pickle.PickleError):
cookie_jar = None
2023-09-06 21:34:04 +08:00
global http_session
http_session = aiohttp.ClientSession(
response_class=CustomClientResponse,
timeout=aiohttp.ClientTimeout(total=10),
2024-12-29 23:54:18 +08:00
cookie_jar=cookie_jar,
)
2023-07-29 12:48:57 +08:00
cfg = config.get_config()
if not cfg.is_open_live_configured:
_update_common_server_base_urls()
2023-09-06 21:34:04 +08:00
async def shut_down():
if http_session is not None:
await http_session.close()
class CustomClientResponse(aiohttp.ClientResponse):
# 因为aiohttp的BUG当底层连接断开时_wait_released可能会抛出CancelledError导致上层协程结束。这里改个错误类型
async def _wait_released(self):
try:
return await super()._wait_released()
except asyncio.CancelledError as e:
raise aiohttp.ClientConnectionError('Connection released') from e
def _update_common_server_base_urls():
global _last_update_common_server_time
cur_time = datetime.datetime.now()
if (
_last_update_common_server_time is not None
and cur_time - _last_update_common_server_time < datetime.timedelta(minutes=3)
):
return
_last_update_common_server_time = cur_time
utils.async_io.create_task_with_ref(_do_update_common_server_base_urls())
async def _do_update_common_server_base_urls():
global _last_update_common_server_time
_last_update_common_server_time = datetime.datetime.now()
async def request_get_urls(discovery_url):
async with http_session.get(discovery_url) as res:
res.raise_for_status()
data = await res.json()
return data['endpoints']
common_server_base_urls = []
futures = [
asyncio.create_task(request_get_urls(url))
for url in _COMMON_SERVER_DISCOVERY_URLS
]
for future in asyncio.as_completed(futures):
try:
common_server_base_urls = await future
break
except Exception as e:
logger.warning('Failed to discover common server endpoints from one source: %s', e)
for future in futures:
future.cancel()
if not common_server_base_urls:
logger.error('Failed to discover common server endpoints from any source')
return
# 按响应时间排序
sorted_common_server_base_urls = []
error_base_urls = []
async def test_endpoint(base_url):
try:
2024-11-12 20:09:42 +08:00
url = base_url + '/api/ping'
async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=3)) as res:
res.raise_for_status()
sorted_common_server_base_urls.append(base_url)
except Exception: # noqa
error_base_urls.append(base_url)
await asyncio.gather(*(test_endpoint(base_url) for base_url in common_server_base_urls))
sorted_common_server_base_urls.extend(error_base_urls)
global _common_server_base_urls, _cur_common_server_base_url
_common_server_base_urls = sorted_common_server_base_urls
if _cur_common_server_base_url not in _common_server_base_urls:
_cur_common_server_base_url = None
logger.info('Found common server endpoints: %s', _common_server_base_urls)
def get_common_server_base_url_and_circuit_breaker() -> Tuple[Optional[str], Optional[circuitbreaker.CircuitBreaker]]:
_update_common_server_base_urls()
global _cur_common_server_base_url
if _cur_common_server_base_url is not None:
breaker = _get_or_add_common_server_circuit_breaker(_cur_common_server_base_url)
if breaker.state != circuitbreaker.STATE_OPEN:
return _cur_common_server_base_url, breaker
_cur_common_server_base_url = None
# 找第一个未熔断的
for base_url in _common_server_base_urls:
breaker = _get_or_add_common_server_circuit_breaker(base_url)
if breaker.state != circuitbreaker.STATE_OPEN:
_cur_common_server_base_url = base_url
logger.info('Switch common server endpoint to %s', _cur_common_server_base_url)
return _cur_common_server_base_url, breaker
return None, None
def _get_or_add_common_server_circuit_breaker(base_url):
breaker = _common_server_base_url_to_circuit_breaker.get(base_url, None)
if breaker is None:
breaker = _common_server_base_url_to_circuit_breaker[base_url] = circuitbreaker.CircuitBreaker(
failure_threshold=3,
recovery_timeout=60,
expected_exception=api.open_live.TransportError,
)
return breaker