添加插件框架

This commit is contained in:
John Smith 2023-11-04 16:52:07 +08:00
parent 5d98bb5d14
commit f886506c39
10 changed files with 388 additions and 6 deletions

98
api/plugin.py Normal file
View File

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
import asyncio
import json
import logging
from typing import *
import tornado.web
import tornado.websocket
import api.base
import blcsdk.models as models
import services.plugin
logger = logging.getLogger(__name__)
class _PluginHandlerBase(api.base.ApiHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plugin: Optional[services.plugin.Plugin] = None
def prepare(self):
try:
auth = self.request.headers['Authorization']
if not auth.startswith('Bearer '):
raise ValueError(f'Bad authorization: {auth}')
token = auth[7:]
self.plugin = services.plugin.get_plugin_by_token(token)
if self.plugin is None:
raise ValueError(f'Token error: {token}')
except (KeyError, ValueError) as e:
logger.warning('client=%s failed to find plugin: %r', self.request.remote_ip, e)
raise tornado.web.HTTPError(403)
super().prepare()
def make_message_body(cmd, data, extra: Optional[dict] = None):
body = {'cmd': cmd, 'data': data}
if extra:
body['extra'] = extra
return json.dumps(body).encode('utf-8')
class PluginWsHandler(_PluginHandlerBase, tornado.websocket.WebSocketHandler):
HEARTBEAT_INTERVAL = 10
RECEIVE_TIMEOUT = HEARTBEAT_INTERVAL + 5
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._heartbeat_timer_handle = None
self._receive_timeout_timer_handle = None
def open(self):
logger.info('plugin=%s connected, client=%s', self.plugin.id, self.request.remote_ip)
self._heartbeat_timer_handle = asyncio.get_running_loop().call_later(
self.HEARTBEAT_INTERVAL, self._on_send_heartbeat
)
self._refresh_receive_timeout_timer()
self.plugin.on_client_connect(self)
def _on_send_heartbeat(self):
self.send_cmd_data(models.Command.HEARTBEAT, {})
self._heartbeat_timer_handle = asyncio.get_running_loop().call_later(
self.HEARTBEAT_INTERVAL, self._on_send_heartbeat
)
def _refresh_receive_timeout_timer(self):
if self._receive_timeout_timer_handle is not None:
self._receive_timeout_timer_handle.cancel()
self._receive_timeout_timer_handle = asyncio.get_running_loop().call_later(
self.RECEIVE_TIMEOUT, self._on_receive_timeout
)
def _on_receive_timeout(self):
logger.info('plugin=%s timed out', self.plugin.id)
self._receive_timeout_timer_handle = None
self.close()
def on_close(self):
logger.info('plugin=%s disconnected', self.plugin.id)
self.plugin.on_client_close(self)
def send_cmd_data(self, cmd, data, extra: Optional[dict] = None):
self.send_body_no_raise(make_message_body(cmd, data, extra))
def send_body_no_raise(self, body: Union[bytes, str, Dict[str, Any]]):
try:
self.write_message(body)
except tornado.websocket.WebSocketClosedError:
self.close()
ROUTES = [
(r'/api/plugin/websocket', PluginWsHandler),
]

2
blcsdk/__init__.py Normal file
View File

@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
__version__ = '0.0.1'

1
blcsdk/client.py Normal file
View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

6
blcsdk/models.py Normal file
View File

@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
import enum
class Command(enum.IntEnum):
HEARTBEAT = 0

15
main.py
View File

@ -15,10 +15,12 @@ import tornado.web
import api.chat
import api.main
import api.open_live
import api.plugin
import config
import models.database
import services.avatar
import services.chat
import services.plugin
import services.translate
import update
import utils.request
@ -29,6 +31,7 @@ ROUTES = [
*api.main.ROUTES,
*api.chat.ROUTES,
*api.open_live.ROUTES,
*api.plugin.ROUTES,
*api.main.LAST_ROUTES,
]
@ -63,10 +66,14 @@ def init():
services.translate.init()
services.chat.init()
update.check_update()
init_server(args.host, args.port, args.debug)
return server is not None
if server is None:
return False
services.plugin.init()
update.check_update()
return True
def init_signal_handlers():
@ -152,6 +159,8 @@ async def run():
async def shut_down():
services.plugin.shut_down()
logger.info('Closing server')
server.stop()
await server.close_all_connections()

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
import asyncio
import sys
import blcsdk
async def main():
print('hello world!', blcsdk.__version__)
return 0
if __name__ == '__main__':
sys.exit(asyncio.run(main()))

View File

@ -41,7 +41,7 @@ def init():
global _avatar_url_cache, _task_queue
_avatar_url_cache = cachetools.TTLCache(cfg.avatar_cache_size, 10 * 60)
_task_queue = asyncio.Queue(cfg.fetch_avatar_max_queue_size)
asyncio.get_running_loop().create_task(_do_init())
asyncio.create_task(_do_init())
async def _do_init():

252
services/plugin.py Normal file
View File

@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
import dataclasses
import datetime
import json
import logging
import os
import random
import string
import subprocess
from typing import *
import api.plugin
import config
logger = logging.getLogger(__name__)
PLUGINS_PATH = os.path.join(config.DATA_PATH, 'plugins')
_plugins: Dict[str, 'Plugin'] = {}
def init():
plugin_ids = _discover_plugin_ids()
if not plugin_ids:
return
logger.info('Found plugins: %s', plugin_ids)
for plugin_id in plugin_ids:
plugin = _create_plugin(plugin_id)
if plugin is not None:
_plugins[plugin_id] = plugin
for plugin in _plugins.values():
if plugin.enabled:
try:
plugin.start()
except StartPluginError:
pass
def shut_down():
for plugin in _plugins.values():
plugin.stop()
def _discover_plugin_ids():
res = []
try:
with os.scandir(PLUGINS_PATH) as it:
for entry in it:
if entry.is_dir() and os.path.isfile(os.path.join(entry.path, 'plugin.json')):
res.append(entry.name)
except OSError:
logger.exception('Failed to discover plugins:')
return res
def _create_plugin(plugin_id):
config_path = os.path.join(PLUGINS_PATH, plugin_id, 'plugin.json')
try:
plugin_config = PluginConfig.from_file(config_path)
except (OSError, json.JSONDecodeError, TypeError):
logger.exception('plugin=%s failed to load config:', plugin_id)
return None
return Plugin(plugin_id, plugin_config)
def iter_plugins() -> Iterable['Plugin']:
return _plugins.values()
def get_plugin_by_token(token):
if token == '':
return None
for plugin in _plugins.values():
if plugin.token == token:
return plugin
return None
def broadcast_cmd_data(cmd, data, extra: Optional[dict] = None):
body = api.plugin.make_message_body(cmd, data, extra)
for plugin in _plugins.values():
plugin.send_body_no_raise(body)
@dataclasses.dataclass
class PluginConfig:
name: str = ''
version: str = ''
author: str = ''
description: str = ''
run_cmd: str = ''
enabled: bool = False
@classmethod
def from_file(cls, path):
with open(path, encoding='utf-8') as f:
cfg = json.load(f)
if not isinstance(cfg, dict):
raise TypeError(f'Config type error, type={type(cfg)}')
return cls(
name=str(cfg.get('name', '')),
version=str(cfg.get('version', '')),
author=str(cfg.get('author', '')),
description=str(cfg.get('description', '')),
run_cmd=str(cfg.get('run', '')),
enabled=bool(cfg.get('enabled', False)),
)
def save(self, path):
try:
with open(path, encoding='utf-8') as f:
cfg = json.load(f)
if not isinstance(cfg, dict):
raise TypeError(f'Config type error, type={type(cfg)}')
except (OSError, json.JSONDecodeError, TypeError):
cfg = {}
cfg['name'] = self.name
cfg['version'] = self.version
cfg['author'] = self.author
cfg['description'] = self.description
cfg['run_cmd'] = self.run_cmd
cfg['enabled'] = self.enabled
tmp_path = path + '.tmp'
with open(tmp_path, encoding='utf-8') as f:
json.dump(cfg, f, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
class StartPluginError(Exception):
"""启动插件时错误"""
class StartTooFrequently(StartPluginError):
"""启动插件太频繁"""
class Plugin:
def __init__(self, plugin_id, cfg: PluginConfig):
self._id = plugin_id
self._config = cfg
self._last_start_time = datetime.datetime.fromtimestamp(0)
self._token = ''
self._client: Optional['api.plugin.PluginWsHandler'] = None
@property
def id(self):
return self._id
@property
def enabled(self):
return self._config.enabled
@enabled.setter
def enabled(self, value):
if self._config.enabled == value:
return
self._config.enabled = value
config_path = os.path.join(self.base_path, 'plugin.json')
try:
self._config.save(config_path)
except OSError:
logger.exception('plugin=%s failed to save config', self._id)
if value:
self.start()
else:
self.stop()
@property
def base_path(self):
return os.path.join(PLUGINS_PATH, self._id)
@property
def token(self):
return self._token
@property
def is_started(self):
return self._token != ''
@property
def is_connected(self):
return self._client is not None
def start(self):
if self.is_started:
return
cur_time = datetime.datetime.now()
if cur_time - self._last_start_time < datetime.timedelta(seconds=3):
raise StartTooFrequently(f'plugin={self._id} starts too frequently')
self._last_start_time = cur_time
token = ''.join(random.choice(string.hexdigits) for _ in range(32))
self._set_token(token)
env = {
**os.environ,
'BLC_PORT': str(12450), # TODO 读配置
'BLC_TOKEN': self._token,
}
try:
subprocess.Popen(
self._config.run_cmd,
shell=True,
cwd=self.base_path,
env=env,
)
except OSError as e:
logger.exception('plugin=%s failed to start', self._id)
raise StartPluginError(str(e))
def stop(self):
if self.is_started:
self._set_token('')
def _set_token(self, token):
if self._token == token:
return
self._token = token
# 踢掉已经连接的客户端
self._set_client(None)
def _set_client(self, client: Optional['api.plugin.PluginWsHandler']):
if self._client is client:
return
if self._client is not None:
self._client.close()
self._client = client
def on_client_connect(self, client: 'api.plugin.PluginWsHandler'):
self._set_client(client)
def on_client_close(self, client: 'api.plugin.PluginWsHandler'):
if self._client is client:
self._set_client(None)
def send_cmd_data(self, cmd, data, extra: Optional[dict] = None):
if self._client is not None:
self._client.send_cmd_data(cmd, data, extra)
def send_body_no_raise(self, body):
if self._client is not None:
self._client.send_body_no_raise(body)

View File

@ -56,7 +56,7 @@ def init():
_translate_cache = cachetools.LRUCache(cfg.translation_cache_size)
# 总队列长度会超过translate_max_queue_size不用这么严格
_task_queues = [asyncio.Queue(cfg.translate_max_queue_size) for _ in range(len(Priority))]
asyncio.get_running_loop().create_task(_do_init())
asyncio.create_task(_do_init())
async def _do_init():

View File

@ -9,7 +9,7 @@ VERSION = 'v1.8.2'
def check_update():
asyncio.get_running_loop().create_task(_do_check_update())
asyncio.create_task(_do_check_update())
async def _do_check_update():