blivechat/utils/request.py
2024-12-29 23:54:18 +08:00

168 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import asyncio
import datetime
import logging
import os
import pickle
from typing import *
import aiohttp
import circuitbreaker
import api.open_live
import config
import utils.async_io
logger = logging.getLogger(__name__)
COOKIE_JAR_PATH = os.path.join(config.DATA_PATH, 'cookie_jar.pickle')
# 不带这堆头部有时候也能成功请求,但是带上后成功的概率更高
BILIBILI_COMMON_HEADERS = {
'Origin': 'https://www.bilibili.com',
'Referer': 'https://www.bilibili.com/',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)'
' Chrome/114.0.0.0 Safari/537.36'
}
http_session: Optional[aiohttp.ClientSession] = None
_COMMON_SERVER_DISCOVERY_URLS = [
'https://api1.blive.chat/api/endpoints',
'https://api2.blive.chat/api/endpoints',
]
_last_update_common_server_time: Optional[datetime.datetime] = None
_common_server_base_urls = [
'https://api1.blive.chat',
'https://api2.blive.chat',
]
_cur_common_server_base_url: Optional[str] = None
_common_server_base_url_to_circuit_breaker: Dict[str, circuitbreaker.CircuitBreaker] = {}
def init():
try:
cookie_jar = aiohttp.CookieJar()
cookie_jar.load(COOKIE_JAR_PATH)
except (OSError, pickle.PickleError):
cookie_jar = None
global http_session
http_session = aiohttp.ClientSession(
response_class=CustomClientResponse,
timeout=aiohttp.ClientTimeout(total=10),
cookie_jar=cookie_jar,
)
cfg = config.get_config()
if not cfg.is_open_live_configured:
_update_common_server_base_urls()
async def shut_down():
if http_session is not None:
await http_session.close()
class CustomClientResponse(aiohttp.ClientResponse):
# 因为aiohttp的BUG当底层连接断开时_wait_released可能会抛出CancelledError导致上层协程结束。这里改个错误类型
async def _wait_released(self):
try:
return await super()._wait_released()
except asyncio.CancelledError as e:
raise aiohttp.ClientConnectionError('Connection released') from e
def _update_common_server_base_urls():
global _last_update_common_server_time
cur_time = datetime.datetime.now()
if (
_last_update_common_server_time is not None
and cur_time - _last_update_common_server_time < datetime.timedelta(minutes=3)
):
return
_last_update_common_server_time = cur_time
utils.async_io.create_task_with_ref(_do_update_common_server_base_urls())
async def _do_update_common_server_base_urls():
global _last_update_common_server_time
_last_update_common_server_time = datetime.datetime.now()
async def request_get_urls(discovery_url):
async with http_session.get(discovery_url) as res:
res.raise_for_status()
data = await res.json()
return data['endpoints']
common_server_base_urls = []
futures = [
asyncio.create_task(request_get_urls(url))
for url in _COMMON_SERVER_DISCOVERY_URLS
]
for future in asyncio.as_completed(futures):
try:
common_server_base_urls = await future
break
except Exception as e:
logger.warning('Failed to discover common server endpoints from one source: %s', e)
for future in futures:
future.cancel()
if not common_server_base_urls:
logger.error('Failed to discover common server endpoints from any source')
return
# 按响应时间排序
sorted_common_server_base_urls = []
error_base_urls = []
async def test_endpoint(base_url):
try:
url = base_url + '/api/ping'
async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=3)) as res:
res.raise_for_status()
sorted_common_server_base_urls.append(base_url)
except Exception: # noqa
error_base_urls.append(base_url)
await asyncio.gather(*(test_endpoint(base_url) for base_url in common_server_base_urls))
sorted_common_server_base_urls.extend(error_base_urls)
global _common_server_base_urls, _cur_common_server_base_url
_common_server_base_urls = sorted_common_server_base_urls
if _cur_common_server_base_url not in _common_server_base_urls:
_cur_common_server_base_url = None
logger.info('Found common server endpoints: %s', _common_server_base_urls)
def get_common_server_base_url_and_circuit_breaker() -> Tuple[Optional[str], Optional[circuitbreaker.CircuitBreaker]]:
_update_common_server_base_urls()
global _cur_common_server_base_url
if _cur_common_server_base_url is not None:
breaker = _get_or_add_common_server_circuit_breaker(_cur_common_server_base_url)
if breaker.state != circuitbreaker.STATE_OPEN:
return _cur_common_server_base_url, breaker
_cur_common_server_base_url = None
# 找第一个未熔断的
for base_url in _common_server_base_urls:
breaker = _get_or_add_common_server_circuit_breaker(base_url)
if breaker.state != circuitbreaker.STATE_OPEN:
_cur_common_server_base_url = base_url
logger.info('Switch common server endpoint to %s', _cur_common_server_base_url)
return _cur_common_server_base_url, breaker
return None, None
def _get_or_add_common_server_circuit_breaker(base_url):
breaker = _common_server_base_url_to_circuit_breaker.get(base_url, None)
if breaker is None:
breaker = _common_server_base_url_to_circuit_breaker[base_url] = circuitbreaker.CircuitBreaker(
failure_threshold=3,
recovery_timeout=60,
expected_exception=api.open_live.TransportError,
)
return breaker