头像缓存持久化

This commit is contained in:
John Smith 2020-02-03 16:18:21 +08:00
parent cae06858fc
commit 8d55331e6c
11 changed files with 329 additions and 95 deletions

1
.gitignore vendored
View File

@ -105,3 +105,4 @@ venv.bak/
.idea/ .idea/
data/database.db

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import asyncio import asyncio
import datetime
import enum import enum
import json import json
import logging import logging
@ -12,6 +11,7 @@ import aiohttp
import tornado.websocket import tornado.websocket
import blivedm.blivedm as blivedm import blivedm.blivedm as blivedm
import models.avatar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,74 +26,14 @@ class Command(enum.IntEnum):
DEL_SUPER_CHAT = 6 DEL_SUPER_CHAT = 6
DEFAULT_AVATAR_URL = '//static.hdslb.com/images/member/noface.gif'
_http_session = aiohttp.ClientSession() _http_session = aiohttp.ClientSession()
_avatar_url_cache: Dict[int, str] = {}
_last_fetch_avatar_time = datetime.datetime.now() room_manager: Optional['RoomManager'] = None
_last_avatar_failed_time = None
_uids_to_fetch_avatar = asyncio.Queue(15)
async def get_avatar_url(user_id): def init():
if user_id in _avatar_url_cache: global room_manager
return _avatar_url_cache[user_id] room_manager = RoomManager()
global _last_avatar_failed_time, _last_fetch_avatar_time
cur_time = datetime.datetime.now()
# 防止获取头像频率太高被ban
if (cur_time - _last_fetch_avatar_time).total_seconds() < 0.2:
# 由_fetch_avatar_loop过一段时间再获取并缓存
try:
_uids_to_fetch_avatar.put_nowait(user_id)
except asyncio.QueueFull:
pass
return DEFAULT_AVATAR_URL
if _last_avatar_failed_time is not None:
if (cur_time - _last_avatar_failed_time).total_seconds() < 3 * 60 + 3:
# 3分钟以内被ban解封大约要15分钟
return DEFAULT_AVATAR_URL
else:
_last_avatar_failed_time = None
_last_fetch_avatar_time = cur_time
try:
async with _http_session.get('https://api.bilibili.com/x/space/acc/info',
params={'mid': user_id}) as r:
if r.status != 200: # 可能会被B站ban
logger.warning('Failed to fetch avatar: status=%d %s uid=%d', r.status, r.reason, user_id)
_last_avatar_failed_time = cur_time
return DEFAULT_AVATAR_URL
data = await r.json()
except aiohttp.ClientConnectionError:
return DEFAULT_AVATAR_URL
url = data['data']['face'].replace('http:', '').replace('https:', '')
if not url.endswith('noface.gif'):
url += '@48w_48h'
_avatar_url_cache[user_id] = url
if len(_avatar_url_cache) > 50000:
for _, key in zip(range(100), _avatar_url_cache):
del _avatar_url_cache[key]
return url
async def _fetch_avatar_loop():
while True:
try:
user_id = await _uids_to_fetch_avatar.get()
if user_id in _avatar_url_cache:
continue
# 延时长一些使实时弹幕有机会获取头像
await asyncio.sleep(0.4 - (datetime.datetime.now() - _last_fetch_avatar_time).total_seconds())
asyncio.ensure_future(get_avatar_url(user_id))
except:
pass
asyncio.ensure_future(_fetch_avatar_loop())
class Room(blivedm.BLiveClient): class Room(blivedm.BLiveClient):
@ -119,7 +59,7 @@ class Room(blivedm.BLiveClient):
data = command['data'] data = command['data']
return self._on_receive_gift(blivedm.GiftMessage( return self._on_receive_gift(blivedm.GiftMessage(
data['giftName'], data['num'], data['uname'], data['face'], None, data['giftName'], data['num'], data['uname'], data['face'], None,
None, data['timestamp'], None, None, data['uid'], data['timestamp'], None, None,
None, None, None, data['coin_type'], data['total_coin'] None, None, None, data['coin_type'], data['total_coin']
)) ))
@ -135,7 +75,7 @@ class Room(blivedm.BLiveClient):
return self._on_super_chat(blivedm.SuperChatMessage( return self._on_super_chat(blivedm.SuperChatMessage(
data['price'], data['message'], None, data['start_time'], data['price'], data['message'], None, data['start_time'],
None, None, data['id'], None, None, None, data['id'], None,
None, None, data['user_info']['uname'], None, data['uid'], data['user_info']['uname'],
data['user_info']['face'], None, data['user_info']['face'], None,
None, None, None, None,
None, None, None, None, None, None,
@ -182,7 +122,7 @@ class Room(blivedm.BLiveClient):
else: else:
author_type = 0 author_type = 0
self.send_message(Command.ADD_TEXT, { self.send_message(Command.ADD_TEXT, {
'avatarUrl': await get_avatar_url(danmaku.uid), 'avatarUrl': await models.avatar.get_avatar_url(danmaku.uid),
'timestamp': danmaku.timestamp, 'timestamp': danmaku.timestamp,
'authorName': danmaku.uname, 'authorName': danmaku.uname,
'authorType': author_type, 'authorType': author_type,
@ -196,10 +136,12 @@ class Room(blivedm.BLiveClient):
}) })
async def _on_receive_gift(self, gift: blivedm.GiftMessage): async def _on_receive_gift(self, gift: blivedm.GiftMessage):
avatar_url = gift.face.replace('http:', '').replace('https:', '')
models.avatar.update_avatar_cache(gift.uid, avatar_url)
if gift.coin_type != 'gold': # 丢人 if gift.coin_type != 'gold': # 丢人
return return
self.send_message(Command.ADD_GIFT, { self.send_message(Command.ADD_GIFT, {
'avatarUrl': gift.face.replace('http:', '').replace('https:', ''), 'avatarUrl': avatar_url,
'timestamp': gift.timestamp, 'timestamp': gift.timestamp,
'authorName': gift.uname, 'authorName': gift.uname,
'giftName': gift.gift_name, 'giftName': gift.gift_name,
@ -212,14 +154,16 @@ class Room(blivedm.BLiveClient):
async def __on_buy_guard(self, message: blivedm.GuardBuyMessage): async def __on_buy_guard(self, message: blivedm.GuardBuyMessage):
self.send_message(Command.ADD_MEMBER, { self.send_message(Command.ADD_MEMBER, {
'avatarUrl': await get_avatar_url(message.uid), 'avatarUrl': await models.avatar.get_avatar_url(message.uid),
'timestamp': message.start_time, 'timestamp': message.start_time,
'authorName': message.username 'authorName': message.username
}) })
async def _on_super_chat(self, message: blivedm.SuperChatMessage): async def _on_super_chat(self, message: blivedm.SuperChatMessage):
avatar_url = message.face.replace('http:', '').replace('https:', '')
models.avatar.update_avatar_cache(message.uid, avatar_url)
self.send_message(Command.ADD_SUPER_CHAT, { self.send_message(Command.ADD_SUPER_CHAT, {
'avatarUrl': message.face.replace('http:', '').replace('https:', ''), 'avatarUrl': avatar_url,
'timestamp': message.start_time, 'timestamp': message.start_time,
'authorName': message.uname, 'authorName': message.uname,
'price': message.price, 'price': message.price,
@ -282,9 +226,6 @@ class RoomManager:
del self._rooms[room_id] del self._rooms[room_id]
room_manager = RoomManager()
# noinspection PyAbstractClass # noinspection PyAbstractClass
class ChatHandler(tornado.websocket.WebSocketHandler): class ChatHandler(tornado.websocket.WebSocketHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

43
config.py Normal file
View File

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
import configparser
import logging
import os
from typing import *
logger = logging.getLogger(__name__)
CONFIG_PATH = os.path.join('data', 'config.ini')
_config: Optional['AppConfig'] = None
def init():
reload()
def reload():
config = AppConfig()
if config.load(CONFIG_PATH):
global _config
_config = config
def get_config():
return _config
class AppConfig:
def __init__(self):
self.database_url = 'sqlite:///data/database.db'
def load(self, path):
config = configparser.ConfigParser()
config.read(path)
try:
app_section = config['app']
self.database_url = app_section['database_url']
except (KeyError, ValueError):
logger.exception('Failed to load config:')
return False
return True

8
data/config.ini Normal file
View File

@ -0,0 +1,8 @@
[app]
# See https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
database_url = sqlite:///data/database.db
# DON'T modify this section
[DEFAULT]
database_url = sqlite:///data/database.db

57
main.py
View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import argparse import argparse
import asyncio
import logging import logging
import os import os
import webbrowser import webbrowser
@ -9,52 +8,72 @@ import webbrowser
import tornado.ioloop import tornado.ioloop
import tornado.web import tornado.web
import api.chat
import api.main
import config
import models.avatar
import models.database
import update import update
import views.chat
import views.main
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
WEB_ROOT = os.path.join(os.path.dirname(__file__), 'frontend', 'dist') WEB_ROOT = os.path.join(os.path.dirname(__file__), 'frontend', 'dist')
routes = [
(r'/chat', api.chat.ChatHandler),
(r'/((css|fonts|img|js|static)/.*)', tornado.web.StaticFileHandler, {'path': WEB_ROOT}),
(r'/(favicon\.ico)', tornado.web.StaticFileHandler, {'path': WEB_ROOT}),
(r'/.*', api.main.MainHandler, {'path': WEB_ROOT})
]
def main(): def main():
args = parse_args()
init_logging(args.debug)
config.init()
models.database.init(args.debug)
models.avatar.init()
api.chat.init()
update.check_update()
run_server(args.host, args.port, args.debug)
def parse_args():
parser = argparse.ArgumentParser(description='用于OBS的仿YouTube风格的bilibili直播聊天层') parser = argparse.ArgumentParser(description='用于OBS的仿YouTube风格的bilibili直播聊天层')
parser.add_argument('--host', help='服务器host默认为127.0.0.1', default='127.0.0.1') parser.add_argument('--host', help='服务器host默认为127.0.0.1', default='127.0.0.1')
parser.add_argument('--port', help='服务器端口默认为12450', type=int, default=12450) parser.add_argument('--port', help='服务器端口默认为12450', type=int, default=12450)
parser.add_argument('--debug', help='调试模式', action='store_true') parser.add_argument('--debug', help='调试模式', action='store_true')
args = parser.parse_args() return parser.parse_args()
def init_logging(debug):
logging.basicConfig( logging.basicConfig(
format='{asctime} {levelname} [{name}]: {message}', format='{asctime} {levelname} [{name}]: {message}',
datefmt='%Y-%m-%d %H:%M:%S', datefmt='%Y-%m-%d %H:%M:%S',
style='{', style='{',
level=logging.INFO if not args.debug else logging.DEBUG level=logging.INFO if not debug else logging.DEBUG
) )
asyncio.ensure_future(update.check_update())
def run_server(host, port, debug):
app = tornado.web.Application( app = tornado.web.Application(
[ routes,
(r'/chat', views.chat.ChatHandler), websocket_ping_interval=10,
debug=debug,
(r'/((css|fonts|img|js|static)/.*)', tornado.web.StaticFileHandler, {'path': WEB_ROOT}),
(r'/(favicon\.ico)', tornado.web.StaticFileHandler, {'path': WEB_ROOT}),
(r'/.*', views.main.MainHandler, {'path': WEB_ROOT})
],
websocket_ping_interval=30,
debug=args.debug,
autoreload=False autoreload=False
) )
try: try:
app.listen(args.port, args.host) app.listen(port, host)
except OSError: except OSError:
logger.warning('Address is used %s:%d', args.host, args.port) logger.warning('Address is used %s:%d', host, port)
return return
finally: finally:
url = 'http://localhost' if args.port == 80 else f'http://localhost:{args.port}' url = 'http://localhost' if port == 80 else f'http://localhost:{port}'
webbrowser.open(url) webbrowser.open(url)
logger.info('Server started: %s:%d', args.host, args.port) logger.info('Server started: %s:%d', host, port)
tornado.ioloop.IOLoop.current().start() tornado.ioloop.IOLoop.current().start()

181
models/avatar.py Normal file
View File

@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
import asyncio
import datetime
import logging
from typing import *
import aiohttp
import sqlalchemy
import sqlalchemy.exc
import models.database
logger = logging.getLogger(__name__)
DEFAULT_AVATAR_URL = '//static.hdslb.com/images/member/noface.gif'
_main_event_loop = asyncio.get_event_loop()
_http_session = aiohttp.ClientSession()
# user_id -> avatar_url
_avatar_url_cache: Dict[int, str] = {}
# (user_id, future)
_fetch_task_queue = asyncio.Queue(15)
_last_fetch_failed_time: Optional[datetime.datetime] = None
def init():
asyncio.ensure_future(_get_avatar_url_from_web_consumer())
async def get_avatar_url(user_id):
avatar_url = get_avatar_url_from_memory(user_id)
if avatar_url is not None:
return avatar_url
avatar_url = await get_avatar_url_from_database(user_id)
if avatar_url is not None:
return avatar_url
return await get_avatar_url_from_web(user_id)
def get_avatar_url_from_memory(user_id):
return _avatar_url_cache.get(user_id, None)
def get_avatar_url_from_database(user_id) -> Awaitable[Optional[str]]:
return asyncio.get_event_loop().run_in_executor(
None, _do_get_avatar_url_from_database, user_id
)
def _do_get_avatar_url_from_database(user_id):
try:
with models.database.get_session() as session:
user = session.query(BilibiliUser).filter(BilibiliUser.uid == user_id).one_or_none()
if user is None:
return None
avatar_url = user.avatar_url
# 如果离上次更新太久就更新所有缓存
if (datetime.datetime.now() - user.update_time).days >= 3:
def refresh_cache():
try:
del _avatar_url_cache[user_id]
except KeyError:
pass
get_avatar_url_from_web(user_id)
_main_event_loop.call_soon(refresh_cache)
else:
# 否则只更新内存缓存
_update_avatar_cache_in_memory(user_id, avatar_url)
except sqlalchemy.exc.SQLAlchemyError:
return None
return avatar_url
def get_avatar_url_from_web(user_id) -> Awaitable[str]:
future = _main_event_loop.create_future()
try:
_fetch_task_queue.put_nowait((user_id, future))
except asyncio.QueueFull:
future.set_result(DEFAULT_AVATAR_URL)
return future
async def _get_avatar_url_from_web_consumer():
while True:
try:
user_id, future = await _fetch_task_queue.get()
# 先查缓存防止队列中出现相同uid时重复获取
avatar_url = get_avatar_url_from_memory(user_id)
if avatar_url is not None:
continue
# 防止在被ban的时候获取
global _last_fetch_failed_time
if _last_fetch_failed_time is not None:
cur_time = datetime.datetime.now()
if (cur_time - _last_fetch_failed_time).total_seconds() < 3 * 60 + 3:
# 3分钟以内被ban则先返回默认头像解封大约要15分钟
return DEFAULT_AVATAR_URL
else:
_last_fetch_failed_time = None
asyncio.ensure_future(_get_avatar_url_from_web_coroutine(user_id, future))
# 限制频率防止被B站ban
await asyncio.sleep(0.2)
except:
pass
async def _get_avatar_url_from_web_coroutine(user_id, future):
try:
avatar_url = await _do_get_avatar_url_from_web(user_id)
except BaseException as e:
future.set_exception(e)
return
future.set_result(avatar_url)
async def _do_get_avatar_url_from_web(user_id):
try:
async with _http_session.get('https://api.bilibili.com/x/space/acc/info',
params={'mid': user_id}) as r:
if r.status != 200:
# 可能被B站ban了
logger.warning('Failed to fetch avatar: status=%d %s uid=%d', r.status, r.reason, user_id)
global _last_fetch_failed_time
_last_fetch_failed_time = datetime.datetime.now()
return DEFAULT_AVATAR_URL
data = await r.json()
except (aiohttp.ClientConnectionError, asyncio.TimeoutError):
return DEFAULT_AVATAR_URL
avatar_url = data['data']['face'].replace('http:', '').replace('https:', '')
if not avatar_url.endswith('noface.gif'):
avatar_url += '@48w_48h'
update_avatar_cache(user_id, avatar_url)
return avatar_url
def update_avatar_cache(user_id, avatar_url):
_update_avatar_cache_in_memory(user_id, avatar_url)
asyncio.get_event_loop().run_in_executor(
None, _update_avatar_cache_in_database, user_id, avatar_url
)
def _update_avatar_cache_in_memory(user_id, avatar_url):
_avatar_url_cache[user_id] = avatar_url
if len(_avatar_url_cache) > 50000:
for _, key in zip(range(100), _avatar_url_cache):
del _avatar_url_cache[key]
def _update_avatar_cache_in_database(user_id, avatar_url):
try:
with models.database.get_session() as session:
user = session.query(BilibiliUser).filter(BilibiliUser.uid == user_id).one_or_none()
if user is None:
user = BilibiliUser(uid=user_id, avatar_url=avatar_url,
update_time=datetime.datetime.now())
session.add(user)
else:
user.avatar_url = avatar_url
user.update_time = datetime.datetime.now()
session.commit()
except sqlalchemy.exc.SQLAlchemyError:
# SQLite会锁整个文件忽略就行
logger.exception('_update_avatar_cache_in_database failed:')
class BilibiliUser(models.database.OrmBase):
__tablename__ = 'bilibili_users'
uid = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
avatar_url = sqlalchemy.Column(sqlalchemy.Text)
update_time = sqlalchemy.Column(sqlalchemy.DateTime)

34
models/database.py Normal file
View File

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
import contextlib
from typing import *
import sqlalchemy.ext.declarative
import sqlalchemy.orm
import config
OrmBase = sqlalchemy.ext.declarative.declarative_base()
engine = None
DbSession: Optional[Type[sqlalchemy.orm.Session]] = None
def init(debug):
cfg = config.get_config()
global engine, DbSession
engine = sqlalchemy.create_engine(cfg.database_url, echo=debug)
DbSession = sqlalchemy.orm.sessionmaker(bind=engine)
OrmBase.metadata.create_all(engine)
@contextlib.contextmanager
def get_session():
session = DbSession()
try:
yield session
except:
session.rollback()
raise
finally:
session.close()

View File

@ -1,2 +1,3 @@
aiohttp==3.5.4 aiohttp==3.5.4
sqlalchemy==1.3.13
tornado==6.0.2 tornado==6.0.2

View File

@ -1,11 +1,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import asyncio
import aiohttp import aiohttp
VERSION = 'v1.2.4' VERSION = 'v1.2.4'
async def check_update(): def check_update():
asyncio.ensure_future(_do_check_update())
async def _do_check_update():
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get('https://api.github.com/repos/xfgryujk/blivechat/releases/latest') as r: async with session.get('https://api.github.com/repos/xfgryujk/blivechat/releases/latest') as r: