├── src └── satori │ ├── client │ ├── network │ │ ├── __init__.py │ │ ├── base.py │ │ ├── util.py │ │ ├── webhook.py │ │ └── websocket.py │ ├── config.py │ └── account.py │ ├── adapters │ ├── satori │ │ ├── __init__.py │ │ └── main.py │ ├── console │ │ ├── __init__.py │ │ ├── message.py │ │ ├── main.py │ │ ├── backend.py │ │ └── api.py │ ├── onebot11 │ │ ├── __init__.py │ │ ├── events │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── request.py │ │ │ └── message.py │ │ ├── utils.py │ │ ├── reverse.py │ │ ├── forward.py │ │ └── message.py │ └── milky │ │ ├── __init__.py │ │ ├── events │ │ ├── __init__.py │ │ ├── base.py │ │ ├── group.py │ │ ├── message.py │ │ └── request.py │ │ ├── utils.py │ │ ├── webhook.py │ │ ├── main.py │ │ └── api.py │ ├── server │ ├── formdata.py │ ├── utils.py │ ├── model.py │ ├── connection.py │ ├── adapter.py │ └── route.py │ ├── exception.py │ ├── utils.py │ ├── event.py │ ├── __init__.py │ └── const.py ├── .github ├── dependabot.yml ├── workflows │ ├── ruff.yml │ ├── release-adapter-milky.yml │ ├── release-adapter-satori.yml │ ├── release-adapter-console.yml │ ├── release-adapter-onebot11.yml │ ├── release-core.yml │ ├── release-client.yml │ ├── release-server.yml │ └── release.yml ├── ISSUE_TEMPLATE │ ├── feature.md │ └── bug.md └── actions │ └── setup-python │ └── action.yml ├── experimental ├── model.py ├── model.pyi └── _model_msgspec.py ├── example ├── server_webhook.py ├── server_with_adapter.py ├── client_webhook.py ├── client.py ├── adapter.py └── server.py ├── .pre-commit-config.yaml ├── .mina ├── adapter_satori.toml ├── adapter_milky.toml ├── adapter_onebot11.toml ├── adapter_console.toml ├── core.toml ├── client.toml └── server.toml ├── LICENSE ├── exam.py ├── pyproject.toml ├── .gitignore └── README.md /src/satori/client/network/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/satori/adapters/satori/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import SatoriAdapter as SatoriAdapter 2 | -------------------------------------------------------------------------------- /src/satori/adapters/console/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import ConsoleAdapter as ConsoleAdapter 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: .github/workflows 5 | schedule: 6 | interval: weekly 7 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward import OneBot11ForwardAdapter as OneBot11ForwardAdapter 2 | from .reverse import OneBot11ReverseAdapter as OneBot11ReverseAdapter 3 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/events/__init__.py: -------------------------------------------------------------------------------- 1 | from . import message as message # noqa: F401 2 | from . import notice as notice # noqa: F401 3 | from . import request as request # noqa: F401 4 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import MilkyAdapter as MilkyAdapter 2 | from .webhook import MilkyWebhookAdapter as MilkyWebhookAdapter 3 | 4 | __all__ = ["MilkyAdapter", "MilkyWebhookAdapter"] 5 | -------------------------------------------------------------------------------- /experimental/model.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | 3 | if find_spec("msgspec") is not None: 4 | from ._model_msgspec import * # noqa: F403,F401 5 | else: 6 | from ._model_dcls import * # noqa: F403,F401 7 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/events/__init__.py: -------------------------------------------------------------------------------- 1 | # import event handlers to register them 2 | from . import group as group # noqa: F401 3 | from . import message as message # noqa: F401 4 | from . import request as request # noqa: F401 5 | from .base import event_handlers, register_event 6 | 7 | __all__ = ["event_handlers", "register_event"] 8 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | ruff: 11 | name: Ruff Lint 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v5 15 | 16 | - name: Run Ruff Lint 17 | uses: chartboost/ruff-action@v1 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature 特性请求 3 | about: 为 SDK 加份菜 4 | title: "[Feature] " 5 | labels: enhancement, triage 6 | assignees: "" 7 | --- 8 | 9 | ## 请确认: 10 | 11 | * [ ] 新特性的目的明确 12 | * [ ] 我已经阅读了[相关文档](https://satori.js.org/zh-CN/) 并且找不到类似特性 13 | 14 | 15 | ## Feature 16 | ### 概要 17 | 18 | 19 | 20 | ### 是否已有相关实现 21 | 22 | 暂无 23 | 24 | 25 | ### 其他内容 26 | 27 | 暂无 28 | -------------------------------------------------------------------------------- /example/server_webhook.py: -------------------------------------------------------------------------------- 1 | from adapter import ExampleAdapter 2 | 3 | from satori import Api, Channel, ChannelType 4 | from satori.server import Server, WebhookEndpoint 5 | 6 | server = Server(host="localhost", port=12345, webhooks=[WebhookEndpoint("http://localhost:8080/bar")]) 7 | server.apply(ExampleAdapter()) 8 | 9 | 10 | @server.route(Api.CHANNEL_GET) 11 | async def handle(*args, **kwargs): 12 | return Channel("1234567890", ChannelType.TEXT, "test").dump() 13 | 14 | 15 | server.run() 16 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/events/base.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Awaitable, Callable 2 | 3 | from satori.model import Event, Login 4 | 5 | from ..utils import OneBotNetwork 6 | 7 | events: dict[str, Callable[[Login, OneBotNetwork, dict], Awaitable[Event | None]]] = {} 8 | 9 | 10 | def register_event(event_type: str): 11 | def wrapper(func: Callable[[Login, OneBotNetwork, dict], Awaitable[Event | None]]): 12 | events[event_type] = func 13 | return func 14 | 15 | return wrapper 16 | -------------------------------------------------------------------------------- /src/satori/server/formdata.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def parse_content_disposition(header_value): 5 | match = re.match(r"""form-data; (?P.+)""", header_value) 6 | if match: 7 | parameters = match.groupdict()["parameters"] 8 | parsed_data = {} 9 | for param in parameters.split(";"): 10 | key, value = param.strip().split("=") 11 | parsed_data[key.strip('"')] = value.strip('"') 12 | return parsed_data 13 | raise ValueError(header_value) 14 | -------------------------------------------------------------------------------- /.github/actions/setup-python/action.yml: -------------------------------------------------------------------------------- 1 | name: Setup Python 2 | description: Setup Python 3 | 4 | inputs: 5 | python-version: 6 | description: Python version 7 | required: false 8 | default: "3.10" 9 | 10 | runs: 11 | using: "composite" 12 | steps: 13 | - uses: pdm-project/setup-pdm@v3 14 | name: Setup PDM 15 | with: 16 | python-version: ${{ inputs.python-version }} 17 | architecture: "x64" 18 | cache: true 19 | 20 | - run: pdm sync -G:all --no-isolation 21 | shell: bash 22 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/events/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Awaitable, Callable 4 | 5 | from satori.model import Event, Login 6 | 7 | from ..utils import MilkyNetwork 8 | 9 | EventHandler = Callable[[Login, MilkyNetwork, dict], Awaitable[Event | None]] 10 | 11 | event_handlers: dict[str, EventHandler] = {} 12 | 13 | 14 | def register_event(event_type: str): 15 | def decorator(func: EventHandler): 16 | event_handlers[event_type] = func 17 | return func 18 | 19 | return decorator 20 | -------------------------------------------------------------------------------- /src/satori/exception.py: -------------------------------------------------------------------------------- 1 | class ActionFailed(Exception): 2 | pass 3 | 4 | 5 | class BadRequestException(ActionFailed): 6 | pass 7 | 8 | 9 | class UnauthorizedException(ActionFailed): 10 | pass 11 | 12 | 13 | class ForbiddenException(ActionFailed): 14 | pass 15 | 16 | 17 | class NotFoundException(ActionFailed): 18 | pass 19 | 20 | 21 | class MethodNotAllowedException(ActionFailed): 22 | pass 23 | 24 | 25 | class ServerException(ActionFailed): 26 | pass 27 | 28 | 29 | class NetworkError(Exception): 30 | pass 31 | 32 | 33 | class ApiNotAvailable(Exception): 34 | pass 35 | -------------------------------------------------------------------------------- /example/server_with_adapter.py: -------------------------------------------------------------------------------- 1 | from adapter import ExampleAdapter 2 | 3 | from satori import Api, Channel, ChannelType 4 | from satori.server import Response, Server, StarletteRequest 5 | 6 | server = Server(host="localhost", port=12345, path="foo") 7 | server.apply(ExampleAdapter()) 8 | 9 | 10 | @server.route(Api.CHANNEL_GET) 11 | async def handle(*args, **kwargs): 12 | return Channel("1234567890", ChannelType.TEXT, "test").dump() 13 | 14 | 15 | @server.asgi_route("/api/v1/test") 16 | async def exam_route(request: StarletteRequest): 17 | return Response(str(dict(request.items()))) 18 | 19 | 20 | server.run() 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug 报告 3 | about: 有关 bug 的报告 4 | title: "[Bug]" 5 | labels: bug, triage 6 | assignees: "" 7 | --- 8 | 9 | ## 请确认: 10 | 11 | * [ ] 问题的标题明确 12 | * [ ] 我翻阅过其他的 issue 并且找不到类似的问题 13 | * [ ] 我已经阅读了[相关文档](https://satori.js.org/zh-CN/) 并仍然认为这是一个Bug 14 | 15 | # Bug 16 | 17 | ## 问题 18 | 19 | 20 | ## 如何复现 21 | 22 | 23 | ## 预期行为 24 | 25 | 26 | ## 使用环境: 27 | - 操作系统 (Windows/Linux/Mac): 28 | - Python 版本: 29 | - SDK 版本: 30 | - 使用的 Satori 服务端 (例如 Chronocat): 31 | 32 | ## 日志/截图 33 | 34 | -------------------------------------------------------------------------------- /example/client_webhook.py: -------------------------------------------------------------------------------- 1 | from satori import EventType 2 | from satori.client import Account, App, WebhookInfo 3 | from satori.event import MessageEvent 4 | 5 | app = App(WebhookInfo(server_port=12345, path="bar")) 6 | 7 | 8 | @app.register_on(EventType.MESSAGE_CREATED) 9 | async def on_message(account: Account, event: MessageEvent): 10 | if event.user and event.user.id == "9876543210": 11 | print(await account.channel_get(event.channel.id)) # noqa: T201 12 | await account.send_message(event.channel.id, "Hello, World!") 13 | 14 | 15 | @app.lifecycle 16 | async def record(account, state): 17 | print(account, state) # noqa: T201 18 | 19 | 20 | app.run() 21 | -------------------------------------------------------------------------------- /src/satori/utils.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | 4 | def get_public_ip(): 5 | st = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 6 | try: 7 | st.connect(("10.255.255.255", 1)) 8 | IP = st.getsockname()[0] 9 | except Exception: 10 | IP = "localhost" 11 | finally: 12 | st.close() 13 | return IP 14 | 15 | 16 | try: 17 | from msgspec.json import Decoder, Encoder # noqa: F401 18 | 19 | decoder = Decoder() 20 | encoder = Encoder() 21 | 22 | decode = decoder.decode 23 | 24 | def encode(obj): 25 | return encoder.encode(obj).decode() 26 | 27 | except ImportError: 28 | import json 29 | 30 | def encode(obj): 31 | return json.dumps(obj, separators=(",", ":"), ensure_ascii=False) 32 | 33 | decode = json.loads 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_install_hook_types: [pre-commit, prepare-commit-msg] 2 | ci: 3 | autofix_commit_msg: ":rotating_light: auto fix by pre-commit hooks" 4 | autofix_prs: true 5 | autoupdate_branch: master 6 | autoupdate_schedule: monthly 7 | autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks" 8 | repos: 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: v0.3.2 11 | hooks: 12 | - id: ruff 13 | args: [--fix, --exit-non-zero-on-fix] 14 | stages: [commit] 15 | 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.13.2 18 | hooks: 19 | - id: isort 20 | stages: [commit] 21 | 22 | - repo: https://ghproxy.com/github.com/psf/black 23 | rev: 24.2.0 24 | hooks: 25 | - id: black 26 | stages: [commit] 27 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | from satori import Role 4 | 5 | 6 | class OneBotNetwork(Protocol): 7 | async def call_api(self, action: str, params: dict | None = None) -> dict: ... 8 | 9 | 10 | SPECIAL_POST_TYPE = {"message_sent": "message"} 11 | 12 | 13 | def onebot11_event_type(raw: dict) -> str: 14 | return ( 15 | f"{(post := raw['post_type'])}." 16 | f"{raw.get(f'{SPECIAL_POST_TYPE.get(post, post)}_type', '_')}" 17 | f"{f'.{sub}' if (sub:=raw.get('sub_type')) else ''}" 18 | ) 19 | 20 | 21 | USER_AVATAR_URL = "https://q2.qlogo.cn/headimg_dl?dst_uin={uin}&spec=640" 22 | GROUP_AVATAR_URL = "https://p.qlogo.cn/gh/{group}/{group}/" 23 | 24 | ROLE_MAPPING = { 25 | "member": Role("MEMBER", "群成员"), 26 | "admin": Role("ADMINISTRATOR", "管理员"), 27 | "owner": Role("OWNER", "群主"), 28 | } 29 | -------------------------------------------------------------------------------- /src/satori/adapters/console/message.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from nonechat.message import ConsoleMessage, Markdown 4 | from nonechat.message import Text as ConsoleText 5 | 6 | from satori.element import At, Text, transform 7 | from satori.parser import parse 8 | 9 | 10 | def encode_message(message: ConsoleMessage) -> str: 11 | content = str(message) 12 | content = re.sub(r"@(\w+)", r"@", content) # Handle mentions 13 | return content 14 | 15 | 16 | def decode_message(content: str) -> ConsoleMessage: 17 | elements = [] 18 | msg = transform(parse(content)) 19 | for seg in msg: 20 | if isinstance(seg, Text): 21 | elements.append(ConsoleText(seg.text)) 22 | elif isinstance(seg, At): 23 | elements.append(ConsoleText(f"@{seg.id}")) 24 | else: 25 | elements.append(Markdown(str(seg))) 26 | return ConsoleMessage(elements) 27 | -------------------------------------------------------------------------------- /.github/workflows/release-adapter-milky.yml: -------------------------------------------------------------------------------- 1 | name: Milky Release 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | release: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | id-token: write 11 | contents: write 12 | steps: 13 | - uses: actions/checkout@v5 14 | 15 | - uses: actions/setup-python@v5 16 | name: Ensure Python Runtime 17 | with: 18 | python-version: '3.11' 19 | architecture: 'x64' 20 | 21 | - name: Ensure PDM 22 | run: | 23 | python3 -m pip install pdm==2.13.2 pdm-mina "mina-build<0.6" "pdm-backend<2.4.0" 24 | 25 | - name: Build Package 26 | run: | 27 | export MINA_BUILD_TARGET=adapter_milky && pdm mina build adapter_milky --no-isolation 28 | 29 | - name: Publish Package 30 | run: | 31 | pdm publish --no-build 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /.github/workflows/release-adapter-satori.yml: -------------------------------------------------------------------------------- 1 | name: Satori Release 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | release: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | id-token: write 11 | contents: write 12 | steps: 13 | - uses: actions/checkout@v5 14 | 15 | - uses: actions/setup-python@v5 16 | name: Ensure Python Runtime 17 | with: 18 | python-version: '3.11' 19 | architecture: 'x64' 20 | 21 | - name: Ensure PDM 22 | run: | 23 | python3 -m pip install pdm==2.13.2 pdm-mina "mina-build<0.6" "pdm-backend<2.4.0" 24 | 25 | - name: Build Package 26 | run: | 27 | export MINA_BUILD_TARGET=adapter_satori && pdm mina build adapter_satori --no-isolation 28 | 29 | - name: Publish Package 30 | run: | 31 | pdm publish --no-build 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /.github/workflows/release-adapter-console.yml: -------------------------------------------------------------------------------- 1 | name: Console Release 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | release: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | id-token: write 11 | contents: write 12 | steps: 13 | - uses: actions/checkout@v5 14 | 15 | - uses: actions/setup-python@v5 16 | name: Ensure Python Runtime 17 | with: 18 | python-version: '3.11' 19 | architecture: 'x64' 20 | 21 | - name: Ensure PDM 22 | run: | 23 | python3 -m pip install pdm==2.13.2 pdm-mina "mina-build<0.6" "pdm-backend<2.4.0" 24 | 25 | - name: Build Package 26 | run: | 27 | export MINA_BUILD_TARGET=adapter_console && pdm mina build adapter_console --no-isolation 28 | 29 | - name: Publish Package 30 | run: | 31 | pdm publish --no-build 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /.github/workflows/release-adapter-onebot11.yml: -------------------------------------------------------------------------------- 1 | name: Onebot11 Release 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | release: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | id-token: write 11 | contents: write 12 | steps: 13 | - uses: actions/checkout@v5 14 | 15 | - uses: actions/setup-python@v5 16 | name: Ensure Python Runtime 17 | with: 18 | python-version: '3.11' 19 | architecture: 'x64' 20 | 21 | - name: Ensure PDM 22 | run: | 23 | python3 -m pip install pdm==2.13.2 pdm-mina "mina-build<0.6" "pdm-backend<2.4.0" 24 | 25 | - name: Build Package 26 | run: | 27 | export MINA_BUILD_TARGET=adapter_onebot11 && pdm mina build adapter_onebot11 --no-isolation 28 | 29 | - name: Publish Package 30 | run: | 31 | pdm publish --no-build 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /src/satori/client/network/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from typing import TYPE_CHECKING, Generic, TypeVar 5 | 6 | from launart import Service 7 | 8 | from ..config import Config as Config 9 | 10 | if TYPE_CHECKING: 11 | from satori.client import Account, App 12 | 13 | TConfig = TypeVar("TConfig", bound=Config) 14 | 15 | 16 | class BaseNetwork(Generic[TConfig], Service): 17 | close_signal: asyncio.Event 18 | sequence: int 19 | 20 | def __init__(self, app: App, config: TConfig): 21 | super().__init__() 22 | self.app = app 23 | self.config = config 24 | self.accounts: dict[str, Account] = {} 25 | self.close_signal = asyncio.Event() 26 | self.sequence = -1 27 | self.proxy_urls = [] 28 | 29 | async def wait_for_available(self): ... 30 | 31 | @property 32 | def alive(self) -> bool: ... 33 | 34 | async def connection_closed(self): 35 | self.close_signal.set() 36 | -------------------------------------------------------------------------------- /.mina/adapter_satori.toml: -------------------------------------------------------------------------------- 1 | includes = ["src/satori/adapters/satori"] 2 | raw-dependencies = ["satori-python >= 0.17.2"] 3 | 4 | [project] 5 | name = "satori-python-adapter-satori" 6 | version = "0.2.3" 7 | authors = [ 8 | {name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com"} 9 | ] 10 | dependencies = [] 11 | description = "Satori Protocol SDK for python, adapter for Satori" 12 | license = {text = "MIT"} 13 | readme = "README.md" 14 | requires-python = ">=3.10,<4.0" 15 | classifiers = [ 16 | "Typing :: Typed", 17 | "Development Status :: 4 - Beta", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3.8", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Operating System :: OS Independent", 25 | ] 26 | 27 | [project.urls] 28 | homepage = "https://github.com/RF-Tar-Railt/satori-python" 29 | repository = "https://github.com/RF-Tar-Railt/satori-python" 30 | -------------------------------------------------------------------------------- /.mina/adapter_milky.toml: -------------------------------------------------------------------------------- 1 | includes = ["src/satori/adapters/milky"] 2 | raw-dependencies = ["satori-python-server >= 0.17.6"] 3 | 4 | [project] 5 | name = "satori-python-adapter-milky" 6 | version = "0.1.2" 7 | authors = [ 8 | {name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com"} 9 | ] 10 | dependencies = [] 11 | description = "Satori Protocol SDK for python, adapter for Milky" 12 | license = {text = "MIT"} 13 | readme = "README.md" 14 | requires-python = ">=3.10,<4.0" 15 | classifiers = [ 16 | "Typing :: Typed", 17 | "Development Status :: 4 - Beta", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3.8", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Operating System :: OS Independent", 25 | ] 26 | 27 | [project.urls] 28 | homepage = "https://github.com/RF-Tar-Railt/satori-python" 29 | repository = "https://github.com/RF-Tar-Railt/satori-python" 30 | -------------------------------------------------------------------------------- /.mina/adapter_onebot11.toml: -------------------------------------------------------------------------------- 1 | includes = ["src/satori/adapters/onebot11"] 2 | raw-dependencies = ["satori-python-server >= 0.17.6"] 3 | 4 | [project] 5 | name = "satori-python-adapter-onebot11" 6 | version = "0.2.7" 7 | authors = [ 8 | {name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com"} 9 | ] 10 | dependencies = [] 11 | description = "Satori Protocol SDK for python, adapter for OneBot 11" 12 | license = {text = "MIT"} 13 | readme = "README.md" 14 | requires-python = ">=3.10,<4.0" 15 | classifiers = [ 16 | "Typing :: Typed", 17 | "Development Status :: 4 - Beta", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3.8", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Operating System :: OS Independent", 25 | ] 26 | 27 | [project.urls] 28 | homepage = "https://github.com/RF-Tar-Railt/satori-python" 29 | repository = "https://github.com/RF-Tar-Railt/satori-python" 30 | -------------------------------------------------------------------------------- /src/satori/server/utils.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | from collections import deque 3 | 4 | 5 | class Deque: 6 | def __init__(self, maxlen: int): 7 | self.data = deque(maxlen=maxlen) 8 | self.offset = 0 9 | 10 | def append(self, x): 11 | if len(self.data) == self.data.maxlen: 12 | self.offset += 1 13 | self.data.append(x) 14 | 15 | def __getitem__(self, i: int): 16 | index = i - self.offset 17 | if index < 0 or index >= len(self.data): 18 | return 19 | return self.data[index] 20 | 21 | def after(self, i: int): 22 | if i < self.offset: 23 | i = self.offset - 1 24 | return list(self.data)[i + 1 - self.offset :] 25 | 26 | 27 | ctx = ssl.create_default_context() 28 | ctx.set_ciphers("DEFAULT") 29 | 30 | 31 | if __name__ == "__main__": 32 | d = Deque(3) 33 | d.append(0) 34 | d.append(1) 35 | d.append(2) 36 | print(d.after(0)) # noqa: T201 37 | d.append(3) 38 | d.append(4) 39 | d.append(5) 40 | print(d.data) # noqa: T201 41 | print(d.after(2)) # noqa: T201 42 | -------------------------------------------------------------------------------- /.github/workflows/release-core.yml: -------------------------------------------------------------------------------- 1 | name: Core Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | version: 7 | description: 'Release Version' 8 | required: true 9 | 10 | jobs: 11 | release: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | id-token: write 15 | contents: write 16 | steps: 17 | - uses: actions/checkout@v5 18 | 19 | - uses: actions/setup-python@v5 20 | name: Ensure Python Runtime 21 | with: 22 | python-version: '3.11' 23 | architecture: 'x64' 24 | 25 | - name: Ensure PDM 26 | run: | 27 | python3 -m pip install pdm==2.13.2 pdm-mina "mina-build<0.6" "pdm-backend<2.4.0" 28 | 29 | - name: Build Package 30 | run: | 31 | export MINA_BUILD_TARGET=core && pdm mina build core --no-isolation 32 | 33 | - name: Publish Package 34 | run: | 35 | pdm publish --no-build 36 | gh release upload --clobber ${{ inputs.version }} dist/*.tar.gz dist/*.whl 37 | env: 38 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 39 | -------------------------------------------------------------------------------- /.github/workflows/release-client.yml: -------------------------------------------------------------------------------- 1 | name: Client Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | version: 7 | description: 'Release Version' 8 | required: true 9 | 10 | jobs: 11 | release: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | id-token: write 15 | contents: write 16 | steps: 17 | - uses: actions/checkout@v5 18 | 19 | - uses: actions/setup-python@v5 20 | name: Ensure Python Runtime 21 | with: 22 | python-version: '3.11' 23 | architecture: 'x64' 24 | 25 | - name: Ensure PDM 26 | run: | 27 | python3 -m pip install pdm==2.13.2 pdm-mina "mina-build<0.6" "pdm-backend<2.4.0" 28 | 29 | - name: Build Package 30 | run: | 31 | export MINA_BUILD_TARGET=client && pdm mina build client --no-isolation 32 | 33 | - name: Publish Package 34 | run: | 35 | pdm publish --no-build 36 | gh release upload --clobber ${{ inputs.version }} dist/*.tar.gz dist/*.whl 37 | env: 38 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 39 | -------------------------------------------------------------------------------- /.github/workflows/release-server.yml: -------------------------------------------------------------------------------- 1 | name: Server Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | version: 7 | description: 'Release Version' 8 | required: true 9 | 10 | jobs: 11 | release: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | id-token: write 15 | contents: write 16 | steps: 17 | - uses: actions/checkout@v5 18 | 19 | - uses: actions/setup-python@v5 20 | name: Ensure Python Runtime 21 | with: 22 | python-version: '3.11' 23 | architecture: 'x64' 24 | 25 | - name: Ensure PDM 26 | run: | 27 | python3 -m pip install pdm==2.13.2 pdm-mina "mina-build<0.6" "pdm-backend<2.4.0" 28 | 29 | - name: Build Package 30 | run: | 31 | export MINA_BUILD_TARGET=server && pdm mina build server --no-isolation 32 | 33 | - name: Publish Package 34 | run: | 35 | pdm publish --no-build 36 | gh release upload --clobber ${{ inputs.version }} dist/*.tar.gz dist/*.whl 37 | env: 38 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 39 | -------------------------------------------------------------------------------- /.mina/adapter_console.toml: -------------------------------------------------------------------------------- 1 | includes = ["src/satori/adapters/console"] 2 | raw-dependencies = [ 3 | "satori-python-server >= 0.17.2", 4 | "nonechat<0.7.0,>=0.6.0", 5 | ] 6 | 7 | [project] 8 | name = "satori-python-adapter-console" 9 | version = "0.4.1" 10 | authors = [ 11 | {name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com"} 12 | ] 13 | dependencies = [] 14 | description = "Satori Protocol SDK for python, adapter for Nonechat" 15 | license = {text = "MIT"} 16 | readme = "README.md" 17 | requires-python = ">=3.10,<4.0" 18 | classifiers = [ 19 | "Typing :: Typed", 20 | "Development Status :: 4 - Beta", 21 | "License :: OSI Approved :: MIT License", 22 | "Programming Language :: Python :: 3.8", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | "Operating System :: OS Independent", 28 | ] 29 | 30 | [project.urls] 31 | homepage = "https://github.com/RF-Tar-Railt/satori-python" 32 | repository = "https://github.com/RF-Tar-Railt/satori-python" 33 | -------------------------------------------------------------------------------- /.mina/core.toml: -------------------------------------------------------------------------------- 1 | includes = ["src/satori"] 2 | excludes = ["src/satori/client", "src/satori/server", "src/satori/adapters"] 3 | 4 | [project] 5 | name = "satori-python-core" 6 | dynamic = ["version"] 7 | authors = [ 8 | {name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com"} 9 | ] 10 | dependencies = [ 11 | "loguru", 12 | "yarl", 13 | "typing-extensions", 14 | ] 15 | description = "Satori Protocol SDK for python, specify common part" 16 | license = {text = "MIT"} 17 | readme = "README.md" 18 | requires-python = ">=3.10,<4.0" 19 | classifiers = [ 20 | "Typing :: Typed", 21 | "Development Status :: 4 - Beta", 22 | "License :: OSI Approved :: MIT License", 23 | "Programming Language :: Python :: 3.8", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12", 28 | "Operating System :: OS Independent", 29 | ] 30 | 31 | [project.urls] 32 | homepage = "https://github.com/RF-Tar-Railt/satori-python" 33 | repository = "https://github.com/RF-Tar-Railt/satori-python" 34 | 35 | -------------------------------------------------------------------------------- /.mina/client.toml: -------------------------------------------------------------------------------- 1 | includes = ["src/satori/client"] 2 | raw-dependencies = [ 3 | "satori-python-core >= 0.17.0", 4 | "graia-amnesia >= 0.11.0", 5 | ] 6 | 7 | [project] 8 | name = "satori-python-client" 9 | dynamic = ["version"] 10 | authors = [ 11 | {name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com"} 12 | ] 13 | dependencies = [ 14 | "aiohttp", 15 | "launart", 16 | ] 17 | description = "Satori Protocol SDK for python, specify client part" 18 | license = {text = "MIT"} 19 | readme = "README.md" 20 | requires-python = ">=3.10,<4.0" 21 | classifiers = [ 22 | "Typing :: Typed", 23 | "Development Status :: 4 - Beta", 24 | "License :: OSI Approved :: MIT License", 25 | "Programming Language :: Python :: 3.8", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | "Programming Language :: Python :: 3.12", 30 | "Operating System :: OS Independent", 31 | ] 32 | 33 | [project.urls] 34 | homepage = "https://github.com/RF-Tar-Railt/satori-python" 35 | repository = "https://github.com/RF-Tar-Railt/satori-python" 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 RF-Tar-Railt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /example/client.py: -------------------------------------------------------------------------------- 1 | from satori import EventType, Upload 2 | from satori.client import Account, App, WebsocketsInfo 3 | from satori.event import Event, MessageEvent 4 | 5 | app = App(WebsocketsInfo(port=12345, path="foo")) 6 | 7 | 8 | @app.register 9 | async def _(account: Account, event: Event): 10 | print(event.sn) # noqa: T201 11 | 12 | 13 | @app.register_on(EventType.MESSAGE_CREATED) 14 | async def on_message(account: Account, event: MessageEvent): 15 | print(event.message) # noqa: T201 16 | if event.user and event.user.id == "9876543210": 17 | print(await account.channel_get(event.channel.id)) # noqa: T201 18 | print(await account.send_message(event.channel, "Hello, World!")) # noqa: T201 19 | print( # noqa: T201 20 | res := await account.upload( 21 | Upload( 22 | b"1234", 23 | name="foo.png", 24 | ) 25 | ) 26 | ) 27 | print(await account.download(res[0])) # noqa: T201 28 | 29 | 30 | @app.lifecycle 31 | async def record(account: Account, state): 32 | print(account, state) # noqa: T201 33 | 34 | 35 | app.run() 36 | -------------------------------------------------------------------------------- /src/satori/event.py: -------------------------------------------------------------------------------- 1 | from satori.model import ( 2 | ArgvInteraction, 3 | ButtonInteraction, 4 | Channel, 5 | Event, 6 | Guild, 7 | LoginPartial, 8 | Member, 9 | MessageObject, 10 | Role, 11 | User, 12 | ) 13 | 14 | 15 | class MessageEvent(Event): 16 | channel: Channel 17 | member: Member 18 | message: MessageObject 19 | user: User 20 | 21 | 22 | class UserEvent(Event): 23 | user: User 24 | 25 | 26 | class GuildEvent(Event): 27 | guild: Guild 28 | 29 | 30 | class GuildMemberEvent(GuildEvent): 31 | user: User 32 | member: Member 33 | 34 | 35 | class GuildRoleEvent(GuildEvent): 36 | role: Role 37 | 38 | 39 | class LoginEvent(Event): 40 | login: LoginPartial 41 | 42 | 43 | class ReactionEvent(Event): 44 | channel: Channel 45 | user: User 46 | message: MessageObject 47 | 48 | 49 | class ButtonInteractionEvent(Event): 50 | button: ButtonInteraction 51 | user: User 52 | channel: Channel 53 | 54 | 55 | class ArgvInteractionEvent(Event): 56 | argv: ArgvInteraction 57 | user: User 58 | channel: Channel 59 | 60 | 61 | class InternalEvent(Event): 62 | _type: str 63 | _data: dict 64 | -------------------------------------------------------------------------------- /.mina/server.toml: -------------------------------------------------------------------------------- 1 | includes = ["src/satori/server"] 2 | raw-dependencies = ["satori-python-core >= 0.17.0"] 3 | 4 | [project] 5 | name = "satori-python-server" 6 | dynamic = ["version"] 7 | authors = [ 8 | {name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com"} 9 | ] 10 | dependencies = [ 11 | "aiohttp", 12 | "launart", 13 | "graia-amnesia[uvicorn]", 14 | "starlette[python-multipart]", 15 | "websockets", 16 | "python-multipart", 17 | ] 18 | description = "Satori Protocol SDK for python, specify server part" 19 | license = {text = "MIT"} 20 | readme = "README.md" 21 | requires-python = ">=3.10,<4.0" 22 | classifiers = [ 23 | "Typing :: Typed", 24 | "Development Status :: 4 - Beta", 25 | "License :: OSI Approved :: MIT License", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: 3.12", 31 | "Operating System :: OS Independent", 32 | ] 33 | 34 | [project.urls] 35 | homepage = "https://github.com/RF-Tar-Railt/satori-python" 36 | repository = "https://github.com/RF-Tar-Railt/satori-python" 37 | -------------------------------------------------------------------------------- /src/satori/server/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable 3 | 4 | from starlette.requests import Request as StarletteRequest 5 | from starlette.responses import Response 6 | 7 | from satori.const import Api 8 | from satori.model import Login 9 | 10 | if TYPE_CHECKING: 11 | from .route import RouteCall 12 | 13 | TA = TypeVar("TA", str, Api) 14 | TP = TypeVar("TP") 15 | 16 | 17 | @dataclass 18 | class Request(Generic[TP]): 19 | origin: StarletteRequest 20 | action: str 21 | params: TP 22 | platform: str 23 | self_id: str 24 | 25 | 26 | @runtime_checkable 27 | class Provider(Protocol): 28 | async def get_logins(self) -> list[Login]: ... 29 | 30 | @staticmethod 31 | def proxy_urls() -> list[str]: ... 32 | 33 | def ensure(self, platform: str, self_id: str) -> bool: ... 34 | 35 | async def handle_internal(self, request: Request, path: str) -> Response: ... 36 | 37 | async def handle_proxied(self, prefix: str, url: str) -> Response | None: ... 38 | 39 | 40 | @runtime_checkable 41 | class Router(Protocol): 42 | routes: dict[str, "RouteCall[Any, Any]"] 43 | 44 | 45 | @dataclass 46 | class WebhookEndpoint: 47 | url: str 48 | token: str | None = None 49 | timeout: float | None = None 50 | -------------------------------------------------------------------------------- /src/satori/client/network/util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, overload 2 | 3 | from aiohttp import ClientResponse 4 | 5 | from satori.exception import ( 6 | BadRequestException, 7 | ForbiddenException, 8 | MethodNotAllowedException, 9 | NotFoundException, 10 | ServerException, 11 | UnauthorizedException, 12 | ) 13 | from satori.utils import decode 14 | 15 | 16 | @overload 17 | async def validate_response(resp: ClientResponse) -> dict: ... 18 | 19 | 20 | @overload 21 | async def validate_response(resp: ClientResponse, noreturn: Literal[True]) -> None: ... 22 | 23 | 24 | async def validate_response(resp: ClientResponse, noreturn=False): 25 | if 200 <= resp.status < 300: 26 | if noreturn: 27 | return 28 | return decode(content) if (content := await resp.text()) else {} 29 | elif resp.status == 400: 30 | raise BadRequestException(await resp.text()) 31 | elif resp.status == 401: 32 | raise UnauthorizedException(await resp.text()) 33 | elif resp.status == 403: 34 | raise ForbiddenException(await resp.text()) 35 | elif resp.status == 404: 36 | raise NotFoundException(await resp.text()) 37 | elif resp.status == 405: 38 | raise MethodNotAllowedException(await resp.text()) 39 | elif resp.status >= 500: 40 | raise ServerException(await resp.text()) 41 | else: 42 | resp.raise_for_status() 43 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Full Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | version: 7 | description: 'Release Version' 8 | required: true 9 | push: 10 | tags: 11 | - v* 12 | 13 | jobs: 14 | release: 15 | runs-on: ubuntu-latest 16 | permissions: 17 | id-token: write 18 | contents: write 19 | steps: 20 | - uses: actions/checkout@v5 21 | 22 | - name: Setup Python environment 23 | uses: ./.github/actions/setup-python 24 | 25 | - name: Get Version 26 | id: version 27 | run: | 28 | echo "VERSION=$(pdm show --version -q)" >> $GITHUB_OUTPUT 29 | if [[ "${{ github.event_name }}" == 'push' ]]; then 30 | echo "TAG_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT 31 | echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT 32 | else 33 | input_version=${{ github.event.inputs.version }} 34 | echo "TAG_VERSION=${input_version#v}" >> $GITHUB_OUTPUT 35 | echo "TAG_NAME=${input_version}" >> $GITHUB_OUTPUT 36 | fi 37 | 38 | - name: Check Version 39 | if: steps.version.outputs.VERSION != steps.version.outputs.TAG_VERSION 40 | run: exit 1 41 | 42 | - name: Build Package 43 | run: | 44 | pdm build 45 | 46 | - name: Publish Package 47 | run: | 48 | pdm publish --no-build 49 | gh release upload --clobber ${{ steps.version.outputs.TAG_NAME }} dist/*.tar.gz dist/*.whl 50 | env: 51 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 52 | -------------------------------------------------------------------------------- /src/satori/server/connection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | 5 | from loguru import logger 6 | from starlette.websockets import WebSocket, WebSocketDisconnect 7 | 8 | from satori.model import Opcode 9 | from satori.utils import decode, encode 10 | 11 | 12 | class WebsocketConnection: 13 | connection: WebSocket 14 | 15 | def __init__(self, connection: WebSocket): 16 | self.connection = connection 17 | self.close_signal: asyncio.Event = asyncio.Event() 18 | 19 | @property 20 | def alive(self) -> bool: 21 | return not self.close_signal.is_set() 22 | 23 | async def heartbeat(self): 24 | while True: 25 | try: 26 | msg = await asyncio.wait_for(self.connection.receive_text(), timeout=12) 27 | msg = decode(msg) 28 | if not isinstance(msg, dict) or msg.get("op") != Opcode.PING: 29 | continue 30 | await self.connection.send_text(encode({"op": Opcode.PONG})) 31 | except asyncio.TimeoutError: 32 | logger.warning(f"Connection {id(self):x} heartbeat timeout, closing connection.") 33 | await self.connection.close() 34 | await self.connection_closed() 35 | break 36 | except WebSocketDisconnect: 37 | return 38 | 39 | async def connection_closed(self): 40 | self.close_signal.set() 41 | 42 | async def wait_for_available(self): 43 | return 44 | 45 | async def send(self, payload: dict) -> None: 46 | return await self.connection.send_text(encode(payload)) 47 | -------------------------------------------------------------------------------- /src/satori/server/adapter.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import TYPE_CHECKING, Union 3 | 4 | from launart import Service 5 | from starlette.responses import Response 6 | from starlette.routing import BaseRoute 7 | 8 | from satori.model import Login, LoginPartial 9 | 10 | from .model import Request 11 | from .route import RouterMixin 12 | from .utils import ctx 13 | 14 | if TYPE_CHECKING: 15 | from . import Server 16 | 17 | 18 | LoginType = Union[Login, LoginPartial] # noqa: UP007 19 | 20 | 21 | class Adapter(Service, RouterMixin): 22 | server: "Server" 23 | 24 | @abstractmethod 25 | def get_platform(self) -> str: 26 | """该方法仅用于自动标识适配器类型 27 | 28 | 若你继承该类并且重写了 `id`, 该方法可以返回任意字符串 29 | """ 30 | 31 | @abstractmethod 32 | def ensure(self, platform: str, self_id: str) -> bool: ... 33 | 34 | @staticmethod 35 | def proxy_urls() -> list[str]: 36 | return [] 37 | 38 | @abstractmethod 39 | async def handle_internal(self, request: Request, path: str) -> Response: ... 40 | 41 | async def handle_proxied(self, prefix: str, url: str) -> Response | None: 42 | async with self.server.session.get(url, ssl=ctx) as resp: 43 | return Response(await resp.read()) 44 | 45 | @abstractmethod 46 | async def get_logins(self) -> list[LoginType]: ... 47 | 48 | def __init__(self): 49 | super().__init__() 50 | self.routes = {} 51 | 52 | @property 53 | def id(self): 54 | return f"satori-python.adapter.{self.get_platform()}#{id(self)}" 55 | 56 | def ensure_server(self, server: "Server"): 57 | self.server = server 58 | 59 | def get_routes(self) -> list[BaseRoute]: 60 | """return extra routes that will mount to the server root""" 61 | return [] 62 | -------------------------------------------------------------------------------- /src/satori/__init__.py: -------------------------------------------------------------------------------- 1 | from .const import Api as Api 2 | from .const import EventType as EventType 3 | from .element import At as At 4 | from .element import Audio as Audio 5 | from .element import Author as Author 6 | from .element import Bold as Bold 7 | from .element import Button as Button 8 | from .element import Code as Code 9 | from .element import E as E 10 | from .element import Element as Element 11 | from .element import File as File 12 | from .element import Image as Image 13 | from .element import Italic as Italic 14 | from .element import Link as Link 15 | from .element import Message as Message 16 | from .element import Paragraph as Paragraph 17 | from .element import Quote as Quote 18 | from .element import Sharp as Sharp 19 | from .element import Spoiler as Spoiler 20 | from .element import Strikethrough as Strikethrough 21 | from .element import Subscript as Subscript 22 | from .element import Superscript as Superscript 23 | from .element import Text as Text 24 | from .element import Underline as Underline 25 | from .element import Video as Video 26 | from .element import register_element as register_element 27 | from .element import select as select 28 | from .element import transform as transform 29 | from .model import ArgvInteraction as ArgvInteraction 30 | from .model import ButtonInteraction as ButtonInteraction 31 | from .model import Channel as Channel 32 | from .model import ChannelType as ChannelType 33 | from .model import Event as Event 34 | from .model import Guild as Guild 35 | from .model import Login as Login 36 | from .model import LoginStatus as LoginStatus 37 | from .model import Member as Member 38 | from .model import MessageObject as MessageObject 39 | from .model import PageDequeResult as PageDequeResult 40 | from .model import PageResult as PageResult 41 | from .model import Role as Role 42 | from .model import Upload as Upload 43 | from .model import User as User 44 | 45 | __version__ = "0.17.7" 46 | -------------------------------------------------------------------------------- /src/satori/client/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from yarl import URL 4 | 5 | 6 | class Config: 7 | @property 8 | def identity(self) -> str: 9 | raise NotImplementedError 10 | 11 | @property 12 | def token(self) -> str | None: 13 | raise NotImplementedError 14 | 15 | @property 16 | def api_base(self) -> URL: 17 | raise NotImplementedError 18 | 19 | 20 | @dataclass 21 | class WebsocketsInfo(Config): 22 | host: str = "localhost" 23 | port: int = 5140 24 | path: str = "" 25 | token: str | None = None 26 | timeout: float | None = None 27 | identity: str = None # type: ignore 28 | api_base: URL = None # type: ignore 29 | 30 | def __post_init__(self): 31 | if self.path and not self.path.startswith("/"): 32 | self.path = f"/{self.path}" 33 | self.identity = f"{self.host}:{self.port}" 34 | self.api_base = URL(f"http://{self.host}:{self.port}{self.path}") / "v1" 35 | 36 | @property 37 | def ws_base(self): 38 | return URL(f"ws://{self.host}:{self.port}{self.path}") / "v1" 39 | 40 | 41 | @dataclass 42 | class WebhookInfo(Config): 43 | host: str = "127.0.0.1" 44 | port: int = 8080 45 | path: str = "v1/events" 46 | token: str | None = None 47 | server_host: str = "localhost" 48 | server_port: int = 5140 49 | server_path: str = "" 50 | timeout: float | None = None 51 | identity: str = None # type: ignore 52 | api_base: URL = None # type: ignore 53 | 54 | def __post_init__(self): 55 | if self.path and not self.path.startswith("/"): 56 | self.path = f"/{self.path}" 57 | if self.server_path and not self.server_path.startswith("/"): 58 | self.server_path = f"/{self.server_path}" 59 | self.identity = f"{self.host}:{self.port}{self.path}" 60 | self.api_base = URL(f"http://{self.server_host}:{self.server_port}{self.server_path}") / "v1" 61 | -------------------------------------------------------------------------------- /exam.py: -------------------------------------------------------------------------------- 1 | from satori import At, Link, Sharp, Text, transform 2 | from satori.parser import parse 3 | 4 | print(transform(parse(' foobar'))) 5 | 6 | a = Text("1234") 7 | role = At.role_("admin") 8 | chl = Sharp("abcd") 9 | link = Link("www.baidu.com") 10 | link1 = Link("github.com/RF-Tar-Railt/satori-python")("satori-python") 11 | print(a) 12 | print(role) 13 | print(chl) 14 | print(link) 15 | print(link1) 16 | 17 | from satori import Image, Video 18 | 19 | image = Image.of(url="https://example.com/image.png", name="image.png") 20 | print(image) 21 | image1 = Image.of(raw=b"\x89PNG\r\n\x1a\n...", name="image.png") 22 | print(image1) 23 | print( 24 | repr(Video.unpack( 25 | { 26 | "src": "https://example.com/video.mp4", 27 | "title": "video.mp4", 28 | "width": "123", 29 | "height": "456", 30 | "poster": "https://example.com/poster.png", 31 | } 32 | )) 33 | ) 34 | 35 | from satori import Bold, Italic, Paragraph, Underline 36 | 37 | text = Bold("hello", Italic("world,"), Underline()("Satori!"), Paragraph("This is a paragraph.")) 38 | print(text) 39 | 40 | from satori import Author, Message 41 | 42 | message = Message(forward=True)( 43 | Message(id="123456789"), 44 | Message(id="987654321"), 45 | Message(content=[Author(id="123456789"), "Hello, "]), 46 | Message()(Author(id="123456789"), "World!"), 47 | ) 48 | print(message) 49 | 50 | from satori import E 51 | 52 | print(E("", {"id": "123456789"})) 53 | 54 | from satori import register_element 55 | from satori.element import Element 56 | from dataclasses import dataclass 57 | 58 | 59 | @dataclass(repr=False) 60 | class QQPassive(Element): 61 | id: str 62 | 63 | @property 64 | def tag(self) -> str: 65 | return "qq:passive" 66 | 67 | 68 | register_element(QQPassive, "qq:passive") 69 | 70 | print(QQPassive(id="123456789")) 71 | print(transform(parse(""))) 72 | -------------------------------------------------------------------------------- /example/adapter.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | 4 | from launart import Launart, any_completed 5 | 6 | from satori import Api, Channel, ChannelType, Event, User 7 | from satori.model import Login, LoginStatus, MessageObject 8 | from satori.server import Adapter, Request, route 9 | 10 | 11 | class ExampleAdapter(Adapter): 12 | async def handle_internal(self, request: Request, path: str): ... 13 | 14 | @property 15 | def required(self): 16 | return set() 17 | 18 | @property 19 | def stages(self): 20 | return {"blocking"} 21 | 22 | def get_platform(self) -> str: 23 | return "example" 24 | 25 | def ensure(self, platform: str, self_id: str) -> bool: 26 | return platform == self.get_platform() and self_id == "1234567890" 27 | 28 | async def get_logins(self): 29 | return [Login(0, LoginStatus.ONLINE, "test", "example", User("1234567890"))] 30 | 31 | def __init__(self): 32 | super().__init__() 33 | 34 | @self.route(Api.MESSAGE_CREATE) 35 | async def _(request: Request[route.MessageParam]): 36 | return [MessageObject("1234", request.params["content"])] 37 | 38 | async def publish(self): 39 | seq = 0 40 | while True: 41 | await asyncio.sleep(2) 42 | await self.server.post( 43 | Event( 44 | "message-created", 45 | datetime.now(), 46 | (await self.get_logins())[0], 47 | channel=Channel("345678", ChannelType.TEXT), 48 | user=User("9876543210"), 49 | message=MessageObject(f"msg_{seq}", "test"), 50 | ) 51 | ) 52 | seq += 1 53 | 54 | async def launch(self, manager: Launart): 55 | async with self.stage("blocking"): 56 | event_task = asyncio.create_task(self.publish()) 57 | exit_task = asyncio.create_task(manager.status.wait_for_sigexit()) 58 | await any_completed(event_task, exit_task) 59 | -------------------------------------------------------------------------------- /example/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | 4 | from satori import Api, Channel, ChannelType, Event, Login, LoginStatus, MessageObject, Text, User 5 | from satori.server import Request, Server, route 6 | 7 | server = Server(host="localhost", port=12345, path="foo") 8 | 9 | 10 | class ExampleProvider: 11 | @property 12 | def id(self): 13 | return "example" 14 | 15 | @staticmethod 16 | def proxy_urls(): 17 | return [] 18 | 19 | def ensure(self, platform: str, self_id: str) -> bool: 20 | return platform == "example" and self_id == "1234567890" 21 | 22 | async def handle_internal(self, request: Request, path: str): 23 | raise NotImplementedError 24 | 25 | async def handle_proxied(self, prefix: str, url: str): 26 | raise NotImplementedError 27 | 28 | async def get_logins(self): 29 | return [Login(0, LoginStatus.ONLINE, "test", "example", User("1234567890"))] 30 | 31 | async def publisher(self): 32 | seq = 0 33 | while True: 34 | await asyncio.sleep(2) 35 | yield Event( 36 | "message-created", 37 | datetime.now(), 38 | (await self.get_logins())[0], 39 | channel=Channel("345678", ChannelType.TEXT), 40 | user=User("9876543210"), 41 | message=MessageObject(f"msg_{seq}", "123"), 42 | ) 43 | seq += 1 44 | 45 | 46 | server.apply(ExampleProvider()) 47 | 48 | sent = True 49 | 50 | 51 | @server.route(Api.CHANNEL_GET) 52 | async def handle1(request: Request[route.ChannelParam]): 53 | global sent 54 | 55 | async def _(): 56 | await asyncio.sleep(5) 57 | await server.connections[0].connection.close() 58 | 59 | if not sent: 60 | _t = asyncio.create_task(_()) 61 | sent = True 62 | return Channel("1234567890", ChannelType.TEXT, "test").dump() 63 | 64 | 65 | @server.route(Api.MESSAGE_CREATE) 66 | async def handle2(request: Request[route.MessageParam]): 67 | return [MessageObject.from_elements("1234", [Text("example")])] 68 | 69 | 70 | server.run() 71 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/events/request.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | 5 | from satori import EventType 6 | from satori.exception import ActionFailed 7 | from satori.model import Channel, ChannelType, Event, Guild, Login, Member, MessageObject, User 8 | 9 | from ..utils import GROUP_AVATAR_URL, USER_AVATAR_URL, OneBotNetwork 10 | from .base import register_event 11 | 12 | 13 | @register_event("request.friend") 14 | async def request_friend(login: Login, adapter: OneBotNetwork, raw: dict) -> Event: 15 | info = await adapter.call_api("get_stranger_info", {"user_id": raw["user_id"]}) 16 | user_id = str(raw["user_id"]) 17 | user = User(user_id, info["nickname"], info.get("card"), avatar=USER_AVATAR_URL.format(uin=user_id)) 18 | channel = Channel(f"private:{user_id}", ChannelType.DIRECT) 19 | return Event( 20 | EventType.FRIEND_REQUEST, 21 | datetime.fromtimestamp(raw["time"]), 22 | login, 23 | user=user, 24 | channel=channel, 25 | message=MessageObject(raw["flag"], raw.get("comment", "")), 26 | ) 27 | 28 | 29 | @register_event("request.group.invite") 30 | @register_event("request.group.add") 31 | async def request_group_invite(login: Login, adapter: OneBotNetwork, raw: dict) -> Event: 32 | group_id = str(raw["group_id"]) 33 | try: 34 | group_info = await adapter.call_api("get_group_info", {"group_id": raw["group_id"]}) 35 | except ActionFailed: 36 | group_info = {} 37 | guild = Guild(group_id, group_info.get("group_name"), avatar=GROUP_AVATAR_URL.format(group=group_id)) 38 | channel = Channel(group_id, name=group_info.get("group_name")) 39 | info = await adapter.call_api("get_stranger_info", {"user_id": raw["user_id"]}) 40 | user_id = str(raw["user_id"]) 41 | user = User(user_id, info["nickname"], info.get("card"), avatar=USER_AVATAR_URL.format(uin=user_id)) 42 | return Event( 43 | EventType.GUILD_REQUEST if raw["sub_type"] == "invite" else EventType.GUILD_MEMBER_ADDED, 44 | datetime.fromtimestamp(raw["time"]), 45 | login, 46 | user=user, 47 | member=Member(user, avatar=USER_AVATAR_URL.format(uin=user_id)), 48 | channel=channel, 49 | guild=guild, 50 | message=MessageObject(raw["flag"], raw.get("comment", "")), 51 | ) 52 | -------------------------------------------------------------------------------- /src/satori/const.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Api(str, Enum): 5 | MESSAGE_CREATE = "message.create" 6 | MESSAGE_UPDATE = "message.update" 7 | MESSAGE_GET = "message.get" 8 | MESSAGE_DELETE = "message.delete" 9 | MESSAGE_LIST = "message.list" 10 | 11 | CHANNEL_GET = "channel.get" 12 | CHANNEL_LIST = "channel.list" 13 | CHANNEL_CREATE = "channel.create" 14 | CHANNEL_UPDATE = "channel.update" 15 | CHANNEL_DELETE = "channel.delete" 16 | CHANNEL_MUTE = "channel.mute" 17 | USER_CHANNEL_CREATE = "user.channel.create" 18 | 19 | GUILD_GET = "guild.get" 20 | GUILD_LIST = "guild.list" 21 | GUILD_APPROVE = "guild.approve" 22 | 23 | GUILD_MEMBER_LIST = "guild.member.list" 24 | GUILD_MEMBER_GET = "guild.member.get" 25 | GUILD_MEMBER_KICK = "guild.member.kick" 26 | GUILD_MEMBER_MUTE = "guild.member.mute" 27 | GUILD_MEMBER_APPROVE = "guild.member.approve" 28 | GUILD_MEMBER_ROLE_SET = "guild.member.role.set" 29 | GUILD_MEMBER_ROLE_UNSET = "guild.member.role.unset" 30 | 31 | GUILD_ROLE_LIST = "guild.role.list" 32 | GUILD_ROLE_CREATE = "guild.role.create" 33 | GUILD_ROLE_UPDATE = "guild.role.update" 34 | GUILD_ROLE_DELETE = "guild.role.delete" 35 | 36 | REACTION_CREATE = "reaction.create" 37 | REACTION_DELETE = "reaction.delete" 38 | REACTION_CLEAR = "reaction.clear" 39 | REACTION_LIST = "reaction.list" 40 | 41 | LOGIN_GET = "login.get" 42 | 43 | USER_GET = "user.get" 44 | FRIEND_LIST = "friend.list" 45 | FRIEND_APPROVE = "friend.approve" 46 | 47 | UPLOAD_CREATE = "upload.create" 48 | 49 | 50 | class EventType(str, Enum): 51 | FRIEND_REQUEST = "friend-request" 52 | GUILD_ADDED = "guild-added" 53 | GUILD_MEMBER_ADDED = "guild-member-added" 54 | GUILD_MEMBER_REMOVED = "guild-member-removed" 55 | GUILD_MEMBER_REQUEST = "guild-member-request" 56 | GUILD_MEMBER_UPDATED = "guild-member-updated" 57 | GUILD_REMOVED = "guild-removed" 58 | GUILD_REQUEST = "guild-request" 59 | GUILD_ROLE_CREATED = "guild-role-created" 60 | GUILD_ROLE_DELETED = "guild-role-deleted" 61 | GUILD_ROLE_UPDATED = "guild-role-updated" 62 | GUILD_UPDATED = "guild-updated" 63 | LOGIN_ADDED = "login-added" 64 | LOGIN_REMOVED = "login-removed" 65 | LOGIN_UPDATED = "login-updated" 66 | MESSAGE_CREATED = "message-created" 67 | MESSAGE_DELETED = "message-deleted" 68 | MESSAGE_UPDATED = "message-updated" 69 | REACTION_ADDED = "reaction-added" 70 | REACTION_REMOVED = "reaction-removed" 71 | INTERNAL = "internal" 72 | INTERACTION_BUTTON = "interaction/button" 73 | INTERACTION_COMMAND = "interaction/command" 74 | -------------------------------------------------------------------------------- /src/satori/adapters/console/main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from launart import Launart 4 | from launart.status import Phase 5 | from nonechat import ConsoleSetting, Frontend 6 | from starlette.responses import JSONResponse, Response 7 | 8 | from satori.server import Adapter as BaseAdapter 9 | from satori.server import Request 10 | from satori.server.adapter import LoginType 11 | 12 | from .api import apply 13 | from .backend import SatoriConsoleBackend 14 | 15 | 16 | class ConsoleAdapter(BaseAdapter): 17 | def __init__(self, logger_id: int = -1, **kwargs): 18 | super().__init__() 19 | self.app = Frontend( 20 | SatoriConsoleBackend, 21 | ConsoleSetting(**kwargs), 22 | ) 23 | self.app.backend.set_adapter(self) 24 | self._logger_id = logger_id 25 | apply(self) 26 | 27 | @property 28 | def id(self): 29 | return f"satori-python.adapter.console#{id(self)}" 30 | 31 | def get_platform(self) -> str: 32 | return "console" 33 | 34 | def ensure(self, platform: str, self_id: str) -> bool: 35 | return platform == "console" and self_id in self.app.backend.logins 36 | 37 | async def handle_internal(self, request: Request, path: str) -> Response: 38 | if path.startswith("_api"): 39 | api = path[5:] 40 | data = await request.origin.json() 41 | if api == "send_msg": 42 | await self.app.send_message(**data) 43 | return JSONResponse({}) 44 | if api == "bell": 45 | await self.app.toggle_bell() 46 | return JSONResponse({}) 47 | else: 48 | return Response(f"Unknown API: {api}", status_code=404) 49 | async with self.server.session.get(path) as resp: 50 | return Response(await resp.read()) 51 | 52 | async def get_logins(self) -> list[LoginType]: 53 | return list(self.app.backend.logins.values()) 54 | 55 | @property 56 | def required(self) -> set[str]: 57 | return { 58 | "satori-python.server", 59 | } 60 | 61 | @property 62 | def stages(self) -> set[Phase]: 63 | return {"preparing", "blocking", "cleanup"} 64 | 65 | async def launch(self, manager: Launart): 66 | 67 | async with self.stage("preparing"): 68 | ... 69 | 70 | async with self.stage("blocking"): 71 | task = asyncio.create_task(self.app.run_async()) 72 | await manager.status.wait_for_sigexit() 73 | 74 | async with self.stage("cleanup"): 75 | self.app.exit() 76 | if task: 77 | await task 78 | 79 | 80 | Adapter = ConsoleAdapter 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "satori-python" 3 | description = "Satori Protocol SDK for python" 4 | authors = [ 5 | {name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"}, 6 | ] 7 | dynamic = ["version"] 8 | dependencies = [ 9 | "aiohttp>=3.9.3", 10 | "loguru>=0.7.2", 11 | "launart>=0.8.2", 12 | "typing-extensions>=4.7.0", 13 | "graia-amnesia[uvicorn]<0.12.0,>=0.11.0", 14 | "yarl>=1.9.4", 15 | "python-multipart>=0.0.9", 16 | "websockets>=15.0.1", 17 | "starlette>=0.40.0", 18 | ] 19 | requires-python = ">=3.10,<4.0" 20 | readme = "README.md" 21 | license = {text = "MIT"} 22 | classifiers = [ 23 | "Typing :: Typed", 24 | "Development Status :: 4 - Beta", 25 | "License :: OSI Approved :: MIT License", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: 3.12", 31 | "Operating System :: OS Independent", 32 | ] 33 | 34 | [project.urls] 35 | homepage = "https://github.com/RF-Tar-Railt/satori-python" 36 | repository = "https://github.com/RF-Tar-Railt/satori-python" 37 | 38 | [project.optional-dependencies] 39 | msgspec = [ 40 | "msgspec>=0.19.0", 41 | ] 42 | [build-system] 43 | requires = ["mina-build<0.6,>=0.5.1", "pdm-backend<2.4.0"] 44 | build-backend = "mina.backend" 45 | 46 | [dependency-groups] 47 | dev = [ 48 | "isort>=5.13.2", 49 | "black>=24.4.0", 50 | "ruff>=0.4.1", 51 | "pre-commit>=3.7.0", 52 | "fix-future-annotations>=0.5.0", 53 | "mina-build<0.6,>=0.5.1", 54 | "pdm-mina>=0.3.2", 55 | "nonechat<0.7.0,>=0.6.0", 56 | "uvicorn[standard]>=0.37.0", 57 | ] 58 | 59 | 60 | [tool.pdm.build] 61 | includes = ["src/satori"] 62 | excludes = ["src/satori/adapters/*"] 63 | 64 | [tool.pdm.scripts] 65 | format = { composite = ["isort ./src/ ./example/","black ./src/ ./example/","ruff check"] } 66 | 67 | [tool.pdm.version] 68 | source = "file" 69 | path = "src/satori/__init__.py" 70 | 71 | [tool.black] 72 | line-length = 120 73 | include = '\.pyi?$' 74 | extend-exclude = ''' 75 | ''' 76 | 77 | [tool.isort] 78 | profile = "black" 79 | line_length = 120 80 | skip_gitignore = true 81 | extra_standard_library = ["typing_extensions"] 82 | 83 | [tool.ruff] 84 | line-length = 120 85 | target-version = "py310" 86 | exclude = ["exam_qps.py", "exam1.py", "exam2.py", "src/satori/_vendor/*"] 87 | 88 | [tool.ruff.lint] 89 | select = ["E", "W", "F", "UP", "C", "T", "Q"] 90 | ignore = ["E402", "F403", "F405", "C901", "T201"] 91 | 92 | [tool.pyright] 93 | pythonPlatform = "All" 94 | pythonVersion = "3.10" 95 | typeCheckingMode = "basic" 96 | reportShadowedImports = false 97 | disableBytesTypePromotions = true -------------------------------------------------------------------------------- /src/satori/adapters/milky/events/group.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | 5 | from satori import EventType 6 | from satori.model import Channel, ChannelType, Event, Guild, Member, User 7 | 8 | from ..utils import group_avatar, user_avatar 9 | from .base import register_event 10 | 11 | 12 | @register_event("group_member_increase") 13 | async def group_member_increase(login, net, raw): 14 | data = raw["data"] 15 | guild_id = str(data["group_id"]) 16 | guild = Guild(guild_id, avatar=group_avatar(guild_id)) 17 | channel = Channel(guild_id, ChannelType.TEXT) 18 | user = User(str(data["user_id"]), avatar=user_avatar(data["user_id"])) 19 | member = Member(user, avatar=user.avatar) 20 | operator = None 21 | if data.get("operator_id"): 22 | operator = User(str(data["operator_id"]), avatar=user_avatar(data["operator_id"])) 23 | return Event( 24 | EventType.GUILD_MEMBER_ADDED, 25 | datetime.fromtimestamp(raw["time"]), 26 | login, 27 | guild=guild, 28 | channel=channel, 29 | user=user, 30 | member=member, 31 | operator=operator, 32 | ) 33 | 34 | 35 | @register_event("group_member_decrease") 36 | async def group_member_decrease(login, net, raw): 37 | data = raw["data"] 38 | guild_id = str(data["group_id"]) 39 | guild = Guild(guild_id, avatar=group_avatar(guild_id)) 40 | channel = Channel(guild_id, ChannelType.TEXT) 41 | user = User(str(data["user_id"]), avatar=user_avatar(data["user_id"])) 42 | member = Member(user, avatar=user.avatar) 43 | operator = None 44 | if data.get("operator_id"): 45 | operator = User(str(data["operator_id"]), avatar=user_avatar(data["operator_id"])) 46 | return Event( 47 | EventType.GUILD_MEMBER_REMOVED, 48 | datetime.fromtimestamp(raw["time"]), 49 | login, 50 | guild=guild, 51 | channel=channel, 52 | user=user, 53 | member=member, 54 | operator=operator, 55 | ) 56 | 57 | 58 | @register_event("group_name_change") 59 | async def group_name_change(login, net, raw): 60 | data = raw["data"] 61 | guild_id = str(data["group_id"]) 62 | guild = Guild(guild_id, name=data.get("new_group_name"), avatar=group_avatar(guild_id)) 63 | channel = Channel(guild_id, ChannelType.TEXT, name=data.get("new_group_name")) 64 | operator = None 65 | if data.get("operator_id"): 66 | operator = User(str(data["operator_id"]), avatar=user_avatar(data["operator_id"])) 67 | return Event( 68 | EventType.GUILD_UPDATED, 69 | datetime.fromtimestamp(raw["time"]), 70 | login, 71 | guild=guild, 72 | channel=channel, 73 | operator=operator, 74 | ) 75 | -------------------------------------------------------------------------------- /src/satori/client/account.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from dataclasses import dataclass, field 5 | from typing_extensions import Generic, TypeVar # noqa: UP035 6 | 7 | from yarl import URL 8 | 9 | from satori.model import Login 10 | 11 | from .protocol import ApiProtocol 12 | 13 | TP = TypeVar("TP", bound="ApiProtocol", default=ApiProtocol, covariant=True) 14 | TP1 = TypeVar("TP1", bound="ApiProtocol", default=ApiProtocol, covariant=True) 15 | 16 | 17 | @dataclass 18 | class ApiInfo: 19 | host: str = "localhost" 20 | port: int = 5140 21 | path: str = "" 22 | token: str | None = None 23 | timeout: float | None = None 24 | api_base: URL = field(init=False) 25 | 26 | def __post_init__(self): 27 | if self.path and not self.path.startswith("/"): 28 | self.path = f"/{self.path}" 29 | self.api_base = URL(f"http://{self.host}:{self.port}{self.path}") / "v1" 30 | 31 | 32 | class Account(Generic[TP]): 33 | def __init__( 34 | self, 35 | login: Login, 36 | config: ApiInfo, 37 | proxy_urls: list[str], 38 | protocol_cls: type[TP] = ApiProtocol, 39 | ): 40 | self.adapter = login.adapter 41 | self.self_info = login 42 | self.config = config 43 | self.proxy_urls = proxy_urls 44 | self.protocol = protocol_cls(self) # type: ignore 45 | self.connected = asyncio.Event() 46 | 47 | @property 48 | def platform(self): 49 | return self.self_info.platform or "satori" 50 | 51 | @property 52 | def self_id(self): 53 | return self.self_info.user.id 54 | 55 | def custom(self, config: ApiInfo | None = None, protocol_cls: type[TP1] = ApiProtocol, **kwargs) -> Account[TP1]: 56 | return Account( 57 | self.self_info, 58 | config or (ApiInfo(**kwargs) if kwargs else self.config), 59 | self.proxy_urls, 60 | protocol_cls, 61 | ) 62 | 63 | def ensure_url(self, url: str) -> URL: 64 | """确定链接形式。 65 | 66 | 若链接符合以下条件之一,则返回链接的代理形式 ({host}/{path}/{version}/proxy/{url}): 67 | - 链接以 "internal:" 开头 68 | - 链接开头出现在 proxy_urls 中的某一项 69 | """ 70 | if url.startswith("internal:"): 71 | return self.config.api_base / "proxy" / url.lstrip("/") 72 | for proxy_url in self.proxy_urls: 73 | if url.startswith(proxy_url): 74 | return self.config.api_base / "proxy" / url.lstrip("/") 75 | return URL(url) 76 | 77 | def __repr__(self): 78 | return f"" 79 | 80 | def __getattr__(self, item): 81 | if hasattr(self.protocol, item): 82 | return getattr(self.protocol, item) 83 | raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") 84 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/events/message.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | 5 | from satori import EventType 6 | from satori.model import Channel, ChannelType, Event, Guild, MessageObject, User 7 | 8 | from ..message import decode_message 9 | from ..utils import group_avatar, user_avatar 10 | from .base import register_event 11 | 12 | 13 | @register_event("message_receive") 14 | async def message_receive(login, net, raw): 15 | message = await decode_message(net, raw["data"]) 16 | return Event( 17 | EventType.MESSAGE_CREATED, 18 | datetime.fromtimestamp(raw["time"]), 19 | login, 20 | channel=message.channel, 21 | guild=message.guild, 22 | member=message.member, 23 | user=message.user, 24 | message=message, 25 | ) 26 | 27 | 28 | @register_event("message_recall") 29 | async def message_recall(login, net, raw): 30 | data = raw["data"] 31 | scene = data["message_scene"] 32 | peer_id = str(data["peer_id"]) 33 | if scene == "group": 34 | channel = Channel(peer_id, ChannelType.TEXT) 35 | guild = Guild(peer_id, avatar=group_avatar(peer_id)) 36 | elif scene == "temp": 37 | channel = Channel(f"private:temp_{peer_id}", ChannelType.DIRECT) 38 | guild = None 39 | else: 40 | channel = Channel(f"private:{peer_id}", ChannelType.DIRECT) 41 | guild = None 42 | user = User(str(data["sender_id"]), avatar=user_avatar(data["sender_id"])) 43 | operator = User(str(data["operator_id"]), avatar=user_avatar(data["operator_id"])) 44 | message = MessageObject(str(data["message_seq"]), "", channel=channel, guild=guild, user=user) 45 | return Event( 46 | EventType.MESSAGE_DELETED, 47 | datetime.fromtimestamp(raw["time"]), 48 | login, 49 | channel=channel, 50 | guild=guild, 51 | user=user, 52 | operator=operator, 53 | message=message, 54 | ) 55 | 56 | 57 | @register_event("group_message_reaction") 58 | async def group_message_reaction(login, net, raw): 59 | data = raw["data"] 60 | guild_id = str(data["group_id"]) 61 | guild = Guild(guild_id, avatar=group_avatar(guild_id)) 62 | channel = Channel(guild_id, ChannelType.TEXT) 63 | user = User(str(data["user_id"]), avatar=user_avatar(data["user_id"])) 64 | face_id = data["face_id"] 65 | message = MessageObject( 66 | str(data["message_seq"]), f"", channel=channel, guild=guild, user=user 67 | ) 68 | if data["is_add"]: 69 | event_type = EventType.REACTION_ADDED 70 | else: 71 | event_type = EventType.REACTION_REMOVED 72 | return Event( 73 | event_type, 74 | datetime.fromtimestamp(raw["time"]), 75 | login, 76 | channel=channel, 77 | guild=guild, 78 | user=user, 79 | message=message, 80 | ) 81 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import Literal, Protocol 5 | 6 | from satori.model import Channel, ChannelType, Guild, Member, User 7 | 8 | AVATAR_URL = "https://q.qlogo.cn/headimg_dl?dst_uin={uin}&spec=640" 9 | GROUP_AVATAR_URL = "https://p.qlogo.cn/gh/{group}/{group}/640" 10 | 11 | 12 | class MilkyNetwork(Protocol): 13 | async def call_api(self, action: str, params: dict | None = None) -> dict: ... 14 | 15 | 16 | def user_avatar(uin: int | str) -> str: 17 | return AVATAR_URL.format(uin=uin) 18 | 19 | 20 | def group_avatar(group_id: int | str) -> str: 21 | return GROUP_AVATAR_URL.format(group=group_id) 22 | 23 | 24 | def decode_group_channel(group: dict) -> Channel: 25 | return Channel(str(group["group_id"]), ChannelType.TEXT, group.get("group_name")) 26 | 27 | 28 | def decode_private_channel(profile: dict, channel_id: str) -> Channel: 29 | return Channel(channel_id, ChannelType.DIRECT, profile.get("nickname")) 30 | 31 | 32 | def decode_guild(group: dict) -> Guild: 33 | return Guild(str(group["group_id"]), group.get("group_name"), group_avatar(group["group_id"])) 34 | 35 | 36 | def decode_member(member: dict) -> Member: 37 | user_id = str(member["user_id"]) 38 | user = User(user_id, member.get("nickname"), avatar=user_avatar(user_id)) 39 | joined_at = member.get("join_time") 40 | return Member( 41 | user=user, 42 | nick=member.get("card") or member.get("nickname"), 43 | avatar=user_avatar(user_id), 44 | joined_at=datetime.fromtimestamp(joined_at) if joined_at else None, 45 | ) 46 | 47 | 48 | def decode_friend(friend: dict) -> User: 49 | user_id = str(friend["user_id"]) 50 | return User(user_id, friend.get("nickname"), avatar=user_avatar(user_id)) 51 | 52 | 53 | def decode_login_user(login: dict) -> User: 54 | user_id = str(login["uin"]) 55 | return User(user_id, login.get("nickname"), avatar=user_avatar(user_id)) 56 | 57 | 58 | def decode_user_profile(profile: dict, user_id: str) -> User: 59 | return User(user_id, profile.get("nickname"), avatar=user_avatar(user_id), nick=profile.get("remark")) 60 | 61 | 62 | def decode_guild_channel_id(data: dict) -> tuple[str | None, str]: 63 | scene = data.get("message_scene") 64 | peer_id = str(data.get("peer_id")) 65 | if scene == "group": 66 | return peer_id, peer_id 67 | if scene == "temp": 68 | return None, f"private:temp_{peer_id}" 69 | return None, f"private:{peer_id}" 70 | 71 | 72 | def get_scene_and_peer(channel_id: str) -> tuple[Literal["friend", "group", "temp"], int]: 73 | if channel_id.startswith("private:temp_"): 74 | return "temp", int(channel_id.removeprefix("private:temp_")) 75 | if channel_id.startswith("private:"): 76 | return "friend", int(channel_id.removeprefix("private:")) 77 | return "group", int(channel_id) 78 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/events/request.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | 5 | from satori import EventType 6 | from satori.model import Channel, ChannelType, Event, Guild, Member, MessageObject, User 7 | 8 | from ..utils import group_avatar, user_avatar 9 | from .base import register_event 10 | 11 | 12 | @register_event("friend_request") 13 | async def friend_request(login, net, raw): 14 | data = raw["data"] 15 | user = User(str(data["initiator_id"]), avatar=user_avatar(data["initiator_id"])) 16 | channel = Channel(f"private:{user.id}", ChannelType.DIRECT) 17 | message_id = f"{data['initiator_uid']}|{1 if data.get('is_filtered') else 0}" 18 | message = MessageObject(message_id, data.get("comment", "")) 19 | return Event( 20 | EventType.FRIEND_REQUEST, 21 | datetime.fromtimestamp(raw["time"]), 22 | login, 23 | user=user, 24 | channel=channel, 25 | message=message, 26 | ) 27 | 28 | 29 | @register_event("group_join_request") 30 | async def group_join_request(login, net, raw): 31 | data = raw["data"] 32 | guild_id = str(data["group_id"]) 33 | guild = Guild(guild_id, avatar=group_avatar(guild_id)) 34 | channel = Channel(guild_id) 35 | user = User(str(data["initiator_id"]), avatar=user_avatar(data["initiator_id"])) 36 | member = Member(user, avatar=user_avatar(data["initiator_id"])) 37 | message_id = f"{data['notification_seq']}|join_request|{guild_id}|{1 if data.get('is_filtered') else 0}" 38 | message = MessageObject(message_id, data.get("comment", "")) 39 | return Event( 40 | EventType.GUILD_MEMBER_REQUEST, 41 | datetime.fromtimestamp(raw["time"]), 42 | login, 43 | guild=guild, 44 | channel=channel, 45 | user=user, 46 | member=member, 47 | message=message, 48 | ) 49 | 50 | 51 | @register_event("group_invited_join_request") 52 | async def group_invited_join_request(login, net, raw): 53 | data = raw["data"] 54 | guild_id = str(data["group_id"]) 55 | guild = Guild(guild_id, avatar=group_avatar(guild_id)) 56 | channel = Channel(guild_id) 57 | user = User(str(data["target_user_id"]), avatar=user_avatar(data["target_user_id"])) 58 | member = Member(user, avatar=user.avatar) 59 | message_id = f"{data['notification_seq']}|invited_join_request|{guild_id}|0" 60 | message = MessageObject(message_id, "") 61 | operator = User(str(data["initiator_id"]), avatar=user_avatar(data["initiator_id"])) 62 | return Event( 63 | EventType.GUILD_MEMBER_REQUEST, 64 | datetime.fromtimestamp(raw["time"]), 65 | login, 66 | guild=guild, 67 | channel=channel, 68 | user=user, 69 | member=member, 70 | operator=operator, 71 | message=message, 72 | ) 73 | 74 | 75 | @register_event("group_invitation") 76 | async def group_invitation(login, net, raw): 77 | data = raw["data"] 78 | guild_id = str(data["group_id"]) 79 | guild = Guild(guild_id, avatar=group_avatar(guild_id)) 80 | channel = Channel(guild_id) 81 | user = User(str(data["initiator_id"]), avatar=user_avatar(data["initiator_id"])) 82 | message_id = f"{guild_id}|{data['invitation_seq']}" 83 | message = MessageObject(message_id, "") 84 | return Event( 85 | EventType.GUILD_REQUEST, 86 | datetime.fromtimestamp(raw["time"]), 87 | login, 88 | guild=guild, 89 | channel=channel, 90 | user=user, 91 | message=message, 92 | ) 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # satori-python 2 | 3 | ![latest release](https://img.shields.io/github/release/RF-Tar-Railt/satori-python) 4 | [![Licence](https://img.shields.io/github/license/RF-Tar-Railt/satori-python)](https://github.com/RF-Tar-Railt/satori-python/blob/main/LICENSE) 5 | [![PyPI](https://img.shields.io/pypi/v/satori-python)](https://pypi.org/project/satori-python) 6 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/satori-python)](https://www.python.org/) 7 | 8 | 基于 [Satori](https://satori.js.org/zh-CN/) 协议的 Python 开发工具包 9 | 10 | ## 协议介绍 11 | 12 | [Satori Protocol](https://satori.js.org/zh-CN/) 13 | 14 | ### 协议端 15 | 16 | 目前提供了 `satori` 协议实现的有: 17 | 18 | - [Chronocat](https://chronocat.vercel.app) 19 | - [nekobox](https://github.com/wyapx/nekobox) 20 | - Koishi (搭配 `@koishijs/plugin-server`) 21 | 22 | ### 使用该 SDK 的框架 23 | 24 | - [`Entari`](https://github.com/ArcletProject/Entari) 25 | 26 | ## 安装 27 | 28 | 安装完整体: 29 | ```shell 30 | pip install satori-python 31 | ``` 32 | 33 | 只安装基础部分: 34 | ```shell 35 | pip install satori-python-core 36 | ``` 37 | 38 | 只安装客户端部分: 39 | ```shell 40 | pip install satori-python-client 41 | ``` 42 | 43 | 只安装服务端部分: 44 | ```shell 45 | pip install satori-python-server 46 | ``` 47 | 48 | ### 官方适配器 49 | 50 | | 适配器 | 安装 | 路径 | 51 | |------------|----------------------------------------------|--------------------------------------------------------------------| 52 | | Satori | `pip install satori-python-adapter-satori` | satori.adapters.satori | 53 | | OneBot V11 | `pip install satori-python-adapter-onebot11` | satori.adapters.onebot11.forward, satori.adapters.onebot11.reverse | 54 | | Console | `pip install satori-python-adapter-console` | satori.adapters.console | 55 | | Milky | `pip install satori-python-adapter-milky` | satori.adapters.milky.main, satori.adapters.milky.webhook | 56 | 57 | ### 社区适配器 58 | 59 | | 适配器 | 安装 | 路径 | 60 | |-------------------|-----------------------|--------------| 61 | | nekobox(Lagrange) | `pip install nekobox` | nekobox.main | 62 | 63 | ## 使用 64 | 65 | 客户端: 66 | 67 | ```python 68 | from satori import EventType 69 | from satori.event import MessageEvent 70 | from satori.client import Account, App, WebsocketsInfo 71 | 72 | app = App(WebsocketsInfo(port=5140)) 73 | 74 | @app.register_on(EventType.MESSAGE_CREATED) 75 | async def on_message(account: Account, event: MessageEvent): 76 | if event.user.id == "xxxxxxxxxxx": 77 | await account.send(event, "Hello, World!") 78 | 79 | app.run() 80 | ``` 81 | 82 | 服务端: 83 | 84 | ```python 85 | from satori import Api 86 | from satori.server import Server 87 | 88 | server = Server(port=5140) 89 | 90 | @server.route(Api.MESSAGE_CREATE) 91 | async def on_message_create(*args, **kwargs): 92 | return [{"id": "1234", "content": "example"}] 93 | 94 | server.run() 95 | ``` 96 | 97 | ## 文档 98 | 99 | 请阅读 [仓库文档](./docs.md) 100 | 101 | ## 示例 102 | 103 | - 客户端:[client.py](./example/client.py) 104 | - 服务端:[server.py](./example/server.py) 105 | - 服务端(使用适配器):[server_with_adapter.py](./example/server_with_adapter.py) 106 | - 客户端(webhook):[client_webhook](./example/client_webhook.py) 107 | - 服务端(webhook):[server_webhook](./example/server_webhook.py) 108 | - 适配器:[adapter.py](./example/adapter.py) 109 | 110 | ## 架构 111 | 112 | ```mermaid 113 | graph LR 114 | subgraph Server 115 | server -- run --> asgi 116 | server -- register --> router -- mount --> asgi 117 | server -- apply --> provider -- mount --> asgi 118 | provider -- event,logins --> server 119 | end 120 | subgraph Client 121 | config -- apply --> app -- run --> network 122 | app -- register --> listener 123 | network -- account,event --> listener 124 | listener -- handle --> account -- session --> api 125 | end 126 | 127 | api -- request --> asgi -- response --> api 128 | server -- raw-event --> asgi -- websocket/webhook --> network 129 | ``` 130 | -------------------------------------------------------------------------------- /src/satori/adapters/satori/main.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from launart import Launart, any_completed 4 | from launart.status import Phase 5 | from starlette.datastructures import FormData 6 | from starlette.responses import JSONResponse, Response 7 | 8 | from satori import Api 9 | from satori.client import App, WebsocketsInfo 10 | from satori.exception import ActionFailed 11 | from satori.server import Adapter as BaseAdapter 12 | from satori.server import Request 13 | from satori.server.adapter import LoginType 14 | 15 | 16 | class SatoriAdapter(BaseAdapter): 17 | def __init__( 18 | self, 19 | host: str = "localhost", 20 | port: int = 5140, 21 | path: str = "", 22 | token: str | None = None, 23 | post_upload: bool = False, 24 | ): 25 | super().__init__() 26 | self.app = App(WebsocketsInfo(host, port, path, token), main_app=False) 27 | 28 | @self.app.register 29 | async def _(acc, event): 30 | await self.server.post(event) 31 | 32 | self.routes["internal/*"] = self._handle_request 33 | self.routes |= {api.value: self._handle_request for api in Api.__members__.values()} 34 | if not post_upload: 35 | self.routes.pop(Api.UPLOAD_CREATE.value, None) 36 | 37 | @property 38 | def id(self): 39 | return f"satori-python.adapter.satori#{id(self)}" 40 | 41 | @property 42 | def account(self): 43 | return next(iter(self.app.accounts.values()), None) 44 | 45 | def get_platform(self) -> str: 46 | return "satori" 47 | 48 | def ensure(self, platform: str, self_id: str) -> bool: 49 | if not (acc := self.account): 50 | return False 51 | return acc.platform == platform and acc.self_info.user.id == self_id 52 | 53 | async def _handle_request(self, request: Request): 54 | if not (acc := self.account): 55 | return Response("No account found", status_code=404) 56 | if request.action == Api.UPLOAD_CREATE.value: 57 | data = cast(FormData, request.params) 58 | files = { 59 | k: ( 60 | v 61 | if isinstance(v, str) 62 | else {"value": v.file.read(), "content_type": v.content_type, "filename": v.filename} 63 | ) 64 | for k, v in data.items() 65 | } 66 | return await acc.protocol.call_api(request.action, files, multipart=True) 67 | return await acc.protocol.call_api(request.action, request.params) 68 | 69 | async def handle_internal(self, request: Request, path: str) -> Response: 70 | if path.startswith("_api"): 71 | if not (acc := self.account): 72 | return Response("No account found", status_code=404) 73 | try: 74 | return JSONResponse( 75 | await acc.protocol.call_api(path[5:], await request.origin.json(), method=request.origin.method) 76 | ) 77 | except ActionFailed as e: 78 | return Response(str(e), status_code=500) 79 | if acc := self.account: 80 | return Response(await self.account.protocol.download(f"internal:{acc.platform}/{acc.self_id}/{path}")) 81 | async with self.server.session.get(path) as resp: 82 | return Response(await resp.read()) 83 | 84 | async def get_logins(self) -> list[LoginType]: 85 | if not (acc := self.account): 86 | return [] 87 | return [acc.self_info] 88 | 89 | @property 90 | def required(self) -> set[str]: 91 | return { 92 | "satori-python.server", 93 | } 94 | 95 | @property 96 | def stages(self) -> set[Phase]: 97 | return {"preparing", "blocking", "cleanup"} 98 | 99 | async def launch(self, manager: Launart): 100 | manager.add_component(self.app.connections[0]) 101 | 102 | async with self.stage("preparing"): 103 | pass 104 | 105 | async with self.stage("blocking"): 106 | await any_completed( 107 | self.app.connections[0].status.wait_for("blocking-completed"), 108 | manager.status.wait_for_sigexit(), 109 | ) 110 | 111 | async with self.stage("cleanup"): 112 | pass 113 | 114 | 115 | Adapter = SatoriAdapter 116 | -------------------------------------------------------------------------------- /src/satori/client/network/webhook.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | 5 | from aiohttp import ClientTimeout, web 6 | from graia.amnesia.builtins.aiohttp import AiohttpClientService 7 | from launart.manager import Launart 8 | from loguru import logger 9 | 10 | from satori.model import Event, LoginStatus, Meta, MetaPayload, Opcode 11 | from satori.utils import decode 12 | 13 | from ..account import Account 14 | from ..config import WebhookInfo as WebhookInfo 15 | from .base import BaseNetwork 16 | from .util import validate_response 17 | 18 | 19 | class WebhookNetwork(BaseNetwork[WebhookInfo]): 20 | required: set[str] = set() 21 | stages: set[str] = {"preparing", "blocking", "cleanup"} 22 | wsgi: web.Application | None = None 23 | 24 | @property 25 | def id(self): 26 | return f"satori/net/wh/{self.config.identity}#{id(self):x}" 27 | 28 | async def handle_request(self, req: web.Request): 29 | header = req.headers 30 | auth = header["Authorization"] 31 | if not auth.startswith("Bearer"): 32 | return web.Response(status=401) 33 | token = auth.split(" ", 1)[1] 34 | if self.config.token and self.config.token != token: 35 | return web.Response(status=401) 36 | op_code = int(header.get("Satori-OpCode", "0")) 37 | body = decode(await req.text()) 38 | if op_code == Opcode.META: 39 | payload = MetaPayload.parse(body) 40 | self.proxy_urls = payload.proxy_urls 41 | for account in self.accounts.values(): 42 | account.proxy_urls = payload.proxy_urls 43 | return web.Response() 44 | if op_code != Opcode.EVENT: 45 | return web.Response(status=202) 46 | # if "X-Platform" in header and "X-Self-ID" in header: 47 | # platform = header["X-Platform"] 48 | # self_id = header["X-Self-ID"] 49 | # elif "Satori-Platform" in header and "Satori-User-ID" in header: 50 | # platform = header["Satori-Platform"] 51 | # self_id = header["Satori-User-ID"] 52 | # else: 53 | # return web.Response(status=400) 54 | try: 55 | event = Event.parse(body) 56 | except Exception as e: 57 | if ( 58 | "self_id" in body 59 | or ("login" in body and "self_id" in body["login"]) 60 | or ("login" in body and "user" in body["login"] and "self_id" in body["login"]["user"]) 61 | ): 62 | logger.warning(f"Failed to parse event: {body}\nCaused by {e!r}") 63 | else: 64 | logger.trace(f"Failed to parse event: {body}\nCaused by {e!r}") 65 | return web.Response(status=500, reason=f"Failed to parse event caused by {e!r}") 66 | else: 67 | self.sequence = event.sn 68 | asyncio.create_task(self.app.post(event, self)) 69 | return web.Response() 70 | 71 | @property 72 | def alive(self): 73 | return self.wsgi is not None 74 | 75 | async def wait_for_available(self): 76 | await self.status.wait_for_available() 77 | 78 | async def launch(self, manager: Launart): 79 | async with self.stage("preparing"): 80 | logger.info(f"starting server on {self.config.identity}") 81 | self.wsgi = web.Application(logger=logger) # type: ignore 82 | self.wsgi.router.freeze = lambda: None # monkey patch 83 | self.wsgi.router.add_post(self.config.path, self.handle_request) 84 | runner = web.AppRunner(self.wsgi) 85 | await runner.setup() 86 | site = web.TCPSite(runner, self.config.host, self.config.port) 87 | 88 | async with self.stage("blocking"): 89 | endpoint = self.config.api_base / "meta" 90 | headers = { 91 | "Content-Type": "application/json", 92 | } 93 | aio = Launart.current().get_component(AiohttpClientService) 94 | 95 | async with aio.session.request( 96 | "POST", 97 | endpoint, 98 | json={}, 99 | headers=headers, 100 | timeout=ClientTimeout(total=self.config.timeout or 300), 101 | ) as resp: 102 | data = await validate_response(resp) 103 | meta = Meta.parse(data) 104 | self.proxy_urls = meta.proxy_urls 105 | for login in meta.logins: 106 | if not login.user: 107 | continue 108 | login_sn = f"{login.user.id}@{id(self):x}" 109 | account = Account(login, self.config, meta.proxy_urls, self.app.default_api_cls) 110 | logger.info(f"account registered: {account}") 111 | (account.connected.set() if login.status == LoginStatus.ONLINE else account.connected.clear()) 112 | self.app.accounts[login_sn] = account 113 | self.accounts[login_sn] = account 114 | await self.app.account_update(account, LoginStatus.ONLINE) 115 | await site.start() 116 | await manager.status.wait_for_sigexit() 117 | logger.info(f"{self.id} Webhook server exiting...") 118 | self.close_signal.set() 119 | for v in list(self.app.accounts.values()): 120 | if (identity := f"{v.self_id}@{id(self):x}") in self.accounts: 121 | v.connected.clear() 122 | await self.app.account_update(v, LoginStatus.OFFLINE) 123 | del self.app.accounts[identity] 124 | del self.accounts[identity] 125 | 126 | async with self.stage("cleanup"): 127 | await site.stop() 128 | await self.wsgi.shutdown() 129 | await self.wsgi.cleanup() 130 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/events/message.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | 5 | from satori import EventType 6 | from satori.model import Channel, ChannelType, Event, Guild, Login, Member, MessageObject, User 7 | 8 | from ..message import decode 9 | from ..utils import GROUP_AVATAR_URL, ROLE_MAPPING, USER_AVATAR_URL, OneBotNetwork 10 | from .base import register_event 11 | 12 | 13 | @register_event("message.private.friend") 14 | @register_event("message.private.other") 15 | async def private_friend(login: Login, net: OneBotNetwork, raw: dict): 16 | sender: dict = raw["sender"] 17 | user = User( 18 | str(sender["user_id"]), sender["nickname"], sender.get("card"), USER_AVATAR_URL.format(uin=sender["user_id"]) 19 | ) 20 | channel = Channel(f"private:{sender['user_id']}", ChannelType.DIRECT, sender["nickname"]) 21 | return Event( 22 | EventType.MESSAGE_CREATED, 23 | datetime.now(), 24 | login=login, 25 | user=user, 26 | channel=channel, 27 | message=MessageObject(str(raw["message_id"]), await decode(raw["message"], net)), 28 | ) 29 | 30 | 31 | @register_event("message.private.group") 32 | async def private_group(login: Login, net: OneBotNetwork, raw: dict): 33 | sender: dict = raw["sender"] 34 | user = User( 35 | str(sender["user_id"]), sender["nickname"], sender.get("card"), USER_AVATAR_URL.format(uin=sender["user_id"]) 36 | ) 37 | channel = Channel(f"private:{sender['user_id']}", ChannelType.DIRECT, sender["nickname"]) 38 | group_id = sender["group_id"] if "group_id" in sender else raw.get("group_id") 39 | group_info = await net.call_api("get_group_info", {"group_id": group_id}) if group_id else {} 40 | return Event( 41 | EventType.MESSAGE_CREATED, 42 | datetime.now(), 43 | login=login, 44 | user=user, 45 | member=Member(user, sender["nickname"], USER_AVATAR_URL.format(uin=sender["user_id"])), 46 | guild=( 47 | Guild(str(group_id), group_info.get("group_name"), avatar=GROUP_AVATAR_URL.format(group=group_id)) 48 | if group_id 49 | else None 50 | ), 51 | channel=channel, 52 | message=MessageObject(str(raw["message_id"]), await decode(raw["message"], net)), 53 | ) 54 | 55 | 56 | @register_event("notice.friend_recall") 57 | async def friend_message_recall(login: Login, net: OneBotNetwork, raw: dict): 58 | info = await net.call_api("get_stranger_info", {"user_id": raw["user_id"]}) 59 | user = User(str(raw["user_id"]), info.get("nickname"), info.get("card"), USER_AVATAR_URL.format(uin=raw["user_id"])) 60 | channel = Channel(f"private:{raw['user_id']}", ChannelType.DIRECT, info.get("nickname")) 61 | return Event( 62 | EventType.MESSAGE_DELETED, 63 | datetime.now(), 64 | login=login, 65 | user=user, 66 | channel=channel, 67 | message=MessageObject(str(raw["message_id"]), ""), 68 | ) 69 | 70 | 71 | @register_event("message.group.normal") 72 | @register_event("message.group.notice") 73 | @register_event("message_sent.group.normal") 74 | async def group(login: Login, net: OneBotNetwork, raw: dict): 75 | sender: dict = raw["sender"] 76 | user = User( 77 | str(sender["user_id"]), sender["nickname"], sender.get("card"), USER_AVATAR_URL.format(uin=sender["user_id"]) 78 | ) 79 | group_info = await net.call_api("get_group_info", {"group_id": raw["group_id"]}) 80 | member_info = await net.call_api( 81 | "get_group_member_info", {"group_id": raw["group_id"], "user_id": sender["user_id"]} 82 | ) 83 | member = Member(user, member_info["card"], USER_AVATAR_URL.format(uin=sender["user_id"])) 84 | guild = Guild( 85 | str(raw["group_id"]), group_info.get("group_name"), avatar=GROUP_AVATAR_URL.format(group=raw["group_id"]) 86 | ) 87 | channel = Channel(str(raw["group_id"]), ChannelType.TEXT, group_info.get("group_name")) 88 | return Event( 89 | EventType.MESSAGE_CREATED, 90 | datetime.now(), 91 | login=login, 92 | user=user, 93 | guild=guild, 94 | channel=channel, 95 | member=member, 96 | role=ROLE_MAPPING[sender["role"]], 97 | message=MessageObject(str(raw["message_id"]), await decode(raw["message"], net)), 98 | ) 99 | 100 | 101 | @register_event("notice.group_recall") 102 | async def group_message_recall(login: Login, net: OneBotNetwork, raw: dict): 103 | group_info = await net.call_api("get_group_info", {"group_id": raw["group_id"]}) 104 | member_info = await net.call_api("get_group_member_info", {"group_id": raw["group_id"], "user_id": raw["user_id"]}) 105 | operator_info = await net.call_api( 106 | "get_group_member_info", {"group_id": raw["group_id"], "user_id": raw["operator_id"]} 107 | ) 108 | user = User( 109 | str(raw["user_id"]), 110 | member_info["nickname"], 111 | member_info.get("card"), 112 | USER_AVATAR_URL.format(uin=raw["user_id"]), 113 | ) 114 | member = Member(user, member_info.get("card"), USER_AVATAR_URL.format(uin=raw["user_id"])) 115 | guild = Guild( 116 | str(raw["group_id"]), group_info.get("group_name"), avatar=GROUP_AVATAR_URL.format(group=raw["group_id"]) 117 | ) 118 | channel = Channel(str(raw["group_id"]), ChannelType.TEXT, group_info.get("group_name")) 119 | operator = User( 120 | str(raw["operator_id"]), 121 | operator_info["nickname"], 122 | operator_info.get("card"), 123 | USER_AVATAR_URL.format(uin=raw["operator_id"]), 124 | ) 125 | return Event( 126 | EventType.MESSAGE_DELETED, 127 | datetime.now(), 128 | login=login, 129 | user=user, 130 | guild=guild, 131 | channel=channel, 132 | member=member, 133 | operator=operator, 134 | role=ROLE_MAPPING[member_info["role"]], 135 | message=MessageObject(str(raw["message_id"]), ""), 136 | ) 137 | -------------------------------------------------------------------------------- /src/satori/adapters/console/backend.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import asdict 4 | from datetime import datetime 5 | from typing import TYPE_CHECKING, cast 6 | 7 | from loguru import _colorama, logger 8 | from loguru._handler import Handler 9 | from loguru._logger import Logger 10 | from loguru._simple_sinks import StreamSink 11 | from nonechat import Backend, Frontend 12 | from nonechat.backend import BotAdd 13 | from nonechat.model import DIRECT 14 | from nonechat.model import Event as ConsoleEvent 15 | from nonechat.model import MessageEvent as ConsoleMessageEvent 16 | from nonechat.model import Robot 17 | 18 | from satori.const import EventType 19 | from satori.event import Event 20 | from satori.model import Channel, ChannelType, Guild, Login, LoginStatus, Member, MessageObject, User 21 | 22 | from .message import encode_message 23 | 24 | if TYPE_CHECKING: 25 | from .main import ConsoleAdapter 26 | 27 | 28 | class SatoriConsoleBackend(Backend): 29 | _adapter: ConsoleAdapter 30 | 31 | def __init__(self, app: Frontend): 32 | super().__init__(app) 33 | self.logins = {} 34 | self.sn = 0 35 | self._origin_sink: StreamSink | None = None 36 | 37 | def set_adapter(self, adapter: ConsoleAdapter): 38 | self._adapter = adapter 39 | 40 | def on_console_load(self): 41 | if self._adapter._logger_id >= 0: 42 | current_handler: Handler = cast(Logger, logger)._core.handlers[self._adapter._logger_id] 43 | else: 44 | current_handler: Handler = list(cast(Logger, logger)._core.handlers.values())[-1] 45 | if current_handler._colorize and _colorama.should_wrap(self.frontend._fake_output): 46 | stream = _colorama.wrap(self.frontend._fake_output) 47 | else: 48 | stream = self.frontend._fake_output 49 | self._origin_sink = current_handler._sink 50 | current_handler._sink = StreamSink(stream) 51 | 52 | async def add_bot(self, bot: Robot): 53 | if self.storage.add_bot(bot): 54 | for watcher in self.bot_watchers: 55 | watcher.post_message(BotAdd(bot)) 56 | login = Login( 57 | self.sn, 58 | LoginStatus.ONLINE, 59 | "console", 60 | "console", 61 | User( 62 | id=bot.id, 63 | name=bot.nickname, 64 | avatar=f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(bot.avatar):x}.png", 65 | is_bot=True, 66 | ), 67 | features=["guild.plain"], 68 | ) 69 | self.sn += 1 70 | self.logins[bot.id] = login 71 | await self._adapter.server.post(Event(EventType.LOGIN_ADDED, datetime.now(), login)) 72 | 73 | async def on_console_mount(self): 74 | logger.success("Console mounted.") 75 | 76 | async def on_console_unmount(self): 77 | if self._origin_sink is not None: 78 | if self._adapter._logger_id >= 0: 79 | current_handler: Handler = cast(Logger, logger)._core.handlers[self._adapter._logger_id] 80 | else: 81 | current_handler: Handler = list(cast(Logger, logger)._core.handlers.values())[-1] 82 | current_handler._sink = self._origin_sink 83 | self._origin_sink = None 84 | for login in self.logins.values(): 85 | login.status = LoginStatus.OFFLINE 86 | await self._adapter.server.post( 87 | Event( 88 | EventType.LOGIN_REMOVED, 89 | datetime.now(), 90 | login, 91 | ) 92 | ) 93 | 94 | logger.success("Console exit.") 95 | logger.warning("Press Ctrl-C for Application exit") 96 | 97 | async def post_event(self, event: ConsoleEvent): 98 | if event.self_id not in self.logins: 99 | logger.warning(f"Received event from unknown bot: {event.self_id}") 100 | return 101 | user = User( 102 | event.user.id, 103 | event.user.nickname, 104 | avatar=f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(event.user.avatar):x}.png", 105 | ) 106 | member = Member(user, nick=user.name, avatar=user.avatar) 107 | if isinstance(event, ConsoleMessageEvent): 108 | message = MessageObject(event.message_id, encode_message(event.message)) 109 | if event.channel == DIRECT: 110 | await self._adapter.server.post( 111 | Event( 112 | EventType.MESSAGE_CREATED, 113 | event.time, 114 | self.logins[event.self_id], 115 | user=user, 116 | channel=Channel( 117 | id=f"private:{user.id}", 118 | type=ChannelType.DIRECT, 119 | name=user.name, 120 | ), 121 | message=message, 122 | ) 123 | ) 124 | else: 125 | guild = Guild( 126 | id=event.channel.id, 127 | name=event.channel.name, 128 | avatar=f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(event.channel.avatar):x}.png", 129 | ) 130 | channel = Channel( 131 | id=event.channel.id, 132 | type=ChannelType.TEXT, 133 | name=event.channel.name, 134 | ) 135 | await self._adapter.server.post( 136 | Event( 137 | EventType.MESSAGE_CREATED, 138 | event.time, 139 | self.logins[event.self_id], 140 | user=user, 141 | member=member, 142 | channel=channel, 143 | guild=guild, 144 | message=message, 145 | ) 146 | ) 147 | else: 148 | await self._adapter.server.post( 149 | Event( 150 | EventType.INTERNAL, 151 | event.time, 152 | self.logins[event.self_id], 153 | _type=event.type, 154 | _data=asdict(event), 155 | ) 156 | ) 157 | -------------------------------------------------------------------------------- /experimental/model.pyi: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncIterable, Awaitable 2 | from dataclasses import dataclass 3 | from datetime import datetime 4 | from enum import IntEnum 5 | from os import PathLike 6 | from typing import IO, Any, Generic, Literal, TypeVar 7 | from collections.abc import Callable, AsyncIterator 8 | from typing_extensions import Self 9 | from typing import TypeAlias 10 | 11 | from satori.element import Element 12 | 13 | 14 | @dataclass(kw_only=True) 15 | class ModelBase: 16 | @classmethod 17 | def parse(cls: type[Self], raw: dict) -> Self: ... 18 | def dump(self) -> dict: ... 19 | 20 | 21 | class ChannelType(IntEnum): 22 | TEXT = 0 23 | DIRECT = 1 24 | CATEGORY = 2 25 | VOICE = 3 26 | 27 | 28 | @dataclass(kw_only=True) 29 | class Channel(ModelBase): 30 | id: str 31 | type: ChannelType = ChannelType.TEXT 32 | name: str | None = None 33 | parent_id: str | None = None 34 | 35 | 36 | @dataclass(kw_only=True) 37 | class Guild(ModelBase): 38 | id: str 39 | name: str | None = None 40 | avatar: str | None = None 41 | 42 | 43 | 44 | @dataclass(kw_only=True) 45 | class User(ModelBase): 46 | id: str 47 | name: str | None = None 48 | nick: str | None = None 49 | avatar: str | None = None 50 | is_bot: bool | None = None 51 | 52 | 53 | 54 | @dataclass(kw_only=True) 55 | class Member(ModelBase): 56 | user: User | None = None 57 | nick: str | None = None 58 | avatar: str | None = None 59 | joined_at: datetime | None = None 60 | 61 | 62 | 63 | @dataclass(kw_only=True) 64 | class Role(ModelBase): 65 | id: str 66 | name: str | None = None 67 | 68 | 69 | 70 | class LoginStatus(IntEnum): 71 | OFFLINE = 0 72 | """离线""" 73 | ONLINE = 1 74 | """在线""" 75 | CONNECT = 2 76 | """正在连接""" 77 | DISCONNECT = 3 78 | """正在断开连接""" 79 | RECONNECT = 4 80 | """正在重新连接""" 81 | 82 | 83 | @dataclass(kw_only=True) 84 | class Login(ModelBase): 85 | sn: int 86 | status: LoginStatus 87 | adapter: str 88 | platform: str 89 | user: User 90 | features: list[str] = ... 91 | 92 | @property 93 | def id(self) -> str: ... 94 | 95 | 96 | @dataclass(kw_only=True) 97 | class LoginPartial(Login): 98 | platform: str | None = None 99 | user: User | None = None 100 | 101 | 102 | @dataclass(kw_only=True) 103 | class ArgvInteraction(ModelBase): 104 | name: str 105 | arguments: list 106 | options: Any 107 | 108 | 109 | @dataclass(kw_only=True) 110 | class ButtonInteraction(ModelBase): 111 | id: str 112 | 113 | 114 | class Opcode(IntEnum): 115 | EVENT = 0 116 | """事件 (接收)""" 117 | PING = 1 118 | """心跳 (发送)""" 119 | PONG = 2 120 | """心跳回复 (接收)""" 121 | IDENTIFY = 3 122 | """鉴权 (发送)""" 123 | READY = 4 124 | """鉴权成功 (接收)""" 125 | META = 5 126 | """元信息更新 (接收)""" 127 | 128 | 129 | @dataclass(kw_only=True) 130 | class Identify(ModelBase): 131 | token: str | None = None 132 | sn: int | None = None 133 | 134 | 135 | @property 136 | def sequence(self) -> int | None: ... 137 | 138 | @dataclass(kw_only=True) 139 | class Ready(ModelBase): 140 | logins: list[LoginPartial] 141 | proxy_urls: list[str] = ... 142 | 143 | @dataclass(kw_only=True) 144 | class MetaPayload(ModelBase): 145 | """Meta 信令""" 146 | 147 | proxy_urls: list[str] 148 | 149 | 150 | 151 | @dataclass(kw_only=True) 152 | class Meta(ModelBase): 153 | """Meta 数据""" 154 | 155 | logins: list[LoginPartial] 156 | proxy_urls: list[str] = ... 157 | 158 | 159 | 160 | @dataclass(kw_only=True) 161 | class MessageObject(ModelBase): 162 | id: str 163 | content: str 164 | channel: Channel | None = None 165 | guild: Guild | None = None 166 | member: Member | None = None 167 | user: User | None = None 168 | created_at: datetime | None = None 169 | updated_at: datetime | None = None 170 | 171 | @classmethod 172 | def from_elements( 173 | cls, 174 | id: str, 175 | content: list[Element], 176 | channel: Channel | None = None, 177 | guild: Guild | None = None, 178 | member: Member | None = None, 179 | user: User | None = None, 180 | created_at: datetime | None = None, 181 | updated_at: datetime | None = None, 182 | ) -> MessageObject: ... 183 | 184 | @property 185 | def message(self) -> list[Element]: ... 186 | 187 | @message.setter 188 | def message(self, value: list[Element]): ... 189 | 190 | 191 | @dataclass(kw_only=True) 192 | class MessageReceipt(ModelBase): 193 | id: str 194 | content: str | None = None 195 | 196 | @classmethod 197 | def from_elements( 198 | cls, 199 | id: str, 200 | content: list[Element] | None = None, 201 | ) -> MessageReceipt: ... 202 | 203 | @property 204 | def message(self) -> list[Element] | None: ... 205 | 206 | @message.setter 207 | def message(self, value: list[Element] | None): ... 208 | 209 | 210 | @dataclass(kw_only=True) 211 | class Event(ModelBase): 212 | type: str 213 | timestamp: datetime 214 | login: Login 215 | argv: ArgvInteraction | None = None 216 | button: ButtonInteraction | None = None 217 | channel: Channel | None = None 218 | guild: Guild | None = None 219 | member: Member | None = None 220 | message: MessageObject | None = None 221 | operator: User | None = None 222 | role: Role | None = None 223 | user: User | None = None 224 | 225 | _type: str | None = None 226 | _data: dict | None = None 227 | 228 | sn: int = 0 229 | 230 | 231 | @property 232 | def platform(self) -> str: ... 233 | 234 | @property 235 | def self_id(self) -> str: ... 236 | 237 | 238 | T = TypeVar("T", bound=ModelBase) 239 | 240 | 241 | @dataclass 242 | class PageResult(ModelBase, Generic[T]): 243 | data: list[T] 244 | next: str | None = None 245 | 246 | @classmethod 247 | def parse(cls, raw: dict, parser: Callable[[dict], T] | None = None) -> PageResult[T]: ... 248 | 249 | 250 | @dataclass 251 | class PageDequeResult(PageResult[T]): 252 | prev: str | None = None 253 | 254 | @classmethod 255 | def parse(cls, raw: dict, parser: Callable[[dict], T] | None = None) -> PageDequeResult[T]: ... 256 | 257 | 258 | class IterablePageResult(Generic[T], AsyncIterable[T], Awaitable[PageResult[T]]): 259 | func: Callable[[str | None], Awaitable[PageResult[T]]] 260 | next_page: str | None 261 | 262 | def __init__(self, func: Callable[[str | None], Awaitable[PageResult[T]]], initial_page: str | None = None): ... 263 | 264 | def __await__(self): ... 265 | 266 | def __aiter__(self) -> AsyncIterator[T]: ... 267 | 268 | 269 | Direction: TypeAlias = Literal["before", "after", "around"] 270 | Order: TypeAlias = Literal["asc", "desc"] 271 | 272 | 273 | @dataclass 274 | class Upload: 275 | file: bytes | IO[bytes] | PathLike 276 | mimetype: str = "image/png" 277 | name: str | None = None 278 | 279 | def dump(self) -> dict[str, Any]: ... 280 | -------------------------------------------------------------------------------- /src/satori/client/network/websocket.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from contextlib import suppress 5 | from typing import cast 6 | 7 | import aiohttp 8 | from launart.manager import Launart 9 | from launart.utilles import any_completed 10 | from loguru import logger 11 | 12 | from satori.model import Event, Identify, LoginStatus, MetaPayload, Opcode, Ready 13 | from satori.utils import decode, encode 14 | 15 | from ..account import Account 16 | from ..config import WebsocketsInfo as WebsocketsInfo 17 | from .base import BaseNetwork 18 | 19 | 20 | class WsNetwork(BaseNetwork[WebsocketsInfo]): 21 | required: set[str] = set() 22 | stages: set[str] = {"preparing", "blocking", "cleanup"} 23 | 24 | @property 25 | def id(self): 26 | return f"satori/net/ws/{self.config.identity}#{id(self):x}" 27 | 28 | connection: aiohttp.ClientWebSocketResponse | None = None 29 | 30 | async def event_parse_task(self, raw: dict): 31 | try: 32 | event = Event.parse(raw) 33 | except Exception as e: 34 | if ( 35 | "self_id" in raw 36 | or ("login" in raw and "self_id" in raw["login"]) 37 | or ("login" in raw and "user" in raw["login"] and "self_id" in raw["login"]["user"]) 38 | ): 39 | logger.warning(f"Failed to parse event: {raw}\nCaused by {e!r}") 40 | else: 41 | logger.trace(f"Failed to parse event: {raw}\nCaused by {e!r}") 42 | else: 43 | self.sequence = event.sn 44 | await self.app.post(event, self) 45 | 46 | async def message_receive(self): 47 | if self.connection is None: 48 | raise RuntimeError("connection is not established") 49 | 50 | async for msg in self.connection: 51 | if msg.type in {aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED}: 52 | self.close_signal.set() 53 | return 54 | elif msg.type == aiohttp.WSMsgType.TEXT: 55 | data: dict = decode(cast(str, msg.data)) 56 | if data["op"] == Opcode.EVENT: 57 | asyncio.create_task(self.event_parse_task(data["body"])) 58 | elif data["op"] == Opcode.META: 59 | payload = MetaPayload.parse(data["body"]) 60 | self.proxy_urls = payload.proxy_urls 61 | for account in self.accounts.values(): 62 | account.proxy_urls = payload.proxy_urls.copy() 63 | elif data["op"] > 5: 64 | logger.warning(f"Received unknown event: {data}") 65 | else: 66 | logger.trace(f"Received payload: {data}") 67 | continue 68 | else: 69 | await self.connection_closed() 70 | 71 | async def send(self, payload: dict): 72 | if self.connection is None: 73 | raise RuntimeError("connection is not established") 74 | 75 | await self.connection.send_str(encode(payload)) 76 | 77 | @property 78 | def alive(self): 79 | return self.connection is not None and not self.connection.closed 80 | 81 | async def wait_for_available(self): 82 | await self.status.wait_for_available() 83 | 84 | async def _authenticate(self): 85 | """鉴权连接""" 86 | if not self.connection: 87 | raise RuntimeError("connection is not established") 88 | payload = Identify(token=self.config.token) 89 | if self.sequence > -1: 90 | payload.sn = self.sequence 91 | try: 92 | await self.send({"op": Opcode.IDENTIFY.value, "body": payload.dump()}) 93 | except Exception as e: 94 | logger.error(f"Error while sending IDENTIFY event: {e!r}") 95 | return False 96 | 97 | resp = await self.connection.receive() 98 | if resp.type != aiohttp.WSMsgType.TEXT: 99 | logger.error(f"Received unexpected payload: {resp}") 100 | return False 101 | data = decode(cast(str, resp.data)) 102 | if data["op"] != Opcode.READY: 103 | logger.error(f"Received unexpected payload: {data}") 104 | return False 105 | ready = Ready.parse(data["body"]) 106 | self.proxy_urls = ready.proxy_urls 107 | for login in ready.logins: 108 | if not login.user: 109 | continue 110 | login_sn = f"{login.user.id}@{id(self):x}" 111 | if login_sn in self.app.accounts: 112 | account = self.app.accounts[login_sn] 113 | self.accounts[login_sn] = account 114 | if login.status == LoginStatus.ONLINE: 115 | account.connected.set() 116 | else: 117 | account.connected.clear() 118 | account.config = self.config 119 | else: 120 | account = Account(login, self.config, ready.proxy_urls, self.app.default_api_cls) 121 | logger.info(f"account registered: {account}") 122 | (account.connected.set() if login.status == LoginStatus.ONLINE else account.connected.clear()) 123 | self.app.accounts[login_sn] = account 124 | self.accounts[login_sn] = account 125 | await self.app.account_update(account, LoginStatus.ONLINE) 126 | if not self.accounts: 127 | logger.warning(f"No account available for {self.config}") 128 | # return False 129 | return True 130 | 131 | async def _heartbeat(self): 132 | """心跳""" 133 | while True: 134 | with suppress(Exception): 135 | await self.send({"op": 1}) 136 | await asyncio.sleep(9) 137 | 138 | async def daemon(self, manager: Launart, session: aiohttp.ClientSession): 139 | while not manager.status.exiting: 140 | try: 141 | async with session.ws_connect(self.config.ws_base / "events", timeout=300) as self.connection: 142 | logger.debug(f"{self.id} Websocket client connected") 143 | self.close_signal.clear() 144 | result = await self._authenticate() 145 | if not result: 146 | await asyncio.sleep(3) 147 | continue 148 | self.close_signal.clear() 149 | close_task = asyncio.create_task(self.close_signal.wait()) 150 | receiver_task = asyncio.create_task(self.message_receive()) 151 | sigexit_task = asyncio.create_task(manager.status.wait_for_sigexit()) 152 | heartbeat_task = asyncio.create_task(self._heartbeat()) 153 | done, pending = await any_completed( 154 | sigexit_task, 155 | close_task, 156 | receiver_task, 157 | heartbeat_task, 158 | ) 159 | if sigexit_task in done: 160 | logger.info(f"{self.id} Websocket client exiting...") 161 | await self.connection.close() 162 | self.close_signal.set() 163 | self.connection = None 164 | for v in list(self.app.accounts.values()): 165 | if (identity := f"{v.self_id}@{id(self):x}") in self.accounts: 166 | v.connected.clear() 167 | await self.app.account_update(v, LoginStatus.OFFLINE) 168 | del self.app.accounts[identity] 169 | del self.accounts[identity] 170 | return 171 | if close_task in done: 172 | receiver_task.cancel() 173 | logger.warning(f"{self.id} Connection closed by server, will reconnect in 5 seconds...") 174 | for k in self.accounts.keys(): 175 | logger.debug(f"Unregistering satori account {k}...") 176 | account = self.app.accounts[k] 177 | account.connected.clear() 178 | await self.app.account_update(account, LoginStatus.RECONNECT) 179 | self.accounts.clear() 180 | await asyncio.sleep(5) 181 | logger.info(f"{self.id} Reconnecting...") 182 | continue 183 | except Exception as e: 184 | logger.error(f"{self.id} Error while connecting: {e}") 185 | await asyncio.sleep(5) 186 | logger.info(f"{self.id} Reconnecting...") 187 | 188 | async def launch(self, manager: Launart): 189 | async with self.stage("preparing"): 190 | session = aiohttp.ClientSession() 191 | 192 | async with self.stage("blocking"): 193 | await self.daemon(manager, session) 194 | 195 | async with self.stage("cleanup"): 196 | await session.close() 197 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/reverse.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from datetime import datetime 5 | 6 | from launart import Launart, any_completed 7 | from launart.status import Phase 8 | from loguru import logger 9 | from starlette.responses import JSONResponse, Response 10 | from starlette.routing import WebSocketRoute 11 | from starlette.websockets import WebSocket 12 | from yarl import URL 13 | 14 | from satori import Event, EventType, LoginStatus 15 | from satori.exception import ActionFailed 16 | from satori.model import Login, User 17 | from satori.server import Request 18 | from satori.server.adapter import Adapter as BaseAdapter 19 | from satori.utils import decode, encode 20 | 21 | from .api import apply 22 | from .events.base import events 23 | from .utils import USER_AVATAR_URL, onebot11_event_type 24 | 25 | 26 | class _Connection: 27 | def __init__(self, adapter: OneBot11ReverseAdapter, ws: WebSocket): 28 | self.adapter = adapter 29 | self.ws = ws 30 | self.close_signal = asyncio.Event() 31 | self.response_waiters: dict[str, asyncio.Future] = {} 32 | 33 | async def message_receive(self): 34 | async for msg in self.ws.iter_text(): 35 | yield self, decode(msg) 36 | else: 37 | self.close_signal.set() 38 | 39 | async def message_handle(self): 40 | async for connection, data in self.message_receive(): 41 | if echo := data.get("echo"): 42 | if future := self.response_waiters.get(echo): 43 | future.set_result(data) 44 | continue 45 | 46 | async def event_parse_task(data: dict): 47 | event_type = onebot11_event_type(data) 48 | if event_type == "meta_event.lifecycle.connect": 49 | self_id = str(data["self_id"]) 50 | if self_id not in self.adapter.logins: 51 | self_info = await self.call_api("get_login_info") 52 | login = Login( 53 | 0, 54 | LoginStatus.ONLINE, 55 | "onebot", 56 | platform="onebot", 57 | user=User( 58 | self_id, 59 | (self_info or {})["nickname"], 60 | avatar=USER_AVATAR_URL.format(uin=self_id), 61 | ), 62 | features=["guild.plain"], 63 | ) 64 | self.adapter.logins[self_id] = login 65 | await self.adapter.server.post(Event(EventType.LOGIN_ADDED, datetime.now(), login)) 66 | elif event_type == "meta_event.lifecycle.enable": 67 | logger.warning(f"received lifecycle.enable event that is only supported in http adapter: {data}") 68 | return 69 | elif event_type == "meta_event.lifecycle.disable": 70 | logger.warning(f"received lifecycle.disable event that is only supported in http adapter: {data}") 71 | return 72 | elif event_type == "meta_event.heartbeat": 73 | self_id = str(data["self_id"]) 74 | if self_id not in self.adapter.logins: 75 | self_info = await self.call_api("get_login_info") 76 | login = Login( 77 | 0, 78 | LoginStatus.ONLINE, 79 | "onebot", 80 | platform="onebot", 81 | user=User( 82 | self_id, 83 | (self_info or {})["nickname"], 84 | avatar=USER_AVATAR_URL.format(uin=self_id), 85 | ), 86 | features=["guild.plain"], 87 | ) 88 | self.adapter.logins[self_id] = login 89 | await self.adapter.server.post(Event(EventType.LOGIN_ADDED, datetime.now(), login)) 90 | logger.trace(f"received heartbeat from {self_id}") 91 | else: 92 | self_id = str(data["self_id"]) 93 | if self_id not in self.adapter.logins: 94 | logger.warning(f"received event from unknown self_id: {data}") 95 | return 96 | login = self.adapter.logins[self_id] 97 | handler = events.get(event_type) 98 | if not handler: 99 | event = Event(EventType.INTERNAL, datetime.now(), login, _type=event_type, _data=data) 100 | else: 101 | event = await handler(login, self, data) 102 | if event: 103 | await self.adapter.server.post(event) 104 | 105 | asyncio.create_task(event_parse_task(data)) 106 | 107 | async def call_api(self, action: str, params: dict | None = None) -> dict: 108 | if not self.ws: 109 | raise RuntimeError("connection is not established") 110 | 111 | future: asyncio.Future[dict] = asyncio.get_running_loop().create_future() 112 | echo = str(hash(future)) 113 | self.response_waiters[echo] = future 114 | 115 | try: 116 | await self.ws.send_text(encode({"action": action, "params": params or {}, "echo": echo})) 117 | result = await future 118 | finally: 119 | del self.response_waiters[echo] 120 | 121 | if result["status"] != "ok": 122 | raise ActionFailed(f"{result['retcode']}: {result}", result) 123 | 124 | return result.get("data", {}) 125 | 126 | 127 | class OneBot11ReverseAdapter(BaseAdapter): 128 | 129 | def __init__( 130 | self, 131 | prefix: str = "/", 132 | path: str = "onebot/v11", 133 | endpoint: str = "ws", 134 | access_token: str | None = None, 135 | ): 136 | super().__init__() 137 | self.endpoint = URL(prefix) / path / endpoint 138 | self.access_token = access_token 139 | self.logins: dict[str, Login] = {} 140 | self.connections: dict[str, _Connection] = {} 141 | 142 | apply(self, lambda _: self.connections[_], lambda _: self.logins[_]) 143 | 144 | def ensure(self, platform: str, self_id: str) -> bool: 145 | return platform == "onebot" and self_id in self.logins 146 | 147 | async def get_logins(self) -> list[Login]: 148 | logins = list(self.logins.values()) 149 | for index, login in enumerate(logins): 150 | login.sn = index 151 | return logins 152 | 153 | @property 154 | def required(self) -> set[str]: 155 | return {"satori-python.server", "asgi.service/uvicorn"} 156 | 157 | @property 158 | def stages(self) -> set[Phase]: 159 | return {"preparing", "blocking", "cleanup"} 160 | 161 | async def websocket_server_handler(self, ws: WebSocket): 162 | if ws.headers.get("Authorization", "")[7:] != (self.access_token or ""): 163 | return await ws.close(1008, "Authorization Header is invalid") 164 | 165 | if "X-Self-ID" not in ws.headers: 166 | return await ws.close(1008, "Missing X-Self-ID Header") 167 | 168 | account_id = ws.headers["X-Self-ID"] 169 | if account_id in self.connections: 170 | return await ws.close(1008, "Duplicate X-Self-ID") 171 | 172 | await ws.accept() 173 | connection = _Connection(self, ws) 174 | self.connections[account_id] = connection 175 | 176 | try: 177 | await any_completed(connection.message_handle(), connection.close_signal.wait()) 178 | finally: 179 | del self.connections[account_id] 180 | logger.info(f"Websocket {ws} closed") 181 | self.logins[account_id].status = LoginStatus.OFFLINE 182 | await self.server.post(Event(EventType.LOGIN_REMOVED, datetime.now(), self.logins[account_id])) 183 | await asyncio.sleep(1) 184 | 185 | async def launch(self, manager: Launart): 186 | async with self.stage("preparing"): 187 | pass 188 | 189 | async with self.stage("blocking"): 190 | await manager.status.wait_for_sigexit() 191 | 192 | async with self.stage("cleanup"): 193 | pass 194 | 195 | def get_routes(self): 196 | return [ 197 | WebSocketRoute(str(self.endpoint), self.websocket_server_handler), 198 | ] 199 | 200 | def get_platform(self) -> str: 201 | return "onebot" 202 | 203 | async def handle_internal(self, request: Request, path: str) -> Response: 204 | if path.startswith("_api"): 205 | self_id = request.self_id 206 | return JSONResponse(await self.connections[self_id].call_api(path[5:], await request.origin.json())) 207 | async with self.server.session.get(path) as resp: 208 | return Response(await resp.read()) 209 | 210 | def __str__(self): 211 | return self.id 212 | 213 | 214 | Adapter = OneBot11ReverseAdapter 215 | -------------------------------------------------------------------------------- /src/satori/adapters/console/api.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from satori.const import Api 4 | from satori.model import Channel, ChannelType, Guild, Member, MessageObject, PageDequeResult, PageResult, User 5 | from satori.server import Request 6 | from satori.server.route import ( 7 | ChannelListParam, 8 | ChannelParam, 9 | FriendListParam, 10 | GuildGetParam, 11 | GuildListParam, 12 | GuildMemberGetParam, 13 | GuildXXXListParam, 14 | MessageListParam, 15 | MessageOpParam, 16 | MessageParam, 17 | MessageUpdateParam, 18 | UserChannelCreateParam, 19 | UserGetParam, 20 | ) 21 | 22 | from .message import decode_message, encode_message 23 | 24 | if TYPE_CHECKING: 25 | from .main import ConsoleAdapter 26 | 27 | 28 | def apply(adapter: "ConsoleAdapter"): 29 | @adapter.route(Api.USER_GET) 30 | async def user_get(request: Request[UserGetParam]) -> User: 31 | user_id = request.params["user_id"] 32 | ans = await adapter.app.backend.get_user(user_id) 33 | return User( 34 | ans.id, 35 | ans.nickname, 36 | avatar=f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(ans.avatar):x}.png", 37 | ) 38 | 39 | @adapter.route(Api.CHANNEL_GET) 40 | async def channel_get(request: Request[ChannelParam]) -> Channel: 41 | channel_id = request.params["channel_id"] 42 | ans = await adapter.app.backend.get_channel(channel_id) 43 | if not ans: 44 | raise ValueError(f"Channel {channel_id} not found") 45 | return Channel( 46 | ans.id, 47 | ChannelType.TEXT, 48 | ans.name, 49 | ) 50 | 51 | @adapter.route(Api.GUILD_GET) 52 | async def guild_get(request: Request[GuildGetParam]) -> Guild: 53 | guild_id = request.params["guild_id"] 54 | ans = await adapter.app.backend.get_channel(guild_id) 55 | return Guild( 56 | ans.id, 57 | ans.name, 58 | f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(ans.avatar):x}.png", 59 | ) 60 | 61 | @adapter.route(Api.FRIEND_LIST) 62 | async def friend_list(request: Request[FriendListParam]) -> PageResult[User]: 63 | return PageResult( 64 | [ 65 | User( 66 | user.id, 67 | user.nickname, 68 | avatar=f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(user.avatar):x}.png", 69 | ) 70 | for user in await adapter.app.backend.list_users() 71 | ] 72 | ) 73 | 74 | @adapter.route(Api.GUILD_LIST) 75 | async def guild_list(request: Request[GuildListParam]) -> PageResult[Guild]: 76 | return PageResult( 77 | [ 78 | Guild( 79 | channel.id, 80 | channel.name, 81 | f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(channel.avatar):x}.png", 82 | ) 83 | for channel in await adapter.app.backend.list_channels() 84 | ] 85 | ) 86 | 87 | @adapter.route(Api.CHANNEL_LIST) 88 | async def channel_list(request: Request[ChannelListParam]) -> PageResult[Channel]: 89 | return PageResult( 90 | [ 91 | Channel( 92 | channel.id, 93 | ChannelType.TEXT, 94 | channel.name, 95 | ) 96 | for channel in await adapter.app.backend.list_channels() 97 | ] 98 | ) 99 | 100 | @adapter.route(Api.GUILD_MEMBER_GET) 101 | async def guild_member_get(request: Request[GuildMemberGetParam]) -> Member: 102 | user_id = request.params["user_id"] 103 | ans = await adapter.app.backend.get_user(user_id) 104 | user = User( 105 | ans.id, 106 | ans.nickname, 107 | avatar=f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(ans.avatar):x}.png", 108 | ) 109 | return Member(user, user.name, user.avatar) 110 | 111 | @adapter.route(Api.GUILD_MEMBER_LIST) 112 | async def guild_member_list(request: Request[GuildXXXListParam]) -> PageResult[Member]: 113 | members = [ 114 | Member( 115 | User( 116 | user.id, 117 | user.nickname, 118 | avatar=f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(user.avatar):x}.png", 119 | ), 120 | user.nickname, 121 | f"https://emoji.aranja.com/static/emoji-data/img-apple-160/{ord(user.avatar):x}.png", 122 | ) 123 | for user in await adapter.app.backend.list_users() 124 | ] 125 | return PageResult(members) 126 | 127 | @adapter.route(Api.USER_CHANNEL_CREATE) 128 | async def user_channel_create(request: Request[UserChannelCreateParam]) -> Channel: 129 | user_id = request.params["user_id"] 130 | user = await adapter.app.backend.get_user(user_id) 131 | channel = Channel( 132 | id=f"private:{user.id}", 133 | type=ChannelType.DIRECT, 134 | name=user.nickname, 135 | ) 136 | return channel 137 | 138 | @adapter.route(Api.MESSAGE_CREATE) 139 | async def message_create(request: Request[MessageParam]): 140 | content = request.params["content"] 141 | channel_id = request.params["channel_id"] 142 | 143 | if channel_id.startswith("private:"): 144 | user_id = channel_id.split(":")[1] 145 | user = await adapter.app.backend.get_user(user_id) 146 | target = await adapter.app.backend.create_dm(user) 147 | else: 148 | target = await adapter.app.backend.get_channel(channel_id) 149 | 150 | bot = next((b for b in await adapter.app.backend.list_bots() if b.id == request.self_id), None) 151 | message_id = await adapter.app.send_message(decode_message(content), target, bot) # type: ignore 152 | return [MessageObject(message_id, content)] 153 | 154 | @adapter.route(Api.MESSAGE_GET) 155 | async def message_get(request: Request[MessageOpParam]) -> MessageObject: 156 | message_id = request.params["message_id"] 157 | channel_id = request.params["channel_id"] 158 | if channel_id.startswith("private:"): 159 | user_id = channel_id.split(":")[1] 160 | user = await adapter.app.backend.get_user(user_id) 161 | channel = await adapter.app.backend.create_dm(user) 162 | else: 163 | channel = await adapter.app.backend.get_channel(channel_id) 164 | event = await adapter.app.backend.get_chat(message_id, channel) 165 | if not event: 166 | raise ValueError(f"Message {message_id} not found in channel {channel_id}") 167 | return MessageObject(message_id, encode_message(event.message)) 168 | 169 | @adapter.route(Api.MESSAGE_LIST) 170 | async def message_list(request: Request[MessageListParam]) -> PageDequeResult[MessageObject]: 171 | channel_id = request.params["channel_id"] 172 | if channel_id.startswith("private:"): 173 | user_id = channel_id.split(":")[1] 174 | user = await adapter.app.backend.get_user(user_id) 175 | channel = await adapter.app.backend.create_dm(user) 176 | else: 177 | channel = await adapter.app.backend.get_channel(channel_id) 178 | messages = await adapter.app.backend.get_chat_history(channel) 179 | return PageDequeResult( 180 | [ 181 | MessageObject( 182 | event.message_id, 183 | encode_message(event.message), 184 | ) 185 | for i, event in enumerate(messages) 186 | ] 187 | ) 188 | 189 | @adapter.route(Api.MESSAGE_UPDATE) 190 | async def message_update(request: Request[MessageUpdateParam]) -> None: 191 | content = request.params["content"] 192 | message_id = request.params["message_id"] 193 | channel_id = request.params["channel_id"] 194 | if channel_id.startswith("private:"): 195 | user_id = channel_id.split(":")[1] 196 | user = await adapter.app.backend.get_user(user_id) 197 | channel = await adapter.app.backend.create_dm(user) 198 | else: 199 | channel = await adapter.app.backend.get_channel(channel_id) 200 | 201 | await adapter.app.edit_message(message_id, decode_message(content), channel) 202 | 203 | @adapter.route(Api.MESSAGE_DELETE) 204 | async def message_delete(request: Request[MessageOpParam]) -> None: 205 | message_id = request.params["message_id"] 206 | channel_id = request.params["channel_id"] 207 | if channel_id.startswith("private:"): 208 | user_id = channel_id.split(":")[1] 209 | user = await adapter.app.backend.get_user(user_id) 210 | channel = await adapter.app.backend.create_dm(user) 211 | else: 212 | channel = await adapter.app.backend.get_channel(channel_id) 213 | 214 | await adapter.app.recall_message(message_id, channel) 215 | 216 | @adapter.route(Api.LOGIN_GET) 217 | async def login_get(request: Request): 218 | return adapter.app.backend.logins[request.self_id] 219 | 220 | @adapter.route("bell") 221 | async def bell(request: Request): 222 | await adapter.app.toggle_bell() 223 | return 224 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/webhook.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | 5 | import aiohttp 6 | from launart import Launart 7 | from launart.status import Phase 8 | from loguru import logger 9 | from starlette.requests import Request as StarletteRequest 10 | from starlette.responses import JSONResponse, Response 11 | from starlette.routing import Route 12 | from yarl import URL 13 | 14 | from satori import EventType 15 | from satori.exception import ActionFailed 16 | from satori.model import Event, Login, LoginStatus 17 | from satori.server.adapter import Adapter as BaseAdapter 18 | from satori.server.model import Request 19 | from satori.utils import decode, encode 20 | 21 | from .api import apply 22 | from .events import event_handlers 23 | from .utils import MilkyNetwork, decode_login_user 24 | 25 | DEFAULT_FEATURES = ["guild.plain", "reaction"] 26 | 27 | 28 | class _MilkyNetwork: 29 | def __init__(self, adapter: MilkyWebhookAdapter): # type: ignore[name-defined] 30 | self.adapter = adapter 31 | 32 | async def call_api(self, action: str, params: dict | None = None): 33 | return await self.adapter.call_api(action, params or {}) 34 | 35 | 36 | class MilkyWebhookAdapter(BaseAdapter): 37 | 38 | session: aiohttp.ClientSession | None 39 | 40 | def __init__( 41 | self, 42 | endpoint: str | URL, 43 | *, 44 | token: str | None = None, 45 | headers: dict[str, str] | None = None, 46 | path: str = "/milky", 47 | self_token: str | None = None, 48 | ): 49 | super().__init__() 50 | self.base_url = URL(str(endpoint)) 51 | base_path = self.base_url.path.rstrip("/") 52 | self.api_base = self.base_url.with_path(f"{base_path}/api") 53 | self.token = token 54 | self.headers = headers.copy() if headers else {} 55 | self.session = None 56 | self.logins: dict[str, Login] = {} 57 | self.networks: dict[str, MilkyNetwork] = {} 58 | self.features = list(DEFAULT_FEATURES) 59 | self.webhook_token = self_token if self_token is not None else token 60 | self.webhook_paths = self._normalize_webhook_paths(path) 61 | apply(self, self._get_network, self._get_login) 62 | 63 | def get_platform(self) -> str: 64 | return "milky" 65 | 66 | def ensure(self, platform: str, self_id: str) -> bool: 67 | return platform == "milky" and self_id in self.logins 68 | 69 | async def get_logins(self) -> list[Login]: 70 | logins = list(self.logins.values()) 71 | for index, login in enumerate(logins): 72 | login.sn = index 73 | return logins 74 | 75 | @property 76 | def required(self) -> set[str]: 77 | return {"satori-python.server"} 78 | 79 | @property 80 | def stages(self) -> set[Phase]: 81 | return {"preparing", "blocking", "cleanup"} 82 | 83 | async def launch(self, manager: Launart): 84 | async with self.stage("preparing"): 85 | self.session = aiohttp.ClientSession() 86 | 87 | async with self.stage("blocking"): 88 | await manager.status.wait_for_sigexit() 89 | 90 | async with self.stage("cleanup"): 91 | if self.session: 92 | await self.session.close() 93 | self.session = None 94 | await self._handle_disconnect() 95 | 96 | def proxy_urls(self) -> list[str]: 97 | return [] 98 | 99 | async def handle_internal(self, request: Request, path: str) -> Response: 100 | if path.startswith("_api"): 101 | data = await request.origin.json() 102 | return JSONResponse(await self.call_api(path[5:], data)) 103 | if not self.session: 104 | raise RuntimeError("HTTP session not initialized") 105 | url = self.base_url.with_path(path) 106 | headers = self.headers.copy() 107 | if self.token: 108 | headers.setdefault("Authorization", f"Bearer {self.token}") 109 | async with self.session.get(url, headers=headers) as resp: 110 | content = await resp.read() 111 | return Response(content=content, media_type=resp.headers.get("Content-Type")) 112 | 113 | async def call_api(self, action: str, params: dict | None = None) -> dict: 114 | if not self.session: 115 | raise RuntimeError("HTTP session not initialized") 116 | url = self.api_base.with_path(f"{self.api_base.path.rstrip('/')}/{action}") 117 | headers = self.headers.copy() 118 | headers["Content-Type"] = "application/json" 119 | if self.token: 120 | headers.setdefault("Authorization", f"Bearer {self.token}") 121 | async with self.session.post(url, data=encode(params or {}), headers=headers) as resp: 122 | resp.raise_for_status() 123 | data = decode(await resp.text()) 124 | if data.get("status") == "failed" or data.get("retcode", 0) != 0: 125 | raise ActionFailed(f"{data.get('retcode')}: {data.get('message')}", data) 126 | return data.get("data") 127 | 128 | def get_routes(self) -> list[Route]: 129 | return [Route(path, self.webhook_endpoint, methods=["POST"]) for path in self.webhook_paths] 130 | 131 | async def handle_event(self, payload: dict): 132 | event_type = payload.get("event_type") 133 | self_id = str(payload.get("self_id")) 134 | if not event_type or not self_id: 135 | return 136 | if self_id not in self.logins: 137 | await self.refresh_login() 138 | if self_id not in self.logins: 139 | logger.warning(f"Ignoring event for unknown self_id {self_id}") 140 | return 141 | login = self.logins[self_id] 142 | handler = event_handlers.get(event_type) 143 | network = self.networks.get(self_id) 144 | if not network: 145 | network = _MilkyNetwork(self) 146 | self.networks[self_id] = network 147 | if handler: 148 | event = await handler(login, network, payload) 149 | else: 150 | body = payload.get("data", {}) 151 | event = Event( 152 | EventType.INTERNAL, 153 | datetime.fromtimestamp(payload.get("time", datetime.now().timestamp())), 154 | login, 155 | _type=event_type, 156 | _data=body, 157 | ) 158 | if event: 159 | await self.server.post(event) 160 | 161 | async def refresh_login(self): 162 | try: 163 | data = await self.call_api("get_login_info", {}) 164 | except Exception as e: 165 | logger.error(f"Failed to fetch milky login info: {e}") 166 | return 167 | if not data: 168 | return 169 | user = decode_login_user(data) 170 | login = Login(0, LoginStatus.ONLINE, "milky", platform="milky", user=user, features=self.features.copy()) 171 | self_id = login.id 172 | previous = self.logins.get(self_id) 173 | self.logins[self_id] = login 174 | self.networks[self_id] = _MilkyNetwork(self) 175 | event_type = EventType.LOGIN_ADDED if previous is None else EventType.LOGIN_UPDATED 176 | await self.server.post(Event(event_type, datetime.now(), login)) 177 | 178 | async def _handle_disconnect(self): 179 | for self_id, login in list(self.logins.items()): 180 | login.status = LoginStatus.OFFLINE 181 | await self.server.post(Event(EventType.LOGIN_REMOVED, datetime.now(), login)) 182 | self.logins.pop(self_id, None) 183 | self.networks.pop(self_id, None) 184 | 185 | def _normalize_webhook_paths(self, webhook_path: str) -> tuple[str, ...]: 186 | path = webhook_path or "/" 187 | normalized = path if path.startswith("/") else f"/{path}" 188 | paths: set[str] = {normalized} 189 | stripped = normalized.rstrip("/") 190 | if stripped and stripped != normalized: 191 | paths.add(stripped) 192 | return tuple(sorted(paths)) 193 | 194 | def _get_network(self, self_id: str) -> MilkyNetwork: 195 | network = self.networks.get(self_id) 196 | if not network: 197 | network = _MilkyNetwork(self) 198 | self.networks[self_id] = network 199 | return network 200 | 201 | def _get_login(self, self_id: str) -> Login: 202 | return self.logins[self_id] 203 | 204 | async def webhook_endpoint(self, request: StarletteRequest) -> Response: 205 | if self.webhook_token: 206 | auth_header = request.headers.get("Authorization") 207 | provided = None 208 | if auth_header: 209 | if auth_header.lower().startswith("bearer "): 210 | provided = auth_header[7:] 211 | else: 212 | provided = auth_header 213 | if provided is None: 214 | provided = request.query_params.get("access_token") 215 | if provided != self.webhook_token: 216 | return JSONResponse({"error": "unauthorized"}, status_code=401) 217 | try: 218 | payload = await request.json() 219 | except Exception as e: # pragma: no cover - defensive 220 | logger.error(f"Failed to parse milky webhook payload: {e}") 221 | return JSONResponse({"error": "invalid json"}, status_code=400) 222 | if not isinstance(payload, dict): 223 | return JSONResponse({"error": "invalid payload"}, status_code=400) 224 | try: 225 | await self.handle_event(payload) 226 | except Exception as e: # pragma: no cover - defensive 227 | logger.exception("Error while processing milky webhook event", exc_info=e) 228 | return JSONResponse({"error": "internal error"}, status_code=500) 229 | return Response(status_code=204) 230 | 231 | 232 | __all__ = ["MilkyWebhookAdapter"] 233 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from datetime import datetime 5 | 6 | import aiohttp 7 | from launart import Launart, any_completed 8 | from launart.status import Phase 9 | from loguru import logger 10 | from starlette.responses import JSONResponse, Response 11 | from yarl import URL 12 | 13 | from satori import EventType 14 | from satori.exception import ActionFailed 15 | from satori.model import Event, Login, LoginStatus 16 | from satori.server.adapter import Adapter as BaseAdapter 17 | from satori.server.model import Request 18 | from satori.utils import decode, encode 19 | 20 | from .api import apply 21 | from .events import event_handlers 22 | from .utils import MilkyNetwork, decode_login_user 23 | 24 | DEFAULT_FEATURES = ["guild.plain", "reaction"] 25 | 26 | 27 | class _MilkyNetwork: 28 | def __init__(self, adapter: MilkyAdapter): # type: ignore[name-defined] 29 | self.adapter = adapter 30 | 31 | async def call_api(self, action: str, params: dict | None = None): 32 | return await self.adapter.call_api(action, params or {}) 33 | 34 | 35 | class MilkyAdapter(BaseAdapter): 36 | 37 | session: aiohttp.ClientSession | None 38 | connection: aiohttp.ClientWebSocketResponse | None 39 | 40 | def __init__( 41 | self, 42 | endpoint: str | URL, 43 | *, 44 | token: str | None = None, 45 | token_in_query: bool = False, 46 | headers: dict[str, str] | None = None, 47 | ): 48 | super().__init__() 49 | self.base_url = URL(str(endpoint)) 50 | base_path = self.base_url.path.rstrip("/") 51 | self.api_base = self.base_url.with_path(f"{base_path}/api") 52 | ws_scheme = "wss" if self.base_url.scheme == "https" else "ws" 53 | self.event_url = self.base_url.with_scheme(ws_scheme).with_path(f"{base_path}/event") 54 | if token_in_query and token: 55 | self.event_url = self.event_url.update_query(access_token=token) 56 | self.token = token 57 | self.headers = headers.copy() if headers else {} 58 | self.session = None 59 | self.connection = None 60 | self.close_signal = asyncio.Event() 61 | self.logins: dict[str, Login] = {} 62 | self.networks: dict[str, MilkyNetwork] = {} 63 | self.features = list(DEFAULT_FEATURES) 64 | apply(self, self._get_network, self._get_login) 65 | 66 | def get_platform(self) -> str: 67 | return "milky" 68 | 69 | def ensure(self, platform: str, self_id: str) -> bool: 70 | return platform == "milky" and self_id in self.logins 71 | 72 | async def get_logins(self) -> list[Login]: 73 | logins = list(self.logins.values()) 74 | for index, login in enumerate(logins): 75 | login.sn = index 76 | return logins 77 | 78 | @property 79 | def required(self) -> set[str]: 80 | return {"satori-python.server"} 81 | 82 | @property 83 | def stages(self) -> set[Phase]: 84 | return {"preparing", "blocking", "cleanup"} 85 | 86 | async def launch(self, manager: Launart): 87 | async with self.stage("preparing"): 88 | self.session = aiohttp.ClientSession() 89 | 90 | async with self.stage("blocking"): 91 | await self.connection_daemon(manager, self.session) 92 | 93 | async with self.stage("cleanup"): 94 | if self.connection and not self.connection.closed: 95 | await self.connection.close() 96 | if self.session: 97 | await self.session.close() 98 | self.connection = None 99 | self.session = None 100 | await self._handle_disconnect() 101 | 102 | def proxy_urls(self) -> list[str]: 103 | return [] 104 | 105 | async def handle_internal(self, request: Request, path: str) -> Response: 106 | if path.startswith("_api"): 107 | data = await request.origin.json() 108 | return JSONResponse(await self.call_api(path[5:], data)) 109 | if not self.session: 110 | raise RuntimeError("HTTP session not initialized") 111 | url = self.base_url.with_path(path) 112 | headers = self.headers.copy() 113 | if self.token: 114 | headers.setdefault("Authorization", f"Bearer {self.token}") 115 | async with self.session.get(url, headers=headers) as resp: 116 | content = await resp.read() 117 | return Response(content=content, media_type=resp.headers.get("Content-Type")) 118 | 119 | async def call_api(self, action: str, params: dict | None = None) -> dict: 120 | if not self.session: 121 | raise RuntimeError("HTTP session not initialized") 122 | url = self.api_base.with_path(f"{self.api_base.path.rstrip('/')}/{action}") 123 | headers = self.headers.copy() 124 | headers["Content-Type"] = "application/json" 125 | if self.token: 126 | headers.setdefault("Authorization", f"Bearer {self.token}") 127 | async with self.session.post(url, data=encode(params or {}), headers=headers) as resp: 128 | resp.raise_for_status() 129 | data = decode(await resp.text()) 130 | if data.get("status") == "failed" or data.get("retcode", 0) != 0: 131 | raise ActionFailed(f"{data.get('retcode')}: {data.get('message')}", data) 132 | return data.get("data") 133 | 134 | async def connection_daemon(self, manager: Launart, session: aiohttp.ClientSession): 135 | while not manager.status.exiting: 136 | headers = self.headers.copy() 137 | if self.token: 138 | headers.setdefault("Authorization", f"Bearer {self.token}") 139 | try: 140 | self.connection = await session.ws_connect(self.event_url, headers=headers) 141 | except Exception as e: 142 | logger.error(f"Milky adapter websocket connect failed: {e}") 143 | await asyncio.sleep(5) 144 | continue 145 | logger.info("Milky adapter websocket connected") 146 | self.close_signal.clear() 147 | await self.refresh_login() 148 | receiver_task = asyncio.create_task(self.message_handle()) 149 | close_task = asyncio.create_task(self.close_signal.wait()) 150 | sigexit_task = asyncio.create_task(manager.status.wait_for_sigexit()) 151 | 152 | done, pending = await any_completed(receiver_task, close_task, sigexit_task) 153 | for task in pending: 154 | task.cancel() 155 | await asyncio.gather(*pending, return_exceptions=True) 156 | if sigexit_task in done: 157 | break 158 | logger.warning("Milky adapter websocket closed, retrying in 5 seconds") 159 | await self._handle_disconnect() 160 | await asyncio.sleep(5) 161 | await self._handle_disconnect() 162 | 163 | async def message_handle(self): 164 | assert self.connection is not None 165 | async for msg in self.connection: 166 | if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): 167 | self.close_signal.set() 168 | break 169 | if msg.type != aiohttp.WSMsgType.TEXT: 170 | continue 171 | try: 172 | data = decode(msg.data) 173 | except Exception as e: # pragma: no cover - defensive 174 | logger.error(f"Failed to decode milky event: {e}") 175 | continue 176 | if not isinstance(data, dict): 177 | continue 178 | await self.handle_event(data) 179 | 180 | async def handle_event(self, payload: dict): 181 | event_type = payload.get("event_type") 182 | self_id = str(payload.get("self_id")) 183 | if not event_type or not self_id: 184 | return 185 | if self_id not in self.logins: 186 | await self.refresh_login() 187 | if self_id not in self.logins: 188 | logger.warning(f"Ignoring event for unknown self_id {self_id}") 189 | return 190 | login = self.logins[self_id] 191 | handler = event_handlers.get(event_type) 192 | network = self.networks.get(self_id) 193 | if not network: 194 | network = _MilkyNetwork(self) 195 | self.networks[self_id] = network 196 | if handler: 197 | event = await handler(login, network, payload) 198 | else: 199 | body = payload.get("data", {}) 200 | event = Event( 201 | EventType.INTERNAL, 202 | datetime.fromtimestamp(payload.get("time", datetime.now().timestamp())), 203 | login, 204 | _type=event_type, 205 | _data=body, 206 | ) 207 | if event: 208 | await self.server.post(event) 209 | 210 | async def refresh_login(self): 211 | try: 212 | data = await self.call_api("get_login_info", {}) 213 | except Exception as e: 214 | logger.error(f"Failed to fetch milky login info: {e}") 215 | return 216 | if not data: 217 | return 218 | user = decode_login_user(data) 219 | login = Login(0, LoginStatus.ONLINE, "milky", platform="milky", user=user, features=self.features.copy()) 220 | self_id = login.id 221 | previous = self.logins.get(self_id) 222 | self.logins[self_id] = login 223 | self.networks[self_id] = _MilkyNetwork(self) 224 | event_type = EventType.LOGIN_ADDED if previous is None else EventType.LOGIN_UPDATED 225 | await self.server.post(Event(event_type, datetime.now(), login)) 226 | 227 | async def _handle_disconnect(self): 228 | for self_id, login in list(self.logins.items()): 229 | login.status = LoginStatus.OFFLINE 230 | await self.server.post(Event(EventType.LOGIN_REMOVED, datetime.now(), login)) 231 | self.logins.pop(self_id, None) 232 | self.networks.pop(self_id, None) 233 | if self.connection and not self.connection.closed: 234 | await self.connection.close() 235 | self.close_signal.set() 236 | 237 | def _get_network(self, self_id: str) -> MilkyNetwork: 238 | network = self.networks.get(self_id) 239 | if not network: 240 | network = _MilkyNetwork(self) 241 | self.networks[self_id] = network 242 | return network 243 | 244 | def _get_login(self, self_id: str) -> Login: 245 | return self.logins[self_id] 246 | 247 | 248 | __all__ = ["MilkyAdapter"] 249 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/forward.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from contextlib import suppress 5 | from datetime import datetime 6 | from typing import cast 7 | 8 | import aiohttp 9 | from aiohttp import ClientSession, ClientWebSocketResponse 10 | from launart import Launart, any_completed 11 | from launart.status import Phase 12 | from loguru import logger 13 | from starlette.responses import JSONResponse, Response 14 | from yarl import URL 15 | 16 | from satori import Event, EventType, LoginStatus 17 | from satori.exception import ActionFailed 18 | from satori.model import Login, User 19 | from satori.server import Request 20 | from satori.server.adapter import Adapter as BaseAdapter 21 | from satori.utils import decode, encode 22 | 23 | from .api import apply 24 | from .events.base import events 25 | from .utils import USER_AVATAR_URL, onebot11_event_type 26 | 27 | 28 | class OneBot11ForwardAdapter(BaseAdapter): 29 | 30 | session: ClientSession 31 | connection: ClientWebSocketResponse | None 32 | 33 | def __init__( 34 | self, 35 | endpoint: str | URL, 36 | access_token: str | None = None, 37 | ): 38 | super().__init__() 39 | self.endpoint = URL(endpoint) 40 | self.access_token = access_token 41 | self.close_signal = asyncio.Event() 42 | self.response_waiters: dict[str, asyncio.Future] = {} 43 | self.logins: dict[str, Login] = {} 44 | 45 | apply(self, lambda _: self, lambda _: self.logins[_]) 46 | 47 | def ensure(self, platform: str, self_id: str) -> bool: 48 | return platform == "onebot" and self_id in self.logins 49 | 50 | async def get_logins(self) -> list[Login]: 51 | logins = list(self.logins.values()) 52 | for index, login in enumerate(logins): 53 | login.sn = index 54 | return logins 55 | 56 | @property 57 | def required(self) -> set[str]: 58 | return {"satori-python.server"} 59 | 60 | @property 61 | def stages(self) -> set[Phase]: 62 | return {"preparing", "blocking", "cleanup"} 63 | 64 | async def message_handle(self): 65 | async for connection, data in self.message_receive(): 66 | if echo := data.get("echo"): 67 | if future := self.response_waiters.get(echo): 68 | future.set_result(data) 69 | continue 70 | 71 | async def event_parse_task(data: dict): 72 | event_type = onebot11_event_type(data) 73 | if event_type == "meta_event.lifecycle.connect": 74 | self_id = str(data["self_id"]) 75 | if self_id not in self.logins: 76 | self_info = await self.call_api("get_login_info") 77 | login = Login( 78 | 0, 79 | LoginStatus.ONLINE, 80 | "onebot", 81 | platform="onebot", 82 | user=User( 83 | self_id, 84 | (self_info or {})["nickname"], 85 | avatar=USER_AVATAR_URL.format(uin=self_id), 86 | ), 87 | features=["guild.plain"], 88 | ) 89 | self.logins[self_id] = login 90 | await self.server.post(Event(EventType.LOGIN_ADDED, datetime.now(), login)) 91 | elif event_type == "meta_event.lifecycle.enable": 92 | logger.warning(f"received lifecycle.enable event that is only supported in http adapter: {data}") 93 | return 94 | elif event_type == "meta_event.lifecycle.disable": 95 | logger.warning(f"received lifecycle.disable event that is only supported in http adapter: {data}") 96 | return 97 | elif event_type == "meta_event.heartbeat": 98 | self_id = str(data["self_id"]) 99 | if self_id not in self.logins: 100 | self_info = await self.call_api("get_login_info") 101 | login = Login( 102 | 0, 103 | LoginStatus.ONLINE, 104 | "onebot", 105 | platform="onebot", 106 | user=User( 107 | self_id, 108 | (self_info or {})["nickname"], 109 | avatar=USER_AVATAR_URL.format(uin=self_id), 110 | ), 111 | features=["guild.plain"], 112 | ) 113 | self.logins[self_id] = login 114 | await self.server.post(Event(EventType.LOGIN_ADDED, datetime.now(), login)) 115 | logger.trace(f"received heartbeat from {self_id}") 116 | else: 117 | self_id = str(data["self_id"]) 118 | if self_id not in self.logins: 119 | logger.warning(f"received event from unknown self_id: {data}") 120 | return 121 | login = self.logins[self_id] 122 | handler = events.get(event_type) 123 | if not handler: 124 | event = Event(EventType.INTERNAL, datetime.now(), login, _type=event_type, _data=data) 125 | else: 126 | event = await handler(login, self, data) 127 | if event: 128 | await self.server.post(event) 129 | 130 | asyncio.create_task(event_parse_task(data)) 131 | 132 | async def message_receive(self): 133 | if self.connection is None: 134 | raise RuntimeError("connection is not established") 135 | 136 | async for msg in self.connection: 137 | if msg.type in {aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED}: 138 | self.close_signal.set() 139 | break 140 | elif msg.type == aiohttp.WSMsgType.TEXT: 141 | data: dict = decode(cast(str, msg.data)) 142 | yield self, data 143 | else: 144 | self.close_signal.set() 145 | 146 | async def connection_daemon(self, manager: Launart, session: ClientSession): 147 | while not manager.status.exiting: 148 | ctx = session.ws_connect( 149 | self.endpoint, 150 | headers=( 151 | {"Authorization": f"Bearer {access_token}"} 152 | if (access_token := self.access_token) is not None 153 | else None 154 | ), 155 | ) 156 | try: 157 | self.connection = await ctx.__aenter__() 158 | except Exception as e: 159 | logger.error(f"{self} Websocket client connection failed: {e}") 160 | logger.debug(f"{self} Will retry in 5 seconds...") 161 | with suppress(AttributeError): 162 | await ctx.__aexit__(None, None, None) 163 | await asyncio.sleep(5) 164 | continue 165 | logger.info(f"{self} Websocket client connected") 166 | self.close_signal.clear() 167 | if self.logins: 168 | for login in self.logins.values(): 169 | login.status = LoginStatus.ONLINE 170 | await self.server.post(Event(EventType.LOGIN_UPDATED, datetime.now(), login)) 171 | close_task = asyncio.create_task(self.close_signal.wait()) 172 | receiver_task = asyncio.create_task(self.message_handle()) 173 | sigexit_task = asyncio.create_task(manager.status.wait_for_sigexit()) 174 | 175 | done, pending = await any_completed( 176 | sigexit_task, 177 | close_task, 178 | receiver_task, 179 | ) 180 | if sigexit_task in done: 181 | logger.info(f"{self} Websocket client exiting...") 182 | await self.connection.close() 183 | self.close_signal.set() 184 | self.connection = None 185 | for login in self.logins.values(): 186 | login.status = LoginStatus.OFFLINE 187 | await self.server.post(Event(EventType.LOGIN_REMOVED, datetime.now(), login)) 188 | await asyncio.sleep(1) 189 | return 190 | if close_task in done: 191 | receiver_task.cancel() 192 | logger.warning(f"{self} Connection closed by server, will reconnect in 5 seconds...") 193 | for login in self.logins.values(): 194 | login.status = LoginStatus.RECONNECT 195 | await self.server.post(Event(EventType.LOGIN_UPDATED, datetime.now(), login)) 196 | await asyncio.sleep(5) 197 | logger.info(f"{self} Reconnecting...") 198 | continue 199 | 200 | async def launch(self, manager: Launart): 201 | async with self.stage("preparing"): 202 | self.session = ClientSession() 203 | 204 | async with self.stage("blocking"): 205 | await self.connection_daemon(manager, self.session) 206 | 207 | async with self.stage("cleanup"): 208 | await self.session.close() 209 | self.connection = None 210 | 211 | def get_platform(self) -> str: 212 | return "onebot" 213 | 214 | async def handle_internal(self, request: Request, path: str) -> Response: 215 | if path.startswith("_api"): 216 | return JSONResponse(await self.call_api(path[5:], await request.origin.json())) 217 | async with self.session.get(path) as resp: 218 | return Response(await resp.read()) 219 | 220 | async def call_api(self, action: str, params: dict | None = None) -> dict: 221 | if not self.connection: 222 | raise RuntimeError("connection is not established") 223 | 224 | future: asyncio.Future[dict] = asyncio.get_running_loop().create_future() 225 | echo = str(hash(future)) 226 | self.response_waiters[echo] = future 227 | 228 | try: 229 | await self.connection.send_str(encode({"action": action, "params": params or {}, "echo": echo})) 230 | result = await future 231 | finally: 232 | del self.response_waiters[echo] 233 | 234 | if result["status"] != "ok": 235 | raise ActionFailed(f"{result['retcode']}: {result}", result) 236 | 237 | return result.get("data", {}) 238 | 239 | def __str__(self): 240 | return self.id 241 | 242 | 243 | Adapter = OneBot11ForwardAdapter 244 | -------------------------------------------------------------------------------- /src/satori/server/route.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Awaitable, Callable 2 | from typing import Any, Literal, Protocol, TypeAlias, TypeVar, overload 3 | from typing_extensions import NotRequired, TypedDict 4 | 5 | from starlette.datastructures import FormData 6 | 7 | from satori.model import ( 8 | Channel, 9 | Direction, 10 | Guild, 11 | Login, 12 | Member, 13 | MessageObject, 14 | ModelBase, 15 | Order, 16 | PageDequeResult, 17 | PageResult, 18 | Role, 19 | User, 20 | ) 21 | 22 | from .. import Api 23 | from .model import Request 24 | 25 | T = TypeVar("T") 26 | R = TypeVar("R", covariant=True) 27 | 28 | 29 | class RouteCall(Protocol[T, R]): 30 | def __call__(self, request: Request[T]) -> Awaitable[R]: ... 31 | 32 | 33 | INTERAL: TypeAlias = RouteCall[Any, ModelBase | list[ModelBase] | dict[str, Any] | list[dict[str, Any]] | None] 34 | 35 | 36 | class MessageParam(TypedDict): 37 | channel_id: str 38 | content: str 39 | 40 | 41 | MESSAGE_CREATE: TypeAlias = RouteCall[MessageParam, list[MessageObject] | list[dict[str, Any]]] 42 | 43 | 44 | class MessageOpParam(TypedDict): 45 | channel_id: str 46 | message_id: str 47 | 48 | 49 | MESSAGE_GET: TypeAlias = RouteCall[MessageOpParam, MessageObject | dict[str, Any]] 50 | MESSAGE_DELETE: TypeAlias = RouteCall[MessageOpParam, None] 51 | 52 | 53 | class MessageUpdateParam(TypedDict): 54 | channel_id: str 55 | message_id: str 56 | content: str 57 | 58 | 59 | MESSAGE_UPDATE: TypeAlias = RouteCall[MessageUpdateParam, None] 60 | 61 | 62 | class MessageListParam(TypedDict): 63 | channel_id: str 64 | next: NotRequired[str] 65 | direction: NotRequired[Direction] 66 | limit: NotRequired[int] 67 | order: NotRequired[Order] 68 | 69 | 70 | MESSAGE_LIST: TypeAlias = RouteCall[MessageListParam, PageDequeResult[MessageObject] | dict[str, Any]] 71 | 72 | 73 | class ChannelParam(TypedDict): 74 | channel_id: str 75 | 76 | 77 | CHANNEL_GET: TypeAlias = RouteCall[ChannelParam, Channel | dict[str, Any]] 78 | CHANNEL_DELETE: TypeAlias = RouteCall[ChannelParam, None] 79 | 80 | 81 | class ChannelListParam(TypedDict): 82 | guild_id: str 83 | next: NotRequired[str] 84 | 85 | 86 | CHANNEL_LIST: TypeAlias = RouteCall[ChannelListParam, PageResult[Channel] | dict[str, Any]] 87 | 88 | 89 | class ChannelCreateParam(TypedDict): 90 | guild_id: str 91 | data: dict 92 | 93 | 94 | CHANNEL_CREATE: TypeAlias = RouteCall[ChannelCreateParam, Channel | dict[str, Any]] 95 | 96 | 97 | class ChannelUpdateParam(TypedDict): 98 | channel_id: str 99 | data: dict 100 | 101 | 102 | CHANNEL_UPDATE: TypeAlias = RouteCall[ChannelUpdateParam, None] 103 | 104 | 105 | class ChannelMuteParam(TypedDict): 106 | channel_id: str 107 | duration: float 108 | 109 | 110 | CHANNEL_MUTE: TypeAlias = RouteCall[ChannelMuteParam, None] 111 | 112 | 113 | class UserChannelCreateParam(TypedDict): 114 | user_id: str 115 | guild_id: NotRequired[str] 116 | 117 | 118 | ROUTE_USER_CHANNEL_CREATE: TypeAlias = RouteCall[UserChannelCreateParam, Channel | dict[str, Any]] 119 | 120 | 121 | class GuildGetParam(TypedDict): 122 | guild_id: str 123 | 124 | 125 | GUILD_GET: TypeAlias = RouteCall[GuildGetParam, Guild | dict[str, Any]] 126 | 127 | 128 | class GuildListParam(TypedDict): 129 | next: NotRequired[str] 130 | 131 | 132 | GUILD_LIST: TypeAlias = RouteCall[GuildListParam, PageResult[Guild] | dict[str, Any]] 133 | 134 | 135 | class GuildMemberGetParam(TypedDict): 136 | guild_id: str 137 | user_id: str 138 | 139 | 140 | GUILD_MEMBER_GET: TypeAlias = RouteCall[GuildMemberGetParam, Member | dict[str, Any]] 141 | 142 | 143 | class GuildXXXListParam(TypedDict): 144 | guild_id: str 145 | next: NotRequired[str] 146 | 147 | 148 | GUILD_MEMBER_LIST: TypeAlias = RouteCall[GuildXXXListParam, PageResult[Member] | dict[str, Any]] 149 | 150 | 151 | class GuildMemberKickParam(TypedDict): 152 | guild_id: str 153 | user_id: str 154 | permanent: NotRequired[bool] 155 | 156 | 157 | GUILD_MEMBER_KICK: TypeAlias = RouteCall[GuildMemberKickParam, None] 158 | 159 | 160 | class GuildMemberMuteParam(TypedDict): 161 | guild_id: str 162 | user_id: str 163 | duration: float 164 | 165 | 166 | GUILD_MEMBER_MUTE: TypeAlias = RouteCall[GuildMemberMuteParam, None] 167 | 168 | 169 | class GuildMemberRoleParam(TypedDict): 170 | guild_id: str 171 | user_id: str 172 | role_id: str 173 | 174 | 175 | GUILD_MEMBER_ROLE_SET: TypeAlias = RouteCall[GuildMemberRoleParam, None] 176 | GUILD_MEMBER_ROLE_UNSET: TypeAlias = RouteCall[GuildMemberRoleParam, None] 177 | 178 | GUILD_ROLE_LIST: TypeAlias = RouteCall[GuildXXXListParam, PageResult[Role] | dict[str, Any]] 179 | 180 | 181 | class GuildRoleCreateParam(TypedDict): 182 | guild: str 183 | role: dict 184 | 185 | 186 | GUILD_ROLE_CREATE: TypeAlias = RouteCall[GuildRoleCreateParam, Role | dict[str, Any]] 187 | 188 | 189 | class GuildRoleUpdateParam(TypedDict): 190 | guild: str 191 | role_id: str 192 | role: dict 193 | 194 | 195 | GUILD_ROLE_UPDATE: TypeAlias = RouteCall[GuildRoleUpdateParam, None] 196 | 197 | 198 | class GuildRoleDeleteParam(TypedDict): 199 | guild: str 200 | role_id: str 201 | 202 | 203 | GUILD_ROLE_DELETE: TypeAlias = RouteCall[GuildRoleDeleteParam, None] 204 | 205 | 206 | class ReactionCreateParam(TypedDict): 207 | channel_id: str 208 | message_id: str 209 | emoji: str 210 | 211 | 212 | REACTION_CREATE: TypeAlias = RouteCall[ReactionCreateParam, None] 213 | 214 | 215 | class ReactionDeleteParam(TypedDict): 216 | channel_id: str 217 | message_id: str 218 | emoji: str 219 | user_id: NotRequired[str] 220 | 221 | 222 | REACTION_DELETE: TypeAlias = RouteCall[ReactionDeleteParam, None] 223 | 224 | 225 | class ReactionClearParam(TypedDict): 226 | channel_id: str 227 | message_id: str 228 | emoji: NotRequired[str] 229 | 230 | 231 | REACTION_CLEAR: TypeAlias = RouteCall[ReactionClearParam, None] 232 | 233 | 234 | class ReactionListParam(TypedDict): 235 | channel_id: str 236 | message_id: str 237 | emoji: str 238 | next: NotRequired[str] 239 | 240 | 241 | REACTION_LIST: TypeAlias = RouteCall[ReactionListParam, PageResult[User] | dict[str, Any]] 242 | LOGIN_GET: TypeAlias = RouteCall[Any, Login | dict[str, Any]] 243 | 244 | 245 | class UserGetParam(TypedDict): 246 | user_id: str 247 | 248 | 249 | USER_GET: TypeAlias = RouteCall[UserGetParam, User | dict[str, Any]] 250 | 251 | 252 | class FriendListParam(TypedDict): 253 | next: NotRequired[str] 254 | 255 | 256 | FRIEND_LIST: TypeAlias = RouteCall[FriendListParam, PageResult[User] | dict[str, Any]] 257 | 258 | 259 | class ApproveParam(TypedDict): 260 | message_id: str 261 | approve: bool 262 | comment: NotRequired[str] 263 | 264 | 265 | APPROVE: TypeAlias = RouteCall[ApproveParam, None] 266 | 267 | 268 | UPLOAD_CREATE: TypeAlias = RouteCall[FormData, dict[str, str]] 269 | 270 | 271 | class RouterMixin: 272 | routes: dict[str, RouteCall[Any, Any]] 273 | 274 | @overload 275 | def route(self, path: Literal[Api.MESSAGE_CREATE]) -> Callable[[MESSAGE_CREATE], MESSAGE_CREATE]: ... 276 | 277 | @overload 278 | def route(self, path: Literal[Api.MESSAGE_UPDATE]) -> Callable[[MESSAGE_UPDATE], MESSAGE_UPDATE]: ... 279 | 280 | @overload 281 | def route(self, path: Literal[Api.MESSAGE_GET]) -> Callable[[MESSAGE_GET], MESSAGE_GET]: ... 282 | 283 | @overload 284 | def route(self, path: Literal[Api.MESSAGE_DELETE]) -> Callable[[MESSAGE_DELETE], MESSAGE_DELETE]: ... 285 | 286 | @overload 287 | def route(self, path: Literal[Api.MESSAGE_LIST]) -> Callable[[MESSAGE_LIST], MESSAGE_LIST]: ... 288 | 289 | @overload 290 | def route(self, path: Literal[Api.CHANNEL_GET]) -> Callable[[CHANNEL_GET], CHANNEL_GET]: ... 291 | 292 | @overload 293 | def route(self, path: Literal[Api.CHANNEL_LIST]) -> Callable[[CHANNEL_LIST], CHANNEL_LIST]: ... 294 | 295 | @overload 296 | def route(self, path: Literal[Api.CHANNEL_CREATE]) -> Callable[[CHANNEL_CREATE], CHANNEL_CREATE]: ... 297 | 298 | @overload 299 | def route(self, path: Literal[Api.CHANNEL_UPDATE]) -> Callable[[CHANNEL_UPDATE], CHANNEL_UPDATE]: ... 300 | 301 | @overload 302 | def route(self, path: Literal[Api.CHANNEL_DELETE]) -> Callable[[CHANNEL_DELETE], CHANNEL_DELETE]: ... 303 | 304 | @overload 305 | def route(self, path: Literal[Api.CHANNEL_MUTE]) -> Callable[[CHANNEL_MUTE], CHANNEL_MUTE]: ... 306 | 307 | @overload 308 | def route( 309 | self, path: Literal[Api.USER_CHANNEL_CREATE] 310 | ) -> Callable[[ROUTE_USER_CHANNEL_CREATE], ROUTE_USER_CHANNEL_CREATE]: ... 311 | 312 | @overload 313 | def route(self, path: Literal[Api.GUILD_GET]) -> Callable[[GUILD_GET], GUILD_GET]: ... 314 | 315 | @overload 316 | def route(self, path: Literal[Api.GUILD_LIST]) -> Callable[[GUILD_LIST], GUILD_LIST]: ... 317 | 318 | @overload 319 | def route(self, path: Literal[Api.GUILD_APPROVE]) -> Callable[[APPROVE], APPROVE]: ... 320 | 321 | @overload 322 | def route(self, path: Literal[Api.GUILD_MEMBER_LIST]) -> Callable[[GUILD_MEMBER_LIST], GUILD_MEMBER_LIST]: ... 323 | 324 | @overload 325 | def route(self, path: Literal[Api.GUILD_MEMBER_GET]) -> Callable[[GUILD_MEMBER_GET], GUILD_MEMBER_GET]: ... 326 | 327 | @overload 328 | def route(self, path: Literal[Api.GUILD_MEMBER_KICK]) -> Callable[[GUILD_MEMBER_KICK], GUILD_MEMBER_KICK]: ... 329 | 330 | @overload 331 | def route(self, path: Literal[Api.GUILD_MEMBER_MUTE]) -> Callable[[GUILD_MEMBER_MUTE], GUILD_MEMBER_MUTE]: ... 332 | 333 | @overload 334 | def route(self, path: Literal[Api.GUILD_MEMBER_APPROVE]) -> Callable[[APPROVE], APPROVE]: ... 335 | 336 | @overload 337 | def route( 338 | self, path: Literal[Api.GUILD_MEMBER_ROLE_SET] 339 | ) -> Callable[[GUILD_MEMBER_ROLE_SET], GUILD_MEMBER_ROLE_SET]: ... 340 | 341 | @overload 342 | def route( 343 | self, path: Literal[Api.GUILD_MEMBER_ROLE_UNSET] 344 | ) -> Callable[[GUILD_MEMBER_ROLE_UNSET], GUILD_MEMBER_ROLE_UNSET]: ... 345 | 346 | @overload 347 | def route(self, path: Literal[Api.GUILD_ROLE_LIST]) -> Callable[[GUILD_ROLE_LIST], GUILD_ROLE_LIST]: ... 348 | 349 | @overload 350 | def route(self, path: Literal[Api.GUILD_ROLE_CREATE]) -> Callable[[GUILD_ROLE_CREATE], GUILD_ROLE_CREATE]: ... 351 | 352 | @overload 353 | def route(self, path: Literal[Api.GUILD_ROLE_UPDATE]) -> Callable[[GUILD_ROLE_UPDATE], GUILD_ROLE_UPDATE]: ... 354 | 355 | @overload 356 | def route(self, path: Literal[Api.GUILD_ROLE_DELETE]) -> Callable[[GUILD_ROLE_DELETE], GUILD_ROLE_DELETE]: ... 357 | 358 | @overload 359 | def route(self, path: Literal[Api.REACTION_CREATE]) -> Callable[[REACTION_CREATE], REACTION_CREATE]: ... 360 | 361 | @overload 362 | def route(self, path: Literal[Api.REACTION_DELETE]) -> Callable[[REACTION_DELETE], REACTION_DELETE]: ... 363 | 364 | @overload 365 | def route(self, path: Literal[Api.REACTION_CLEAR]) -> Callable[[REACTION_CLEAR], REACTION_CLEAR]: ... 366 | 367 | @overload 368 | def route(self, path: Literal[Api.REACTION_LIST]) -> Callable[[REACTION_LIST], REACTION_LIST]: ... 369 | 370 | @overload 371 | def route(self, path: Literal[Api.LOGIN_GET]) -> Callable[[LOGIN_GET], LOGIN_GET]: ... 372 | 373 | @overload 374 | def route(self, path: Literal[Api.USER_GET]) -> Callable[[USER_GET], USER_GET]: ... 375 | 376 | @overload 377 | def route(self, path: Literal[Api.FRIEND_LIST]) -> Callable[[FRIEND_LIST], FRIEND_LIST]: ... 378 | 379 | @overload 380 | def route(self, path: Literal[Api.FRIEND_APPROVE]) -> Callable[[APPROVE], APPROVE]: ... 381 | 382 | @overload 383 | def route(self, path: Literal[Api.UPLOAD_CREATE]) -> Callable[[UPLOAD_CREATE], UPLOAD_CREATE]: ... 384 | 385 | @overload 386 | def route(self, path: str) -> Callable[[INTERAL], INTERAL]: ... 387 | 388 | def route(self, path: str | Api) -> Callable[[RouteCall], RouteCall]: 389 | """注册一个 Satori 路由 390 | 391 | Args: 392 | path (str | Api): 路由路径;若 path 不属于 Api,则会被认为是内部接口 393 | """ 394 | 395 | def wrapper(func: RouteCall): 396 | if isinstance(path, Api): 397 | self.routes[path.value] = func 398 | else: 399 | self.routes[f"internal/{path}"] = func 400 | return func 401 | 402 | return wrapper 403 | -------------------------------------------------------------------------------- /experimental/_model_msgspec.py: -------------------------------------------------------------------------------- 1 | import mimetypes 2 | from collections.abc import AsyncIterable, Awaitable, Callable 3 | from datetime import datetime 4 | from enum import IntEnum 5 | from os import PathLike 6 | from pathlib import Path 7 | from typing import IO, Any, Generic, Literal, TypeAlias, TypeVar 8 | from typing_extensions import Self 9 | 10 | from msgspec import Struct, field, convert, to_builtins 11 | 12 | from satori.element import Element, transform 13 | from satori.parser import Element as RawElement 14 | from satori.parser import parse 15 | 16 | 17 | class ModelBase: 18 | 19 | @classmethod 20 | def parse(cls: type[Self], raw: dict) -> Self: 21 | obj = convert(raw, cls, strict=False) 22 | obj._raw_data = raw 23 | return obj 24 | 25 | def dump(self) -> dict: 26 | _raw_data = getattr(self, "_raw_data", None) 27 | try: 28 | return to_builtins(self) # type: ignore 29 | finally: 30 | if _raw_data is not None: 31 | self._raw_data = _raw_data 32 | 33 | 34 | class ChannelType(IntEnum): 35 | TEXT = 0 36 | DIRECT = 1 37 | CATEGORY = 2 38 | VOICE = 3 39 | 40 | 41 | class Channel(Struct, ModelBase, kw_only=True): 42 | id: str 43 | type: ChannelType = ChannelType.TEXT 44 | name: str | None = None 45 | parent_id: str | None = None 46 | 47 | 48 | class Guild(Struct, ModelBase, kw_only=True): 49 | id: str 50 | name: str | None = None 51 | avatar: str | None = None 52 | 53 | 54 | class User(Struct, ModelBase, kw_only=True): 55 | id: str 56 | name: str | None = None 57 | nick: str | None = None 58 | avatar: str | None = None 59 | is_bot: bool | None = None 60 | 61 | 62 | class Member(Struct, ModelBase, kw_only=True): 63 | user: User | None = None 64 | nick: str | None = None 65 | avatar: str | None = None 66 | joined_at: datetime | None = None 67 | 68 | @classmethod 69 | def parse(cls, raw: dict): 70 | if "joined_at" in raw: 71 | raw["joined_at"] = int(raw["joined_at"]) / 1000 72 | return super().parse(raw) 73 | 74 | def dump(self): 75 | _joined_at = None 76 | if self.joined_at is not None: 77 | _joined_at = self.joined_at 78 | self.joined_at = self.joined_at.timestamp() * 1000 # type: ignore 79 | try: 80 | return super().dump() 81 | finally: 82 | if _joined_at is not None: 83 | self.joined_at = _joined_at 84 | 85 | 86 | class Role(Struct, ModelBase, kw_only=True): 87 | id: str 88 | name: str | None = None 89 | 90 | 91 | class LoginStatus(IntEnum): 92 | OFFLINE = 0 93 | """离线""" 94 | ONLINE = 1 95 | """在线""" 96 | CONNECT = 2 97 | """正在连接""" 98 | DISCONNECT = 3 99 | """正在断开连接""" 100 | RECONNECT = 4 101 | """正在重新连接""" 102 | 103 | 104 | class Login(Struct, ModelBase, kw_only=True): 105 | sn: int 106 | status: LoginStatus 107 | adapter: str 108 | platform: str 109 | user: User 110 | features: list[str] = field(default_factory=list) 111 | 112 | @property 113 | def id(self) -> str: 114 | return self.user.id 115 | 116 | 117 | class LoginPartial(Login): 118 | platform: str | None = None 119 | user: User | None = None 120 | 121 | 122 | class ArgvInteraction(Struct, ModelBase, kw_only=True): 123 | name: str 124 | arguments: list 125 | options: Any 126 | 127 | 128 | class ButtonInteraction(Struct, ModelBase, kw_only=True): 129 | id: str 130 | 131 | 132 | class Opcode(IntEnum): 133 | EVENT = 0 134 | """事件 (接收)""" 135 | PING = 1 136 | """心跳 (发送)""" 137 | PONG = 2 138 | """心跳回复 (接收)""" 139 | IDENTIFY = 3 140 | """鉴权 (发送)""" 141 | READY = 4 142 | """鉴权成功 (接收)""" 143 | META = 5 144 | """元信息更新 (接收)""" 145 | 146 | 147 | class Identify(Struct, ModelBase, kw_only=True): 148 | token: str | None = None 149 | sn: int | None = None 150 | 151 | @property 152 | def sequence(self) -> int | None: 153 | return self.sn 154 | 155 | 156 | class Ready(Struct, ModelBase, kw_only=True): 157 | logins: list[LoginPartial] 158 | proxy_urls: list[str] = field(default_factory=list) 159 | 160 | 161 | class MetaPayload(Struct, ModelBase, kw_only=True): 162 | """Meta 信令""" 163 | 164 | proxy_urls: list[str] 165 | 166 | 167 | class Meta(Struct, ModelBase, kw_only=True): 168 | """Meta 数据""" 169 | 170 | logins: list[LoginPartial] 171 | proxy_urls: list[str] = field(default_factory=list) 172 | 173 | 174 | class MessageObject(Struct, ModelBase, kw_only=True): 175 | id: str 176 | content: str 177 | channel: Channel | None = None 178 | guild: Guild | None = None 179 | member: Member | None = None 180 | user: User | None = None 181 | created_at: datetime | None = None 182 | updated_at: datetime | None = None 183 | 184 | @classmethod 185 | def from_elements( 186 | cls, 187 | id: str, 188 | content: list[Element], 189 | channel: Channel | None = None, 190 | guild: Guild | None = None, 191 | member: Member | None = None, 192 | user: User | None = None, 193 | created_at: datetime | None = None, 194 | updated_at: datetime | None = None, 195 | ): 196 | content = "".join(str(i) for i in content) # type: ignore 197 | data = locals().copy() 198 | data.pop("cls", None) 199 | data.pop("__class__", None) 200 | obj = cls(**data) # type: ignore 201 | obj._parsed_message = content 202 | return obj 203 | 204 | @property 205 | def message(self) -> list[Element]: 206 | if hasattr(self, "_parsed_message"): 207 | return self._parsed_message 208 | self._parsed_message = transform(parse(self.content)) 209 | return self._parsed_message 210 | 211 | @message.setter 212 | def message(self, value: list[Element]): 213 | self._parsed_message = value 214 | self.content = "".join(str(i) for i in value) 215 | 216 | @classmethod 217 | def parse(cls, raw: dict): 218 | if "elements" in raw and "content" not in raw: 219 | content = [RawElement(*item.values()) for item in raw["elements"]] 220 | raw["content"] = "".join(str(i) for i in content) 221 | if "created_at" in raw: 222 | raw["created_at"] = int(raw["created_at"]) / 1000 223 | if "updated_at" in raw: 224 | raw["updated_at"] = int(raw["updated_at"]) / 1000 225 | return super().parse(raw) 226 | 227 | 228 | def dump(self): 229 | _created_at = None 230 | if self.created_at is not None: 231 | _created_at = self.created_at 232 | self.created_at = self.created_at.timestamp() * 1000 # type: ignore 233 | _updated_at = None 234 | if self.updated_at is not None: 235 | _updated_at = self.updated_at 236 | self.updated_at = self.updated_at.timestamp() * 1000 # type: ignore 237 | try: 238 | return super().dump() 239 | finally: 240 | if _created_at is not None: 241 | self.created_at = _created_at 242 | if _updated_at is not None: 243 | self.updated_at = _updated_at 244 | 245 | 246 | class MessageReceipt(Struct, ModelBase, kw_only=True): 247 | id: str 248 | content: str | None = None 249 | 250 | @classmethod 251 | def from_elements( 252 | cls, 253 | id: str, 254 | content: list[Element] | None = None, 255 | ): 256 | return cls(id=id, content="".join(str(i) for i in content) if content else None) 257 | 258 | @property 259 | def message(self) -> list[Element] | None: 260 | return transform(parse(self.content)) if self.content else None 261 | 262 | @message.setter 263 | def message(self, value: list[Element] | None): 264 | self.content = "".join(str(i) for i in value) if value else None 265 | 266 | @classmethod 267 | def parse(cls, raw: dict): 268 | if "elements" in raw and "content" not in raw: 269 | content = [RawElement(*item.values()) for item in raw["elements"]] 270 | raw["content"] = "".join(str(i) for i in content) 271 | return super().parse(raw) 272 | 273 | 274 | class Event(Struct, ModelBase, kw_only=True): 275 | type: str 276 | timestamp: datetime 277 | login: Login 278 | argv: ArgvInteraction | None = None 279 | button: ButtonInteraction | None = None 280 | channel: Channel | None = None 281 | guild: Guild | None = None 282 | member: Member | None = None 283 | message: MessageObject | None = None 284 | operator: User | None = None 285 | role: Role | None = None 286 | user: User | None = None 287 | 288 | _type: str | None = None 289 | _data: dict | None = None 290 | 291 | sn: int = 0 292 | 293 | @classmethod 294 | def parse(cls, raw: dict): 295 | if "timestamp" in raw: 296 | raw["timestamp"] = int(raw["timestamp"]) / 1000 297 | return super().parse(raw) 298 | 299 | @property 300 | def platform(self): 301 | return self.login.platform 302 | 303 | @property 304 | def self_id(self): 305 | return self.login.id 306 | 307 | def dump(self): 308 | _timestamp = None 309 | if self.timestamp is not None: 310 | _timestamp = self.timestamp 311 | self.timestamp = self.timestamp.timestamp() * 1000 # type: ignore 312 | try: 313 | return super().dump() 314 | finally: 315 | if _timestamp is not None: 316 | self.timestamp = _timestamp 317 | 318 | 319 | T = TypeVar("T", bound=ModelBase) 320 | 321 | 322 | class PageResult(ModelBase, Generic[T]): 323 | data: list[T] 324 | next: str | None = None 325 | 326 | @classmethod 327 | def parse(cls, raw: dict, parser: Callable[[dict], T] | None = None) -> "PageResult[T]": 328 | data = [(parser or ModelBase.parse)(item) for item in raw["data"]] 329 | return cls(data, raw.get("next")) # type: ignore 330 | 331 | def dump(self): 332 | res: dict = {"data": [item.dump() for item in self.data]} 333 | if self.next: 334 | res["next"] = self.next 335 | return res 336 | 337 | 338 | class PageDequeResult(PageResult[T]): 339 | prev: str | None = None 340 | 341 | @classmethod 342 | def parse(cls, raw: dict, parser: Callable[[dict], T] | None = None) -> "PageDequeResult[T]": 343 | data = [(parser or ModelBase.parse)(item) for item in raw["data"]] 344 | return cls(data, raw.get("next"), raw.get("prev")) # type: ignore 345 | 346 | def dump(self): 347 | res: dict = {"data": [item.dump() for item in self.data]} 348 | if self.next: 349 | res["next"] = self.next 350 | if self.prev: 351 | res["prev"] = self.prev 352 | return res 353 | 354 | 355 | class IterablePageResult(Generic[T], AsyncIterable[T], Awaitable[PageResult[T]]): 356 | def __init__( 357 | self, func: Callable[[str | None], Awaitable[PageResult[T]]], initial_page: str | None = None 358 | ): 359 | self.func = func 360 | self.next_page = initial_page 361 | 362 | def __await__(self): 363 | return self.func(self.next_page).__await__() 364 | 365 | def __aiter__(self): 366 | async def _gen(): 367 | while True: 368 | result = await self.func(self.next_page) 369 | for item in result.data: 370 | yield item 371 | self.next_page = result.next 372 | if not self.next_page: 373 | break 374 | 375 | return _gen() 376 | 377 | 378 | Direction: TypeAlias = Literal["before", "after", "around"] 379 | Order: TypeAlias = Literal["asc", "desc"] 380 | 381 | 382 | class Upload: 383 | file: bytes | IO[bytes] | PathLike 384 | mimetype: str = "image/png" 385 | name: str | None = None 386 | 387 | def __post_init__(self): 388 | if isinstance(self.file, PathLike): 389 | self.mimetype = mimetypes.guess_type(str(self.file))[0] or self.mimetype 390 | self.name = Path(self.file).name 391 | 392 | def dump(self): 393 | file = self.file 394 | 395 | if isinstance(file, PathLike): 396 | file = open(file, "rb") 397 | 398 | return {"value": file, "filename": self.name, "content_type": self.mimetype} 399 | -------------------------------------------------------------------------------- /src/satori/adapters/milky/api.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | 5 | from satori import Api 6 | from satori.model import Channel, ChannelType, Login, PageDequeResult, PageResult 7 | from satori.server import Adapter, Request 8 | from satori.server.route import ( 9 | ApproveParam, 10 | ChannelListParam, 11 | ChannelMuteParam, 12 | ChannelParam, 13 | ChannelUpdateParam, 14 | FriendListParam, 15 | GuildGetParam, 16 | GuildListParam, 17 | GuildMemberGetParam, 18 | GuildMemberKickParam, 19 | GuildMemberMuteParam, 20 | GuildXXXListParam, 21 | MessageListParam, 22 | MessageOpParam, 23 | MessageParam, 24 | ReactionCreateParam, 25 | ReactionDeleteParam, 26 | UserChannelCreateParam, 27 | UserGetParam, 28 | ) 29 | 30 | from .message import MilkyMessageEncoder, decode_message 31 | from .utils import ( 32 | MilkyNetwork, 33 | decode_friend, 34 | decode_group_channel, 35 | decode_guild, 36 | decode_member, 37 | decode_private_channel, 38 | decode_user_profile, 39 | get_scene_and_peer, 40 | ) 41 | 42 | 43 | def apply( 44 | adapter: Adapter, 45 | net_getter: Callable[[str], MilkyNetwork], 46 | login_getter: Callable[[str], Login], 47 | ): 48 | @adapter.route(Api.LOGIN_GET) 49 | async def login_get(request: Request): 50 | return login_getter(request.self_id) 51 | 52 | @adapter.route(Api.MESSAGE_CREATE) 53 | async def message_create(request: Request[MessageParam]): 54 | net = net_getter(request.self_id) 55 | login = login_getter(request.self_id) 56 | encoder = MilkyMessageEncoder(login, net, request.params["channel_id"]) 57 | return await encoder.send(request.params["content"]) 58 | 59 | @adapter.route(Api.MESSAGE_GET) 60 | async def message_get(request: Request[MessageOpParam]): 61 | net = net_getter(request.self_id) 62 | scene, peer_id = get_scene_and_peer(request.params["channel_id"]) 63 | result = await net.call_api( 64 | "get_message", 65 | { 66 | "message_scene": scene, 67 | "peer_id": peer_id, 68 | "message_seq": int(request.params["message_id"]), 69 | }, 70 | ) 71 | if not result or "message" not in result: 72 | raise RuntimeError("Failed to get message") 73 | return await decode_message(net, result["message"]) 74 | 75 | @adapter.route(Api.MESSAGE_DELETE) 76 | async def message_delete(request: Request[MessageOpParam]): 77 | net = net_getter(request.self_id) 78 | scene, peer_id = get_scene_and_peer(request.params["channel_id"]) 79 | message_seq = int(request.params["message_id"]) 80 | if scene == "group": 81 | await net.call_api("recall_group_message", {"group_id": peer_id, "message_seq": message_seq}) 82 | else: 83 | await net.call_api("recall_private_message", {"user_id": peer_id, "message_seq": message_seq}) 84 | return 85 | 86 | @adapter.route(Api.MESSAGE_LIST) 87 | async def message_list(request: Request[MessageListParam]): 88 | net = net_getter(request.self_id) 89 | params = request.params 90 | direction = params.get("direction", "before") 91 | if direction != "before": 92 | raise RuntimeError("Milky adapter only supports direction='before'") 93 | scene, peer_id = get_scene_and_peer(params["channel_id"]) 94 | result = await net.call_api( 95 | "get_history_messages", 96 | { 97 | "message_scene": scene, 98 | "peer_id": peer_id, 99 | "start_message_seq": int(params["next"]) if params.get("next") else None, # type: ignore 100 | "limit": params.get("limit"), 101 | }, 102 | ) 103 | if not result: 104 | return PageDequeResult([]) 105 | messages = [await decode_message(net, item) for item in result.get("messages", [])] 106 | next_seq = result.get("next_message_seq") 107 | return PageDequeResult(messages, str(next_seq) if next_seq is not None else None) 108 | 109 | @adapter.route(Api.CHANNEL_GET) 110 | async def channel_get(request: Request[ChannelParam]): 111 | net = net_getter(request.self_id) 112 | channel_id = request.params["channel_id"] 113 | scene, peer_id = get_scene_and_peer(channel_id) 114 | if scene == "group": 115 | result = await net.call_api("get_group_info", {"group_id": peer_id}) 116 | if not result: 117 | raise RuntimeError("Failed to get group info") 118 | return decode_group_channel(result["group"]) 119 | profile = await net.call_api("get_user_profile", {"user_id": peer_id}) 120 | if not profile: 121 | raise RuntimeError("Failed to get user profile") 122 | return decode_private_channel(profile, channel_id) 123 | 124 | @adapter.route(Api.CHANNEL_LIST) 125 | async def channel_list(request: Request[ChannelListParam]): 126 | net = net_getter(request.self_id) 127 | guild_id = int(request.params["guild_id"]) 128 | result = await net.call_api("get_group_info", {"group_id": guild_id}) 129 | if not result: 130 | raise RuntimeError("Failed to get group info") 131 | channel = decode_group_channel(result["group"]) 132 | return PageResult([channel]) 133 | 134 | @adapter.route(Api.USER_CHANNEL_CREATE) 135 | async def user_channel_create(request: Request[UserChannelCreateParam]): 136 | return Channel(f"private:{request.params['user_id']}", ChannelType.DIRECT) 137 | 138 | @adapter.route(Api.CHANNEL_UPDATE) 139 | async def channel_update(request: Request[ChannelUpdateParam]): 140 | net = net_getter(request.self_id) 141 | data = request.params["data"] 142 | channel_id = request.params["channel_id"] 143 | scene, peer_id = get_scene_and_peer(channel_id) 144 | if scene != "group": 145 | raise RuntimeError("Only group channels support update") 146 | if "name" in data: 147 | await net.call_api("set_group_name", {"group_id": peer_id, "new_group_name": data["name"]}) 148 | return 149 | 150 | @adapter.route(Api.CHANNEL_MUTE) 151 | async def channel_mute(request: Request[ChannelMuteParam]): 152 | net = net_getter(request.self_id) 153 | scene, peer_id = get_scene_and_peer(request.params["channel_id"]) 154 | if scene != "group": 155 | raise RuntimeError("Only group channels support mute") 156 | await net.call_api( 157 | "set_group_whole_mute", 158 | {"group_id": peer_id, "is_mute": request.params["duration"] > 0}, 159 | ) 160 | return 161 | 162 | @adapter.route(Api.GUILD_GET) 163 | async def guild_get(request: Request[GuildGetParam]): 164 | net = net_getter(request.self_id) 165 | guild_id = int(request.params["guild_id"]) 166 | result = await net.call_api("get_group_info", {"group_id": guild_id}) 167 | if not result: 168 | raise RuntimeError("Failed to get group info") 169 | return decode_guild(result["group"]) 170 | 171 | @adapter.route(Api.GUILD_LIST) 172 | async def guild_list(request: Request[GuildListParam]): 173 | net = net_getter(request.self_id) 174 | result = await net.call_api("get_group_list", {}) 175 | groups = [decode_guild(item) for item in result.get("groups", [])] if result else [] 176 | return PageResult(groups) 177 | 178 | @adapter.route(Api.GUILD_MEMBER_GET) 179 | async def guild_member_get(request: Request[GuildMemberGetParam]): 180 | net = net_getter(request.self_id) 181 | result = await net.call_api( 182 | "get_group_member_info", 183 | {"group_id": int(request.params["guild_id"]), "user_id": int(request.params["user_id"])}, 184 | ) 185 | if not result: 186 | raise RuntimeError("Failed to get group member") 187 | return decode_member(result["member"]) 188 | 189 | @adapter.route(Api.GUILD_MEMBER_LIST) 190 | async def guild_member_list(request: Request[GuildXXXListParam]): 191 | net = net_getter(request.self_id) 192 | result = await net.call_api("get_group_member_list", {"group_id": int(request.params["guild_id"])}) 193 | members = [decode_member(item) for item in result.get("members", [])] if result else [] 194 | return PageResult(members) 195 | 196 | @adapter.route(Api.GUILD_MEMBER_KICK) 197 | async def guild_member_kick(request: Request[GuildMemberKickParam]): 198 | net = net_getter(request.self_id) 199 | await net.call_api( 200 | "kick_group_member", 201 | { 202 | "group_id": int(request.params["guild_id"]), 203 | "user_id": int(request.params["user_id"]), 204 | "reject_add_request": request.params.get("permanent", False), 205 | }, 206 | ) 207 | return 208 | 209 | @adapter.route(Api.GUILD_MEMBER_MUTE) 210 | async def guild_member_mute(request: Request[GuildMemberMuteParam]): 211 | net = net_getter(request.self_id) 212 | await net.call_api( 213 | "set_group_member_mute", 214 | { 215 | "group_id": int(request.params["guild_id"]), 216 | "user_id": int(request.params["user_id"]), 217 | "duration": int(request.params["duration"] / 1000), 218 | }, 219 | ) 220 | return 221 | 222 | @adapter.route(Api.GUILD_MEMBER_APPROVE) 223 | async def guild_member_approve(request: Request[ApproveParam]): 224 | net = net_getter(request.self_id) 225 | message_id = request.params["message_id"] 226 | notification_seq, notification_type, group_id, is_filtered = message_id.split("|") 227 | params = { 228 | "notification_seq": int(notification_seq), 229 | "notification_type": notification_type, 230 | "group_id": int(group_id), 231 | "is_filtered": bool(int(is_filtered)), 232 | } 233 | if request.params["approve"]: 234 | await net.call_api("accept_group_request", params) 235 | else: 236 | params["reason"] = request.params.get("comment") 237 | await net.call_api("reject_group_request", params) 238 | return 239 | 240 | @adapter.route(Api.GUILD_APPROVE) 241 | async def guild_approve(request: Request[ApproveParam]): 242 | net = net_getter(request.self_id) 243 | group_id, invitation_seq = request.params["message_id"].split("|") 244 | payload = {"group_id": int(group_id), "invitation_seq": int(invitation_seq)} 245 | if request.params["approve"]: 246 | await net.call_api("accept_group_invitation", payload) 247 | else: 248 | await net.call_api("reject_group_invitation", payload) 249 | return 250 | 251 | @adapter.route(Api.REACTION_CREATE) 252 | async def reaction_create(request: Request[ReactionCreateParam]): 253 | net = net_getter(request.self_id) 254 | scene, peer_id = get_scene_and_peer(request.params["channel_id"]) 255 | if scene != "group": 256 | raise RuntimeError("Reactions only supported in group channels") 257 | await net.call_api( 258 | "send_group_message_reaction", 259 | { 260 | "group_id": peer_id, 261 | "message_seq": int(request.params["message_id"]), 262 | "reaction": request.params["emoji"], 263 | "is_add": True, 264 | }, 265 | ) 266 | return 267 | 268 | @adapter.route(Api.REACTION_DELETE) 269 | async def reaction_delete(request: Request[ReactionDeleteParam]): 270 | net = net_getter(request.self_id) 271 | scene, peer_id = get_scene_and_peer(request.params["channel_id"]) 272 | if scene != "group": 273 | raise RuntimeError("Reactions only supported in group channels") 274 | await net.call_api( 275 | "send_group_message_reaction", 276 | { 277 | "group_id": peer_id, 278 | "message_seq": int(request.params["message_id"]), 279 | "reaction": request.params["emoji"], 280 | "is_add": False, 281 | }, 282 | ) 283 | return 284 | 285 | @adapter.route(Api.USER_GET) 286 | async def user_get(request: Request[UserGetParam]): 287 | net = net_getter(request.self_id) 288 | user_id = request.params["user_id"] 289 | profile = await net.call_api("get_user_profile", {"user_id": int(user_id)}) 290 | if not profile: 291 | raise RuntimeError("Failed to get user profile") 292 | return decode_user_profile(profile, user_id) 293 | 294 | @adapter.route(Api.FRIEND_LIST) 295 | async def friend_list(request: Request[FriendListParam]): 296 | net = net_getter(request.self_id) 297 | result = await net.call_api("get_friend_list", {}) 298 | friends = [decode_friend(item) for item in result.get("friends", [])] if result else [] 299 | return PageResult(friends) 300 | 301 | @adapter.route(Api.FRIEND_APPROVE) 302 | async def friend_approve(request: Request[ApproveParam]): 303 | net = net_getter(request.self_id) 304 | initiator_uid, is_filtered = request.params["message_id"].split("|") 305 | payload = {"initiator_uid": initiator_uid, "is_filtered": bool(int(is_filtered))} 306 | if request.params["approve"]: 307 | await net.call_api("accept_friend_request", payload) 308 | else: 309 | payload["reason"] = request.params.get("comment") 310 | await net.call_api("reject_friend_request", payload) 311 | return 312 | 313 | @adapter.route("*") 314 | async def internal_api(request: Request[dict]): 315 | net = net_getter(request.self_id) 316 | return await net.call_api(request.action.removeprefix("internal/"), request.params) 317 | -------------------------------------------------------------------------------- /src/satori/adapters/onebot11/message.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from dataclasses import dataclass, field 5 | from datetime import datetime 6 | from pathlib import Path 7 | from typing import Any, Literal, TypedDict 8 | from urllib.parse import urlparse 9 | 10 | from satori.element import Custom, E, Element 11 | from satori.model import Login, MessageObject 12 | from satori.parser import Element as RawElement 13 | from satori.parser import parse 14 | 15 | from .utils import USER_AVATAR_URL, OneBotNetwork 16 | 17 | 18 | class MessageSegment(TypedDict): 19 | type: str 20 | data: dict[str, Any] 21 | 22 | 23 | def uri_to_path(uri): 24 | parsed = urlparse(uri) 25 | path_str = parsed.path 26 | 27 | # 在 Windows 上处理驱动器字母 28 | if path_str.startswith("/") and len(path_str) > 2 and path_str[2] == ":": 29 | # 删除开头的 '/',Windows 路径如 /C:/Users 需要转换为 C:/Users 30 | path_str = path_str[1:] 31 | 32 | return Path(path_str).resolve() 33 | 34 | 35 | # def escape(text: str, inline: bool = False) -> str: 36 | # result = text.replace("&", "&").replace("[", "[").replace("]", "]") 37 | # if inline: 38 | # result = result.replace(",", ",") 39 | # result = re.sub( 40 | # r"(\ud83c[\udf00-\udfff])|(\ud83d[\udc00-\ude4f\ude80-\udeff])|[\u2600-\u2B55]", " ", result 41 | # ) 42 | # return result 43 | # 44 | # 45 | # def unescape(text: str) -> str: 46 | # return text.replace("[", "[").replace("]", "]").replace(",", ",").replace("&", "&") 47 | 48 | 49 | b64_cap = re.compile(r"^data:([\w/.+-]+);base64,") 50 | 51 | 52 | @dataclass 53 | class State: 54 | type: Literal["message", "reply", "forward"] 55 | children: list[MessageSegment] = field(default_factory=list) 56 | author: dict[str, Any] = field(default_factory=dict) 57 | 58 | 59 | class OneBot11MessageEncoder: 60 | def __init__(self, login: Login, net: OneBotNetwork, channel_id: str): 61 | self.net = net 62 | self.login = login 63 | self.channel_id = channel_id 64 | self.children: list[MessageSegment] = [] 65 | self.stack = [State("message")] 66 | self.results: list[MessageObject] = [] 67 | 68 | async def send_forward(self): 69 | if not self.stack[0].children: 70 | return 71 | if self.channel_id.startswith("private:"): 72 | resp = await self.net.call_api( 73 | "send_private_forward_msg", 74 | { 75 | "user_id": int(self.channel_id[8:]), 76 | "messages": self.stack[0].children, 77 | }, 78 | ) 79 | else: 80 | resp = await self.net.call_api( 81 | "send_group_forward_msg", 82 | { 83 | "group_id": int(self.channel_id), 84 | "messages": self.stack[0].children, 85 | }, 86 | ) 87 | if resp: 88 | self.results.append(MessageObject(resp["message_id"], "")) 89 | 90 | async def flush(self): 91 | if not self.children: 92 | return 93 | 94 | while True: 95 | first = self.children[0] 96 | if first["type"] != "text": 97 | break 98 | first["data"]["text"] = first["data"]["text"].lstrip() 99 | if first["data"]["text"]: 100 | break 101 | self.children.pop(0) 102 | 103 | while True: 104 | last = self.children[-1] 105 | if last["type"] != "text": 106 | break 107 | last["data"]["text"] = last["data"]["text"].rstrip() 108 | if last["data"]["text"]: 109 | break 110 | self.children.pop() 111 | 112 | slot = self.stack[0] 113 | type_, author = slot.type, slot.author 114 | if not self.children and "message_id" not in author: 115 | return 116 | if type_ == "forward": 117 | if "message_id" in author: 118 | self.stack[1].children.append( 119 | { 120 | "type": "node", 121 | "data": { 122 | "id": author["message_id"], 123 | }, 124 | } 125 | ) 126 | else: 127 | self.stack[1].children.append( 128 | { 129 | "type": "node", 130 | "data": { 131 | "name": author.get( 132 | "name", 133 | (self.login.user.name or self.login.user.id) if self.login.user else "", 134 | ), 135 | "uin": author.get("id", self.login.user.id if self.login.user else 0), 136 | "content": self.children, 137 | "time": int(datetime.now().timestamp()), 138 | }, 139 | } 140 | ) 141 | 142 | self.children = [] 143 | return 144 | 145 | if self.channel_id.startswith("private:"): 146 | resp = await self.net.call_api( 147 | "send_private_msg", 148 | { 149 | "user_id": int(self.channel_id[8:]), 150 | "message": self.children, 151 | }, 152 | ) 153 | else: 154 | resp = await self.net.call_api( 155 | "send_group_msg", 156 | { 157 | "group_id": int(self.channel_id), 158 | "message": self.children, 159 | }, 160 | ) 161 | if resp: 162 | self.results.append(MessageObject(resp["message_id"], "")) 163 | self.children = [] 164 | 165 | async def _send_file(self, attrs: dict[str, Any]): 166 | src = attrs.get("src") or attrs["url"] 167 | if src.startswith("file:"): 168 | file = uri_to_path(src) 169 | name = file.name 170 | elif mat := b64_cap.match(src): 171 | file = f"base64://{src[len(mat[0]):]}" 172 | name = attrs.get("title") 173 | else: 174 | file = src 175 | name = attrs.get("title") or src.split("/")[-1][:32] 176 | if self.channel_id.startswith("private:"): 177 | await self.net.call_api( 178 | "upload_private_file", 179 | { 180 | "user_id": int(self.channel_id[8:]), 181 | "file": str(file), 182 | "name": name, 183 | }, 184 | ) 185 | else: 186 | await self.net.call_api( 187 | "upload_group_file", 188 | { 189 | "group_id": int(self.channel_id), 190 | "file": str(file), 191 | "name": name, 192 | }, 193 | ) 194 | self.results.append(MessageObject("", "")) 195 | 196 | async def send(self, content: str): 197 | await self.render(parse(content)) 198 | await self.flush() 199 | return self.results 200 | 201 | async def render(self, elements: list[RawElement]): 202 | for element in elements: 203 | await self.visit(element) 204 | 205 | async def visit(self, element: RawElement): 206 | type_, attrs, _children = element.type, element.attrs, element.children 207 | if type_ == "text": 208 | self.children.append({"type": "text", "data": {"text": attrs["text"]}}) 209 | elif type_ == "br": 210 | self.children.append({"type": "text", "data": {"text": "\n"}}) 211 | elif type_ == "p": 212 | prev = self.children[-1] if self.children else None 213 | if prev and prev["type"] == "text": 214 | if not prev["data"]["text"].endswith("\n"): 215 | prev["data"]["text"] += "\n" 216 | else: 217 | self.children.append({"type": "text", "data": {"text": "\n"}}) 218 | await self.render(_children) 219 | self.children.append({"type": "text", "data": {"text": "\n"}}) 220 | elif type_ == "at": 221 | if "type" in attrs and attrs["type"] == "all": 222 | self.children.append({"type": "at", "data": {"qq": "all"}}) 223 | else: 224 | self.children.append({"type": "at", "data": {"qq": str(attrs["id"]), "name": attrs.get("name")}}) 225 | elif type_ == "sharp": 226 | if "id" in attrs: 227 | self.children.append({"type": "text", "data": {"text": attrs["id"]}}) 228 | elif type_ == "onebot:face": 229 | self.children.append({"type": "face", "data": {"id": int(attrs["id"])}}) 230 | elif type_ == "a": 231 | await self.render(_children) 232 | if "href" in attrs: 233 | self.children.append({"type": "text", "data": {"text": f" ({attrs['href']})"}}) 234 | elif type_ in ("video", "audio", "img", "image"): 235 | if type_ in ("video", "audio"): 236 | await self.flush() 237 | if type_ == "audio": 238 | type_ = "record" 239 | elif type_ == "img": 240 | type_ = "image" 241 | _data = { 242 | "cache": 1 if "cache" in attrs and attrs["cache"] else 0, 243 | "file": attrs.get("src") or attrs.get("url"), 244 | } 245 | if mat := b64_cap.match(_data["file"]): 246 | _data["file"] = f"base64://{_data['file'][len(mat[0]):]}" 247 | self.children.append({"type": type_, "data": _data}) 248 | elif type_ == "file": 249 | await self.flush() 250 | await self._send_file(attrs) 251 | elif type_ == "onebot:music": 252 | await self.flush() 253 | self.children.append({"type": "music", "data": attrs}) 254 | elif type_ == "onebot:poke": 255 | await self.flush() 256 | self.children.append({"type": "poke", "data": attrs}) 257 | elif type_ == "onebot:gift": 258 | await self.flush() 259 | self.children.append({"type": "gift", "data": attrs}) 260 | elif type_ == "onebot:share": 261 | await self.flush() 262 | self.children.append({"type": "share", "data": attrs}) 263 | elif type_ == "onebot:json": 264 | await self.flush() 265 | self.children.append({"type": "json", "data": attrs}) 266 | elif type_ == "onebot:xml": 267 | await self.flush() 268 | self.children.append({"type": "xml", "data": attrs}) 269 | elif type_ == "author": 270 | self.stack[0].author.update(attrs) 271 | elif type_ == "quote": 272 | await self.flush() 273 | self.children.append({"type": "reply", "data": attrs}) 274 | elif type_ == "message": 275 | await self.flush() 276 | if "forward" in attrs: 277 | self.stack.insert(0, State("forward")) 278 | await self.render(_children) 279 | await self.flush() 280 | self.stack.pop(0) 281 | await self.send_forward() 282 | elif "id" in attrs: 283 | self.stack[0].author["message_id"] = str(attrs["id"]) 284 | else: 285 | payload = {} 286 | if "name" in attrs: 287 | payload["name"] = attrs["name"] 288 | if "nickname" in attrs: 289 | payload["name"] = attrs["nickname"] 290 | if "username" in attrs: 291 | payload["name"] = attrs["username"] 292 | if "id" in attrs: 293 | payload["id"] = int(attrs["id"]) 294 | if "user_id" in attrs: 295 | payload["id"] = int(attrs["user_id"]) 296 | if "time" in attrs: 297 | payload["time"] = int(attrs["time"]) 298 | self.stack[0].author.update(payload) 299 | await self.render(_children) 300 | await self.flush() 301 | else: 302 | await self.render(_children) 303 | 304 | 305 | async def _decode(content: list[MessageSegment], net: OneBotNetwork) -> list[Element]: 306 | result = [] 307 | for seg in content: 308 | seg_type = seg["type"] 309 | seg_data = seg["data"] 310 | if seg_type == "text": 311 | result.append(E.text(seg_data["text"])) 312 | elif seg_type == "at": 313 | qq = seg_data["qq"] 314 | if qq == "all": 315 | result.append(E.at_all()) 316 | else: 317 | result.append(E.at(str(qq), name=seg_data.get("name"))) 318 | elif seg_type == "image": 319 | result.append(E.image(seg_data.get("url") or seg_data.get("file"))) 320 | elif seg_type == "record": 321 | result.append(E.audio(seg_data.get("url") or seg_data.get("file"))) 322 | elif seg_type == "video": 323 | result.append(E.video(seg_data.get("url") or seg_data.get("file"))) 324 | elif seg_type == "file": 325 | result.append(E.file(seg_data.get("url") or seg_data.get("file"))) 326 | elif seg_type == "reply": 327 | if msg := (await net.call_api("get_msg", {"message_id": seg_data["id"]})): 328 | author = E.author( 329 | str(msg["sender"]["user_id"]), 330 | msg["sender"]["nickname"], 331 | USER_AVATAR_URL.format(uin=msg["sender"]["user_id"]), 332 | ) 333 | result.append(E.quote(seg_data["id"], content=[author, *(await _decode(msg["message"], net))])) 334 | else: 335 | result.append(Custom(f"onebot:{seg_type}", seg_data)) 336 | return result 337 | 338 | 339 | async def decode(content: list[MessageSegment], net: OneBotNetwork) -> str: 340 | return "".join(str(x) for x in await _decode(content, net)) 341 | --------------------------------------------------------------------------------