添加插件框架

This commit is contained in:
John Smith 2023-11-04 16:52:07 +08:00
parent 5d98bb5d14
commit f886506c39
10 changed files with 388 additions and 6 deletions

98
api/plugin.py Normal file
View File

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
import asyncio
import json
import logging
from typing import *
import tornado.web
import tornado.websocket
import api.base
import blcsdk.models as models
import services.plugin
logger = logging.getLogger(__name__)
class _PluginHandlerBase(api.base.ApiHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plugin: Optional[services.plugin.Plugin] = None
def prepare(self):
try:
auth = self.request.headers['Authorization']
if not auth.startswith('Bearer '):
raise ValueError(f'Bad authorization: {auth}')
token = auth[7:]
self.plugin = services.plugin.get_plugin_by_token(token)
if self.plugin is None:
raise ValueError(f'Token error: {token}')
except (KeyError, ValueError) as e:
logger.warning('client=%s failed to find plugin: %r', self.request.remote_ip, e)
raise tornado.web.HTTPError(403)
super().prepare()
def make_message_body(cmd, data, extra: Optional[dict] = None):
body = {'cmd': cmd, 'data': data}
if extra:
body['extra'] = extra
return json.dumps(body).encode('utf-8')
class PluginWsHandler(_PluginHandlerBase, tornado.websocket.WebSocketHandler):
HEARTBEAT_INTERVAL = 10
RECEIVE_TIMEOUT = HEARTBEAT_INTERVAL + 5
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._heartbeat_timer_handle = None
self._receive_timeout_timer_handle = None
def open(self):
logger.info('plugin=%s connected, client=%s', self.plugin.id, self.request.remote_ip)
self._heartbeat_timer_handle = asyncio.get_running_loop().call_later(
self.HEARTBEAT_INTERVAL, self._on_send_heartbeat
)
self._refresh_receive_timeout_timer()
self.plugin.on_client_connect(self)
def _on_send_heartbeat(self):
self.send_cmd_data(models.Command.HEARTBEAT, {})
self._heartbeat_timer_handle = asyncio.get_running_loop().call_later(
self.HEARTBEAT_INTERVAL, self._on_send_heartbeat
)
def _refresh_receive_timeout_timer(self):
if self._receive_timeout_timer_handle is not None:
self._receive_timeout_timer_handle.cancel()
self._receive_timeout_timer_handle = asyncio.get_running_loop().call_later(
self.RECEIVE_TIMEOUT, self._on_receive_timeout
)
def _on_receive_timeout(self):
logger.info('plugin=%s timed out', self.plugin.id)
self._receive_timeout_timer_handle = None
self.close()
def on_close(self):
logger.info('plugin=%s disconnected', self.plugin.id)
self.plugin.on_client_close(self)
def send_cmd_data(self, cmd, data, extra: Optional[dict] = None):
self.send_body_no_raise(make_message_body(cmd, data, extra))
def send_body_no_raise(self, body: Union[bytes, str, Dict[str, Any]]):
try:
self.write_message(body)
except tornado.websocket.WebSocketClosedError:
self.close()
ROUTES = [
(r'/api/plugin/websocket', PluginWsHandler),
]

2
blcsdk/__init__.py Normal file
View File

@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
__version__ = '0.0.1'

1
blcsdk/client.py Normal file
View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

6
blcsdk/models.py Normal file
View File

@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
import enum
class Command(enum.IntEnum):
HEARTBEAT = 0

15
main.py
View File

@ -15,10 +15,12 @@ import tornado.web
import api.chat import api.chat
import api.main import api.main
import api.open_live import api.open_live
import api.plugin
import config import config
import models.database import models.database
import services.avatar import services.avatar
import services.chat import services.chat
import services.plugin
import services.translate import services.translate
import update import update
import utils.request import utils.request
@ -29,6 +31,7 @@ ROUTES = [
*api.main.ROUTES, *api.main.ROUTES,
*api.chat.ROUTES, *api.chat.ROUTES,
*api.open_live.ROUTES, *api.open_live.ROUTES,
*api.plugin.ROUTES,
*api.main.LAST_ROUTES, *api.main.LAST_ROUTES,
] ]
@ -63,10 +66,14 @@ def init():
services.translate.init() services.translate.init()
services.chat.init() services.chat.init()
update.check_update()
init_server(args.host, args.port, args.debug) init_server(args.host, args.port, args.debug)
return server is not None if server is None:
return False
services.plugin.init()
update.check_update()
return True
def init_signal_handlers(): def init_signal_handlers():
@ -152,6 +159,8 @@ async def run():
async def shut_down(): async def shut_down():
services.plugin.shut_down()
logger.info('Closing server') logger.info('Closing server')
server.stop() server.stop()
await server.close_all_connections() await server.close_all_connections()

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
import asyncio
import sys
import blcsdk
async def main():
print('hello world!', blcsdk.__version__)
return 0
if __name__ == '__main__':
sys.exit(asyncio.run(main()))

