From f886506c398e0f32fd7af946dd36d07ac954a331 Mon Sep 17 00:00:00 2001 From: John Smith Date: Sat, 4 Nov 2023 16:52:07 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=8F=92=E4=BB=B6=E6=A1=86?= =?UTF-8?q?=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/plugin.py | 98 ++++++++++++++ blcsdk/__init__.py | 2 + blcsdk/client.py | 1 + blcsdk/models.py | 6 + main.py | 15 ++- plugins/msg-logging/main.py | 14 ++ services/avatar.py | 2 +- services/plugin.py | 252 ++++++++++++++++++++++++++++++++++++ services/translate.py | 2 +- update.py | 2 +- 10 files changed, 388 insertions(+), 6 deletions(-) create mode 100644 api/plugin.py create mode 100644 blcsdk/__init__.py create mode 100644 blcsdk/client.py create mode 100644 blcsdk/models.py create mode 100644 plugins/msg-logging/main.py create mode 100644 services/plugin.py diff --git a/api/plugin.py b/api/plugin.py new file mode 100644 index 0000000..8c09a6c --- /dev/null +++ b/api/plugin.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +import asyncio +import json +import logging +from typing import * + +import tornado.web +import tornado.websocket + +import api.base +import blcsdk.models as models +import services.plugin + +logger = logging.getLogger(__name__) + + +class _PluginHandlerBase(api.base.ApiHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.plugin: Optional[services.plugin.Plugin] = None + + def prepare(self): + try: + auth = self.request.headers['Authorization'] + if not auth.startswith('Bearer '): + raise ValueError(f'Bad authorization: {auth}') + token = auth[7:] + + self.plugin = services.plugin.get_plugin_by_token(token) + if self.plugin is None: + raise ValueError(f'Token error: {token}') + except (KeyError, ValueError) as e: + logger.warning('client=%s failed to find plugin: %r', self.request.remote_ip, e) + raise tornado.web.HTTPError(403) + + super().prepare() + + +def make_message_body(cmd, data, extra: Optional[dict] = None): + body = {'cmd': cmd, 'data': data} + if extra: + body['extra'] = extra + return json.dumps(body).encode('utf-8') + + +class PluginWsHandler(_PluginHandlerBase, tornado.websocket.WebSocketHandler): + HEARTBEAT_INTERVAL = 10 + RECEIVE_TIMEOUT = HEARTBEAT_INTERVAL + 5 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._heartbeat_timer_handle = None + self._receive_timeout_timer_handle = None + + def open(self): + logger.info('plugin=%s connected, client=%s', self.plugin.id, self.request.remote_ip) + self._heartbeat_timer_handle = asyncio.get_running_loop().call_later( + self.HEARTBEAT_INTERVAL, self._on_send_heartbeat + ) + self._refresh_receive_timeout_timer() + + self.plugin.on_client_connect(self) + + def _on_send_heartbeat(self): + self.send_cmd_data(models.Command.HEARTBEAT, {}) + self._heartbeat_timer_handle = asyncio.get_running_loop().call_later( + self.HEARTBEAT_INTERVAL, self._on_send_heartbeat + ) + + def _refresh_receive_timeout_timer(self): + if self._receive_timeout_timer_handle is not None: + self._receive_timeout_timer_handle.cancel() + self._receive_timeout_timer_handle = asyncio.get_running_loop().call_later( + self.RECEIVE_TIMEOUT, self._on_receive_timeout + ) + + def _on_receive_timeout(self): + logger.info('plugin=%s timed out', self.plugin.id) + self._receive_timeout_timer_handle = None + self.close() + + def on_close(self): + logger.info('plugin=%s disconnected', self.plugin.id) + self.plugin.on_client_close(self) + + def send_cmd_data(self, cmd, data, extra: Optional[dict] = None): + self.send_body_no_raise(make_message_body(cmd, data, extra)) + + def send_body_no_raise(self, body: Union[bytes, str, Dict[str, Any]]): + try: + self.write_message(body) + except tornado.websocket.WebSocketClosedError: + self.close() + + +ROUTES = [ + (r'/api/plugin/websocket', PluginWsHandler), +] diff --git a/blcsdk/__init__.py b/blcsdk/__init__.py new file mode 100644 index 0000000..2bb0f2c --- /dev/null +++ b/blcsdk/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +__version__ = '0.0.1' diff --git a/blcsdk/client.py b/blcsdk/client.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/blcsdk/client.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/blcsdk/models.py b/blcsdk/models.py new file mode 100644 index 0000000..7411d31 --- /dev/null +++ b/blcsdk/models.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +import enum + + +class Command(enum.IntEnum): + HEARTBEAT = 0 diff --git a/main.py b/main.py index 9f4a629..2462340 100644 --- a/main.py +++ b/main.py @@ -15,10 +15,12 @@ import tornado.web import api.chat import api.main import api.open_live +import api.plugin import config import models.database import services.avatar import services.chat +import services.plugin import services.translate import update import utils.request @@ -29,6 +31,7 @@ ROUTES = [ *api.main.ROUTES, *api.chat.ROUTES, *api.open_live.ROUTES, + *api.plugin.ROUTES, *api.main.LAST_ROUTES, ] @@ -63,10 +66,14 @@ def init(): services.translate.init() services.chat.init() - update.check_update() - init_server(args.host, args.port, args.debug) - return server is not None + if server is None: + return False + + services.plugin.init() + + update.check_update() + return True def init_signal_handlers(): @@ -152,6 +159,8 @@ async def run(): async def shut_down(): + services.plugin.shut_down() + logger.info('Closing server') server.stop() await server.close_all_connections() diff --git a/plugins/msg-logging/main.py b/plugins/msg-logging/main.py new file mode 100644 index 0000000..4644440 --- /dev/null +++ b/plugins/msg-logging/main.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +import asyncio +import sys + +import blcsdk + + +async def main(): + print('hello world!', blcsdk.__version__) + return 0 + + +if __name__ == '__main__': + sys.exit(asyncio.run(main())) diff --git a/services/avatar.py b/services/avatar.py index e727f14..2ed8011 100644 --- a/services/avatar.py +++ b/services/avatar.py @@ -41,7 +41,7 @@ def init(): global _avatar_url_cache, _task_queue _avatar_url_cache = cachetools.TTLCache(cfg.avatar_cache_size, 10 * 60) _task_queue = asyncio.Queue(cfg.fetch_avatar_max_queue_size) - asyncio.get_running_loop().create_task(_do_init()) + asyncio.create_task(_do_init()) async def _do_init(): diff --git a/services/plugin.py b/services/plugin.py new file mode 100644 index 0000000..8c56e37 --- /dev/null +++ b/services/plugin.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- +import dataclasses +import datetime +import json +import logging +import os +import random +import string +import subprocess +from typing import * + +import api.plugin +import config + +logger = logging.getLogger(__name__) + +PLUGINS_PATH = os.path.join(config.DATA_PATH, 'plugins') + +_plugins: Dict[str, 'Plugin'] = {} + + +def init(): + plugin_ids = _discover_plugin_ids() + if not plugin_ids: + return + logger.info('Found plugins: %s', plugin_ids) + + for plugin_id in plugin_ids: + plugin = _create_plugin(plugin_id) + if plugin is not None: + _plugins[plugin_id] = plugin + + for plugin in _plugins.values(): + if plugin.enabled: + try: + plugin.start() + except StartPluginError: + pass + + +def shut_down(): + for plugin in _plugins.values(): + plugin.stop() + + +def _discover_plugin_ids(): + res = [] + try: + with os.scandir(PLUGINS_PATH) as it: + for entry in it: + if entry.is_dir() and os.path.isfile(os.path.join(entry.path, 'plugin.json')): + res.append(entry.name) + except OSError: + logger.exception('Failed to discover plugins:') + return res + + +def _create_plugin(plugin_id): + config_path = os.path.join(PLUGINS_PATH, plugin_id, 'plugin.json') + try: + plugin_config = PluginConfig.from_file(config_path) + except (OSError, json.JSONDecodeError, TypeError): + logger.exception('plugin=%s failed to load config:', plugin_id) + return None + return Plugin(plugin_id, plugin_config) + + +def iter_plugins() -> Iterable['Plugin']: + return _plugins.values() + + +def get_plugin_by_token(token): + if token == '': + return None + for plugin in _plugins.values(): + if plugin.token == token: + return plugin + return None + + +def broadcast_cmd_data(cmd, data, extra: Optional[dict] = None): + body = api.plugin.make_message_body(cmd, data, extra) + for plugin in _plugins.values(): + plugin.send_body_no_raise(body) + + +@dataclasses.dataclass +class PluginConfig: + name: str = '' + version: str = '' + author: str = '' + description: str = '' + run_cmd: str = '' + enabled: bool = False + + @classmethod + def from_file(cls, path): + with open(path, encoding='utf-8') as f: + cfg = json.load(f) + if not isinstance(cfg, dict): + raise TypeError(f'Config type error, type={type(cfg)}') + + return cls( + name=str(cfg.get('name', '')), + version=str(cfg.get('version', '')), + author=str(cfg.get('author', '')), + description=str(cfg.get('description', '')), + run_cmd=str(cfg.get('run', '')), + enabled=bool(cfg.get('enabled', False)), + ) + + def save(self, path): + try: + with open(path, encoding='utf-8') as f: + cfg = json.load(f) + if not isinstance(cfg, dict): + raise TypeError(f'Config type error, type={type(cfg)}') + except (OSError, json.JSONDecodeError, TypeError): + cfg = {} + + cfg['name'] = self.name + cfg['version'] = self.version + cfg['author'] = self.author + cfg['description'] = self.description + cfg['run_cmd'] = self.run_cmd + cfg['enabled'] = self.enabled + + tmp_path = path + '.tmp' + with open(tmp_path, encoding='utf-8') as f: + json.dump(cfg, f, ensure_ascii=False, indent=2) + os.replace(tmp_path, path) + + +class StartPluginError(Exception): + """启动插件时错误""" + + +class StartTooFrequently(StartPluginError): + """启动插件太频繁""" + + +class Plugin: + def __init__(self, plugin_id, cfg: PluginConfig): + self._id = plugin_id + self._config = cfg + + self._last_start_time = datetime.datetime.fromtimestamp(0) + self._token = '' + self._client: Optional['api.plugin.PluginWsHandler'] = None + + @property + def id(self): + return self._id + + @property + def enabled(self): + return self._config.enabled + + @enabled.setter + def enabled(self, value): + if self._config.enabled == value: + return + self._config.enabled = value + + config_path = os.path.join(self.base_path, 'plugin.json') + try: + self._config.save(config_path) + except OSError: + logger.exception('plugin=%s failed to save config', self._id) + + if value: + self.start() + else: + self.stop() + + @property + def base_path(self): + return os.path.join(PLUGINS_PATH, self._id) + + @property + def token(self): + return self._token + + @property + def is_started(self): + return self._token != '' + + @property + def is_connected(self): + return self._client is not None + + def start(self): + if self.is_started: + return + + cur_time = datetime.datetime.now() + if cur_time - self._last_start_time < datetime.timedelta(seconds=3): + raise StartTooFrequently(f'plugin={self._id} starts too frequently') + self._last_start_time = cur_time + + token = ''.join(random.choice(string.hexdigits) for _ in range(32)) + self._set_token(token) + + env = { + **os.environ, + 'BLC_PORT': str(12450), # TODO 读配置 + 'BLC_TOKEN': self._token, + } + try: + subprocess.Popen( + self._config.run_cmd, + shell=True, + cwd=self.base_path, + env=env, + ) + except OSError as e: + logger.exception('plugin=%s failed to start', self._id) + raise StartPluginError(str(e)) + + def stop(self): + if self.is_started: + self._set_token('') + + def _set_token(self, token): + if self._token == token: + return + self._token = token + + # 踢掉已经连接的客户端 + self._set_client(None) + + def _set_client(self, client: Optional['api.plugin.PluginWsHandler']): + if self._client is client: + return + if self._client is not None: + self._client.close() + self._client = client + + def on_client_connect(self, client: 'api.plugin.PluginWsHandler'): + self._set_client(client) + + def on_client_close(self, client: 'api.plugin.PluginWsHandler'): + if self._client is client: + self._set_client(None) + + def send_cmd_data(self, cmd, data, extra: Optional[dict] = None): + if self._client is not None: + self._client.send_cmd_data(cmd, data, extra) + + def send_body_no_raise(self, body): + if self._client is not None: + self._client.send_body_no_raise(body) diff --git a/services/translate.py b/services/translate.py index 7ccf3e2..0fe111d 100644 --- a/services/translate.py +++ b/services/translate.py @@ -56,7 +56,7 @@ def init(): _translate_cache = cachetools.LRUCache(cfg.translation_cache_size) # 总队列长度会超过translate_max_queue_size,不用这么严格 _task_queues = [asyncio.Queue(cfg.translate_max_queue_size) for _ in range(len(Priority))] - asyncio.get_running_loop().create_task(_do_init()) + asyncio.create_task(_do_init()) async def _do_init(): diff --git a/update.py b/update.py index 1ec3ad7..bb31df4 100644 --- a/update.py +++ b/update.py @@ -9,7 +9,7 @@ VERSION = 'v1.8.2' def check_update(): - asyncio.get_running_loop().create_task(_do_check_update()) + asyncio.create_task(_do_check_update()) async def _do_check_update():