View File

@ -41,7 +41,7 @@ def init():
global _avatar_url_cache, _task_queue global _avatar_url_cache, _task_queue
_avatar_url_cache = cachetools.TTLCache(cfg.avatar_cache_size, 10 * 60) _avatar_url_cache = cachetools.TTLCache(cfg.avatar_cache_size, 10 * 60)
_task_queue = asyncio.Queue(cfg.fetch_avatar_max_queue_size) _task_queue = asyncio.Queue(cfg.fetch_avatar_max_queue_size)
asyncio.get_running_loop().create_task(_do_init()) asyncio.create_task(_do_init())
async def _do_init(): async def _do_init():

252
services/plugin.py Normal file
View File

@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
import dataclasses
import datetime
import json
import logging
import os
import random
import string
import subprocess
from typing import *
import api.plugin
import config
logger = logging.getLogger(__name__)
PLUGINS_PATH = os.path.join(config.DATA_PATH, 'plugins')
_plugins: Dict[str, 'Plugin'] = {}
def init():
plugin_ids = _discover_plugin_ids()
if not plugin_ids:
return
logger.info('Found plugins: %s', plugin_ids)
for plugin_id in plugin_ids:
plugin = _create_plugin(plugin_id)
if plugin is not None:
_plugins[plugin_id] = plugin
for plugin in _plugins.values():
if plugin.enabled:
try:
plugin.start()
except StartPluginError:
pass
def shut_down():
for plugin in _plugins.values():
plugin.stop()
def _discover_plugin_ids():
res = []
try:
with os.scandir(PLUGINS_PATH) as it:
for entry in it:
if entry.is_dir() and os.path.isfile(os.path.join(entry.path, 'plugin.json')):
res.append(entry.name)
except OSError:
logger.exception('Failed to discover plugins:')
return res
def _create_plugin(plugin_id):
config_path = os.path.join(PLUGINS_PATH, plugin_id, 'plugin.json')
try:
plugin_config = PluginConfig.from_file(config_path)
except (OSError, json.JSONDecodeError, TypeError):
logger.exception('plugin=%s failed to load config:', plugin_id)
return None
return Plugin(plugin_id, plugin_config)
def iter_plugins() -> Iterable['Plugin']:
return _plugins.values()
def get_plugin_by_token(token):
if token == '':
return None
for plugin in _plugins.values():
if plugin.token == token:
return plugin
return None
def broadcast_cmd_data(cmd, data, extra: Optional[dict] = None):
body = api.plugin.make_message_body(cmd, data, extra)
for plugin in _plugins.values():
plugin.send_body_no_raise(body)
@dataclasses.dataclass
class PluginConfig:
name: str = ''
version: str = ''
author: str = ''
description: str = ''
run_cmd: str = ''
enabled: bool = False
@classmethod
def from_file(cls, path):
with open(path, encoding='utf-8') as f:
cfg = json.load(f)
if not isinstance(cfg, dict):
raise TypeError(f'Config type error, type={type(cfg)}')
return cls(
name=str(cfg.get('name', '')),
version=str(cfg.get('version', '')),
author=str(cfg.get('author', '')),
description=str(cfg.get('description', '')),
run_cmd=str(cfg.get('run', '')),
enabled=bool(cfg.get('enabled', False)),
)
def save(self, path):
try:
with open(path, encoding='utf-8') as f:
cfg = json.load(f)
if not isinstance(cfg, dict):
raise TypeError(f'Config type error, type={type(cfg)}')
except (OSError, json.JSONDecodeError, TypeError):
cfg = {}
cfg['name'] = self.name
cfg['version'] = self.version
cfg['author'] = self.author
cfg['description'] = self.description
cfg['run_cmd'] = self.run_cmd
cfg['enabled'] = self.enabled
tmp_path = path + '.tmp'
with open(tmp_path, encoding='utf-8') as f:
json.dump(cfg, f, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
class StartPluginError(Exception):
"""启动插件时错误"""
class StartTooFrequently(StartPluginError):
"""启动插件太频繁"""
class Plugin:
def __init__(self, plugin_id, cfg: PluginConfig):
self._id = plugin_id
self._config = cfg
self._last_start_time = datetime.datetime.fromtimestamp(0)
self._token = ''
self._client: Optional['api.plugin.PluginWsHandler'] = None
@property
def id(self):
return self._id
@property
def enabled(self):
return self._config.enabled
@enabled.setter
def enabled(self, value):
if self._config.enabled == value:
return
self._config.enabled = value
config_path = os.path.join(self.base_path, 'plugin.json')
try:
self._config.save(config_path)
except OSError:
logger.exception('plugin=%s failed to save config', self._id)
if value:
self.start()
else:
self.stop()
@property
def base_path(self):
return os.path.join(PLUGINS_PATH, self._id)
@property
def token(self):
return self._token
@property
def is_started(self):
return self._token != ''
@property
def is_connected(self):
return self._client is not None
def start(self):
if self.is_started:
return
cur_time = datetime.datetime.now()
if cur_time - self._last_start_time < datetime.timedelta(seconds=3):
raise StartTooFrequently(f'plugin={self._id} starts too frequently')
self._last_start_time = cur_time
token = ''.join(random.choice(string.hexdigits) for _ in range(32))
self._set_token(token)
env = {
**os.environ,
'BLC_PORT': str(12450), # TODO 读配置
'BLC_TOKEN': self._token,
}
try:
subprocess.Popen(
self._config.run_cmd,
shell=True,
cwd=self.base_path,
env=env,
)
except OSError as e:
logger.exception('plugin=%s failed to start', self._id)
raise StartPluginError(str(e))
def stop(self):
if self.is_started:
self._set_token('')
def _set_token(self, token):
if self._token == token:
return
self._token = token
# 踢掉已经连接的客户端
self._set_client(None)
def _set_client(self, client: Optional['api.plugin.PluginWsHandler']):
if self._client is client:
return
if self._client is not None:
self._client.close()
self._client = client
def on_client_connect(self, client: 'api.plugin.PluginWsHandler'):
self._set_client(client)
def on_client_close(self, client: 'api.plugin.PluginWsHandler'):
if self._client is client:
self._set_client(None)
def send_cmd_data(self, cmd, data, extra: Optional[dict] = None):
if self._client is not None:
self._client.send_cmd_data(cmd, data, extra)
def send_body_no_raise(self, body):
if self._client is not None:
self._client.send_body_no_raise(body)

View File

@ -56,7 +56,7 @@ def init():
_translate_cache = cachetools.LRUCache(cfg.translation_cache_size) _translate_cache = cachetools.LRUCache(cfg.translation_cache_size)
# 总队列长度会超过translate_max_queue_size不用这么严格 # 总队列长度会超过translate_max_queue_size不用这么严格
_task_queues = [asyncio.Queue(cfg.translate_max_queue_size) for _ in range(len(Priority))] _task_queues = [asyncio.Queue(cfg.translate_max_queue_size) for _ in range(len(Priority))]
asyncio.get_running_loop().create_task(_do_init()) asyncio.create_task(_do_init())
async def _do_init(): async def _do_init():

View File

@ -9,7 +9,7 @@ VERSION = 'v1.8.2'
def check_update(): def check_update():
asyncio.get_running_loop().create_task(_do_check_update()) asyncio.create_task(_do_check_update())
async def _do_check_update(): async def _do_check_update():