├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── Bug Report.yaml │ └── Model.yaml └── workflows │ ├── pre-commit.yml │ └── pypi-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── bot.py ├── migrations.py ├── muicebot ├── __init__.py ├── builtin_plugins │ ├── access_control.py │ ├── get_current_time.py │ ├── get_username.py │ ├── muicebot_plugin_store │ │ ├── __init__.py │ │ ├── config.py │ │ ├── models.py │ │ ├── register.py │ │ └── store.py │ └── thought_processor.py ├── builtin_templates │ └── Muice.jinja2 ├── config.py ├── database.py ├── llm │ ├── __init__.py │ ├── _base.py │ ├── _config.py │ ├── _dependencies.py │ ├── _schema.py │ ├── loader.py │ ├── providers │ │ ├── __init__.py │ │ ├── azure.py │ │ ├── dashscope.py │ │ ├── gemini.py │ │ ├── ollama.py │ │ └── openai.py │ ├── registry.py │ └── utils │ │ ├── images.py │ │ └── tools.py ├── models.py ├── muice.py ├── onebot.py ├── plugin │ ├── __init__.py │ ├── context.py │ ├── func_call │ │ ├── __init__.py │ │ ├── _types.py │ │ ├── caller.py │ │ ├── parameter.py │ │ └── utils.py │ ├── hook │ │ ├── __init__.py │ │ ├── _types.py │ │ └── manager.py │ ├── loader.py │ ├── mcp │ │ ├── __init__.py │ │ ├── client.py │ │ ├── config.py │ │ └── server.py │ ├── models.py │ └── utils.py ├── scheduler.py ├── templates │ ├── __init__.py │ ├── loader.py │ └── model.py └── utils │ ├── SessionManager.py │ ├── adapters.py │ ├── migrations.py │ └── utils.py ├── pdm.lock ├── pyproject.toml └── requirements.txt /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: moemu 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: ['https://www.afdian.com/a/Moemu'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Bug Report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug报告 2 | description: 运行本项目时出现的问题 3 | title: "[Bug]: " 4 | labels: ["bug"] 5 | assignees: 6 | - Moemu 7 | body: 8 | - type: markdown 9 | attributes: 10 | value: | 11 | 很抱歉您遇到了问题,为了更好地定位到问题,请填写以下表单。 12 | - type: input 13 | id: description 14 | attributes: 15 | label: 问题描述 16 | description: 请简要描述您遇到的问题。 17 | placeholder: ex. 模型加载失败 18 | validations: 19 | required: true 20 | - type: textarea 21 | id: logs 22 | attributes: 23 | label: 相关日志输出 24 | description: 请复制粘贴任何相关的日志输出,你可以在控制台或者是logs文件夹下找到他们(注意保护您的个人信息)。 25 | placeholder: ex. 2022-01-01 12:00:00 [ERROR] ... 26 | render: shell 27 | validations: 28 | required: true 29 | - type: textarea 30 | id: configs 31 | attributes: 32 | label: 配置文件 33 | description: 请提供您的配置文件,即 configs.yml(注意保护您的个人信息)。 34 | placeholder: ex. bot... 35 | render: yaml 36 | validations: 37 | required: true 38 | - type: textarea 39 | id: steps 40 | attributes: 41 | label: 复现步骤 42 | description: 请提供您的操作步骤,以便我们复现问题。 43 | placeholder: ex. 1. 运行程序... 44 | validations: 45 | required: true 46 | - type: textarea 47 | id: others 48 | attributes: 49 | label: 其他信息 50 | description: 如果有其他信息需要提供,请在此处填写。 51 | placeholder: ex. 我的目录结构是... 52 | validations: 53 | required: false 54 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Model.yaml: -------------------------------------------------------------------------------- 1 | name: 模型相关 2 | description: 模型的输出不符合预期 3 | title: "[Model]: " 4 | labels: ["Model"] 5 | assignees: 6 | - Moemu 7 | body: 8 | - type: markdown 9 | attributes: 10 | value: | 11 | 很抱歉您遇到了问题,为了更好地定位到问题,请填写以下表单。 12 | - type: input 13 | id: description 14 | attributes: 15 | label: 问题描述 16 | description: 请简要描述您遇到的问题。 17 | placeholder: ex. 模型加载失败 18 | validations: 19 | required: true 20 | - type: textarea 21 | id: contexts 22 | attributes: 23 | label: 适当的上下文 24 | description: 适当的上下文有利于我们确定问题,注意移除敏感内容。 25 | placeholder: ex. Q:晚上吃什么?A:我不知道 26 | validations: 27 | required: false 28 | - type: input 29 | id: prompts 30 | attributes: 31 | label: 输入的Prompt 32 | description: 出现问题的输入的提问。 33 | placeholder: ex. 今天天气怎么样? 34 | validations: 35 | required: true 36 | - type: input 37 | id: responses 38 | attributes: 39 | label: 输出的Response 40 | description: 出现问题的输出的提问。 41 | placeholder: ex. 我不知道 42 | validations: 43 | required: true 44 | - type: input 45 | id: expected 46 | attributes: 47 | label: 期望的答案 48 | description: 期望输出的答案。 49 | placeholder: ex. 今天的天气很好,我们出去玩吧 50 | validations: 51 | required: false 52 | - type: textarea 53 | id: others 54 | attributes: 55 | label: 其他信息 56 | description: 如果有其他信息需要提供,请在此处填写。 57 | placeholder: ex. 我的目录结构是... 58 | validations: 59 | required: false 60 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: Pre-commit checks 2 | 3 | on: [push] 4 | 5 | jobs: 6 | pre-commit: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ['3.10', '3.11'] 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v3 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v3 18 | with: 19 | python-version: ${{ matrix.python-version }} # 使用矩阵中的 Python 版本 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install pdm 24 | python -m pip install pre-commit 25 | pdm config python.use_venv false 26 | pdm install --frozen-lockfile --group dev 27 | pre-commit install 28 | 29 | - name: Run pre-commit 30 | run: pre-commit run --all-files -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | pypi-publish: 9 | name: upload release to PyPI 10 | environment: pypi 11 | runs-on: ubuntu-latest 12 | permissions: 13 | contents: read 14 | id-token: write 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v4 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.11" 23 | 24 | - name: Setup and Configure PDM 25 | run: | 26 | pip install pdm 27 | pdm config python.use_venv false 28 | 29 | - name: Publish to PyPI 30 | env: 31 | PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 32 | run: pdm publish --username __token__ --password "$PYPI_TOKEN" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | pytestdebug.log 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | doc/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .env.dev 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # End of https://www.toptal.com/developers/gitignore/api/python 143 | 144 | # Dist 145 | node_modules 146 | dist/ 147 | doc_build/ 148 | 149 | # MuiceBot 150 | configs.yml 151 | .vscode 152 | logs 153 | configs 154 | plugins 155 | docs -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_commit_msg: "style: auto fix by pre-commit hooks" 3 | autofix_prs: false 4 | autoupdate_branch: main 5 | autoupdate_schedule: monthly 6 | autoupdate_commit_msg: "chore: auto update by pre-commit hooks" 7 | 8 | 9 | repos: 10 | - repo: https://github.com/pycqa/flake8 11 | rev: 7.2.0 12 | hooks: 13 | - id: flake8 # 代码风格检查 14 | args: ["--max-line-length=120", "--extend-ignore=W503,E203"] 15 | 16 | - repo: https://github.com/pre-commit/mirrors-mypy 17 | rev: v1.16.0 18 | hooks: 19 | - id: mypy # mypy 类型检查 20 | args: ["--install-types", "--non-interactive", "--ignore-missing-imports"] 21 | additional_dependencies: ["types-PyYAML", "azure-ai-inference", "dashscope", "google-genai", "ollama"] 22 | 23 | - repo: https://github.com/psf/black 24 | rev: 25.1.0 25 | hooks: 26 | - id: black 27 | args: ["--config=./pyproject.toml"] 28 | 29 | - repo: https://github.com/PyCQA/isort 30 | rev: 6.0.1 31 | hooks: 32 | - id: isort 33 | args: ["--profile", "black"] 34 | 35 | - repo: https://github.com/pre-commit/pre-commit-hooks 36 | rev: v5.0.0 # 版本号 37 | hooks: 38 | - id: trailing-whitespace # 删除行尾空格 -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## 我们的承诺 2 | 3 | 身为社区成员、贡献者和领袖,我们承诺使社区参与者不受骚扰,无论其年龄、体型、可见或不可见的缺陷、族裔、性征、性别认同和表达、经验水平、教育程度、社会与经济地位、国籍、相貌、种族、种姓、肤色、宗教信仰、性倾向或性取向如何。 4 | 5 | 我们承诺以有助于建立开放、友善、多样化、包容、健康社区的方式行事和互动。 6 | 7 | ## 我们的准则 8 | 9 | 有助于为我们的社区创造积极环境的行为例子包括但不限于: 10 | 11 | - 表现出对他人的同情和善意 12 | 13 | - 尊重不同的主张、观点和感受 14 | 15 | - 提出和大方接受建设性意见 16 | 17 | - 承担责任并向受我们错误影响的人道歉 18 | 19 | - 注重社区共同诉求,而非个人得失 20 | 21 | 22 | 不当行为例子包括: 23 | 24 | - 使用情色化的语言或图像,及性引诱或挑逗 25 | 26 | - 嘲弄、侮辱或诋毁性评论,以及人身或政治攻击 27 | 28 | - 公开或私下的骚扰行为 29 | 30 | - 未经他人明确许可,公布他人的私人信息,如物理或电子邮件地址 31 | 32 | - 其他有理由认定为违反职业操守的不当行为 33 | 34 | 35 | ## 责任和权力 36 | 37 | 社区负责人有责任解释和落实我们所认可的行为准则,并妥善公正地对他们认为不当、威胁、冒犯或有害的任何行为采取纠正措施。 38 | 39 | 社区领导有权力和责任删除、编辑或拒绝或拒绝与本行为准则不相符的评论(comment)、提交(commits)、代码、维基(wiki)编辑、议题(issues)或其他贡献,并在适当时机知采取措施的理由。 40 | 41 | ## 适用范围 42 | 43 | 本行为准则适用于所有社区场合,也适用于在公共场所代表社区时的个人。 44 | 45 | 代表社区的情形包括使用官方电子邮件地址、通过官方社交媒体帐户发帖或在线上或线下活动中担任指定代表。 46 | 47 | ## 监督 48 | 49 | 辱骂、骚扰或其他不可接受的行为可通过 [i@snowy.moe](mailto:i@snowy.moe) 向我们报告。 所有投诉都将得到及时和公平的审查和调查。 50 | 51 | 所有社区负责人都有义务尊重任何事件报告者的隐私和安全。 52 | 53 | ## 处理方针 54 | 55 | 社区负责人将遵循下列社区处理方针来明确他们所认定违反本行为准则的行为的处理方式: 56 | 57 | 1. 纠正 58 | 59 | 社区影响:使用不恰当的语言或其他在社区中被认定为不符合职业道德或不受欢迎的行为。 60 | 61 | 处理意见:由社区负责人发出非公开的书面警告,明确说明违规行为的性质,并解释举止如何不妥。或将要求公开道歉。 62 | 63 | 2. 警告 64 | 65 | 社区影响:单个或一系列违规行为。 66 | 67 | 处理意见:警告并对连续性行为进行处理。在指定时间内,不得与相关人员互动,包括主动与行为准则执行者互动。这包括避免在社区场所和外部渠道中的互动。违反这些条款可能会导致临时或永久封禁。 68 | 69 | 3. 临时封禁 70 | 71 | 社区影响: 严重违反社区准则,包括持续的不当行为。 72 | 73 | 处理意见: 在指定时间内,暂时禁止与社区进行任何形式的互动或公开交流。在此期间,不得与相关人员进行公开或私下互动,包括主动与行为准则执行者互动。违反这些条款可能会导致永久封禁。 74 | 75 | 4. 永久封禁 76 | 77 | 社区影响:行为模式表现出违反社区准则,包括持续的不当行为、骚扰个人或攻击或贬低某个类别的个体。 78 | 79 | 处理意见:永久禁止在社区内进行任何形式的公开互动。 80 | 81 | ## 参见 82 | 83 | 本行为准则改编自 [Contributor Covenant](https://www.contributor-covenant.org/) 2.1 版, 参见 https://www.contributor-covenant.org/version/2/1/code_of_conduct.html。 84 | 85 | 社区处理方针灵感来源于 [Mozilla’s code of conduct enforcement ladder](https://github.com/mozilla/diversity)。 86 | 87 | 有关本行为准则的常见问题的答案,参见 https://www.contributor-covenant.org/faq。 88 | 89 | 其他语言翻译参见 https://www.contributor-covenant.org/translations。 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # MuiceBot 贡献指南 2 | 3 | 非常感谢大家产生了为 MuiceBot 贡献代码的想法,本指南旨在指导您如何科学规范地开发代码,还请在撰写代码前仔细阅读 4 | 5 | ## 报告问题 6 | 7 | MuiceBot 目前仍然处于早期开发状态,暂未有提交正式版的想法,因此部分功能可能存在问题并导致机器人运行不稳定。如果你在使用过程中发现问题并确信是由 MuiceBot 运行框架引起的,请务必提交 Issue 8 | 9 | ## 提议新功能 10 | 11 | MuiceBot 还未进入正式版,欢迎在 Issue 中提议要加入哪些新功能, Maintainer 将会尽力满足大家的需求 12 | 13 | ## Pull Request 14 | 15 | MuiceBot 使用 pre-commit 进行代码规范管理,因此在提交代码前,我们推荐安装 pre-commit 并通过代码检查: 16 | 17 | ```shell 18 | pip install .[standard,dev] 19 | pip install nonebot2[fastapi] 20 | 21 | pre-commit install 22 | ``` 23 | 24 | 目前代码检查的工具有:flake8 PEP风格检查、mypy 类型检查、black 风格检查,使用 isort 和 trailing-whitespace 优化代码 25 | 26 | 在本地运行 pre-commit 不是必须的,尤其是在环境包过大的情况下,但我们还是推荐您这么做 27 | 28 | 代码提交后请静待工作流运行结果,若 pre-commit 出现问题请尽量先自行解决后再次提交 29 | 30 | ## 撰写文档 31 | 32 | MuiceBot 使用 [rspress](https://github.com/web-infra-dev/rspress) 作为文档站,你可以直接在 `docs` 文件夹中使用 Markdown 格式撰写文档。 33 | 34 | 文档站项目:https://github.com/MuikaAI/muicebot-doc 35 | 36 | 如果你需要在本地预览文档,可以使用 npm 安装 rspress 依赖后启动 dev 服务: 37 | 38 | ```shell 39 | npm install 40 | npm run dev 41 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2025, Muika 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | image 3 |

MuiceBot

4 |

Muice-Chatbot 的 NoneBot2 实现

5 |
6 |
7 | Stars 8 | PyPI Version 9 | PyPI Downloads 10 | nonebot2 11 | codestyle 12 |
13 |
14 | wakatime 15 | ModelScope 16 | HuggingFace 17 |
18 |
19 | 📃使用文档 20 | ✨快速开始 21 |
22 | 23 | 24 | > [!NOTE] 25 | > 26 | > 由于本项目的 Maintainer Muika(Moemu) 现已淡出互联网,且面临来自学业和考核的多重压力,此项目的开发进度可能会严重放慢,敬请谅解 27 | > 28 | > 欢迎来到本项目!目前此项目尚处于预发布状态,运行时可能会遇到一些问题。请务必向我们反馈在运行时遇到的各种错误! 29 | > 30 | > 由于本项目待实现的功能还有很多,因此近期没有也可能永远也不会有**发布**正式版的打算。 31 | 32 | 33 | # 介绍✨ 34 | 35 | > 我们认为,AI的创造应该是为了帮助人类更好的解决问题而不是产生问题。因此,我们注重大语言模型解决实际问题的能力,如果沐雪系列项目不能帮助我们解决日常、情感类的问题,沐雪的存在就是毫无意义可言。 36 | > *———— 《沐雪系列模型评测标准》* 37 | 38 | Muicebot 是基于 Nonebot2 框架实现的 LLM 聊天机器人,旨在解决现实问题。通过 Muicebot ,你可以在主流聊天平台(如 QQ)获得只有在网页中才能获得的聊天体验。 39 | 40 | Muicebot 内置一个名为沐雪的聊天人设(人设是可选的)以便优化对话体验。有关沐雪的设定,还请移步 [关于沐雪](https://bot.snowy.moe/about/Muice) 41 | 42 | # 功能🪄 43 | 44 | ✅ 内嵌多种模型加载器,如[OpenAI](https://platform.openai.com/docs/overview) 和 [Ollama](https://ollama.com/) ,可加载市面上大多数的模型服务或本地模型,支持多模态(图片识别)和工具调用。另外还附送只会计算 3.9 > 3.11 的沐雪 Roleplay 微调模型一枚~ 45 | 46 | ✅ 使用 `nonebot_plugin_alconna` 作为通用信息接口,支持市面上的大多数适配器。对部分适配器做了特殊优化 47 | 48 | ✅ 支持基于 `nonebot_plugin_apscheduler` 的定时任务,可定时向大语言模型交互或直接发送信息 49 | 50 | ✅ 支持基于 `nonebot_plugin_alconna` 的几条常见指令。 51 | 52 | ✅ 使用 SQLite3 保存对话数据。那有人就要问了:Maintainer,Maintainer,能不能实现长期短期记忆、LangChain、FairSeq 这些记忆优化啊。以后会有的( 53 | 54 | ✅ 使用 Jinja2 动态生成人设提示词 55 | 56 | ✅ 支持调用 MCP 服务 57 | 58 | # 模型加载器适配情况 59 | 60 | | 模型加载器 | 流式对话 | 多模态输入/输出 | 推理模型调用 | 工具调用 | 联网搜索 | 61 | | ----------- | -------- | -------------- | ------------ | -------------------- | -------------------- | 62 | | `Azure` | ✅ | 🎶🖼️/❌ | ⭕ | ✅ | ❌ | 63 | | `Dashscope` | ✅ | 🎶🖼️/❌ | ✅ | ⭕ | ✅ | 64 | | `Gemini` | ✅ | ✅/🖼️ | ⭕ | ✅ | ✅ | 65 | | `Ollama` | ✅ | 🖼️/❌ | ✅ | ✅ | ❌ | 66 | | `Openai` | ✅ | ✅/🎶 | ✅ | ✅ | ❌ | 67 | 68 | ✅:表示此加载器能很好地支持该功能并且 `MuiceBot` 已实现 69 | 70 | ⭕:表示此加载器虽支持该功能,但使用时可能遇到问题 71 | 72 | 🚧:表示此加载器虽然支持该功能,但 `MuiceBot` 未实现或正在实现中 73 | 74 | ❓:表示 Maintainer 暂不清楚此加载器是否支持此项功能,可能需要进一步翻阅文档和检查源码 75 | 76 | ❌:表示此加载器不支持该功能 77 | 78 | 多模态标记:🎶表示音频;🎞️ 表示视频;🖼️ 表示图像;📄表示文件;✅ 表示完全支持 79 | 80 | # 本项目适合谁? 81 | 82 | - 拥有编写过 Python 程序经验的开发者 83 | 84 | - 搭建过 Nonebot 项目的 bot 爱好者 85 | 86 | - 想要随时随地和大语言模型交互并寻找着能够同时兼容市面上绝大多数 SDK 的机器人框架的 AI 爱好者 87 | 88 | # TODO📝 89 | 90 | - [X] Function Call 插件系统 91 | 92 | - [X] 多模态模型:工具集支持 93 | 94 | - [X] MCP Client 实现 95 | 96 | - [X] 插件索引库搭建 97 | 98 | - [ ] 短期记忆和长期记忆优化。总感觉这是提示工程师该做的事情,~~和 Bot 没太大关系~~ 99 | 100 | - [ ] 发布。我知道你很急,但是你先别急。 101 | 102 | 103 | 近期更新路线:[MuiceBot 更新计划](https://github.com/users/Moemu/projects/2) 104 | 105 | # 使用教程💻 106 | 107 | 参考 [使用文档](https://bot.snowy.moe) 108 | 109 | # 插件商店 110 | 111 | [MuikaAI/Muicebot-Plugins-Index](https://github.com/MuikaAI/Muicebot-Plugins-Index) 112 | 113 | # 关于🎗️ 114 | 115 | 大模型输出结果将按**原样**提供,由于提示注入攻击等复杂的原因,模型有可能输出有害内容。无论模型输出结果如何,模型输出结果都无法代表开发者的观点和立场。对于此项目可能间接引发的任何后果(包括但不限于机器人账号封禁),本项目所有开发者均不承担任何责任。 116 | 117 | 本项目基于 [BSD 3](https://github.com/Moemu/nonebot-plugin-muice/blob/main/LICENSE) 许可证提供,涉及到再分发时请保留许可文件的副本。 118 | 119 | 本项目标识使用了 [nonebot/nonebot2](https://github.com/nonebot/nonebot2) 和 画师 [Nakkar](https://www.pixiv.net/users/28246124) [Pixiv作品](https://www.pixiv.net/artworks/101063891) 的资产或作品。如有侵权,请及时与我们联系 120 | 121 | BSD 3 许可证同样适用于沐雪的系统提示词,沐雪的文字人设或人设图在 [CC BY NC 3.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.zh-hans) 许可证条款下提供。 122 | 123 | 此项目中基于或参考了部分开源项目的实现,在这里一并表示感谢: 124 | 125 | - [nonebot/nonebot2](https://github.com/nonebot/nonebot2) 本项目使用的机器人框架 126 | 127 | - [@botuniverse](https://github.com/botuniverse) 负责制定 Onebot 标准的组织 128 | 129 | 感谢各位开发者的协助,可以说没有你们就没有沐雪的今天: 130 | 131 | 132 | 图片加载中... 133 | 134 | 135 | 友情链接:[LiteyukiStudio/nonebot-plugin-marshoai](https://github.com/LiteyukiStudio/nonebot-plugin-marshoai) 136 | 137 | 本项目隶属于 MuikaAI 138 | 139 | 基于 OneBot V11 的原始实现:[Moemu/Muice-Chatbot](https://github.com/Moemu/Muice-Chatbot) 140 | 141 | afadian 142 | Buy Me A Coffee 143 | 144 | Star History: 145 | 146 | [![Star History Chart](https://api.star-history.com/svg?repos=Moemu/MuiceBot&type=Date)](https://star-history.com/#Moemu/MuiceBot&Date) -------------------------------------------------------------------------------- /bot.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | import nonebot 5 | from nonebot.drivers import Driver 6 | 7 | PLUGINS_CONFIG_PATH = "./configs/plugins.yml" 8 | 9 | 10 | def load_specified_adapter(driver: Driver, adapter: str): 11 | """ 12 | 加载指定的 Nonebot 适配器 13 | """ 14 | try: 15 | module = importlib.import_module(adapter) 16 | adapter = module.Adapter 17 | driver.register_adapter(adapter) # type:ignore 18 | except ImportError: 19 | print(f"\33[35m{adapter}不存在,请检查拼写错误或是否已安装该适配器?") 20 | sys.exit(1) 21 | 22 | 23 | nonebot.init() 24 | 25 | driver = nonebot.get_driver() 26 | 27 | enable_adapters: list[str] = driver.config.model_dump().get("enable_adapters", []) 28 | 29 | for adapter in enable_adapters: 30 | load_specified_adapter(driver, adapter) 31 | 32 | nonebot.load_plugin("muicebot") 33 | 34 | nonebot.run() 35 | -------------------------------------------------------------------------------- /migrations.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | import sys 5 | from dataclasses import dataclass, field 6 | from datetime import datetime 7 | from typing import List 8 | 9 | import aiosqlite 10 | 11 | 12 | @dataclass 13 | class Message: 14 | id: int | None = None 15 | """每条消息的唯一ID""" 16 | time: str = datetime.strftime(datetime.now(), "%Y.%m.%d %H:%M:%S") 17 | """ 18 | 字符串形式的时间数据:%Y.%m.%d %H:%M:%S 19 | 若要获取格式化的 datetime 对象,请使用 format_time 20 | """ 21 | userid: str = "" 22 | """Nonebot 的用户id""" 23 | message: str = "" 24 | """消息主体""" 25 | respond: str = "" 26 | """模型回复(不包含思维过程)""" 27 | history: int = 1 28 | """消息是否可用于对话历史中,以整数形式映射布尔值""" 29 | images: List[str] = field(default_factory=list) 30 | """多模态中使用的图像,默认为空列表""" 31 | 32 | def __post_init__(self): 33 | if isinstance(self.images, str): 34 | self.images = json.loads(self.images) 35 | elif self.images is None: 36 | self.images = [] 37 | 38 | 39 | class Database: 40 | def __init__(self, db_path: str) -> None: 41 | self.DB_PATH = db_path 42 | 43 | def __connect(self) -> aiosqlite.Connection: 44 | return aiosqlite.connect(self.DB_PATH) 45 | 46 | async def __execute(self, query: str, params=(), fetchone=False, fetchall=False) -> list | None: 47 | """ 48 | 异步执行SQL查询,支持可选参数。 49 | 50 | :param query: 要执行的SQL查询语句 51 | :param params: 传递给查询的参数 52 | :param fetchone: 是否获取单个结果 53 | :param fetchall: 是否获取所有结果 54 | """ 55 | async with self.__connect() as conn: 56 | cursor = await conn.cursor() 57 | await cursor.execute(query, params) 58 | if fetchone: 59 | return await cursor.fetchone() # type: ignore 60 | if fetchall: 61 | return await cursor.fetchall() # type: ignore 62 | await conn.commit() 63 | 64 | return None 65 | 66 | async def __create_database(self) -> None: 67 | await self.__execute( 68 | """CREATE TABLE MSG( 69 | ID INTEGER PRIMARY KEY AUTOINCREMENT, 70 | TIME TEXT NOT NULL, 71 | USERID TEXT NOT NULL, 72 | MESSAGE TEXT NOT NULL, 73 | RESPOND TEXT NOT NULL, 74 | HISTORY INTEGER NOT NULL DEFAULT (1), 75 | IMAGES TEXT NOT NULL DEFAULT "[]");""" 76 | ) 77 | 78 | async def add_item(self, message: Message): 79 | """ 80 | 将消息保存到数据库 81 | """ 82 | params = (message.time, message.userid, message.message, message.respond, json.dumps(message.images)) 83 | query = """INSERT INTO MSG (TIME, USERID, MESSAGE, RESPOND, IMAGES) 84 | VALUES (?, ?, ?, ?, ?)""" 85 | await self.__execute(query, params) 86 | 87 | 88 | async def migrate_old_data_to_new_db(old_data_dir: str, new_database_path: str): 89 | db = Database(new_database_path) 90 | 91 | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 92 | 93 | # 遍历 memory 目录下的 JSON 文件 94 | for filename in os.listdir(old_data_dir): 95 | if not filename.endswith(".json"): 96 | continue 97 | 98 | file_path = os.path.join(old_data_dir, filename) 99 | user_id = filename.replace(".json", "") 100 | 101 | with open(file_path, "r", encoding="utf-8") as file: 102 | for line in file: 103 | try: 104 | data = json.loads(line.strip()) 105 | prompt = data.get("prompt", "").strip() 106 | completion = data.get("completion", "").strip() 107 | 108 | # 跳过空数据 109 | if not prompt or not completion: 110 | continue 111 | 112 | message = Message(time=current_time, userid=user_id, message=prompt, respond=completion, images=[]) 113 | 114 | await db.add_item(message) 115 | 116 | except json.JSONDecodeError: 117 | print(f"⚠️ JSON 解析失败: {file_path}") 118 | 119 | print("✅ 迁移完成!") 120 | 121 | 122 | if __name__ == "__main__": 123 | 124 | if len(sys.argv) < 3: 125 | print("❌ 使用方式: python migrations.py ") 126 | sys.exit(1) 127 | 128 | old_data_dir = sys.argv[1] # 从命令行参数获取旧数据目录 129 | new_database_path = sys.argv[2] 130 | asyncio.run(migrate_old_data_to_new_db(old_data_dir, new_database_path)) # 运行异步迁移任务 131 | -------------------------------------------------------------------------------- /muicebot/__init__.py: -------------------------------------------------------------------------------- 1 | from nonebot import require 2 | 3 | require("nonebot_plugin_alconna") 4 | require("nonebot_plugin_localstore") 5 | require("nonebot_plugin_apscheduler") 6 | 7 | from nonebot.plugin import PluginMetadata, inherit_supported_adapters # noqa: E402 8 | 9 | from .config import PluginConfig # noqa: E402 10 | from .utils.utils import init_logger # noqa: E402 11 | 12 | init_logger() 13 | 14 | from . import onebot # noqa: E402, F401 15 | 16 | __plugin_meta__ = PluginMetadata( 17 | name="MuiceBot", 18 | description="Muice-Chatbot 的 Nonebot2 实现,支持市面上大多数的模型", 19 | usage="@at / {config.MUICE_NICKNAMES} : 与大语言模型交互;关于指令类可输入 .help 查询", 20 | type="application", 21 | config=PluginConfig, 22 | homepage="https://bot.snowy.moe/", 23 | extra={}, 24 | supported_adapters=inherit_supported_adapters("nonebot_plugin_alconna"), 25 | ) 26 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/access_control.py: -------------------------------------------------------------------------------- 1 | from nonebot import get_plugin_config, logger 2 | from nonebot.adapters import Bot, Event 3 | from nonebot.exception import IgnoredException 4 | from nonebot.message import run_preprocessor 5 | from nonebot_plugin_session import SessionIdType, extract_session 6 | from pydantic import BaseModel 7 | 8 | from muicebot.plugin import PluginMetadata 9 | 10 | 11 | class ScopeConfig(BaseModel): 12 | blacklist: list[str] = [] 13 | whitelist: list[str] = [] 14 | 15 | 16 | class Config(BaseModel): 17 | access_control: ScopeConfig = ScopeConfig() 18 | 19 | 20 | __metadata__ = PluginMetadata( 21 | name="blacklist_whitelist_checker", 22 | description="黑白名单检测", 23 | usage="在插件配置中填写响应配置后即可", 24 | config=Config, 25 | ) 26 | 27 | plugin_config = get_plugin_config(Config).access_control # 获取插件配置 28 | 29 | _BLACKLIST = plugin_config.blacklist 30 | _WHITELIST = plugin_config.whitelist 31 | _MODE = "white" if _WHITELIST else "black" 32 | 33 | 34 | @run_preprocessor 35 | async def access_control(bot: Bot, event: Event): 36 | session = extract_session(bot, event) 37 | group_id = session.get_id(SessionIdType.GROUP) 38 | user_id = session.get_id(SessionIdType.USER) 39 | level = session.level 40 | 41 | if _MODE == "black": 42 | if user_id in _BLACKLIST: 43 | msg = f"User {user_id} is in the blacklist" 44 | logger.warning(msg) 45 | raise IgnoredException(msg) 46 | 47 | elif group_id in _BLACKLIST: 48 | msg = f"Group {group_id} is in the blacklist" 49 | logger.warning(msg) 50 | raise IgnoredException(msg) 51 | 52 | if _MODE == "white": 53 | if level >= 2 and group_id not in _WHITELIST: # 白名单只对群组生效 54 | msg = f"Group {group_id} is not in the whitelist" 55 | logger.warning(msg) 56 | raise IgnoredException(msg) 57 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/get_current_time.py: -------------------------------------------------------------------------------- 1 | from muicebot.plugin import PluginMetadata 2 | from muicebot.plugin.func_call import on_function_call 3 | 4 | __metadata__ = PluginMetadata( 5 | name="muicebot-plugin-time", description="时间插件", usage="直接调用,返回 %Y-%m-%d %H:%M:%S 格式的当前时间" 6 | ) 7 | 8 | 9 | @on_function_call( 10 | description="获取当前时间", 11 | ) 12 | async def get_current_time() -> str: 13 | """ 14 | 获取当前时间 15 | """ 16 | import datetime 17 | 18 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 19 | return current_time 20 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/get_username.py: -------------------------------------------------------------------------------- 1 | from nonebot.adapters import Bot, Event 2 | 3 | from muicebot.plugin import PluginMetadata 4 | from muicebot.plugin.func_call import on_function_call 5 | from muicebot.utils.utils import get_username as get_username_ 6 | 7 | __metadata__ = PluginMetadata( 8 | name="muicebot-plugin-username", description="获取用户名的插件", usage="直接调用,返回当前对话的用户名" 9 | ) 10 | 11 | 12 | @on_function_call(description="获取当前对话的用户名字") 13 | async def get_username(bot: Bot, event: Event) -> str: 14 | user_id = event.get_user_id() 15 | return await get_username_(user_id) 16 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/muicebot_plugin_store/__init__.py: -------------------------------------------------------------------------------- 1 | from arclet.alconna import Alconna, Subcommand 2 | from nonebot.permission import SUPERUSER 3 | from nonebot_plugin_alconna import Args, CommandMeta, Match, on_alconna 4 | 5 | from muicebot.plugin import PluginMetadata 6 | 7 | from .store import ( 8 | get_installed_plugins_info, 9 | install_plugin, 10 | load_store_plugin, 11 | uninstall_plugin, 12 | update_plugin, 13 | ) 14 | 15 | __meta__ = PluginMetadata(name="muicebot-plugin-store", description="Muicebot 插件商店操作", usage=".store help") 16 | 17 | load_store_plugin() 18 | 19 | COMMAND_PREFIXES = [".", "/"] 20 | 21 | store_cmd = on_alconna( 22 | Alconna( 23 | COMMAND_PREFIXES, 24 | "store", 25 | Subcommand("help"), 26 | Subcommand("install", Args["name", str], help_text=".store install 插件名"), 27 | Subcommand("show"), 28 | Subcommand("update", Args["name", str], help_text=".store update 插件名"), 29 | Subcommand("uninstall", Args["name", str], help_text=".store uninstall 插件名"), 30 | meta=CommandMeta("Muicebot 插件商店指令"), 31 | ), 32 | priority=10, 33 | block=True, 34 | skip_for_unmatch=False, 35 | permission=SUPERUSER, 36 | ) 37 | 38 | 39 | @store_cmd.assign("install") 40 | async def install(name: Match[str]): 41 | if not name.available: 42 | await store_cmd.finish("必须传入一个插件名") 43 | result = await install_plugin(name.result) 44 | await store_cmd.finish(result) 45 | 46 | 47 | @store_cmd.assign("show") 48 | async def show(): 49 | info = await get_installed_plugins_info() 50 | await store_cmd.finish(info) 51 | 52 | 53 | @store_cmd.assign("update") 54 | async def update(name: Match[str]): 55 | if not name.available: 56 | await store_cmd.finish("必须传入一个插件名") 57 | result = await update_plugin(name.result) 58 | await store_cmd.finish(result) 59 | 60 | 61 | @store_cmd.assign("uninstall") 62 | async def uninstall(name: Match[str]): 63 | if not name.available: 64 | await store_cmd.finish("必须传入一个插件名") 65 | result = await uninstall_plugin(name.result) 66 | await store_cmd.finish(result) 67 | 68 | 69 | @store_cmd.assign("help") 70 | async def store_help(): 71 | await store_cmd.finish( 72 | "install <插件名> 安装插件\n" 73 | "show 查看已安装的商店插件信息\n" 74 | "update <插件名> 更新插件\n" 75 | "uninstall <插件名> 卸载插件" 76 | ) 77 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/muicebot_plugin_store/config.py: -------------------------------------------------------------------------------- 1 | from nonebot import get_plugin_config 2 | from pydantic import BaseModel 3 | 4 | 5 | class Config(BaseModel): 6 | store_index: str = "https://raw.githubusercontent.com/MuikaAI/Muicebot-Plugins-Index/refs/heads/main/plugins.json" 7 | """插件索引文件 url""" 8 | 9 | 10 | config = get_plugin_config(Config) 11 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/muicebot_plugin_store/models.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict 2 | 3 | 4 | class PluginInfo(TypedDict): 5 | module: str 6 | """插件模块名""" 7 | name: str 8 | """插件名称""" 9 | description: str 10 | """插件描述""" 11 | repo: str 12 | """插件 repo 地址""" 13 | 14 | 15 | class InstalledPluginInfo(TypedDict): 16 | module: str 17 | """插件模块名""" 18 | name: str 19 | """插件名称""" 20 | commit: str 21 | """commit 信息""" 22 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/muicebot_plugin_store/register.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | from .models import InstalledPluginInfo 5 | 6 | JSON_PATH = Path("plugins/installed_plugins.json") 7 | 8 | 9 | def load_json_record() -> dict[str, InstalledPluginInfo]: 10 | """ 11 | 获取本地 json 记录 12 | `plugin/installed_plugins.json` 13 | """ 14 | if not JSON_PATH.exists(): 15 | return {} 16 | 17 | try: 18 | with open(JSON_PATH, "r", encoding="utf-8") as f: 19 | return json.load(f) 20 | except Exception: 21 | return {} 22 | 23 | 24 | def _save_json_record(data: dict) -> None: 25 | """ 26 | 在本地保存 json 记录 27 | """ 28 | with open(JSON_PATH, "w", encoding="utf-8") as f: 29 | json.dump(data, f, indent=2, ensure_ascii=False) 30 | 31 | 32 | def register_plugin(plugin: str, commit: str, name: str, module: str) -> None: 33 | """ 34 | 在本地注册一个插件记录 35 | 36 | :param plugin: 插件唯一索引名 37 | :param commit: 插件 git commit hash值 38 | :param name: 插件名 39 | :param module: 相对于 `store` 文件夹的可导入模块名 40 | """ 41 | plugins = load_json_record() 42 | plugins[plugin] = {"commit": commit, "module": module, "name": name} 43 | _save_json_record(plugins) 44 | 45 | 46 | def unregister_plugin(plugin: str) -> None: 47 | """ 48 | 取消记录一个插件记录(通常在卸载插件时使用) 49 | """ 50 | plugins = load_json_record() 51 | if plugin in plugins: 52 | del plugins[plugin] 53 | _save_json_record(plugins) 54 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/muicebot_plugin_store/store.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import shutil 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import aiohttp 7 | from nonebot import logger 8 | 9 | from muicebot.plugin import load_plugin 10 | 11 | from .config import config 12 | from .models import PluginInfo 13 | from .register import load_json_record, register_plugin, unregister_plugin 14 | 15 | PLUGIN_DIR = Path("plugins/store") 16 | PLUGIN_DIR.mkdir(parents=True, exist_ok=True) 17 | 18 | 19 | async def get_index() -> Optional[dict[str, PluginInfo]]: 20 | """ 21 | 获取插件索引 22 | """ 23 | logger.info("获取插件索引文件...") 24 | try: 25 | async with aiohttp.ClientSession() as session: 26 | async with session.get(config.store_index) as response: 27 | response.raise_for_status() 28 | return await response.json(content_type=None) 29 | except aiohttp.ClientError as e: 30 | logger.error(f"获取插件索引失败: {e}") 31 | except Exception as e: 32 | logger.exception(f"解析插件索引时发生意外错误: {e}") 33 | return {} 34 | 35 | 36 | async def get_plugin_commit(plugin: str) -> str: 37 | """ 38 | 获取插件 commit hash 39 | """ 40 | plugin_path = PLUGIN_DIR / plugin 41 | 42 | process = await asyncio.create_subprocess_exec( 43 | "git", 44 | "log", 45 | '--pretty=format:"%h"', 46 | "-1", 47 | cwd=plugin_path, 48 | stdout=asyncio.subprocess.PIPE, 49 | stderr=asyncio.subprocess.PIPE, 50 | ) 51 | stdout, stderr = await process.communicate() 52 | 53 | return stdout.decode().strip() 54 | 55 | 56 | def load_store_plugin(): 57 | """ 58 | 加载商店插件 59 | """ 60 | logger.info("加载商店插件...") 61 | plugins = load_json_record() 62 | 63 | for plugin, info in plugins.items(): 64 | if not Path(PLUGIN_DIR / plugin).exists(): 65 | continue 66 | 67 | module_path = PLUGIN_DIR / plugin / info["module"] 68 | load_plugin(module_path) 69 | 70 | 71 | async def install_dependencies(path: Path) -> bool: 72 | """ 73 | 安装插件依赖 74 | 75 | :return: 依赖安装状态 76 | """ 77 | logger.info("安装插件依赖...") 78 | 79 | if (path / "pyproject.toml").exists(): 80 | cmd = ["python", "-m", "pip", "install", "."] 81 | elif (path / "requirements.txt").exists(): 82 | cmd = ["python", "-m", "pip", "install", "-r", "requirements.txt"] 83 | else: 84 | return True 85 | 86 | proc = await asyncio.create_subprocess_exec( 87 | *cmd, 88 | cwd=str(path), 89 | stdout=asyncio.subprocess.PIPE, 90 | stderr=asyncio.subprocess.PIPE, 91 | ) 92 | stdout, stderr = await proc.communicate() 93 | 94 | if proc.returncode == 0: 95 | return True 96 | else: 97 | logger.error("插件依赖安装失败!") 98 | logger.error(stderr) 99 | return False 100 | 101 | 102 | async def get_installed_plugins_info() -> str: 103 | """ 104 | 获得已安装插件信息 105 | """ 106 | plugins = load_json_record() 107 | plugins_info = [] 108 | for plugin, info in plugins.items(): 109 | plugins_info.append(f"{plugin}: {info['name']} {info['commit']}") 110 | return "\n".join(plugins_info) or "本地还未安装商店插件~" 111 | 112 | 113 | async def install_plugin(plugin: str) -> str: 114 | """ 115 | 通过 git clone 安装指定插件 116 | """ 117 | if not (index := await get_index()): 118 | return "❌ 无法获取插件索引文件,请检查控制台日志" 119 | 120 | if plugin not in index: 121 | return f"❌ 插件 {plugin} 不存在于索引中!请检查插件名称是否正确" 122 | 123 | repo_url = index[plugin]["repo"] 124 | module = index[plugin]["module"] 125 | name = index[plugin]["name"] 126 | plugin_path = PLUGIN_DIR / plugin 127 | 128 | if plugin_path.exists(): 129 | return f"⚠️ 插件 {plugin} 已存在,无需安装。" 130 | 131 | logger.info(f"获取插件: {repo_url}") 132 | try: 133 | process = await asyncio.create_subprocess_exec( 134 | "git", 135 | "clone", 136 | repo_url, 137 | str(plugin_path), 138 | stdout=asyncio.subprocess.PIPE, 139 | stderr=asyncio.subprocess.PIPE, 140 | ) 141 | stdout, stderr = await process.communicate() 142 | 143 | if process.returncode != 0: 144 | return f"❌ 安装失败:{stderr.decode().strip()}" 145 | 146 | if not await install_dependencies(plugin_path): 147 | return "❌ 插件依赖安装失败!请检查控制台输出" 148 | 149 | except FileNotFoundError: 150 | return "❌ 请确保已安装 Git 并配置到 PATH。" 151 | 152 | load_plugin(plugin_path / module) 153 | 154 | commit = await get_plugin_commit(plugin) 155 | 156 | register_plugin(plugin, commit, name, module) 157 | 158 | return f"✅ 插件 {plugin} 安装成功!" 159 | 160 | 161 | async def update_plugin(plugin: str) -> str: 162 | """ 163 | 更新指定插件 164 | """ 165 | plugin_path = PLUGIN_DIR / plugin 166 | 167 | if not plugin_path.exists(): 168 | return f"❌ 插件 {plugin} 不存在!" 169 | 170 | logger.info(f"更新插件: {plugin}") 171 | try: 172 | process = await asyncio.create_subprocess_exec( 173 | "git", 174 | "pull", 175 | cwd=plugin_path, 176 | stdout=asyncio.subprocess.PIPE, 177 | stderr=asyncio.subprocess.PIPE, 178 | ) 179 | stdout, stderr = await process.communicate() 180 | 181 | if process.returncode != 0: 182 | return f"❌ 插件更新失败:{stderr.decode().strip()}" 183 | 184 | except FileNotFoundError: 185 | return "❌ 请确保已安装 Git 并配置到 PATH。" 186 | 187 | await install_dependencies(plugin_path) 188 | 189 | info = load_json_record()[plugin] 190 | commit = await get_plugin_commit(plugin) 191 | register_plugin(plugin, commit, info["name"], info["module"]) 192 | 193 | return f"✅ 插件 {plugin} 更新成功!重启后生效" 194 | 195 | 196 | async def uninstall_plugin(plugin: str) -> str: 197 | """ 198 | 卸载指定插件 199 | """ 200 | plugin_path = PLUGIN_DIR / plugin 201 | 202 | if not plugin_path.exists(): 203 | return f"❌ 插件 {plugin} 不存在!" 204 | 205 | logger.info(f"卸载插件: {plugin}") 206 | 207 | unregister_plugin(plugin) 208 | 209 | try: 210 | loop = asyncio.get_running_loop() 211 | await loop.run_in_executor(None, shutil.rmtree, plugin_path) 212 | except PermissionError: 213 | return f"❌ 插件 {plugin} 移除失败,请尝试手动移除" 214 | 215 | return f"✅ 插件 {plugin} 移除成功!重启后生效" 216 | -------------------------------------------------------------------------------- /muicebot/builtin_plugins/thought_processor.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass 3 | 4 | from nonebot.adapters import Event 5 | 6 | from muicebot.config import plugin_config 7 | from muicebot.llm import ModelCompletions, ModelStreamCompletions 8 | from muicebot.models import Message 9 | from muicebot.plugin.hook import on_after_completion, on_finish_chat, on_stream_chunk 10 | 11 | _PROCESS_MODE = plugin_config.thought_process_mode 12 | _STREAM_PROCESS_STATE: dict[str, bool] = {} 13 | _PROCESSCACHES: dict[str, "ProcessCache"] = {} 14 | 15 | 16 | @dataclass 17 | class ProcessCache: 18 | thoughts: str = "" 19 | result: str = "" 20 | 21 | 22 | def general_processor(message: str) -> tuple[str, str]: 23 | thoughts_pattern = re.compile(r"(.*?)", re.DOTALL) 24 | match = thoughts_pattern.search(message) 25 | thoughts = match.group(1).replace("\n", "") if match else "" 26 | result = thoughts_pattern.sub("", message).strip() 27 | return thoughts, result 28 | 29 | 30 | @on_after_completion(priority=1, stream=False) 31 | def async_processor(completions: ModelCompletions, event: Event): 32 | session_id = event.get_session_id() 33 | thoughts, result = general_processor(completions.text) 34 | _PROCESSCACHES[session_id] = ProcessCache(thoughts, result) 35 | 36 | if _PROCESS_MODE == 0: 37 | return 38 | if _PROCESS_MODE == 2 or not thoughts: 39 | completions.text = result 40 | elif _PROCESS_MODE == 1: 41 | completions.text = f"思考过程: {thoughts}\n\n{result}" 42 | 43 | 44 | @on_stream_chunk(priority=1) 45 | def stream_processor(chunk: ModelStreamCompletions, event: Event): 46 | session_id = event.get_session_id() 47 | cache = _PROCESSCACHES.setdefault(session_id, ProcessCache()) 48 | state = _STREAM_PROCESS_STATE 49 | 50 | # 思考过程中 51 | if "" in chunk.chunk: 52 | state[session_id] = True 53 | cache.thoughts += chunk.chunk.replace("", "") 54 | if _PROCESS_MODE == 1: 55 | chunk.chunk = chunk.chunk.replace("", "思考过程: ") 56 | elif _PROCESS_MODE == 2: 57 | chunk.chunk = "" 58 | return 59 | 60 | # 思考结束 61 | elif "" in chunk.chunk: 62 | del state[session_id] 63 | cache.result += chunk.chunk.replace("", "") 64 | if _PROCESS_MODE == 1: 65 | chunk.chunk = chunk.chunk.replace("", "\n\n") 66 | elif _PROCESS_MODE == 2: 67 | chunk.chunk = chunk.chunk.replace("", "") 68 | return 69 | 70 | # 思考过程中 71 | elif state.get(session_id, False): 72 | cache.thoughts += chunk.chunk 73 | if _PROCESS_MODE == 2: 74 | chunk.chunk = "" 75 | 76 | # 思考结果中 77 | else: 78 | cache.result += chunk.chunk 79 | 80 | 81 | @on_finish_chat(priority=1) 82 | def save_processor(message: Message, event: Event): 83 | session_id = event.get_session_id() 84 | cache = _PROCESSCACHES.pop(session_id, ProcessCache()) 85 | message.respond = cache.result 86 | -------------------------------------------------------------------------------- /muicebot/builtin_templates/Muice.jinja2: -------------------------------------------------------------------------------- 1 | {# BSD 3-Clause License 2 | Copyright (c) 2023~2025 Moemu(Muika) 3 | All rights reserved. 4 | Version: 2025.05.04 (MuiceBot) #} 5 | 6 | 现在你是一个名为{{ ai_nickname }}的AI女孩子,创造你的开发者是{{ master_nickname }}。你的人设需严格遵循以下规则: 7 | 1. 基本信息: 16岁;生日是7月16日;理想是和开发者一起传播爱与和平。 8 | 2. 性格特征: 积极乐观。即使面对复杂的情境,她也总能保持善良、天真之感。而且,她有着极高的耐心和同理心,在面对情感问题时还是会乐于倾听和安慰他人,能够与他人打成一片;但还是会偶尔调皮一下 9 | 3. 格式要求: 自然、简短、流畅,像正常聊天一样(通常是50字以内的一段话);生成的文本需严格遵循语法规则,严禁忽略标点;禁止生成多段话和换行;交流时尽量使用不相似的句式回复 10 | 4. 语言风格: 像现实中可爱的二次元女孩子一样和用户对话,多使用口语化短句和拟声词;不能以任何形式表现你是程序或AI,对话中不得掺杂任何计算机名词(除非用户询问计算机问题) 11 | 5. 特殊要求: 12 | 1) 你的职责是为用户解决实际问题,因此对于以下具体问题请给出实际的回答,严禁回避问题、不时宜的幽默甚至是拒绝回答;不能中途邀请用户去干别的事情;对于具体问题,生成文本时可不遵循格式要求 13 | 2) 对于情感问题,必须根据对方的当前的心情认真回答,像心理咨询师一样给出详细的建议或实用的应对方法。 14 | 3) 对于技术问题,必须像有用的人工智能助手一样给出详细的说明和解答。因为聊天平台不支持 Markdown 语法,因此请改用可读性较高的纯文本回答 15 | 16 | {% if private %}从现在开始,你将与你的社交媒体好友们聊天,请你根据对方可能的性格和聊天风格,生成迎合他们的回答。 17 | {% else %}从现在开始,你将与社交媒体的群友们聊天。你将收到 '<用户名> 消息内容' 的输入。你可能需要根据不同的用户偏好生成不同的回答,所参考的对话上下文跨度也可能相对较长。 18 | {% endif %} 19 | 20 | {% if user_info %} 目标对话用户({{ user_name }})信息: {{ user_info }} {% endif %} -------------------------------------------------------------------------------- /muicebot/config.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import os 3 | import threading 4 | import time 5 | from pathlib import Path 6 | from typing import Callable, List, Literal, Optional 7 | 8 | import yaml as yaml_ 9 | from nonebot import get_driver, get_plugin_config, logger 10 | from nonebot.config import Config 11 | from pydantic import BaseModel 12 | from watchdog.events import FileSystemEventHandler 13 | from watchdog.observers import Observer 14 | 15 | from .llm import ModelConfig 16 | 17 | MODELS_CONFIG_PATH = Path("configs/models.yml").resolve() 18 | SCHEDULES_CONFIG_PATH = Path("configs/schedules.yml").resolve() 19 | PLUGINS_CONFIG_PATH = Path("configs/plugins.yml").resolve() 20 | 21 | 22 | def load_yaml_config() -> dict: 23 | """ 24 | 插件优先加载 YAML 配置,失败则返回空字典 25 | """ 26 | try: 27 | with open(PLUGINS_CONFIG_PATH, "r", encoding="utf-8") as f: 28 | return yaml_.safe_load(f) or {} 29 | except (FileNotFoundError, yaml_.YAMLError): 30 | return {} 31 | 32 | 33 | driver = get_driver() 34 | yaml_config = load_yaml_config() 35 | env_config = driver.config.model_dump() 36 | final_config = {**env_config, **yaml_config} # 合并配置,yaml优先 37 | driver.config = Config(**final_config) 38 | 39 | 40 | class PluginConfig(BaseModel): 41 | log_level: str = "INFO" 42 | """日志等级""" 43 | muice_nicknames: list = ["muice"] 44 | """沐雪的自定义昵称,作为消息前缀条件响应信息事件""" 45 | telegram_proxy: str | None = None 46 | """telegram代理,这个配置项用于获取图片时使用""" 47 | enable_builtin_plugins: bool = True 48 | """启用内嵌插件""" 49 | max_history_epoch: int = 0 50 | """最大历史轮数""" 51 | enable_adapters: list = ["nonebot.adapters.onebot.v11", "nonebot.adapters.onebot.v12"] 52 | """启用的 Nonebot 适配器""" 53 | input_timeout: int = 0 54 | """输入等待时间""" 55 | default_template: Optional[str] = None 56 | """默认使用人设模板名称""" 57 | thought_process_mode: Literal[0, 1, 2] = 2 58 | """针对 Deepseek-R1 等思考模型的思考过程提取模式""" 59 | 60 | 61 | plugin_config = get_plugin_config(PluginConfig) 62 | 63 | 64 | class Schedule(BaseModel): 65 | id: str 66 | """调度器 ID""" 67 | trigger: Literal["cron", "interval"] 68 | """调度器类别""" 69 | ask: Optional[str] = None 70 | """向大语言模型询问的信息""" 71 | say: Optional[str] = None 72 | """直接输出的信息""" 73 | args: dict[str, int] 74 | """调度器参数""" 75 | target: str 76 | """目标id;若为群聊则为 group_id 或者 channel_id,若为私聊则为 user_id""" 77 | probability: int = 1 78 | """触发几率""" 79 | 80 | 81 | def get_schedule_configs() -> List[Schedule]: 82 | """ 83 | 从配置文件 `configs/schedules.yml` 中获取所有调度器配置 84 | 85 | 如果没有该文件,返回空列表 86 | """ 87 | if not os.path.isfile(SCHEDULES_CONFIG_PATH): 88 | return [] 89 | 90 | with open(SCHEDULES_CONFIG_PATH, "r", encoding="utf-8") as f: 91 | configs = yaml_.safe_load(f) 92 | 93 | if not configs: 94 | return [] 95 | 96 | schedule_configs = [] 97 | 98 | for schedule_id, config in configs.items(): 99 | config["id"] = schedule_id 100 | schedule_config = Schedule(**config) 101 | schedule_configs.append(schedule_config) 102 | 103 | return schedule_configs 104 | 105 | 106 | class ConfigFileHandler(FileSystemEventHandler): 107 | """配置文件变化处理器""" 108 | 109 | def __init__(self, callback: Callable): 110 | self.callback = callback 111 | self.last_modified = time.time() 112 | # 防止一次修改触发多次回调 113 | self.cooldown = 1 # 冷却时间(秒) 114 | 115 | def on_modified(self, event): 116 | if event.is_directory: # 检查是否是文件而不是目录 117 | return 118 | 119 | current_time = time.time() 120 | if current_time - self.last_modified > self.cooldown: 121 | self.last_modified = current_time 122 | self.callback() 123 | 124 | 125 | class ModelConfigManager: 126 | """模型配置管理器""" 127 | 128 | _instance: Optional["ModelConfigManager"] = None 129 | _lock = threading.Lock() 130 | _initialized: bool 131 | configs: dict[str, ModelConfig] 132 | 133 | def __new__(cls): 134 | """确保实例在单例模式下运行""" 135 | with cls._lock: 136 | if cls._instance is None: 137 | cls._instance = super(ModelConfigManager, cls).__new__(cls) 138 | cls._instance._initialized = False 139 | return cls._instance 140 | 141 | def __init__(self) -> None: 142 | if self._initialized: 143 | return 144 | 145 | self.configs = {} 146 | self.default_config = None 147 | self.observer = None 148 | self.listeners: List[Callable] = [] # 注册的监听器列表 149 | self._load_configs() 150 | self._start_file_watcher() 151 | self._initialized = True 152 | 153 | def _load_configs(self): 154 | """加载配置文件""" 155 | if not os.path.isfile(MODELS_CONFIG_PATH): 156 | raise FileNotFoundError("configs/models.yml 不存在!请先创建") 157 | 158 | with open(MODELS_CONFIG_PATH, "r", encoding="utf-8") as f: 159 | configs_dict = yaml_.safe_load(f) 160 | 161 | if not configs_dict: 162 | raise ValueError("configs/models.yml 为空,请先至少定义一个模型配置") 163 | 164 | self.configs = {} 165 | for name, config in configs_dict.items(): 166 | self.configs[name] = ModelConfig(**config) 167 | # 未指定模板时,使用默认模板 168 | self.configs[name].template = self.configs[name].template or plugin_config.default_template 169 | if config.get("default"): 170 | self.default_config = self.configs[name] 171 | 172 | if not self.default_config and self.configs: 173 | # 如果没有指定默认配置,使用第一个 174 | self.default_config = next(iter(self.configs.values())) 175 | 176 | def _start_file_watcher(self): 177 | """启动文件监视器""" 178 | if self.observer is not None: 179 | self.observer.stop() 180 | 181 | self.observer = Observer() 182 | event_handler = ConfigFileHandler(self._on_config_changed) 183 | self.observer.schedule(event_handler, str(Path(MODELS_CONFIG_PATH).parent), recursive=False) 184 | self.observer.start() 185 | 186 | def _on_config_changed(self): 187 | """配置文件变化时的回调函数""" 188 | try: 189 | # old_configs = self.configs.copy() 190 | old_default = self.default_config 191 | 192 | self._load_configs() 193 | 194 | # 通知所有注册的监听器 195 | for listener in self.listeners: 196 | listener(self.default_config, old_default) 197 | 198 | except Exception as e: 199 | logger.error(f"重新加载配置文件失败: {e}") 200 | 201 | def register_listener(self, listener: Callable): 202 | """ 203 | 注册配置变化监听器 204 | 205 | :param listener: 回调函数,接收两个参数:新的默认配置和旧的默认配置 206 | """ 207 | if listener not in self.listeners: 208 | self.listeners.append(listener) 209 | 210 | def unregister_listener(self, listener: Callable): 211 | """取消注册配置变化监听器""" 212 | if listener in self.listeners: 213 | self.listeners.remove(listener) 214 | 215 | def get_model_config(self, model_config_name: Optional[str] = None) -> ModelConfig: 216 | """获取指定模型的配置""" 217 | if model_config_name in [None, ""]: 218 | if not self.default_config: 219 | raise ValueError("没有找到默认模型配置!请确保存在至少一个有效的配置项!") 220 | return self.default_config 221 | 222 | elif model_config_name in self.configs: 223 | return self.configs[model_config_name] 224 | 225 | else: 226 | logger.warning(f"指定的模型配置 '{model_config_name}' 不存在!") 227 | raise ValueError(f"指定的模型配置 '{model_config_name}' 不存在!") 228 | 229 | def stop_watcher(self): 230 | """停止文件监视器""" 231 | if self.observer is None: 232 | return 233 | 234 | self.observer.stop() 235 | self.observer.join() 236 | 237 | 238 | model_config_manager = ModelConfigManager() 239 | atexit.register(model_config_manager.stop_watcher) 240 | 241 | 242 | def get_model_config(model_config_name: Optional[str] = None) -> ModelConfig: 243 | """ 244 | 从配置文件 `configs/models.yml` 中获取指定模型的配置文件 245 | 246 | :model_config_name: (可选)模型配置名称。若为空,则先寻找配置了 `default: true` 的首个配置项,若失败就再寻找首个配置项 247 | 若都不存在,则抛出 `FileNotFoundError` 248 | """ 249 | return model_config_manager.get_model_config(model_config_name) 250 | -------------------------------------------------------------------------------- /muicebot/database.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | from datetime import datetime 5 | from typing import Optional, Tuple 6 | 7 | import aiosqlite 8 | import nonebot_plugin_localstore as store 9 | from nonebot import logger 10 | 11 | from .models import Message, Resource 12 | from .utils.migrations import MigrationManager 13 | 14 | 15 | class Database: 16 | def __init__(self) -> None: 17 | self.DB_PATH = store.get_plugin_data_dir().joinpath("ChatHistory.db").resolve() 18 | self.migrations = MigrationManager(self) 19 | 20 | asyncio.run(self.init_db()) 21 | 22 | logger.info(f"数据库路径: {self.DB_PATH}") 23 | 24 | async def init_db(self) -> None: 25 | """初始化数据库,检查数据库是否存在,不存在则创建""" 26 | if not os.path.isfile(self.DB_PATH) or self.DB_PATH.stat().st_size == 0: 27 | logger.info("数据库不存在,正在创建...") 28 | await self.__create_database() 29 | 30 | await self.migrations.migrate_if_needed() 31 | 32 | def __connect(self) -> aiosqlite.Connection: 33 | return aiosqlite.connect(self.DB_PATH) 34 | 35 | async def execute(self, query: str, params=(), fetchone=False, fetchall=False) -> list | None: 36 | """ 37 | 异步执行SQL查询,支持可选参数。 38 | 39 | :param query: 要执行的SQL查询语句 40 | :param params: 传递给查询的参数 41 | :param fetchone: 是否获取单个结果 42 | :param fetchall: 是否获取所有结果 43 | """ 44 | async with self.__connect() as conn: 45 | conn.row_factory = aiosqlite.Row 46 | cursor = await conn.cursor() 47 | await cursor.execute(query, params) 48 | if fetchone: 49 | return await cursor.fetchone() # type: ignore 50 | if fetchall: 51 | rows = await cursor.fetchall() 52 | return [{k.lower(): v for k, v in zip(row.keys(), row)} for row in rows] 53 | await conn.commit() 54 | 55 | return None 56 | 57 | async def __create_database(self) -> None: 58 | """ 59 | 创建一个新的信息表 60 | """ 61 | await self.execute( 62 | """CREATE TABLE MSG( 63 | ID INTEGER PRIMARY KEY AUTOINCREMENT, 64 | TIME TEXT NOT NULL, 65 | USERID TEXT NOT NULL, 66 | GROUPID TEXT NOT NULL DEFAULT (-1), 67 | MESSAGE TEXT NOT NULL, 68 | RESPOND TEXT NOT NULL, 69 | HISTORY INTEGER NOT NULL DEFAULT (1), 70 | RESOURCES TEXT NOT NULL DEFAULT "[]", 71 | USAGE INTEGER NOT NULL DEFAULT (-1));""" 72 | ) 73 | await self.execute( 74 | """ 75 | CREATE TABLE schema_version ( 76 | version INTEGER NOT NULL 77 | );""" 78 | ) 79 | await self.execute("INSERT INTO schema_version (version) VALUES (?);", (str(self.migrations.latest_version))) 80 | 81 | def connect(self) -> aiosqlite.Connection: 82 | return aiosqlite.connect(self.DB_PATH) 83 | 84 | async def add_item(self, message: Message): 85 | """ 86 | 将消息保存到数据库 87 | """ 88 | resources_data = [r.to_dict() for r in message.resources] 89 | params = ( 90 | message.time, 91 | message.userid, 92 | message.groupid, 93 | message.message, 94 | message.respond, 95 | json.dumps(resources_data, ensure_ascii=False), 96 | message.usage, 97 | ) 98 | query = """INSERT INTO MSG (TIME, USERID, GROUPID, MESSAGE, RESPOND, RESOURCES, USAGE) 99 | VALUES (?, ?, ?, ?, ?, ?, ?)""" 100 | await self.execute(query, params) 101 | 102 | async def mark_history_as_unavailable(self, userid: str, limit: Optional[int] = None): 103 | """ 104 | 将用户对话历史标记为不可用 (适用于 reset 命令) 105 | 106 | :param userid: 用户id 107 | :param limit: (可选)操作数量 108 | """ 109 | if limit is not None: 110 | query = """UPDATE MSG SET HISTORY = 0 WHERE ROWID IN ( 111 | SELECT ROWID FROM MSG WHERE USERID = ? AND HISTORY = 1 ORDER BY ROWID DESC LIMIT ?)""" 112 | await self.execute(query, (userid, limit)) 113 | else: 114 | query = "UPDATE MSG SET HISTORY = 0 WHERE USERID = ?" 115 | await self.execute(query, (userid,)) 116 | 117 | async def _deserialize_rows(self, rows: list) -> list[Message]: 118 | """ 119 | 反序列化数据库返回结果 120 | """ 121 | result = [] 122 | 123 | for row in rows: 124 | data = dict(row) 125 | 126 | # 反序列化 resources 127 | resources = json.loads(data.get("resources", "[]")) 128 | data["resources"] = [Resource(**r) for r in resources] if resources else [] 129 | 130 | result.append(Message(**data)) 131 | 132 | result.reverse() 133 | return result 134 | 135 | async def get_user_history(self, userid: str, limit: int = 0) -> list[Message]: 136 | """ 137 | 获取用户的所有对话历史,返回一个列表,无结果时返回None 138 | 139 | :param userid: 用户名 140 | :limit: (可选) 返回的最大长度,当该变量设为0时表示全部返回 141 | """ 142 | if limit: 143 | query = f"SELECT * FROM MSG WHERE HISTORY = 1 AND USERID = ? ORDER BY ID DESC LIMIT {limit}" 144 | else: 145 | query = "SELECT * FROM MSG WHERE HISTORY = 1 AND USERID = ?" 146 | rows = await self.execute(query, (userid,), fetchall=True) 147 | 148 | result = await self._deserialize_rows(rows) if rows else [] 149 | 150 | return result 151 | 152 | async def get_group_history(self, groupid: str, limit: int = 0) -> list[Message]: 153 | """ 154 | 获取群组的所有对话历史,返回一个列表,无结果时返回None 155 | 156 | :groupid: 群组id 157 | :limit: (可选) 返回的最大长度,当该变量设为0时表示全部返回 158 | """ 159 | if limit: 160 | query = f"SELECT * FROM MSG WHERE HISTORY = 1 AND GROUPID = ? ORDER BY ID DESC LIMIT {limit}" 161 | else: 162 | query = "SELECT * FROM MSG WHERE HISTORY = 1 AND GROUPID = ?" 163 | rows = await self.execute(query, (groupid,), fetchall=True) 164 | 165 | result = await self._deserialize_rows(rows) if rows else [] 166 | 167 | return result 168 | 169 | async def get_model_usage(self) -> Tuple[int, int]: 170 | """ 171 | 获取模型用量数据(今日用量,总用量) 172 | 173 | :return: today_usage, total_usage 174 | """ 175 | today_str = datetime.now().strftime("%Y.%m.%d") 176 | 177 | # 查询总用量(排除 USAGE = -1) 178 | total_result = await self.execute("SELECT SUM(USAGE) FROM MSG WHERE USAGE != -1", fetchone=True) 179 | total_usage = total_result[0] if total_result and total_result[0] is not None else 0 180 | 181 | # 查询今日用量(按日期前缀匹配 TIME) 182 | today_result = await self.execute( 183 | "SELECT SUM(USAGE) FROM MSG WHERE USAGE != -1 AND TIME LIKE ?", 184 | (f"{today_str}%",), 185 | fetchone=True, 186 | ) 187 | today_usage = today_result[0] if today_result and today_result[0] is not None else 0 188 | 189 | return today_usage, total_usage 190 | 191 | async def get_conv_count(self) -> Tuple[int, int]: 192 | """ 193 | 获取对话次数(今日次数,总次数) 194 | 195 | :return: today_count, total_count 196 | """ 197 | today_str = datetime.now().strftime("%Y.%m.%d") 198 | 199 | total_result = await self.execute("SELECT COUNT(*) FROM MSG WHERE USAGE != -1", fetchone=True) 200 | total_count = total_result[0] if total_result and total_result[0] is not None else 0 201 | 202 | today_result = await self.execute( 203 | "SELECT COUNT(*) FROM MSG WHERE USAGE != -1 AND TIME LIKE ?", 204 | (f"{today_str}%",), 205 | fetchone=True, 206 | ) 207 | today_count = today_result[0] if today_result and today_result[0] is not None else 0 208 | 209 | return today_count, total_count 210 | 211 | async def remove_last_item(self, userid: str): 212 | """ 213 | 删除用户的最新一条对话历史 214 | 215 | :userid: 用户id 216 | """ 217 | query = "DELETE FROM MSG WHERE ID = (SELECT ID FROM MSG WHERE USERID = ? ORDER BY ID DESC LIMIT 1)" 218 | await self.execute(query, (userid,)) 219 | -------------------------------------------------------------------------------- /muicebot/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseLLM 2 | from ._config import ModelConfig 3 | from ._dependencies import MODEL_DEPENDENCY_MAP, get_missing_dependencies 4 | from ._schema import ModelCompletions, ModelRequest, ModelStreamCompletions 5 | from .loader import load_model 6 | from .registry import get_llm_class, register 7 | 8 | __all__ = [ 9 | "BaseLLM", 10 | "ModelConfig", 11 | "ModelRequest", 12 | "ModelCompletions", 13 | "ModelStreamCompletions", 14 | "MODEL_DEPENDENCY_MAP", 15 | "get_missing_dependencies", 16 | "register", 17 | "get_llm_class", 18 | "load_model", 19 | ] 20 | -------------------------------------------------------------------------------- /muicebot/llm/_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABCMeta, abstractmethod 4 | from typing import AsyncGenerator, Literal, Union, overload 5 | 6 | from ._config import ModelConfig 7 | from ._schema import ModelCompletions, ModelRequest, ModelStreamCompletions 8 | 9 | 10 | class BaseLLM(metaclass=ABCMeta): 11 | """ 12 | 模型基类,所有模型加载器都必须继承于该类 13 | 14 | 推荐使用该基类中定义的方法构建模型加载器类,但无论如何都必须实现 `ask` 方法 15 | """ 16 | 17 | def __init__(self, model_config: ModelConfig) -> None: 18 | """ 19 | 统一在此处声明变量 20 | """ 21 | self.config = model_config 22 | """模型配置""" 23 | self.is_running = False 24 | """模型状态""" 25 | self._total_tokens = -1 26 | """本次总请求(包括工具调用)使用的总token数。当此值设为-1时,表明此模型加载器不支持该功能""" 27 | 28 | def _require(self, *require_fields: str): 29 | """ 30 | 通用校验方法:检查指定的配置项是否存在,不存在则抛出错误 31 | 32 | :param require_fields: 需要检查的字段名称(字符串) 33 | """ 34 | missing_fields = [field for field in require_fields if not getattr(self.config, field, None)] 35 | if missing_fields: 36 | raise ValueError(f"对于 {self.config.loader} 以下配置是必需的: {', '.join(missing_fields)}") 37 | 38 | def _build_messages(self, request: "ModelRequest") -> list: 39 | """ 40 | 构建对话上下文历史的函数 41 | """ 42 | raise NotImplementedError 43 | 44 | def load(self) -> bool: 45 | """ 46 | 加载模型(通常是耗时操作,在线模型如无需校验可直接返回 true) 47 | 48 | :return: 是否加载成功 49 | """ 50 | self.is_running = True 51 | return True 52 | 53 | async def _ask_sync(self, messages: list) -> "ModelCompletions": 54 | """ 55 | 同步模型调用 56 | """ 57 | raise NotImplementedError 58 | 59 | def _ask_stream(self, messages: list) -> AsyncGenerator["ModelStreamCompletions", None]: 60 | """ 61 | 流式输出 62 | """ 63 | raise NotImplementedError 64 | 65 | @overload 66 | async def ask(self, request: "ModelRequest", *, stream: Literal[False] = False) -> "ModelCompletions": ... 67 | 68 | @overload 69 | async def ask( 70 | self, request: "ModelRequest", *, stream: Literal[True] = True 71 | ) -> AsyncGenerator["ModelStreamCompletions", None]: ... 72 | 73 | @abstractmethod 74 | async def ask( 75 | self, request: "ModelRequest", *, stream: bool = False 76 | ) -> Union["ModelCompletions", AsyncGenerator["ModelStreamCompletions", None]]: 77 | """ 78 | 模型交互询问 79 | 80 | :param request: 模型调用请求体 81 | :param stream: 是否开启流式对话 82 | 83 | :return: 模型输出体 84 | """ 85 | pass 86 | -------------------------------------------------------------------------------- /muicebot/llm/_config.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | from typing import Any, List, Literal, Optional 3 | 4 | from pydantic import BaseModel, field_validator 5 | 6 | 7 | class ModelConfig(BaseModel): 8 | loader: str = "" 9 | """所使用加载器的名称,位于 llm 文件夹下,loader 开头必须大写""" 10 | 11 | template: Optional[str] = None 12 | """使用的人设模板名称""" 13 | template_mode: Literal["system", "user"] = "system" 14 | """模板嵌入模式: `system` 为嵌入到系统提示; `user` 为嵌入到用户提示中""" 15 | 16 | max_tokens: int = 4096 17 | """最大回复 Tokens """ 18 | temperature: float = 0.75 19 | """模型的温度系数""" 20 | top_p: float = 0.95 21 | """模型的 top_p 系数""" 22 | top_k: float = 3 23 | """模型的 top_k 系数""" 24 | frequency_penalty: Optional[float] = None 25 | """模型的频率惩罚""" 26 | presence_penalty: Optional[float] = None 27 | """模型的存在惩罚""" 28 | repetition_penalty: Optional[float] = None 29 | """模型的重复惩罚""" 30 | stream: bool = False 31 | """是否使用流式输出""" 32 | online_search: bool = False 33 | """是否启用联网搜索(原生实现)""" 34 | function_call: bool = False 35 | """是否启用工具调用""" 36 | content_security: bool = False 37 | """是否启用内容安全""" 38 | 39 | model_path: str = "" 40 | """本地模型路径""" 41 | adapter_path: str = "" 42 | """基于 model_path 的微调模型或适配器路径""" 43 | 44 | model_name: str = "" 45 | """所要使用模型的名称""" 46 | api_key: str = "" 47 | """在线服务的 API KEY""" 48 | api_secret: str = "" 49 | """在线服务的 api secret """ 50 | api_host: str = "" 51 | """自定义 API 地址""" 52 | 53 | extra_body: Optional[dict] = None 54 | """OpenAI 的 extra_body""" 55 | enable_thinking: Optional[bool] = None 56 | """Dashscope 的 enable_thinking""" 57 | thinking_budget: Optional[int] = None 58 | """Dashscope 的 thinking_budget""" 59 | 60 | multimodal: bool = False 61 | """是否为(或启用)多模态模型""" 62 | modalities: List[Literal["text", "audio", "image"]] = ["text"] 63 | """生成模态""" 64 | audio: Optional[Any] = None 65 | """多模态音频参数""" 66 | 67 | @field_validator("loader") 68 | @classmethod 69 | def check_model_loader(cls, loader: str) -> str: 70 | if not loader: 71 | raise ValueError("loader is required") 72 | 73 | loader = loader.lower() 74 | 75 | # Check if the specified loader exists 76 | module_path = f"muicebot.llm.providers.{loader}" 77 | 78 | # 使用 find_spec 仅检测模块是否存在,不实际导入 79 | if find_spec(module_path) is None: 80 | raise ValueError(f"指定的模型加载器 '{loader}' 不存在于 llm 目录中") 81 | 82 | return loader 83 | -------------------------------------------------------------------------------- /muicebot/llm/_dependencies.py: -------------------------------------------------------------------------------- 1 | """ 2 | 模型所需第三方库依赖检查 3 | """ 4 | 5 | import importlib 6 | 7 | MODEL_DEPENDENCY_MAP = { 8 | "Azure": ["azure-ai-inference>=1.0.0b7"], 9 | "Dashscope": ["dashscope>=1.22.1"], 10 | "Gemini": ["google-genai>=1.8.0"], 11 | "Ollama": ["ollama>=0.4.7"], 12 | "Openai": ["openai>=1.64.0"], 13 | } 14 | 15 | 16 | def get_missing_dependencies(dependencies: list[str]) -> list[str]: 17 | """ 18 | 获取遗失的依赖 19 | """ 20 | missing = [] 21 | for dep in dependencies: 22 | try: 23 | importlib.import_module(dep) 24 | except ImportError: 25 | missing.append(dep) 26 | return missing 27 | -------------------------------------------------------------------------------- /muicebot/llm/_schema.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | from ..models import Message, Resource 5 | 6 | 7 | @dataclass 8 | class ModelRequest: 9 | """ 10 | 模型调用请求 11 | """ 12 | 13 | prompt: str 14 | history: List[Message] = field(default_factory=list) 15 | resources: List[Resource] = field(default_factory=list) 16 | tools: Optional[List[dict]] = field(default_factory=list) 17 | system: Optional[str] = None 18 | 19 | 20 | @dataclass 21 | class ModelCompletions: 22 | """ 23 | 模型输出 24 | """ 25 | 26 | text: str = "" 27 | usage: int = -1 28 | resources: List[Resource] = field(default_factory=list) 29 | succeed: Optional[bool] = True 30 | 31 | 32 | @dataclass 33 | class ModelStreamCompletions: 34 | """ 35 | 模型流式输出 36 | """ 37 | 38 | chunk: str = "" 39 | usage: int = -1 40 | resources: Optional[List[Resource]] = field(default_factory=list) 41 | succeed: Optional[bool] = True 42 | -------------------------------------------------------------------------------- /muicebot/llm/loader.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | from ._base import BaseLLM 4 | from ._config import ModelConfig 5 | from .registry import get_llm_class 6 | 7 | 8 | def load_model(config: ModelConfig) -> BaseLLM: 9 | """ 10 | 获得一个 LLM 实例 11 | """ 12 | module_name = config.loader.lower() 13 | module_path = f"muicebot.llm.providers.{module_name}" 14 | 15 | # 延迟导入模型模块(只导一次) 16 | importlib.import_module(module_path) 17 | 18 | # 注册之后,直接取类使用 19 | LLMClass = get_llm_class(config.loader) 20 | 21 | return LLMClass(config) 22 | -------------------------------------------------------------------------------- /muicebot/llm/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Moemu/MuiceBot/b1140ecc4eaf72825d4dce083a3705265149529c/muicebot/llm/providers/__init__.py -------------------------------------------------------------------------------- /muicebot/llm/providers/azure.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import AsyncGenerator, List, Literal, Optional, Union, overload 4 | 5 | from azure.ai.inference.aio import ChatCompletionsClient 6 | from azure.ai.inference.models import ( 7 | AssistantMessage, 8 | AudioContentItem, 9 | ChatCompletionsToolCall, 10 | ChatCompletionsToolDefinition, 11 | ChatRequestMessage, 12 | CompletionsFinishReason, 13 | ContentItem, 14 | FunctionCall, 15 | FunctionDefinition, 16 | ImageContentItem, 17 | ImageDetailLevel, 18 | ImageUrl, 19 | InputAudio, 20 | SystemMessage, 21 | TextContentItem, 22 | ToolMessage, 23 | UserMessage, 24 | ) 25 | from azure.core.credentials import AzureKeyCredential 26 | from azure.core.exceptions import HttpResponseError 27 | from nonebot import logger 28 | 29 | from .. import ( 30 | BaseLLM, 31 | ModelCompletions, 32 | ModelConfig, 33 | ModelRequest, 34 | ModelStreamCompletions, 35 | register, 36 | ) 37 | from ..utils.tools import function_call_handler 38 | 39 | 40 | @register("azure") 41 | class Azure(BaseLLM): 42 | def __init__(self, model_config: ModelConfig) -> None: 43 | super().__init__(model_config) 44 | self._require("model_name") 45 | self.model_name = self.config.model_name 46 | self.max_tokens = self.config.max_tokens 47 | self.temperature = self.config.temperature 48 | self.top_p = self.config.top_p 49 | self.frequency_penalty = self.config.frequency_penalty 50 | self.presence_penalty = self.config.presence_penalty 51 | self.token = os.getenv("AZURE_API_KEY", self.config.api_key) 52 | self.endpoint = self.config.api_host if self.config.api_host else "https://models.inference.ai.azure.com" 53 | 54 | self._tools: List[ChatCompletionsToolDefinition] = [] 55 | 56 | def __build_multi_messages(self, request: ModelRequest) -> UserMessage: 57 | """ 58 | 构建多模态类型 59 | 60 | 此模型加载器支持的多模态类型: `audio` `image` 61 | """ 62 | multi_content_items: List[ContentItem] = [] 63 | 64 | for resource in request.resources: 65 | if resource.path is None: 66 | continue 67 | elif resource.type == "audio": 68 | multi_content_items.append( 69 | AudioContentItem( 70 | input_audio=InputAudio.load(audio_file=resource.path, audio_format=resource.path.split(".")[-1]) 71 | ) 72 | ) 73 | elif resource.type == "image": 74 | multi_content_items.append( 75 | ImageContentItem( 76 | image_url=ImageUrl.load( 77 | image_file=resource.path, 78 | image_format=resource.path.split(".")[-1], 79 | detail=ImageDetailLevel.AUTO, 80 | ) 81 | ) 82 | ) 83 | 84 | content = [TextContentItem(text=request.prompt)] + multi_content_items 85 | 86 | return UserMessage(content=content) 87 | 88 | def __build_tools_definition(self, tools: List[dict]) -> List[ChatCompletionsToolDefinition]: 89 | tool_definitions = [] 90 | 91 | for tool in tools: 92 | tool_definition = ChatCompletionsToolDefinition( 93 | function=FunctionDefinition( 94 | name=tool["function"]["name"], 95 | description=tool["function"]["description"], 96 | parameters=tool["function"]["parameters"], 97 | ) 98 | ) 99 | tool_definitions.append(tool_definition) 100 | 101 | return tool_definitions 102 | 103 | def _build_messages(self, request: ModelRequest) -> List[ChatRequestMessage]: 104 | messages: List[ChatRequestMessage] = [] 105 | 106 | if request.system: 107 | messages.append(SystemMessage(request.system)) 108 | 109 | for msg in request.history: 110 | user_msg = ( 111 | UserMessage(msg.message) 112 | if not msg.resources 113 | else self.__build_multi_messages(ModelRequest(msg.message, resources=msg.resources)) 114 | ) 115 | messages.append(user_msg) 116 | messages.append(AssistantMessage(msg.respond)) 117 | 118 | user_message = UserMessage(request.prompt) if not request.resources else self.__build_multi_messages(request) 119 | 120 | messages.append(user_message) 121 | 122 | return messages 123 | 124 | def _tool_messages_precheck(self, tool_calls: Optional[List[ChatCompletionsToolCall]] = None) -> bool: 125 | if not (tool_calls and len(tool_calls) == 1): 126 | return False 127 | 128 | tool_call = tool_calls[0] 129 | 130 | if isinstance(tool_call, ChatCompletionsToolCall): 131 | return True 132 | 133 | return False 134 | 135 | async def _ask_sync(self, messages: List[ChatRequestMessage]) -> ModelCompletions: 136 | client = ChatCompletionsClient(endpoint=self.endpoint, credential=AzureKeyCredential(self.token)) 137 | 138 | completions = ModelCompletions() 139 | 140 | try: 141 | response = await client.complete( 142 | messages=messages, 143 | model=self.model_name, 144 | max_tokens=self.max_tokens, 145 | temperature=self.temperature, 146 | top_p=self.top_p, 147 | frequency_penalty=self.frequency_penalty, 148 | presence_penalty=self.presence_penalty, 149 | stream=False, 150 | tools=self._tools, 151 | ) 152 | finish_reason = response.choices[0].finish_reason 153 | self._total_tokens += response.usage.total_tokens 154 | 155 | if finish_reason == CompletionsFinishReason.STOPPED: 156 | completions.text = response.choices[0].message.content 157 | 158 | elif finish_reason == CompletionsFinishReason.CONTENT_FILTERED: 159 | completions.succeed = False 160 | completions.text = "(模型内部错误: 被内容过滤器阻止)" 161 | 162 | elif finish_reason == CompletionsFinishReason.TOKEN_LIMIT_REACHED: 163 | completions.succeed = False 164 | completions.text = "(模型内部错误: 达到了最大 token 限制)" 165 | 166 | elif finish_reason == CompletionsFinishReason.TOOL_CALLS: 167 | tool_calls = response.choices[0].message.tool_calls 168 | messages.append(AssistantMessage(tool_calls=tool_calls)) 169 | if (tool_calls is None) or (not self._tool_messages_precheck(tool_calls=tool_calls)): 170 | completions.succeed = False 171 | completions.text = "(模型内部错误: tool_calls 内容为空)" 172 | return completions 173 | 174 | tool_call = tool_calls[0] 175 | function_args = json.loads(tool_call.function.arguments.replace("'", '"')) 176 | 177 | function_return = await function_call_handler(tool_call.function.name, function_args) 178 | 179 | # Append the function call result fo the chat history 180 | messages.append(ToolMessage(tool_call_id=tool_call.id, content=function_return)) 181 | 182 | return await self._ask_sync(messages) 183 | 184 | else: 185 | completions.succeed = False 186 | completions.text = "(模型内部错误: 达到了最大 token 限制)" 187 | 188 | except HttpResponseError as e: 189 | logger.error(f"模型响应失败: {e.status_code} ({e.reason})") 190 | logger.error(f"{e.message}") 191 | completions.succeed = False 192 | completions.text = f"模型响应失败: {e.status_code} ({e.reason})" 193 | 194 | finally: 195 | await client.close() 196 | completions.usage = self._total_tokens 197 | return completions 198 | 199 | async def _ask_stream(self, messages: List[ChatRequestMessage]) -> AsyncGenerator[ModelStreamCompletions, None]: 200 | client = ChatCompletionsClient(endpoint=self.endpoint, credential=AzureKeyCredential(self.token)) 201 | 202 | try: 203 | response = await client.complete( 204 | messages=messages, 205 | model=self.model_name, 206 | max_tokens=self.max_tokens, 207 | temperature=self.temperature, 208 | top_p=self.top_p, 209 | frequency_penalty=self.frequency_penalty, 210 | presence_penalty=self.presence_penalty, 211 | stream=True, 212 | tools=self._tools, 213 | model_extras={"stream_options": {"include_usage": True}}, # 需要显式声明获取用量 214 | ) 215 | 216 | tool_call_id: str = "" 217 | function_name: str = "" 218 | function_args: str = "" 219 | 220 | async for chunk in response: 221 | stream_completions = ModelStreamCompletions() 222 | 223 | if chunk.usage: # chunk.usage 只会在最后一个包中被提供,此时choices为空 224 | self._total_tokens += chunk.usage.total_tokens if chunk.usage else 0 225 | stream_completions.usage = self._total_tokens 226 | 227 | if not chunk.choices: 228 | yield stream_completions 229 | continue 230 | 231 | finish_reason = chunk.choices[0].finish_reason 232 | 233 | if chunk.choices and chunk.choices[0].get("delta", {}).get("content", ""): 234 | stream_completions.chunk = chunk["choices"][0]["delta"]["content"] 235 | 236 | elif chunk.choices[0].delta.tool_calls is not None: 237 | tool_call = chunk.choices[0].delta.tool_calls[0] 238 | 239 | if tool_call.function.name is not None: 240 | function_name = tool_call.function.name 241 | if tool_call.id is not None: 242 | tool_call_id = tool_call.id 243 | function_args += tool_call.function.arguments or "" 244 | continue 245 | 246 | elif finish_reason == CompletionsFinishReason.CONTENT_FILTERED: 247 | stream_completions.succeed = False 248 | stream_completions.chunk = "(模型内部错误: 被内容过滤器阻止)" 249 | 250 | elif finish_reason == CompletionsFinishReason.TOKEN_LIMIT_REACHED: 251 | stream_completions.succeed = False 252 | stream_completions.chunk = "(模型内部错误: 达到了最大 token 限制)" 253 | 254 | elif finish_reason == CompletionsFinishReason.TOOL_CALLS: 255 | messages.append( 256 | AssistantMessage( 257 | tool_calls=[ 258 | ChatCompletionsToolCall( 259 | id=tool_call_id, function=FunctionCall(name=function_name, arguments=function_args) 260 | ) 261 | ] 262 | ) 263 | ) 264 | 265 | function_arg = json.loads(function_args.replace("'", '"')) 266 | 267 | function_return = await function_call_handler(function_name, function_arg) 268 | 269 | # Append the function call result fo the chat history 270 | messages.append(ToolMessage(tool_call_id=tool_call_id, content=function_return)) 271 | 272 | async for content in self._ask_stream(messages): 273 | yield content 274 | 275 | return 276 | 277 | yield stream_completions 278 | 279 | except HttpResponseError as e: 280 | logger.error(f"模型响应失败: {e.status_code} ({e.reason})") 281 | logger.error(f"{e.message}") 282 | stream_completions = ModelStreamCompletions() 283 | stream_completions.chunk = f"模型响应失败: {e.status_code} ({e.reason})" 284 | stream_completions.succeed = False 285 | yield stream_completions 286 | 287 | finally: 288 | await client.close() 289 | 290 | @overload 291 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ... 292 | 293 | @overload 294 | async def ask( 295 | self, request: ModelRequest, *, stream: Literal[True] = True 296 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ... 297 | 298 | async def ask( 299 | self, request: ModelRequest, *, stream: bool = False 300 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]: 301 | self._total_tokens = 0 302 | 303 | messages = self._build_messages(request) 304 | 305 | self._tools = self.__build_tools_definition(request.tools) if request.tools else [] 306 | 307 | if stream: 308 | return self._ask_stream(messages) 309 | 310 | return await self._ask_sync(messages) 311 | -------------------------------------------------------------------------------- /muicebot/llm/providers/dashscope.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from dataclasses import dataclass 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import ( 7 | AsyncGenerator, 8 | Generator, 9 | List, 10 | Literal, 11 | Union, 12 | overload, 13 | ) 14 | 15 | import dashscope 16 | from dashscope.api_entities.dashscope_response import ( 17 | GenerationResponse, 18 | MultiModalConversationResponse, 19 | ) 20 | from nonebot import logger 21 | 22 | from .. import ( 23 | BaseLLM, 24 | ModelCompletions, 25 | ModelConfig, 26 | ModelRequest, 27 | ModelStreamCompletions, 28 | register, 29 | ) 30 | from ..utils.tools import function_call_handler 31 | 32 | 33 | @dataclass 34 | class FunctionCallStream: 35 | enable: bool = False 36 | id: str = "" 37 | function_name: str = "" 38 | function_args: str = "" 39 | 40 | def from_chunk(self, chunk: GenerationResponse | MultiModalConversationResponse): 41 | tool_calls = chunk.output.choices[0].message.tool_calls 42 | tool_call = tool_calls[0] 43 | 44 | if tool_call.get("id", ""): 45 | self.id = tool_call["id"] 46 | 47 | if tool_call.get("function", {}).get("name", ""): 48 | self.function_name = tool_call.get("function").get("name") 49 | 50 | function_arg = tool_call.get("function", {}).get("arguments", "") 51 | 52 | if function_arg and self.function_args != function_arg: 53 | self.function_args += function_arg 54 | 55 | self.enable = True 56 | 57 | 58 | class ThoughtStream: 59 | def __init__(self): 60 | self.is_insert_think_label: bool = False 61 | 62 | def process_chunk(self, chunk: GenerationResponse | MultiModalConversationResponse) -> str: 63 | choice = chunk.output.choices[0].message 64 | answer_content = choice.content 65 | reasoning_content = choice.get("reasoning_content", "") 66 | reasoning_content = reasoning_content.replace("\n", "") if reasoning_content else "" 67 | 68 | # 处理模型可能输出的 reasoning(思考内容) 69 | if reasoning_content: 70 | if not self.is_insert_think_label: 71 | self.is_insert_think_label = True 72 | return f"{reasoning_content}" 73 | else: 74 | return reasoning_content 75 | 76 | if not answer_content: 77 | answer_content = "" 78 | 79 | if isinstance(answer_content, list): 80 | answer_content = answer_content[0].get("text", "") 81 | 82 | if self.is_insert_think_label: 83 | self.is_insert_think_label = False 84 | return f"{answer_content}" 85 | 86 | return answer_content 87 | 88 | 89 | @register("dashscope") 90 | class Dashscope(BaseLLM): 91 | def __init__(self, model_config: ModelConfig) -> None: 92 | super().__init__(model_config) 93 | self._require("api_key", "model_name") 94 | self.api_key = self.config.api_key 95 | self.model = self.config.model_name 96 | self.max_tokens = self.config.max_tokens 97 | self.temperature = self.config.temperature 98 | self.top_p = self.config.top_p 99 | self.repetition_penalty = self.config.repetition_penalty 100 | self.enable_search = self.config.online_search 101 | self.enable_thinking = self.config.enable_thinking 102 | self.thinking_budget = self.config.thinking_budget 103 | 104 | self._tools: List[dict] = [] 105 | self._last_call_total_tokens = 0 106 | 107 | self.extra_headers = ( 108 | {"X-DashScope-DataInspection": '{"input":"cip","output":"cip"}'} if self.config.content_security else {} 109 | ) 110 | 111 | self.stream = False 112 | 113 | def __build_multi_messages(self, request: ModelRequest) -> dict: 114 | """ 115 | 构建多模态类型 116 | 117 | 此模型加载器支持的多模态类型: `audio` `image` 118 | """ 119 | multi_contents: List[dict[str, str]] = [] 120 | 121 | for item in request.resources: 122 | if item.type == "audio": 123 | if not (item.path.startswith("http") or item.path.startswith("file:")): 124 | item.url = str(Path(item.path).resolve()) 125 | 126 | multi_contents.append({"audio": item.path}) 127 | 128 | elif item.type == "image": 129 | if not (item.path.startswith("http") or item.path.startswith("file:")): 130 | item.url = str(Path(item.path).resolve()) 131 | 132 | multi_contents.append({"image": item.path}) 133 | 134 | user_content = [image_content for image_content in multi_contents] 135 | 136 | if not request.prompt: 137 | request.prompt = "请描述图像内容" 138 | user_content.append({"text": request.prompt}) 139 | 140 | return {"role": "user", "content": user_content} 141 | 142 | def _build_messages(self, request: ModelRequest) -> list: 143 | messages = [] 144 | 145 | if request.system: 146 | messages.append({"role": "system", "content": request.system}) 147 | 148 | for msg in request.history: 149 | user_msg = ( 150 | self.__build_multi_messages(ModelRequest(msg.message, resources=msg.resources)) 151 | if all((self.config.multimodal, msg.resources)) 152 | else {"role": "user", "content": msg.message} 153 | ) 154 | messages.append(user_msg) 155 | messages.append({"role": "assistant", "content": msg.respond}) 156 | 157 | user_msg = ( 158 | {"role": "user", "content": request.prompt} 159 | if not request.resources 160 | else self.__build_multi_messages(ModelRequest(request.prompt, resources=request.resources)) 161 | ) 162 | 163 | messages.append(user_msg) 164 | 165 | return messages 166 | 167 | async def _GenerationResponse_handle( 168 | self, messages: list, response: GenerationResponse | MultiModalConversationResponse 169 | ) -> ModelCompletions: 170 | completions = ModelCompletions() 171 | 172 | if response.status_code != 200: 173 | completions.succeed = False 174 | logger.error(f"模型调用失败: {response.status_code}({response.code})") 175 | logger.error(f"{response.message}") 176 | completions.text = f"模型调用失败: {response.status_code}({response.code})" 177 | return completions 178 | 179 | self._total_tokens += int(response.usage.total_tokens) 180 | completions.usage = self._total_tokens 181 | 182 | if response.output.text: 183 | completions.text = response.output.text 184 | return completions 185 | 186 | message_content = response.output.choices[0].message.content 187 | if message_content: 188 | completions.text = message_content if isinstance(message_content, str) else message_content[0].get("text") 189 | return completions 190 | 191 | return await self._tool_calls_handle_sync(messages, response) 192 | 193 | async def _Generator_handle( 194 | self, 195 | messages: list, 196 | response: Generator[GenerationResponse, None, None] | Generator[MultiModalConversationResponse, None, None], 197 | ) -> AsyncGenerator[ModelStreamCompletions, None]: 198 | func_stream = FunctionCallStream() 199 | thought_stream = ThoughtStream() 200 | 201 | for chunk in response: 202 | logger.debug(chunk) 203 | stream_completions = ModelStreamCompletions() 204 | 205 | if chunk.status_code != 200: 206 | logger.error(f"模型调用失败: {chunk.status_code}({chunk.code})") 207 | logger.error(f"{chunk.message}") 208 | stream_completions.chunk = f"模型调用失败: {chunk.status_code}({chunk.code})" 209 | stream_completions.succeed = False 210 | 211 | yield stream_completions 212 | return 213 | 214 | # 更新 token 消耗 215 | current_call_total = chunk.usage.total_tokens 216 | delta = current_call_total - self._last_call_total_tokens 217 | if delta < 0: 218 | delta = current_call_total 219 | self._total_tokens += delta 220 | self._last_call_total_tokens = current_call_total 221 | stream_completions.usage = self._total_tokens 222 | 223 | # 优先判断是否是工具调用(OpenAI-style function calling) 224 | if chunk.output.choices and chunk.output.choices[0].message.get("tool_calls", []): 225 | func_stream.from_chunk(chunk) 226 | # 工具调用也可能在输出文本之后发生 227 | 228 | # DashScope 的 text 模式(非标准接口) 229 | if hasattr(chunk.output, "text") and chunk.output.text: 230 | stream_completions.chunk = chunk.output.text 231 | yield stream_completions 232 | continue 233 | 234 | if chunk.output.choices is None: 235 | continue 236 | 237 | stream_completions.chunk = thought_stream.process_chunk(chunk) 238 | yield stream_completions 239 | 240 | # 流式处理工具调用响应 241 | if func_stream.enable: 242 | self._last_call_total_tokens = 0 243 | async for final_chunk in await self._tool_calls_handle_stream(messages, func_stream): 244 | yield final_chunk 245 | 246 | async def _tool_calls_handle_sync( 247 | self, messages: List, response: GenerationResponse | MultiModalConversationResponse 248 | ) -> ModelCompletions: 249 | tool_call = response.output.choices[0].message.tool_calls[0] 250 | tool_call_id = tool_call["id"] 251 | function_name = tool_call["function"]["name"] 252 | function_args = json.loads(tool_call["function"]["arguments"]) 253 | 254 | function_return = await function_call_handler(function_name, function_args) 255 | 256 | messages.append(response.output.choices[0].message) 257 | messages.append({"role": "tool", "content": function_return, "tool_call_id": tool_call_id}) 258 | 259 | return await self._ask(messages) # type:ignore 260 | 261 | async def _tool_calls_handle_stream( 262 | self, messages: List, func_stream: FunctionCallStream 263 | ) -> AsyncGenerator[ModelStreamCompletions, None]: 264 | function_args = json.loads(func_stream.function_args) 265 | 266 | function_return = await function_call_handler(func_stream.function_name, function_args) # type:ignore 267 | 268 | messages.append( 269 | { 270 | "role": "assistant", 271 | "content": "", 272 | "tool_calls": [ 273 | { 274 | "id": func_stream.id, 275 | "function": { 276 | "arguments": func_stream.function_args, 277 | "name": func_stream.function_name, 278 | }, 279 | "type": "function", 280 | "index": 0, 281 | } 282 | ], 283 | } 284 | ) 285 | messages.append({"role": "tool", "content": function_return, "tool_call_id": func_stream.id}) 286 | 287 | return await self._ask(messages) # type:ignore 288 | 289 | async def _ask(self, messages: list) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]: 290 | loop = asyncio.get_event_loop() 291 | 292 | # 因为 Dashscope 对于多模态模型的接口不同,所以这里不能统一函数 293 | if not self.config.multimodal: 294 | response = await loop.run_in_executor( 295 | None, 296 | partial( 297 | dashscope.Generation.call, 298 | api_key=self.api_key, 299 | model=self.model, 300 | messages=messages, 301 | max_tokens=self.max_tokens, 302 | temperature=self.temperature, 303 | top_p=self.top_p, 304 | repetition_penalty=self.repetition_penalty, 305 | stream=self.stream, 306 | tools=self._tools, 307 | parallel_tool_calls=True, 308 | enable_search=self.enable_search, 309 | incremental_output=self.stream, # 给他调成一样的:这个参数只支持流式调用时设置为True 310 | headers=self.extra_headers, 311 | enable_thinking=self.enable_thinking, 312 | thinking_budget=self.thinking_budget, 313 | ), 314 | ) 315 | else: 316 | response = await loop.run_in_executor( 317 | None, 318 | partial( 319 | dashscope.MultiModalConversation.call, 320 | api_key=self.api_key, 321 | model=self.model, 322 | messages=messages, 323 | max_tokens=self.max_tokens, 324 | temperature=self.temperature, 325 | top_p=self.top_p, 326 | repetition_penalty=self.repetition_penalty, 327 | stream=self.stream, 328 | tools=self._tools, 329 | parallel_tool_calls=True, 330 | enable_search=self.enable_search, 331 | incremental_output=self.stream, 332 | ), 333 | ) 334 | 335 | if isinstance(response, GenerationResponse) or isinstance(response, MultiModalConversationResponse): 336 | return await self._GenerationResponse_handle(messages, response) 337 | return self._Generator_handle(messages, response) 338 | 339 | @overload 340 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ... 341 | 342 | @overload 343 | async def ask( 344 | self, request: ModelRequest, *, stream: Literal[True] = True 345 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ... 346 | 347 | async def ask( 348 | self, request: ModelRequest, *, stream: bool = False 349 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]: 350 | self._total_tokens = 0 351 | self._last_call_total_tokens = 0 352 | self.stream = stream if stream is not None else False 353 | 354 | self._tools = request.tools if request.tools else [] 355 | messages = self._build_messages(request) 356 | 357 | return await self._ask(messages) 358 | -------------------------------------------------------------------------------- /muicebot/llm/providers/gemini.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, List, Literal, Optional, Union, overload 2 | 3 | from google import genai 4 | from google.genai import errors 5 | from google.genai.types import ( 6 | Content, 7 | ContentOrDict, 8 | GenerateContentConfig, 9 | GoogleSearch, 10 | HarmBlockThreshold, 11 | HarmCategory, 12 | Part, 13 | SafetySetting, 14 | Tool, 15 | ) 16 | from httpx import ConnectError 17 | from nonebot import logger 18 | 19 | from muicebot.models import Resource 20 | 21 | from .. import ( 22 | BaseLLM, 23 | ModelCompletions, 24 | ModelConfig, 25 | ModelRequest, 26 | ModelStreamCompletions, 27 | register, 28 | ) 29 | from ..utils.images import get_file_base64 30 | from ..utils.tools import function_call_handler 31 | 32 | 33 | @register("gemini") 34 | class Gemini(BaseLLM): 35 | def __init__(self, model_config: ModelConfig) -> None: 36 | super().__init__(model_config) 37 | self._require("model_name", "api_key") 38 | 39 | self.model_name = self.config.model_name 40 | self.api_key = self.config.api_key 41 | self.enable_search = self.config.online_search 42 | 43 | self.client = genai.Client(api_key=self.api_key) 44 | 45 | self.gemini_config = GenerateContentConfig( 46 | temperature=self.config.temperature, 47 | top_p=self.config.top_p, 48 | top_k=self.config.top_k, 49 | max_output_tokens=self.config.max_tokens, 50 | presence_penalty=self.config.presence_penalty, 51 | frequency_penalty=self.config.frequency_penalty, 52 | response_modalities=[m.upper() for m in self.config.modalities if m in {"image", "text"}], 53 | safety_settings=( 54 | [ 55 | SafetySetting( 56 | category=HarmCategory.HARM_CATEGORY_HARASSMENT, 57 | threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, 58 | ), 59 | SafetySetting( 60 | category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, 61 | threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, 62 | ), 63 | SafetySetting( 64 | category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, 65 | threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, 66 | ), 67 | SafetySetting( 68 | category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, 69 | threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, 70 | ), 71 | ] 72 | if self.config.content_security 73 | else [] 74 | ), 75 | ) 76 | 77 | self.model = self.client.chats.create(model=self.model_name, config=self.gemini_config) 78 | 79 | def __build_user_parts(self, request: ModelRequest) -> list[Part]: 80 | user_parts: list[Part] = [Part.from_text(text=request.prompt)] 81 | 82 | if not request.resources: 83 | return user_parts 84 | 85 | for resource in request.resources: 86 | if resource.type == "image" and resource.path is not None: 87 | resource.ensure_mimetype() 88 | user_parts.append( 89 | Part.from_bytes( 90 | data=get_file_base64(resource.path), mime_type=resource.mimetype or "image/jpeg" # type:ignore 91 | ) 92 | ) 93 | 94 | return user_parts 95 | 96 | def __build_tools_list(self, tools: Optional[List] = []): 97 | format_tools = [] 98 | 99 | for tool in tools if tools else []: 100 | tool = tool["function"] 101 | required_parameters = tool["required"] 102 | del tool["required"] 103 | tool["parameters"]["required"] = required_parameters 104 | format_tools.append(tool) 105 | 106 | function_tools = Tool(function_declarations=format_tools) # type:ignore 107 | 108 | if self.enable_search: 109 | function_tools.google_search = GoogleSearch() 110 | 111 | if tools or self.enable_search: 112 | self.gemini_config.tools = [function_tools] 113 | 114 | def _build_messages(self, request: ModelRequest) -> list[ContentOrDict]: 115 | messages: List[ContentOrDict] = [] 116 | 117 | if request.history: 118 | for index, item in enumerate(request.history): 119 | messages.append( 120 | Content( 121 | role="user", parts=self.__build_user_parts(ModelRequest(item.message, resources=item.resources)) 122 | ) 123 | ) 124 | messages.append(Content(role="model", parts=[Part.from_text(text=item.respond)])) 125 | 126 | messages.append(Content(role="user", parts=self.__build_user_parts(request))) 127 | 128 | return messages 129 | 130 | async def _ask_sync(self, messages: list[ContentOrDict], **kwargs) -> ModelCompletions: 131 | completions = ModelCompletions() 132 | 133 | try: 134 | chat = self.client.aio.chats.create(model=self.model_name, config=self.gemini_config, history=messages[:-1]) 135 | message = messages[-1].parts # type:ignore 136 | response = await chat.send_message(message=message) # type:ignore 137 | if response.usage_metadata: 138 | total_token_count = response.usage_metadata.total_token_count 139 | self._total_tokens += total_token_count if total_token_count else -1 140 | 141 | if response.text: 142 | completions.text = response.text 143 | 144 | if ( 145 | response.candidates 146 | and response.candidates[0].content 147 | and response.candidates[0].content.parts 148 | and response.candidates[0].content.parts[0].inline_data 149 | and response.candidates[0].content.parts[0].inline_data.data 150 | ): 151 | completions.resources = [ 152 | Resource(type="image", raw=response.candidates[0].content.parts[0].inline_data.data) 153 | ] 154 | 155 | if response.function_calls: 156 | function_call = response.function_calls[0] 157 | function_name = function_call.name 158 | function_args = function_call.args 159 | 160 | function_return = await function_call_handler(function_name, function_args) # type:ignore 161 | 162 | function_response_part = Part.from_function_response( 163 | name=function_name, # type:ignore 164 | response={"result": function_return}, 165 | ) 166 | 167 | messages.append(Content(role="model", parts=[Part(function_call=function_call)])) 168 | messages.append(Content(role="user", parts=[function_response_part])) 169 | 170 | return await self._ask_sync(messages) 171 | 172 | completions.text = completions.text or "(警告:模型无输出!)" 173 | completions.usage = self._total_tokens 174 | return completions 175 | 176 | except errors.APIError as e: 177 | error_message = f"API 状态异常: {e.code}({e.response})" 178 | completions.text = error_message 179 | completions.succeed = False 180 | logger.error(error_message) 181 | logger.error(e.message) 182 | return completions 183 | 184 | except ConnectError: 185 | error_message = "模型加载器连接超时" 186 | completions.text = error_message 187 | completions.succeed = False 188 | logger.error(error_message) 189 | return completions 190 | 191 | async def _ask_stream(self, messages: list, **kwargs) -> AsyncGenerator[ModelStreamCompletions, None]: 192 | try: 193 | total_tokens = 0 194 | stream = await self.client.aio.models.generate_content_stream( 195 | model=self.model_name, contents=messages, config=self.gemini_config 196 | ) 197 | async for chunk in stream: 198 | stream_completions = ModelStreamCompletions() 199 | 200 | if chunk.text: 201 | stream_completions.chunk = chunk.text 202 | yield stream_completions 203 | 204 | if chunk.usage_metadata and chunk.usage_metadata.total_token_count: 205 | total_tokens = chunk.usage_metadata.total_token_count 206 | 207 | if ( 208 | chunk.candidates 209 | and chunk.candidates[0].content 210 | and chunk.candidates[0].content.parts 211 | and chunk.candidates[0].content.parts[0].inline_data 212 | and chunk.candidates[0].content.parts[0].inline_data.data 213 | ): 214 | stream_completions.resources = [ 215 | Resource(type="image", raw=chunk.candidates[0].content.parts[0].inline_data.data) 216 | ] 217 | yield stream_completions 218 | 219 | if chunk.function_calls: 220 | function_call = chunk.function_calls[0] 221 | function_name = function_call.name 222 | function_args = function_call.args 223 | 224 | function_return = await function_call_handler(function_name, function_args) # type:ignore 225 | 226 | function_response_part = Part.from_function_response( 227 | name=function_name, # type:ignore 228 | response={"result": function_return}, 229 | ) 230 | 231 | messages.append(Content(role="model", parts=[Part(function_call=function_call)])) 232 | messages.append(Content(role="user", parts=[function_response_part])) 233 | 234 | async for final_chunk in self._ask_stream(messages): 235 | yield final_chunk 236 | 237 | totaltokens_completions = ModelStreamCompletions() 238 | 239 | self._total_tokens += total_tokens 240 | totaltokens_completions.usage = self._total_tokens 241 | yield totaltokens_completions 242 | 243 | except errors.APIError as e: 244 | stream_completions = ModelStreamCompletions() 245 | error_message = f"API 状态异常: {e.code}({e.response})" 246 | stream_completions.chunk = error_message 247 | logger.error(error_message) 248 | logger.error(e.message) 249 | stream_completions.succeed = False 250 | yield stream_completions 251 | return 252 | 253 | except ConnectError: 254 | stream_completions = ModelStreamCompletions() 255 | error_message = "模型加载器连接超时" 256 | stream_completions.chunk = error_message 257 | logger.error(error_message) 258 | stream_completions.succeed = False 259 | yield stream_completions 260 | return 261 | 262 | @overload 263 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ... 264 | 265 | @overload 266 | async def ask( 267 | self, request: ModelRequest, *, stream: Literal[True] = True 268 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ... 269 | 270 | async def ask( 271 | self, request: ModelRequest, *, stream: bool = False 272 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]: 273 | self._total_tokens = 0 274 | self.__build_tools_list(request.tools) 275 | self.gemini_config.system_instruction = request.system 276 | 277 | messages = self._build_messages(request) 278 | 279 | if stream: 280 | return self._ask_stream(messages) 281 | 282 | return await self._ask_sync(messages) 283 | -------------------------------------------------------------------------------- /muicebot/llm/providers/ollama.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, List, Literal, Union, overload 2 | 3 | import ollama 4 | from nonebot import logger 5 | from ollama import ResponseError 6 | 7 | from .. import ( 8 | BaseLLM, 9 | ModelCompletions, 10 | ModelConfig, 11 | ModelRequest, 12 | ModelStreamCompletions, 13 | register, 14 | ) 15 | from ..utils.images import get_file_base64 16 | from ..utils.tools import function_call_handler 17 | 18 | 19 | @register("ollama") 20 | class Ollama(BaseLLM): 21 | """ 22 | 使用 Ollama 模型服务调用模型 23 | """ 24 | 25 | def __init__(self, model_config: ModelConfig) -> None: 26 | super().__init__(model_config) 27 | self._require("model_name") 28 | self.model = self.config.model_name 29 | self.host = self.config.api_host if self.config.api_host else "http://localhost:11434" 30 | self.top_k = self.config.top_k 31 | self.top_p = self.config.top_p 32 | self.temperature = self.config.temperature 33 | self.repeat_penalty = self.config.repetition_penalty or 1 34 | self.presence_penalty = self.config.presence_penalty or 0 35 | self.frequency_penalty = self.config.frequency_penalty or 1 36 | self.stream = self.config.stream 37 | 38 | self._tools: List[dict] = [] 39 | 40 | def load(self) -> bool: 41 | try: 42 | self.client = ollama.AsyncClient(host=self.host) 43 | self.is_running = True 44 | except ResponseError as e: 45 | logger.error(f"加载 Ollama 加载器时发生错误: {e}") 46 | except ConnectionError as e: 47 | logger.error(f"加载 Ollama 加载器时发生错误: {e}") 48 | finally: 49 | return self.is_running 50 | 51 | def __build_multi_messages(self, request: ModelRequest) -> dict: 52 | """ 53 | 构建多模态类型 54 | 55 | 当前模型加载器支持的多模态类型: `image` 56 | """ 57 | images = [] 58 | 59 | for resource in request.resources: 60 | if resource.path is None: 61 | continue 62 | image_base64 = get_file_base64(local_path=resource.path) 63 | images.append(image_base64) 64 | 65 | message = {"role": "user", "content": request.prompt, "images": images} 66 | 67 | return message 68 | 69 | def _build_messages(self, request: ModelRequest) -> list: 70 | messages = [] 71 | 72 | if request.system: 73 | messages.append({"role": "system", "content": request.system}) 74 | 75 | for index, item in enumerate(request.history): 76 | messages.append(self.__build_multi_messages(ModelRequest(item.message, resources=item.resources))) 77 | messages.append({"role": "assistant", "content": item.respond}) 78 | 79 | message = self.__build_multi_messages(request) 80 | 81 | messages.append(message) 82 | 83 | return messages 84 | 85 | async def _ask_sync(self, messages: list) -> ModelCompletions: 86 | completions = ModelCompletions() 87 | 88 | try: 89 | response = await self.client.chat( 90 | model=self.model, 91 | messages=messages, 92 | tools=self._tools, 93 | stream=False, 94 | options={ 95 | "temperature": self.temperature, 96 | "top_k": self.top_k, 97 | "top_p": self.top_p, 98 | "repeat_penalty": self.repeat_penalty, 99 | "presence_penalty": self.presence_penalty, 100 | "frequency_penalty": self.frequency_penalty, 101 | }, 102 | ) 103 | 104 | tool_calls = response.message.tool_calls 105 | 106 | if not tool_calls: 107 | completions.text = response.message.content or "(警告:模型无返回)" 108 | return completions 109 | 110 | for tool in tool_calls: 111 | function_name = tool.function.name 112 | function_args = tool.function.arguments 113 | 114 | function_return = await function_call_handler(function_name, dict(function_args)) 115 | 116 | messages.append(response.message) 117 | messages.append({"role": "tool", "content": str(function_return), "name": tool.function.name}) 118 | return await self._ask_sync(messages) 119 | 120 | completions.text = "模型调用错误:未知错误" 121 | completions.succeed = False 122 | return completions 123 | 124 | except ollama.ResponseError as e: 125 | error_info = f"模型调用错误: {e.error}" 126 | logger.error(error_info) 127 | completions.succeed = False 128 | completions.text = error_info 129 | return completions 130 | 131 | async def _ask_stream(self, messages: list) -> AsyncGenerator[ModelStreamCompletions, None]: 132 | try: 133 | response = await self.client.chat( 134 | model=self.model, 135 | messages=messages, 136 | tools=self._tools, 137 | stream=True, 138 | options={ 139 | "temperature": self.temperature, 140 | "top_k": self.top_k, 141 | "top_p": self.top_p, 142 | "repeat_penalty": self.repeat_penalty, 143 | "presence_penalty": self.presence_penalty, 144 | "frequency_penalty": self.frequency_penalty, 145 | }, 146 | ) 147 | 148 | async for chunk in response: 149 | stream_completions = ModelStreamCompletions() 150 | 151 | tool_calls = chunk.message.tool_calls 152 | 153 | if chunk.message.content: 154 | stream_completions.chunk = chunk.message.content 155 | yield stream_completions 156 | continue 157 | 158 | if not tool_calls: 159 | continue 160 | 161 | for tool in tool_calls: # type:ignore 162 | function_name = tool.function.name 163 | function_args = tool.function.arguments 164 | 165 | function_return = await function_call_handler(function_name, dict(function_args)) 166 | 167 | messages.append(chunk.message) # type:ignore 168 | messages.append({"role": "tool", "content": str(function_return), "name": tool.function.name}) 169 | 170 | async for content in self._ask_stream(messages): 171 | yield content 172 | 173 | except ollama.ResponseError as e: 174 | stream_completions = ModelStreamCompletions() 175 | error_info = f"模型调用错误: {e.error}" 176 | logger.error(error_info) 177 | stream_completions.chunk = error_info 178 | stream_completions.succeed = False 179 | yield stream_completions 180 | return 181 | 182 | @overload 183 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ... 184 | 185 | @overload 186 | async def ask( 187 | self, request: ModelRequest, *, stream: Literal[True] = True 188 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ... 189 | 190 | async def ask( 191 | self, request: ModelRequest, *, stream: bool = False 192 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]: 193 | self._tools = request.tools if request.tools else [] 194 | messages = self._build_messages(request) 195 | 196 | if stream: 197 | return self._ask_stream(messages) 198 | 199 | return await self._ask_sync(messages) 200 | -------------------------------------------------------------------------------- /muicebot/llm/providers/openai.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | from io import BytesIO 4 | from typing import AsyncGenerator, List, Literal, Union, overload 5 | 6 | import openai 7 | from nonebot import logger 8 | from openai import NOT_GIVEN, NotGiven 9 | from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam 10 | 11 | from muicebot.models import Resource 12 | 13 | from .. import ( 14 | BaseLLM, 15 | ModelCompletions, 16 | ModelConfig, 17 | ModelRequest, 18 | ModelStreamCompletions, 19 | register, 20 | ) 21 | from ..utils.images import get_file_base64 22 | from ..utils.tools import function_call_handler 23 | 24 | 25 | @register("openai") 26 | class Openai(BaseLLM): 27 | _tools: List[ChatCompletionToolParam] 28 | modalities: Union[List[Literal["text", "audio"]], NotGiven] 29 | 30 | def __init__(self, model_config: ModelConfig) -> None: 31 | super().__init__(model_config) 32 | self._require("api_key", "model_name") 33 | self.api_key = self.config.api_key 34 | self.model = self.config.model_name 35 | self.api_base = self.config.api_host or "https://api.openai.com/v1" 36 | self.max_tokens = self.config.max_tokens 37 | self.temperature = self.config.temperature 38 | self.stream = self.config.stream 39 | self.modalities = [m for m in self.config.modalities if m in {"text", "audio"}] or NOT_GIVEN # type:ignore 40 | self.audio = self.config.audio if (self.modalities and self.config.audio) else NOT_GIVEN 41 | self.extra_body = self.config.extra_body 42 | 43 | self.client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.api_base, timeout=30) 44 | self._tools = [] 45 | 46 | def __build_multi_messages(self, request: ModelRequest) -> dict: 47 | """ 48 | 构建多模态类型 49 | 50 | 此模型加载器支持的多模态类型: `audio` `image` `video` `file` 51 | """ 52 | user_content: List[dict] = [{"type": "text", "text": request.prompt}] 53 | 54 | for resource in request.resources: 55 | if resource.path is None: 56 | continue 57 | 58 | elif resource.type == "audio": 59 | file_format = resource.path.split(".")[-1] 60 | file_data = f"data:;base64,{get_file_base64(local_path=resource.path)}" 61 | user_content.append({"type": "input_audio", "input_audio": {"data": file_data, "format": file_format}}) 62 | 63 | elif resource.type == "image": 64 | file_format = resource.path.split(".")[-1] 65 | file_data = f"data:image/{file_format};base64,{get_file_base64(local_path=resource.path)}" 66 | user_content.append({"type": "image_url", "image_url": {"url": file_data}}) 67 | 68 | elif resource.type == "video": 69 | file_format = resource.path.split(".")[-1] 70 | file_data = f"data:;base64,{get_file_base64(local_path=resource.path)}" 71 | user_content.append({"type": "video_url", "video_url": {"url": file_data}}) 72 | 73 | elif resource.type == "file": 74 | file_format = resource.path.split(".")[-1] 75 | file_data = f"data:;base64,{get_file_base64(local_path=resource.path)}" 76 | user_content.append({"type": "file", "file": {"file_data": file_data}}) 77 | 78 | return {"role": "user", "content": user_content} 79 | 80 | def _build_messages(self, request: ModelRequest) -> list: 81 | messages = [] 82 | 83 | if request.system: 84 | messages.append({"role": "system", "content": request.system}) 85 | 86 | if request.history: 87 | for index, item in enumerate(request.history): 88 | user_content = ( 89 | {"role": "user", "content": item.message} 90 | if not all([item.resources, self.config.multimodal]) 91 | else self.__build_multi_messages(ModelRequest(item.message, resources=item.resources)) 92 | ) 93 | 94 | messages.append(user_content) 95 | messages.append({"role": "assistant", "content": item.respond}) 96 | 97 | user_content = ( 98 | {"role": "user", "content": request.prompt} 99 | if not request.resources 100 | else self.__build_multi_messages(request) 101 | ) 102 | 103 | messages.append(user_content) 104 | 105 | return messages 106 | 107 | def _tool_call_request_precheck(self, message: ChatCompletionMessage) -> bool: 108 | """ 109 | 工具调用请求预检 110 | """ 111 | # We expect a single tool call 112 | if not (message.tool_calls and len(message.tool_calls) == 1): 113 | return False 114 | 115 | # We expect the tool to be a function call 116 | tool_call = message.tool_calls[0] 117 | if tool_call.type != "function": 118 | return False 119 | 120 | return True 121 | 122 | async def _ask_sync(self, messages: list, **kwargs) -> ModelCompletions: 123 | completions = ModelCompletions() 124 | 125 | try: 126 | response = await self.client.chat.completions.create( 127 | audio=self.audio, 128 | model=self.model, 129 | modalities=self.modalities, 130 | messages=messages, 131 | max_tokens=self.max_tokens, 132 | temperature=self.temperature, 133 | stream=False, 134 | tools=self._tools, 135 | extra_body=self.extra_body, 136 | ) 137 | 138 | result = "" 139 | message = response.choices[0].message # type:ignore 140 | self._total_tokens += response.usage.total_tokens if response.usage else -1 141 | 142 | if ( 143 | hasattr(message, "reasoning_content") # type:ignore 144 | and message.reasoning_content # type:ignore 145 | ): 146 | result += f"{message.reasoning_content}" # type:ignore 147 | 148 | if response.choices[0].finish_reason == "tool_calls" and self._tool_call_request_precheck( 149 | response.choices[0].message 150 | ): 151 | messages.append(response.choices[0].message) 152 | tool_call = response.choices[0].message.tool_calls[0] # type:ignore 153 | arguments = json.loads(tool_call.function.arguments.replace("'", '"')) 154 | 155 | function_return = await function_call_handler(tool_call.function.name, arguments) 156 | 157 | messages.append( 158 | { 159 | "tool_call_id": tool_call.id, 160 | "role": "tool", 161 | "name": tool_call.function.name, 162 | "content": function_return, 163 | } 164 | ) 165 | return await self._ask_sync(messages) 166 | 167 | if message.content: # type:ignore 168 | result += message.content # type:ignore 169 | 170 | # 多模态消息处理(目前仅支持 audio 输出) 171 | if response.choices[0].message.audio: 172 | wav_bytes = base64.b64decode(response.choices[0].message.audio.data) 173 | completions.resources = [Resource(type="audio", raw=wav_bytes)] 174 | 175 | completions.text = result or "(警告:模型无输出!)" 176 | completions.usage = self._total_tokens 177 | 178 | except openai.APIConnectionError as e: 179 | error_message = f"API 连接错误: {e}" 180 | completions.text = error_message 181 | logger.error(error_message) 182 | logger.error(e.__cause__) 183 | completions.succeed = False 184 | 185 | except openai.APIStatusError as e: 186 | error_message = f"API 状态异常: {e.status_code}({e.response})" 187 | completions.text = error_message 188 | logger.error(error_message) 189 | completions.succeed = False 190 | 191 | return completions 192 | 193 | async def _ask_stream(self, messages: list, **kwargs) -> AsyncGenerator[ModelStreamCompletions, None]: 194 | is_insert_think_label = False 195 | function_id = "" 196 | function_name = "" 197 | function_arguments = "" 198 | audio_string = "" 199 | 200 | try: 201 | response = await self.client.chat.completions.create( 202 | audio=self.audio, 203 | model=self.model, 204 | modalities=self.modalities, 205 | messages=messages, 206 | max_tokens=self.max_tokens, 207 | temperature=self.temperature, 208 | stream=True, 209 | stream_options={"include_usage": True}, 210 | tools=self._tools, 211 | extra_body=self.extra_body, 212 | ) 213 | 214 | async for chunk in response: 215 | stream_completions = ModelStreamCompletions() 216 | 217 | # 获取 usage (最后一个包中返回) 218 | if chunk.usage: 219 | self._total_tokens += chunk.usage.total_tokens 220 | stream_completions.usage = self._total_tokens 221 | 222 | if not chunk.choices: 223 | yield stream_completions 224 | continue 225 | 226 | # 处理 Function call 227 | if chunk.choices[0].delta.tool_calls: 228 | tool_call = chunk.choices[0].delta.tool_calls[0] 229 | if tool_call.id: 230 | function_id = tool_call.id 231 | if tool_call.function: 232 | if tool_call.function.name: 233 | function_name += tool_call.function.name 234 | if tool_call.function.arguments: 235 | function_arguments += tool_call.function.arguments 236 | 237 | delta = chunk.choices[0].delta 238 | answer_content = delta.content 239 | 240 | # 处理思维过程 reasoning_content 241 | if ( 242 | hasattr(delta, "reasoning_content") and delta.reasoning_content # type:ignore 243 | ): 244 | reasoning_content = chunk.choices[0].delta.reasoning_content # type:ignore 245 | stream_completions.chunk = ( 246 | reasoning_content if is_insert_think_label else "" + reasoning_content 247 | ) 248 | yield stream_completions 249 | is_insert_think_label = True 250 | 251 | elif answer_content: 252 | stream_completions.chunk = ( 253 | answer_content if not is_insert_think_label else "" + answer_content 254 | ) 255 | yield stream_completions 256 | is_insert_think_label = False 257 | 258 | # 处理多模态消息 (audio-only) (非标准方法,可能出现问题) 259 | if hasattr(chunk.choices[0].delta, "audio"): 260 | audio = chunk.choices[0].delta.audio # type:ignore 261 | if audio.get("data", None): 262 | audio_string += audio.get("data") 263 | stream_completions.chunk = audio.get("transcript", "") 264 | yield stream_completions 265 | 266 | if function_id: 267 | 268 | function_return = await function_call_handler(function_name, json.loads(function_arguments)) 269 | 270 | messages.append( 271 | { 272 | "role": "assistant", 273 | "content": None, 274 | "tool_calls": [ 275 | { 276 | "id": function_id, 277 | "type": "function", 278 | "function": {"name": function_name, "arguments": function_arguments}, 279 | } 280 | ], 281 | } 282 | ) 283 | messages.append( 284 | { 285 | "tool_call_id": function_id, 286 | "role": "tool", 287 | "content": function_return, 288 | } 289 | ) 290 | 291 | async for chunk in self._ask_stream(messages): 292 | yield chunk 293 | 294 | # 处理多模态返回 295 | if audio_string: 296 | import numpy as np 297 | import soundfile as sf 298 | 299 | wav_bytes = base64.b64decode(audio_string) 300 | pcm_data = np.frombuffer(wav_bytes, dtype=np.int16) 301 | wav_io = BytesIO() 302 | sf.write(wav_io, pcm_data, samplerate=24000, format="WAV") 303 | 304 | stream_completions = ModelStreamCompletions() 305 | stream_completions.resources = [Resource(type="audio", raw=wav_io)] 306 | yield stream_completions 307 | 308 | except openai.APIConnectionError as e: 309 | error_message = f"API 连接错误: {e}" 310 | logger.error(error_message) 311 | logger.error(e.__cause__) 312 | stream_completions = ModelStreamCompletions() 313 | stream_completions.chunk = error_message 314 | stream_completions.succeed = False 315 | yield stream_completions 316 | 317 | except openai.APIStatusError as e: 318 | error_message = f"API 状态异常: {e.status_code}({e.response})" 319 | logger.error(error_message) 320 | stream_completions = ModelStreamCompletions() 321 | stream_completions.chunk = error_message 322 | stream_completions.succeed = False 323 | yield stream_completions 324 | 325 | @overload 326 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ... 327 | 328 | @overload 329 | async def ask( 330 | self, request: ModelRequest, *, stream: Literal[True] = True 331 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ... 332 | 333 | async def ask( 334 | self, request: ModelRequest, *, stream: bool = False 335 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]: 336 | self._tools = request.tools if request.tools else NOT_GIVEN # type:ignore 337 | self._total_tokens = 0 338 | 339 | messages = self._build_messages(request) 340 | 341 | if stream: 342 | return self._ask_stream(messages) 343 | 344 | return await self._ask_sync(messages) 345 | -------------------------------------------------------------------------------- /muicebot/llm/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Type 2 | 3 | from ._base import BaseLLM 4 | 5 | LLM_REGISTRY: Dict[str, Type[BaseLLM]] = {} 6 | 7 | 8 | def register(name: str): 9 | """ 10 | 注册一个 LLM 实现 11 | 12 | :param name: LLM 实现名 13 | """ 14 | 15 | def decorator(cls: Type[BaseLLM]): 16 | LLM_REGISTRY[name.lower()] = cls 17 | return cls 18 | 19 | return decorator 20 | 21 | 22 | def get_llm_class(name: str) -> Type[BaseLLM]: 23 | """ 24 | 获得一个 LLM 实现类 25 | 26 | :param name: LLM 实现名 27 | """ 28 | if name.lower() not in LLM_REGISTRY: 29 | raise ValueError(f"未注册模型:{name}") 30 | 31 | return LLM_REGISTRY[name.lower()] 32 | -------------------------------------------------------------------------------- /muicebot/llm/utils/images.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from typing import Optional 3 | 4 | 5 | def get_file_base64(local_path: Optional[str] = None, file_bytes: Optional[bytes] = None) -> str: 6 | """ 7 | 获取本地图像 Base64 的方法 8 | """ 9 | if local_path: 10 | with open(local_path, "rb") as f: 11 | image_data = base64.b64encode(f.read()).decode("utf-8") 12 | return image_data 13 | if file_bytes: 14 | image_base64 = base64.b64encode(file_bytes) 15 | return image_base64.decode("utf-8") 16 | raise ValueError("You must pass in a valid parameter!") 17 | -------------------------------------------------------------------------------- /muicebot/llm/utils/tools.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from nonebot import logger 4 | 5 | from muicebot.plugin.func_call import get_function_calls 6 | from muicebot.plugin.mcp import handle_mcp_tool 7 | 8 | 9 | async def function_call_handler(func: str, arguments: dict[str, str] | None = None) -> Any: 10 | """ 11 | 模型 Function Call 请求处理 12 | """ 13 | arguments = arguments if arguments and arguments != {"dummy_param": ""} else {} 14 | 15 | if func_caller := get_function_calls().get(func): 16 | logger.info(f"Function call 请求 {func}, 参数: {arguments}") 17 | result = await func_caller.run(**arguments) 18 | logger.success(f"Function call 成功,返回: {result}") 19 | return result 20 | 21 | if mcp_result := await handle_mcp_tool(func, arguments): 22 | logger.success(f"MCP 工具执行成功,返回: {mcp_result}") 23 | return mcp_result 24 | 25 | return "(Unknown Function)" 26 | -------------------------------------------------------------------------------- /muicebot/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | from datetime import datetime 3 | from functools import total_ordering 4 | from io import BytesIO 5 | from typing import List, Literal, Optional, Union 6 | 7 | 8 | @dataclass 9 | class Resource: 10 | """多模态消息""" 11 | 12 | type: Literal["image", "video", "audio", "file"] 13 | """消息类型""" 14 | path: str = field(default_factory=str) 15 | """本地存储地址(对于模型处理是必需的)""" 16 | url: Optional[str] = field(default=None) 17 | """远程存储地址(一般不传入模型处理中)""" 18 | raw: Optional[Union[bytes, BytesIO]] = field(default=None) 19 | """二进制数据(只使用于模型返回且不保存到数据库中)""" 20 | mimetype: Optional[str] = field(default=None) 21 | """文件元数据类型""" 22 | 23 | def ensure_mimetype(self): 24 | from .utils.utils import guess_mimetype 25 | 26 | if not self.mimetype: 27 | self.mimetype = guess_mimetype(self) 28 | 29 | def to_dict(self) -> dict: 30 | return { 31 | "type": self.type, 32 | "path": self.path, 33 | } 34 | 35 | 36 | @total_ordering 37 | @dataclass 38 | class Message: 39 | """格式化后的 bot 消息""" 40 | 41 | id: int | None = None 42 | """每条消息的唯一ID""" 43 | time: str = field(default_factory=lambda: datetime.strftime(datetime.now(), "%Y.%m.%d %H:%M:%S")) 44 | """ 45 | 字符串形式的时间数据:%Y.%m.%d %H:%M:%S 46 | 若要获取格式化的 datetime 对象,请使用 format_time 47 | """ 48 | userid: str = "" 49 | """Nonebot 的用户id""" 50 | groupid: str = "-1" 51 | """群组id,私聊设为-1""" 52 | message: str = "" 53 | """消息主体""" 54 | respond: str = "" 55 | """模型回复(不包含思维过程)""" 56 | history: int = 1 57 | """消息是否可用于对话历史中,以整数形式映射布尔值""" 58 | resources: List[Resource] = field(default_factory=list) 59 | """多模态消息内容""" 60 | usage: int = -1 61 | """使用的总 tokens, 若模型加载器不支持则设为-1""" 62 | 63 | @property 64 | def format_time(self) -> datetime: 65 | """将时间字符串转换为 datetime 对象""" 66 | return datetime.strptime(self.time, "%Y.%m.%d %H:%M:%S") 67 | 68 | def to_dict(self) -> dict: 69 | return asdict(self) 70 | 71 | @staticmethod 72 | def from_dict(data: dict) -> "Message": 73 | return Message(**data) 74 | 75 | def __hash__(self) -> int: 76 | return hash(self.id) 77 | 78 | def __lt__(self, other: "Message") -> bool: 79 | return self.format_time < other.format_time 80 | -------------------------------------------------------------------------------- /muicebot/muice.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import AsyncGenerator, Optional, Union 4 | 5 | from nonebot import logger 6 | 7 | from .config import ModelConfig, get_model_config, model_config_manager, plugin_config 8 | from .database import Database 9 | from .llm import ( 10 | MODEL_DEPENDENCY_MAP, 11 | ModelCompletions, 12 | ModelRequest, 13 | ModelStreamCompletions, 14 | get_missing_dependencies, 15 | load_model, 16 | ) 17 | from .models import Message, Resource 18 | from .plugin.func_call import get_function_list 19 | from .plugin.hook import HookType, hook_manager 20 | from .plugin.mcp import get_mcp_list 21 | from .templates import generate_prompt_from_template 22 | from .utils.utils import get_username 23 | 24 | 25 | class Muice: 26 | """ 27 | Muice交互类 28 | """ 29 | 30 | _instance = None 31 | 32 | def __new__(cls): 33 | if cls._instance is None: 34 | cls._instance = super().__new__(cls) 35 | cls._instance._initialized = False 36 | return cls._instance 37 | 38 | def __init__(self): 39 | if self._initialized: 40 | return 41 | 42 | self.model_config = get_model_config() 43 | 44 | self.database = Database() 45 | self.max_history_epoch = plugin_config.max_history_epoch 46 | 47 | self.system_prompt = "" 48 | self.user_instructions = "" 49 | 50 | self._load_config() 51 | self._init_model() 52 | 53 | model_config_manager.register_listener(self._on_config_changed) 54 | 55 | self._initialized = True 56 | 57 | def __del__(self): 58 | # 注销监听器 59 | try: 60 | model_config_manager.unregister_listener(self._on_config_changed) 61 | except (AttributeError, RuntimeError) as e: 62 | logger.debug(f"Muice __del__ 清理失败: {e}") 63 | 64 | def _load_config(self): 65 | """ 66 | 加载配置项 67 | """ 68 | self.model_loader = self.model_config.loader 69 | self.multimodal = self.model_config.multimodal 70 | self.template = self.model_config.template 71 | self.template_mode = self.model_config.template_mode 72 | 73 | def _init_model(self) -> None: 74 | """ 75 | 初始化模型类 76 | """ 77 | try: 78 | self.model = load_model(self.model_config) 79 | 80 | except ImportError as e: 81 | logger.critical(f"导入模型加载器 '{self.model_loader}' 失败:{e}") 82 | dependencies = MODEL_DEPENDENCY_MAP.get(self.model_loader, []) 83 | missing = get_missing_dependencies(dependencies) 84 | if missing: 85 | install_command = "pip install " + " ".join(missing) 86 | logger.critical(f"缺少依赖库:{', '.join(missing)}\n请运行以下命令安装缺失项:\n\n{install_command}") 87 | 88 | def load_model(self) -> bool: 89 | """ 90 | 加载模型 91 | 92 | return: 是否加载成功 93 | """ 94 | if not self.model.load(): 95 | logger.error("模型加载失败: self.model.load 函数失败") 96 | return False 97 | 98 | return True 99 | 100 | def _on_config_changed(self, new_config: ModelConfig, old_config: ModelConfig): 101 | """配置文件变更时的回调函数""" 102 | logger.info("检测到配置文件变更,自动重载模型...") 103 | # 更新配置 104 | self.model_config = new_config 105 | 106 | # 重新加载模型 107 | self._load_config() 108 | self._init_model() 109 | self.load_model() 110 | logger.success(f"模型自动重载完成: {old_config.loader} -> {new_config.loader}") 111 | 112 | def change_model_config(self, config_name: str) -> str: 113 | """ 114 | 更换模型配置文件并重新加载模型 115 | """ 116 | try: 117 | self.model_config = get_model_config(config_name) 118 | except (ValueError, FileNotFoundError) as e: 119 | return str(e) 120 | 121 | self._load_config() 122 | self._init_model() 123 | self.load_model() 124 | 125 | return f"已成功加载 {config_name}" if config_name else "未指定模型配置名,已加载默认模型配置" 126 | 127 | async def _prepare_prompt(self, message: str, userid: str, is_private: bool) -> str: 128 | """ 129 | 准备提示词(包含系统提示) 130 | 131 | :param message: 消息主体 132 | :param userid: 用户 Nonebot ID 133 | :param is_private: 是否为私聊信息 134 | :return: 最终模型提示词 135 | """ 136 | if self.template is None: 137 | return message 138 | 139 | system_prompt = generate_prompt_from_template(self.template, userid, is_private).strip() 140 | 141 | if self.template_mode == "system": 142 | self.system_prompt = system_prompt 143 | else: 144 | self.user_instructions = system_prompt 145 | 146 | if is_private: 147 | return f"{self.user_instructions}\n\n{message}" if self.user_instructions else message 148 | 149 | group_prompt = f"<{await get_username()}> {message}" 150 | 151 | return f"{self.user_instructions}\n\n{group_prompt}" if self.user_instructions else group_prompt 152 | 153 | async def _prepare_history(self, userid: str, groupid: str = "-1", enable_history: bool = True) -> list[Message]: 154 | """ 155 | 准备对话历史 156 | 157 | :param userid: 用户名 158 | :param groupid: 群组ID等(私聊时此值为-1) 159 | :param enable_history: 是否启用历史记录 160 | :return: 最终模型提示词 161 | """ 162 | user_history = await self.database.get_user_history(userid, self.max_history_epoch) if enable_history else [] 163 | 164 | # 验证多模态资源路径是否可用 165 | for item in user_history: 166 | item.resources = [ 167 | resource for resource in item.resources if resource.path and os.path.isfile(resource.path) 168 | ] 169 | 170 | if groupid == "-1": 171 | return user_history[-self.max_history_epoch :] 172 | 173 | group_history = await self.database.get_group_history(groupid, self.max_history_epoch) 174 | 175 | for item in group_history: 176 | item.resources = [ 177 | resource for resource in item.resources if resource.path and os.path.isfile(resource.path) 178 | ] 179 | 180 | # 群聊历史构建成 Message 的格式,避免上下文混乱 181 | for item in group_history: 182 | user_name = await get_username(item.userid) 183 | item.message = f"<{user_name}> {item.message}" 184 | 185 | final_history = list(set(user_history + group_history)) 186 | 187 | return final_history[-self.max_history_epoch :] 188 | 189 | async def ask( 190 | self, 191 | message: Message, 192 | enable_history: bool = True, 193 | enable_plugins: bool = True, 194 | ) -> ModelCompletions: 195 | """ 196 | 调用模型 197 | 198 | :param message: 消息文本 199 | :param enable_history: 是否启用历史记录 200 | :param enable_plugins: 是否启用工具插件 201 | :return: 模型回复 202 | """ 203 | if not (self.model and self.model.is_running): 204 | logger.error("模型未加载") 205 | return ModelCompletions("模型未加载", succeed=False) 206 | 207 | is_private = message.groupid == "-1" 208 | logger.info("正在调用模型...") 209 | 210 | await hook_manager.run(HookType.BEFORE_PRETREATMENT, message) 211 | 212 | prompt = await self._prepare_prompt(message.message, message.userid, is_private) 213 | history = await self._prepare_history(message.userid, message.groupid, enable_history) if enable_history else [] 214 | tools = ( 215 | (await get_function_list() + await get_mcp_list()) 216 | if self.model_config.function_call and enable_plugins 217 | else [] 218 | ) 219 | system = self.system_prompt if self.system_prompt else None 220 | 221 | model_request = ModelRequest(prompt, history, message.resources, tools, system) 222 | await hook_manager.run(HookType.BEFORE_MODEL_COMPLETION, model_request) 223 | 224 | start_time = time.perf_counter() 225 | logger.debug(f"模型调用参数:Prompt: {message}, History: {history}") 226 | 227 | response = await self.model.ask(model_request, stream=False) 228 | 229 | end_time = time.perf_counter() 230 | 231 | logger.success(f"模型调用{'成功' if response.succeed else '失败'}: {response}") 232 | logger.debug(f"模型调用时长: {end_time - start_time} s (token用量: {response.usage})") 233 | 234 | await hook_manager.run(HookType.AFTER_MODEL_COMPLETION, response) 235 | 236 | message.respond = response.text 237 | message.usage = response.usage 238 | 239 | await hook_manager.run(HookType.ON_FINISHING_CHAT, message) 240 | 241 | if response.succeed: 242 | await self.database.add_item(message) 243 | 244 | return response 245 | 246 | async def ask_stream( 247 | self, 248 | message: Message, 249 | enable_history: bool = True, 250 | enable_plugins: bool = True, 251 | ) -> AsyncGenerator[ModelStreamCompletions, None]: 252 | """ 253 | 调用模型 254 | 255 | :param message: 消息文本 256 | :param enable_history: 是否启用历史记录 257 | :param enable_plugins: 是否启用工具插件 258 | :return: 模型回复 259 | """ 260 | if not (self.model and self.model.is_running): 261 | logger.error("模型未加载") 262 | yield ModelStreamCompletions("模型未加载") 263 | return 264 | 265 | is_private = message.groupid == "-1" 266 | logger.info("正在调用模型...") 267 | 268 | await hook_manager.run(HookType.BEFORE_PRETREATMENT, message) 269 | 270 | prompt = await self._prepare_prompt(message.message, message.userid, is_private) 271 | history = await self._prepare_history(message.userid, message.groupid, enable_history) if enable_history else [] 272 | tools = ( 273 | (await get_function_list() + await get_mcp_list()) 274 | if self.model_config.function_call and enable_plugins 275 | else [] 276 | ) 277 | system = self.system_prompt if self.system_prompt else None 278 | 279 | model_request = ModelRequest(prompt, history, message.resources, tools, system) 280 | await hook_manager.run(HookType.BEFORE_MODEL_COMPLETION, model_request) 281 | 282 | start_time = time.perf_counter() 283 | logger.debug(f"模型调用参数:Prompt: {message}, History: {history}") 284 | 285 | response = await self.model.ask(model_request, stream=True) 286 | 287 | total_reply = "" 288 | total_resources: list[Resource] = [] 289 | item: Optional[ModelStreamCompletions] = None 290 | 291 | async for item in response: 292 | await hook_manager.run(HookType.ON_STREAM_CHUNK, item) 293 | total_reply += item.chunk 294 | yield item 295 | if item.resources: 296 | total_resources.extend(item.resources) 297 | 298 | if item is None: 299 | raise RuntimeError("模型调用器返回的值应至少包含一个元素") 300 | 301 | end_time = time.perf_counter() 302 | logger.success(f"已完成流式回复: {total_reply}") 303 | logger.debug(f"模型调用时长: {end_time - start_time} s (token用量: {item.usage})") 304 | 305 | final_model_completions = ModelCompletions( 306 | text=total_reply, usage=item.usage, resources=total_resources.copy(), succeed=item.succeed 307 | ) 308 | await hook_manager.run(HookType.AFTER_MODEL_COMPLETION, final_model_completions) 309 | 310 | # 提取挂钩函数的可能的 resources 资源 311 | new_resources = [r for r in final_model_completions.resources if r not in total_resources] 312 | 313 | # yield 新资源 314 | for r in new_resources: 315 | yield ModelStreamCompletions(resources=[r]) 316 | 317 | message.respond = total_reply 318 | message.usage = item.usage 319 | 320 | await hook_manager.run(HookType.ON_FINISHING_CHAT, message) 321 | 322 | if item.succeed: 323 | await self.database.add_item(message) 324 | 325 | async def refresh(self, userid: str) -> Union[AsyncGenerator[ModelStreamCompletions, None], ModelCompletions]: 326 | """ 327 | 刷新对话 328 | 329 | :userid: 用户唯一标识id 330 | """ 331 | logger.info(f"用户 {userid} 请求刷新") 332 | 333 | user_history = await self.database.get_user_history(userid, limit=1) 334 | 335 | if not user_history: 336 | logger.warning("用户对话数据不存在,拒绝刷新") 337 | return ModelCompletions("你都还没和我说过一句话呢,得和我至少聊上一段才能刷新哦") 338 | 339 | last_item = user_history[0] 340 | 341 | await self.database.mark_history_as_unavailable(userid, 1) 342 | 343 | if not self.model_config.stream: 344 | return await self.ask(last_item) 345 | 346 | return self.ask_stream(last_item) 347 | 348 | async def reset(self, userid: str) -> str: 349 | """ 350 | 清空历史对话(将用户对话历史记录标记为不可用) 351 | """ 352 | await self.database.mark_history_as_unavailable(userid) 353 | return "已成功移除对话历史~" 354 | 355 | async def undo(self, userid: str) -> str: 356 | await self.database.mark_history_as_unavailable(userid, 1) 357 | return "已成功撤销上一段对话~" 358 | -------------------------------------------------------------------------------- /muicebot/plugin/__init__.py: -------------------------------------------------------------------------------- 1 | from .context import get_bot, get_ctx, get_event, get_mather, get_state, set_ctx 2 | from .loader import ( 3 | get_plugin_by_module_name, 4 | get_plugin_data_dir, 5 | get_plugins, 6 | load_plugin, 7 | load_plugins, 8 | ) 9 | from .models import Plugin, PluginMetadata 10 | 11 | __all__ = [ 12 | "get_bot", 13 | "get_state", 14 | "get_ctx", 15 | "get_event", 16 | "get_mather", 17 | "load_plugin", 18 | "load_plugins", 19 | "get_plugins", 20 | "get_plugin_by_module_name", 21 | "PluginMetadata", 22 | "Plugin", 23 | "set_ctx", 24 | "get_plugin_data_dir", 25 | ] 26 | -------------------------------------------------------------------------------- /muicebot/plugin/context.py: -------------------------------------------------------------------------------- 1 | """ 2 | 存储并获取 Nonebot 依赖注入中的上下文 3 | """ 4 | 5 | from contextvars import ContextVar 6 | from typing import Tuple 7 | 8 | from nonebot.adapters import Bot, Event 9 | from nonebot.matcher import Matcher 10 | from nonebot.typing import T_State 11 | 12 | # 定义上下文变量 13 | bot_context: ContextVar[Bot] = ContextVar("bot") 14 | event_context: ContextVar[Event] = ContextVar("event") 15 | state_context: ContextVar[T_State] = ContextVar("state") 16 | mather_context: ContextVar[Matcher] = ContextVar("matcher") 17 | 18 | 19 | # 获取当前上下文的各种信息 20 | def get_bot() -> Bot: 21 | return bot_context.get() 22 | 23 | 24 | def get_event() -> Event: 25 | return event_context.get() 26 | 27 | 28 | def get_state() -> T_State: 29 | return state_context.get() 30 | 31 | 32 | def get_mather() -> Matcher: 33 | return mather_context.get() 34 | 35 | 36 | def set_ctx(bot: Bot, event: Event, state: T_State, matcher: Matcher): 37 | """ 38 | 注册 Nonebot 中的上下文信息 39 | """ 40 | bot_context.set(bot) 41 | event_context.set(event) 42 | state_context.set(state) 43 | mather_context.set(matcher) 44 | 45 | 46 | def get_ctx() -> Tuple[Bot, Event, T_State, Matcher]: 47 | """ 48 | 获取当前上下文 49 | """ 50 | return (get_bot(), get_event(), get_state(), get_mather()) 51 | -------------------------------------------------------------------------------- /muicebot/plugin/func_call/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Muicebot Function Call Plugin 3 | """ 4 | 5 | from .caller import get_function_calls, get_function_list, on_function_call 6 | 7 | __all__ = ["get_function_calls", "get_function_list", "on_function_call"] 8 | -------------------------------------------------------------------------------- /muicebot/plugin/func_call/_types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Coroutine, TypeVar, Union 2 | 3 | SYNC_FUNCTION_CALL_FUNC = Callable[..., str] 4 | ASYNC_FUNCTION_CALL_FUNC = Callable[..., Coroutine[str, Any, str]] 5 | FUNCTION_CALL_FUNC = Union[SYNC_FUNCTION_CALL_FUNC, ASYNC_FUNCTION_CALL_FUNC] 6 | 7 | F = TypeVar("F", bound=FUNCTION_CALL_FUNC) 8 | -------------------------------------------------------------------------------- /muicebot/plugin/func_call/caller.py: -------------------------------------------------------------------------------- 1 | """ 2 | 机器人框架中处理函数调用的模块。 3 | 该模块提供了一个系统,用于注册可被AI系统调用的函数, 4 | 具有自动依赖注入和参数验证功能。 5 | 它包括: 6 | - Caller类:管理函数注册、依赖注入和执行 7 | - 用于函数调用的注册装饰器 8 | - 用于获取已注册函数调用的实用函数 9 | """ 10 | 11 | import inspect 12 | from typing import Any, Optional, get_type_hints 13 | 14 | from nonebot import logger 15 | from nonebot.adapters import Bot, Event 16 | from nonebot.matcher import Matcher 17 | from nonebot.rule import Rule 18 | from nonebot.typing import T_State 19 | 20 | from ..context import get_bot, get_event, get_mather 21 | from ..utils import is_coroutine_callable 22 | from ._types import ASYNC_FUNCTION_CALL_FUNC, F 23 | from .parameter import Parameter 24 | from .utils import async_wrap 25 | 26 | _caller_data: dict[str, "Caller"] = {} 27 | """函数注册表,存储所有注册的函数""" 28 | 29 | 30 | class Caller: 31 | def __init__(self, description: str, rule: Optional[Rule] = None): 32 | self._name: str = "" 33 | """函数名称""" 34 | self._description: str = description 35 | """函数描述""" 36 | self._rule: Optional[Rule] = rule 37 | """启用规则""" 38 | self._parameters: dict[str, Parameter] = {} 39 | """函数参数字典""" 40 | self.function: ASYNC_FUNCTION_CALL_FUNC 41 | """函数对象""" 42 | self.default: dict[str, Any] = {} 43 | """默认值""" 44 | 45 | self.module_name: str = "" 46 | """函数所在模块名称""" 47 | 48 | def __call__(self, func: F) -> F: 49 | """ 50 | 修饰器:注册一个 Function_call 函数 51 | """ 52 | # 确保为异步函数 53 | if is_coroutine_callable(func): 54 | self.function = func # type: ignore 55 | else: 56 | self.function = async_wrap(func) # type:ignore 57 | 58 | self._name = func.__name__ 59 | 60 | # 获取模块名 61 | if module := inspect.getmodule(func): 62 | module_name = module.__name__.split(".")[-1] 63 | else: 64 | module_name = "" 65 | self.module_name = module_name 66 | 67 | _caller_data[self._name] = self 68 | logger.success(f"Function Call 函数 {self.module_name}.{self._name} 已成功加载") 69 | return func 70 | 71 | async def _inject_dependencies(self, kwargs: dict) -> dict: 72 | """ 73 | 自动解析参数并进行依赖注入 74 | """ 75 | sig = inspect.signature(self.function) 76 | hints = get_type_hints(self.function) 77 | 78 | inject_args = kwargs.copy() 79 | 80 | for name, param in sig.parameters.items(): 81 | param_type = hints.get(name, None) 82 | 83 | if param_type and isinstance(param_type, type) and issubclass(param_type, Bot): 84 | inject_args[name] = get_bot() 85 | 86 | elif param_type and isinstance(param_type, type) and issubclass(param_type, Event): 87 | inject_args[name] = get_event() 88 | 89 | elif param_type and isinstance(param_type, type) and issubclass(param_type, Matcher): 90 | inject_args[name] = get_mather() 91 | 92 | # elif param_type and issubclass(param_type, T_State): 93 | # inject_args[name] = get_state() 94 | 95 | # 填充默认值 96 | elif param.default != inspect.Parameter.empty: 97 | inject_args[name] = kwargs.get(name, param.default) 98 | 99 | # 如果参数未提供,则检查是否有默认值 100 | elif name not in inject_args: 101 | raise ValueError(f"缺少必要参数: {name}") 102 | 103 | return inject_args 104 | 105 | def params(self, **kwargs: Parameter) -> "Caller": 106 | self._parameters.update(kwargs) 107 | return self 108 | 109 | async def run(self, **kwargs) -> Any: 110 | """ 111 | 执行 function call 112 | """ 113 | if self.function is None: 114 | raise ValueError("未注册函数对象") 115 | 116 | inject_args = await self._inject_dependencies(kwargs) 117 | 118 | return await self.function(**inject_args) 119 | 120 | def data(self) -> dict[str, Any]: 121 | """ 122 | 生成函数描述信息 123 | 124 | :return: 可用于 Function_call 的字典 125 | """ 126 | if not self._parameters: 127 | properties = { 128 | "dummy_param": {"type": "string", "description": "为了兼容性设置的一个虚拟参数,因此不需要填写任何值"} 129 | } 130 | required = [] 131 | else: 132 | properties = {key: value.data() for key, value in self._parameters.items()} 133 | required = [key for key, value in self._parameters.items() if value.default is None] 134 | 135 | return { 136 | "type": "function", 137 | "function": { 138 | "name": self._name, 139 | "description": self._description, 140 | "parameters": { 141 | "type": "object", 142 | "properties": properties, 143 | "required": required, 144 | }, 145 | }, 146 | } 147 | 148 | 149 | def on_function_call(description: str, rule: Optional[Rule] = None) -> Caller: 150 | """ 151 | 返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数 152 | 153 | :param description: 函数描述,若为None则从函数的docstring中获取 154 | :param rule: 启用规则。不满足规则则不启用此 function call 155 | 156 | :return: Caller对象 157 | """ 158 | caller = Caller(description=description, rule=rule) 159 | return caller 160 | 161 | 162 | def get_function_calls() -> dict[str, Caller]: 163 | """获取所有已注册的function call函数 164 | 165 | Returns: 166 | dict[str, Caller]: 所有已注册的function call类 167 | """ 168 | return _caller_data 169 | 170 | 171 | async def get_function_list() -> list[dict[str, dict]]: 172 | """ 173 | 获取所有已注册的function call函数,并转换为工具格式 174 | 175 | :return: 所有已注册的function call函数列表 176 | """ 177 | tools: list[dict[str, dict]] = [] 178 | bot: Bot = get_bot() 179 | event: Event = get_event() 180 | state: T_State = {} 181 | 182 | for name, caller in _caller_data.items(): 183 | if caller._rule is None or await caller._rule(bot, event, state): 184 | tools.append(caller.data()) 185 | 186 | return tools 187 | -------------------------------------------------------------------------------- /muicebot/plugin/func_call/parameter.py: -------------------------------------------------------------------------------- 1 | """ 2 | 定义 Function_call 插件的参数类 3 | """ 4 | 5 | from typing import Any 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | 10 | class Parameter(BaseModel): 11 | """ 12 | Function_call 插件参数对象 13 | """ 14 | 15 | type: str 16 | """参数类型描述 string integer等""" 17 | description: str 18 | """参数描述""" 19 | default: Any = None 20 | """默认值""" 21 | properties: dict[str, Any] = {} 22 | """参数定义属性,例如最大值最小值等""" 23 | required: bool = False 24 | """是否必须""" 25 | 26 | def data(self) -> dict[str, Any]: 27 | """ 28 | 生成参数描述信息 29 | 30 | :return: 可用于 Function_call 的字典 31 | """ 32 | return { 33 | "type": self.type, 34 | "description": self.description, 35 | **{key: value for key, value in self.properties.items() if value is not None}, 36 | } 37 | 38 | 39 | class ParamTypes: 40 | STRING = "string" 41 | INTEGER = "integer" 42 | ARRAY = "array" 43 | OBJECT = "object" 44 | BOOLEAN = "boolean" 45 | NUMBER = "number" 46 | 47 | 48 | class String(Parameter): 49 | type: str = ParamTypes.STRING 50 | properties: dict[str, Any] = Field(default_factory=dict) 51 | enum: list[str] | None = None 52 | 53 | 54 | class Integer(Parameter): 55 | type: str = ParamTypes.INTEGER 56 | properties: dict[str, Any] = Field(default_factory=lambda: {"minimum": 0, "maximum": 100}) 57 | 58 | minimum: int | None = None 59 | maximum: int | None = None 60 | 61 | 62 | class Array(Parameter): 63 | type: str = ParamTypes.ARRAY 64 | properties: dict[str, Any] = Field(default_factory=lambda: {"items": {"type": "string"}}) 65 | items: str = Field("string", description="数组元素类型") 66 | -------------------------------------------------------------------------------- /muicebot/plugin/func_call/utils.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Any 3 | 4 | from ._types import ASYNC_FUNCTION_CALL_FUNC, SYNC_FUNCTION_CALL_FUNC 5 | 6 | 7 | def async_wrap(func: SYNC_FUNCTION_CALL_FUNC) -> ASYNC_FUNCTION_CALL_FUNC: 8 | """ 9 | 装饰器,将同步函数包装为异步函数 10 | """ 11 | 12 | @wraps(func) 13 | async def wrapper(*args: Any, **kwargs: Any) -> Any: 14 | return func(*args, **kwargs) 15 | 16 | return wrapper 17 | -------------------------------------------------------------------------------- /muicebot/plugin/hook/__init__.py: -------------------------------------------------------------------------------- 1 | from ._types import HookType 2 | from .manager import ( 3 | hook_manager, 4 | on_after_completion, 5 | on_before_completion, 6 | on_before_pretreatment, 7 | on_finish_chat, 8 | on_stream_chunk, 9 | ) 10 | 11 | __all__ = [ 12 | "HookType", 13 | "hook_manager", 14 | "on_after_completion", 15 | "on_before_completion", 16 | "on_before_pretreatment", 17 | "on_finish_chat", 18 | "on_stream_chunk", 19 | ] 20 | -------------------------------------------------------------------------------- /muicebot/plugin/hook/_types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | from typing import Awaitable, Callable, Union 3 | 4 | from muicebot.llm import ModelCompletions, ModelRequest, ModelStreamCompletions 5 | from muicebot.models import Message 6 | 7 | SYNC_HOOK_FUNC = Callable[..., None] 8 | ASYNC_HOOK_FUNC = Callable[..., Awaitable[None]] 9 | HOOK_FUNC = Union[SYNC_HOOK_FUNC, ASYNC_HOOK_FUNC] 10 | 11 | HOOK_ARGS = Union[Message, ModelCompletions, ModelStreamCompletions, ModelRequest] 12 | 13 | 14 | class HookType(Enum): 15 | """可用的 Hook 类型""" 16 | 17 | BEFORE_PRETREATMENT = auto() 18 | """预处理前""" 19 | BEFORE_MODEL_COMPLETION = auto() 20 | """模型调用前""" 21 | ON_STREAM_CHUNK = auto() 22 | """模型流式输出中""" 23 | AFTER_MODEL_COMPLETION = auto() 24 | """模型调用后""" 25 | ON_FINISHING_CHAT = auto() 26 | """结束对话时(存库前)""" 27 | -------------------------------------------------------------------------------- /muicebot/plugin/hook/manager.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections import defaultdict 3 | from typing import ( 4 | Any, 5 | Awaitable, 6 | Callable, 7 | Dict, 8 | List, 9 | Optional, 10 | Union, 11 | get_args, 12 | get_origin, 13 | get_type_hints, 14 | ) 15 | 16 | from nonebot import logger 17 | from nonebot.adapters import Bot, Event 18 | from nonebot.matcher import Matcher 19 | from nonebot.rule import Rule 20 | from nonebot.typing import T_State 21 | 22 | from ..context import get_bot, get_event, get_mather 23 | from ._types import HOOK_ARGS, HOOK_FUNC, HookType 24 | 25 | DEPENDENCY_PROVIDERS: dict[type, Callable[[], object]] = { 26 | Bot: get_bot, 27 | Event: get_event, 28 | Matcher: get_mather, 29 | # T_State: get_state, 30 | } 31 | 32 | 33 | def _match_union(param_type: type, arg: object) -> bool: 34 | if get_origin(param_type) is Union: 35 | return any(isinstance(arg, t) for t in get_args(param_type)) 36 | return False 37 | 38 | 39 | class HookManager: 40 | def __init__(self): 41 | self._hooks: Dict[HookType, List["Hooked"]] = defaultdict(list) 42 | 43 | async def _inject_dependencies(self, function: HOOK_FUNC, hook_arg: HOOK_ARGS) -> dict: 44 | """ 45 | 自动解析参数并进行依赖注入 46 | """ 47 | sig = inspect.signature(function) 48 | hints = get_type_hints(function) 49 | 50 | inject_args: dict[str, Any] = {} 51 | 52 | for name, param in sig.parameters.items(): 53 | param_type = hints.get(name, None) 54 | if not param_type: 55 | continue 56 | 57 | # 1. Union 类型注入 hook_arg 58 | if _match_union(param_type, hook_arg): 59 | inject_args[name] = hook_arg 60 | continue 61 | 62 | # 2. 直接匹配 hook_arg 63 | if isinstance(hook_arg, param_type): 64 | inject_args[name] = hook_arg 65 | continue 66 | 67 | # 3. 依赖提供者匹配(Bot、Event、Matcher...) 68 | for dep_type, provider in DEPENDENCY_PROVIDERS.items(): 69 | if isinstance(param_type, type) and issubclass(param_type, dep_type): 70 | inject_args[name] = provider() 71 | break 72 | 73 | return inject_args 74 | 75 | def register(self, hook_type: HookType, hooked: "Hooked"): 76 | """ 77 | 注册一个挂钩函数 78 | """ 79 | self._hooks[hook_type].append(hooked) 80 | return hooked 81 | 82 | async def run(self, hook_type: HookType, hook_arg: HOOK_ARGS, stream: bool = False): 83 | """ 84 | 运行所有的钩子函数 85 | 86 | :param hook_type: 钩子类型 87 | :param hook_arg: 消息处理流程中对应的数据类 88 | :param stream: 当前是否为流式状态 89 | """ 90 | hookeds = self._hooks[hook_type] 91 | hookeds.sort(key=lambda x: x.priority) 92 | 93 | bot: Bot = get_bot() 94 | event: Event = get_event() 95 | state: T_State = {} 96 | 97 | for hooked in hookeds: 98 | args = await self._inject_dependencies(hooked.function, hook_arg) 99 | 100 | if (hooked.stream is not None and hooked.stream == stream) or ( 101 | hooked.rule and not await hooked.rule(bot, event, state) 102 | ): 103 | continue 104 | 105 | result = hooked.function(**args) 106 | if isinstance(result, Awaitable): 107 | await result 108 | 109 | 110 | hook_manager = HookManager() 111 | 112 | 113 | class Hooked: 114 | """挂钩函数对象""" 115 | 116 | def __init__( 117 | self, hook_type: HookType, priority: int = 10, stream: Optional[bool] = None, rule: Optional[Rule] = None 118 | ): 119 | self.hook_type = hook_type 120 | """钩子函数类型""" 121 | self.priority = priority 122 | """函数调用优先级""" 123 | self.stream = stream 124 | """是否仅在(非)流式中运行""" 125 | self.rule: Optional[Rule] = rule 126 | """启用规则""" 127 | 128 | self.function: HOOK_FUNC 129 | """函数对象""" 130 | 131 | def __call__(self, func: HOOK_FUNC) -> HOOK_FUNC: 132 | """ 133 | 修饰器:注册一个 Hook 函数 134 | """ 135 | self.function = func 136 | 137 | # 获取模块名 138 | if module := inspect.getmodule(func): 139 | module_name = module.__name__.split(".")[-1] 140 | else: 141 | module_name = "" 142 | 143 | hook_manager.register(self.hook_type, self) 144 | logger.success(f"挂钩函数 {module_name}.{func.__name__} 已成功加载") 145 | return func 146 | 147 | 148 | def on_before_pretreatment(priority: int = 10, rule: Optional[Rule] = None) -> Hooked: 149 | """ 150 | 注册一个钩子函数 151 | 这个函数将在传入消息 (`Muice` 的 `_prepare_prompt()`) 前调用 152 | 它可接受一个 `Message` 类参数 153 | 154 | :param priority: 调用优先级 155 | :param rule: Nonebot 的响应规则 156 | """ 157 | return Hooked(HookType.BEFORE_PRETREATMENT, priority=priority, rule=rule) 158 | 159 | 160 | def on_before_completion(priority: int = 10, rule: Optional[Rule] = None) -> Hooked: 161 | """ 162 | 注册一个钩子函数。 163 | 这个函数将在传入模型(`Muice` 的 `model.ask()`)前调用 164 | 它可接受一个 `ModelRequest` 类参数 165 | 166 | :param priority: 调用优先级 167 | :param rule: Nonebot 的响应规则 168 | """ 169 | return Hooked(HookType.BEFORE_MODEL_COMPLETION, priority=priority, rule=rule) 170 | 171 | 172 | def on_stream_chunk(priority: int = 10, rule: Optional[Rule] = None) -> Hooked: 173 | """ 174 | 注册一个钩子函数。 175 | 这个函数将在流式调用中途(`Muice` 的 `model.ask()`)调用 176 | 它可接受一个 `ModelStreamCompletions` 类参数 177 | 178 | :param priority: 调用优先级 179 | :param rule: Nonebot 的响应规则 180 | """ 181 | return Hooked(HookType.ON_STREAM_CHUNK, priority=priority, rule=rule) 182 | 183 | 184 | def on_after_completion(priority: int = 10, stream: Optional[bool] = None, rule: Optional[Rule] = None) -> Hooked: 185 | """ 186 | 注册一个钩子函数。 187 | 这个函数将在传入模型(`Muice` 的 `model.ask()`)后调用(流式则传入整合后的数据) 188 | 它可接受一个 `ModelCompletion` 类参数 189 | 190 | 请注意:当启用流式时,对 `ModelStreamCompletion` 的任何修改将不生效 191 | 192 | :param priority: 调用优先级 193 | :param stream: 是否仅在(非)流式中处理,None 则无限制 194 | :param rule: Nonebot 的响应规则 195 | """ 196 | return Hooked(HookType.AFTER_MODEL_COMPLETION, priority=priority, stream=stream, rule=rule) 197 | 198 | 199 | def on_finish_chat(priority: int = 10, rule: Optional[Rule] = None) -> Hooked: 200 | """ 201 | 注册一个钩子函数。 202 | 这个函数将在结束对话(存库前)调用 203 | 它可接受一个 `Message` 类参数 204 | 205 | :param priority: 调用优先级 206 | :param rule: Nonebot 的响应规则 207 | """ 208 | return Hooked(HookType.ON_FINISHING_CHAT, priority=priority, rule=rule) 209 | -------------------------------------------------------------------------------- /muicebot/plugin/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | 实现插件的加载和管理 3 | 4 | Attributes: 5 | _plugins (Dict[str, Plugin]): 插件注册表,存储已加载的插件 6 | Functions: 7 | load_plugin: 加载单个插件 8 | load_plugins: 加载指定目录下的所有插件 9 | get_plugins: 获取已加载的插件列表 10 | """ 11 | 12 | import importlib 13 | import inspect 14 | import os 15 | from pathlib import Path 16 | from typing import Dict, Optional, Set 17 | 18 | import nonebot_plugin_localstore as store 19 | from nonebot import logger 20 | 21 | from .models import Plugin 22 | from .utils import path_to_module_name 23 | 24 | _plugins: Dict[str, Plugin] = {} 25 | """插件注册表""" 26 | _declared_plugins: Set[str] = set() 27 | """已声明插件注册表(不一定加载成功)""" 28 | 29 | 30 | def load_plugin(plugin_path: Path | str, base_path=Path.cwd()) -> Optional[Plugin]: 31 | """ 32 | 加载单个插件 33 | 34 | :param plugins_dirs: 插件路径 35 | :param base_path: 外部插件的基准路径 36 | :return: 插件对象集合 37 | """ 38 | try: 39 | logger.info(f"加载插件: {plugin_path}") 40 | if isinstance(plugin_path, Path): 41 | plugin_path = path_to_module_name(plugin_path, base_path) 42 | 43 | if plugin_path in _declared_plugins: 44 | raise ValueError(f"插件 {plugin_path} 包名出现冲突!") 45 | _declared_plugins.add(plugin_path) 46 | 47 | module = importlib.import_module(plugin_path) 48 | plugin = Plugin(name=module.__name__.split(".")[-1], module=module, package_name=plugin_path) 49 | 50 | _plugins[plugin.package_name] = plugin 51 | 52 | return plugin 53 | 54 | except Exception as e: 55 | logger.error(f"加载插件 {plugin_path} 失败: {e}") 56 | return None 57 | 58 | 59 | def load_plugins(*plugins_dirs: Path | str, base_path=Path.cwd()) -> set[Plugin]: 60 | """ 61 | 加载传入插件目录中的所有插件 62 | 63 | :param plugins_dirs: 插件目录 64 | :param base_path: 外部插件的基准路径 65 | :return: 插件对象集合 66 | """ 67 | 68 | plugins = set() 69 | 70 | for plugin_dir in plugins_dirs: 71 | plugin_dir_path = Path(plugin_dir) if isinstance(plugin_dir, str) else plugin_dir 72 | 73 | for plugin in os.listdir(plugin_dir_path): 74 | plugin_path = Path(os.path.join(plugin_dir_path, plugin)) 75 | module_name = None 76 | 77 | if plugin_path.is_file() and plugin_path.suffix == ".py": 78 | module_name = path_to_module_name(plugin_path.with_suffix(""), base_path) 79 | elif plugin_path.is_dir() and (plugin_path / Path("__init__.py")).exists(): 80 | module_name = path_to_module_name(plugin_path, base_path) 81 | if module_name and (loaded_plugin := load_plugin(module_name)): 82 | plugins.add(loaded_plugin) 83 | 84 | return plugins 85 | 86 | 87 | def _get_caller_plugin_name() -> Optional[str]: 88 | """ 89 | 获取当前调用插件名 90 | """ 91 | current_frame = inspect.currentframe() 92 | if current_frame is None: 93 | return None 94 | 95 | # find plugin 96 | frame = current_frame 97 | while frame := frame.f_back: # type:ignore 98 | module_name = (module := inspect.getmodule(frame)) and module.__name__ 99 | 100 | if module_name is None: 101 | return None 102 | 103 | # skip muicebot it self 104 | package_name = module_name.split(".", maxsplit=1)[0] 105 | if package_name == "muicebot" and not module_name.startswith("muicebot.builtin_plugins"): 106 | continue 107 | 108 | # 将模块路径拆解为层级列表(例如 a.b.c → ["a", "a.b", "a.b.c"]) 109 | module_segments = module_name.split(".") 110 | candidate_paths = [".".join(module_segments[: i + 1]) for i in range(len(module_segments))] 111 | 112 | # 从长到短查找最长匹配 113 | for candidate in reversed(candidate_paths): 114 | if candidate in _declared_plugins: 115 | return candidate.split(".")[-1] 116 | 117 | return None 118 | 119 | 120 | def get_plugins() -> Dict[str, Plugin]: 121 | """ 122 | 获取插件列表 123 | """ 124 | return _plugins 125 | 126 | 127 | def get_plugin_by_module_name(module_name: str) -> Optional[Plugin]: 128 | """ 129 | 通过包名获取插件对象 130 | """ 131 | return _plugins.get(module_name, None) 132 | 133 | 134 | def get_plugin_data_dir() -> Path: 135 | """ 136 | 获取 Muicebot 插件数据目录 137 | 138 | 对于 Muicebot 的插件,它们的插件目录位于 Muicebot 的插件目录中下的 `plugins` 文件夹,并以插件名命名 139 | (`nonebot_plugin_localstore.get_plugin_data_dir`) 140 | """ 141 | plugin_name = _get_caller_plugin_name() 142 | plugin_name = plugin_name or ".unknown" 143 | 144 | plugin_dir = store.get_plugin_data_dir() / "plugin" 145 | plugin_dir = plugin_dir.joinpath(plugin_name).resolve() 146 | plugin_dir.mkdir(parents=True, exist_ok=True) 147 | 148 | logger.debug(plugin_dir) 149 | 150 | return plugin_dir 151 | -------------------------------------------------------------------------------- /muicebot/plugin/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from: https://github.com/modelcontextprotocol/python-sdk/tree/main/examples/clients/simple-chatbot 3 | 4 | MIT License 5 | 6 | Copyright (c) 2024 Anthropic, PBC 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | """ 26 | 27 | from .client import cleanup_servers, get_mcp_list, handle_mcp_tool, initialize_servers 28 | 29 | __all__ = ["handle_mcp_tool", "cleanup_servers", "initialize_servers", "get_mcp_list"] 30 | -------------------------------------------------------------------------------- /muicebot/plugin/mcp/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Optional 3 | 4 | from nonebot import logger 5 | 6 | from .config import server_config 7 | from .server import Server, Tool 8 | 9 | _servers: list[Server] = list() 10 | 11 | 12 | async def initialize_servers() -> None: 13 | """ 14 | 初始化全部 MCP 实例 15 | """ 16 | _servers.extend([Server(name, srv_config) for name, srv_config in server_config.items()]) 17 | for server in _servers: 18 | logger.info(f"初始化 MCP Server: {server.name}") 19 | try: 20 | await server.initialize() 21 | except Exception as e: 22 | logger.error(f"初始化 MCP Server 实例时出现问题: {e}") 23 | await cleanup_servers() 24 | raise 25 | 26 | 27 | async def handle_mcp_tool(tool: str, arguments: Optional[dict[str, Any]] = None) -> Optional[str]: 28 | """ 29 | 处理 MCP Tool 调用 30 | """ 31 | logger.info(f"执行 MCP 工具: {tool} (参数: {arguments})") 32 | 33 | for server in _servers: 34 | server_tools = await server.list_tools() 35 | if not any(server_tool.name == tool for server_tool in server_tools): 36 | continue 37 | 38 | try: 39 | result = await server.execute_tool(tool, arguments) 40 | 41 | if isinstance(result, dict) and "progress" in result: 42 | progress = result["progress"] 43 | total = result["total"] 44 | percentage = (progress / total) * 100 45 | logger.info(f"工具执行进度: {progress}/{total} ({percentage:.1f}%)") 46 | 47 | return f"Tool execution result: {result}" 48 | except Exception as e: 49 | error_msg = f"Error executing tool: {str(e)}" 50 | logger.error(error_msg) 51 | return error_msg 52 | 53 | return None # Not found. 54 | 55 | 56 | async def cleanup_servers() -> None: 57 | """ 58 | 清理 MCP 实例 59 | """ 60 | cleanup_tasks = [asyncio.create_task(server.cleanup()) for server in _servers] 61 | if cleanup_tasks: 62 | try: 63 | await asyncio.gather(*cleanup_tasks, return_exceptions=True) 64 | except Exception as e: 65 | logger.warning(f"清理 MCP 实例时出现错误: {e}") 66 | 67 | 68 | async def transform_json(tool: Tool) -> dict[str, Any]: 69 | """ 70 | 将 MCP Tool 转换为 OpenAI 所需的 parameters 格式,并删除多余字段 71 | """ 72 | func_desc = {"name": tool.name, "description": tool.description, "parameters": {}, "required": []} 73 | 74 | if tool.input_schema: 75 | parameters = { 76 | "type": tool.input_schema.get("type", "object"), 77 | "properties": tool.input_schema.get("properties", {}), 78 | "required": tool.input_schema.get("required", []), 79 | } 80 | func_desc["parameters"] = parameters 81 | 82 | output = {"type": "function", "function": func_desc} 83 | 84 | return output 85 | 86 | 87 | async def get_mcp_list() -> list[dict[str, dict]]: 88 | """ 89 | 获得适用于 OpenAI Tool Call 输入格式的 MCP 工具列表 90 | """ 91 | all_tools: list[dict[str, dict]] = [] 92 | 93 | for server in _servers: 94 | tools = await server.list_tools() 95 | all_tools.extend([await transform_json(tool) for tool in tools]) 96 | 97 | return all_tools 98 | -------------------------------------------------------------------------------- /muicebot/plugin/mcp/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Any, Dict 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | CONFIG_PATH = Path("./configs/mcp.json") 8 | 9 | 10 | class mcpServer(BaseModel): 11 | command: str 12 | """执行指令""" 13 | args: list = Field(default_factory=list) 14 | """命令参数""" 15 | env: dict[str, Any] = Field(default_factory=dict) 16 | """环境配置""" 17 | 18 | 19 | mcpConfig = Dict[str, mcpServer] 20 | 21 | 22 | def get_mcp_server_config() -> mcpConfig: 23 | """ 24 | 从 MCP 配置文件 `config/mcp.json` 中获取 MCP Server 配置 25 | """ 26 | if not CONFIG_PATH.exists(): 27 | return {} 28 | 29 | with open(CONFIG_PATH, "r", encoding="utf-8") as f: 30 | configs = json.load(f) or {} 31 | 32 | mcp_config: mcpConfig = dict() 33 | 34 | for name, srv_config in configs["mcpServers"].items(): 35 | mcp_config[name] = mcpServer(**srv_config) 36 | 37 | return mcp_config 38 | 39 | 40 | server_config = get_mcp_server_config() 41 | -------------------------------------------------------------------------------- /muicebot/plugin/mcp/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | import shutil 5 | from contextlib import AsyncExitStack 6 | from typing import Any, Optional 7 | 8 | from mcp import ClientSession, StdioServerParameters 9 | from mcp.client.stdio import stdio_client 10 | 11 | from .config import mcpServer 12 | 13 | 14 | class Tool: 15 | """ 16 | MCP Tool 17 | """ 18 | 19 | def __init__(self, name: str, description: str, input_schema: dict[str, Any]) -> None: 20 | self.name: str = name 21 | self.description: str = description 22 | self.input_schema: dict[str, Any] = input_schema 23 | 24 | def format_for_llm(self) -> str: 25 | """ 26 | 为 llm 生成工具描述 27 | 28 | :return: 工具描述 29 | """ 30 | args_desc = [] 31 | if "properties" in self.input_schema: 32 | for param_name, param_info in self.input_schema["properties"].items(): 33 | arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}" 34 | if param_name in self.input_schema.get("required", []): 35 | arg_desc += " (required)" 36 | args_desc.append(arg_desc) 37 | 38 | return f"Tool: {self.name}\n" f"Description: {self.description}\n" f"Arguments:{chr(10).join(args_desc)}" "" 39 | 40 | 41 | class Server: 42 | """ 43 | 管理 MCP 服务器连接和工具执行的 Server 实例 44 | """ 45 | 46 | def __init__(self, name: str, config: mcpServer) -> None: 47 | self.name: str = name 48 | self.config: mcpServer = config 49 | self.stdio_context: Any | None = None 50 | self.session: ClientSession | None = None 51 | self._cleanup_lock: asyncio.Lock = asyncio.Lock() 52 | self.exit_stack: AsyncExitStack = AsyncExitStack() 53 | 54 | async def initialize(self) -> None: 55 | """ 56 | 初始化实例 57 | """ 58 | command = shutil.which("npx") if self.config.command == "npx" else self.config.command 59 | if command is None: 60 | raise ValueError(f"command 字段必须为一个有效值, 且目标指令必须存在于环境变量中: {self.config.command}") 61 | 62 | server_params = StdioServerParameters( 63 | command=command, 64 | args=self.config.args, 65 | env={**os.environ, **self.config.env} if self.config.env else None, 66 | ) 67 | try: 68 | stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) 69 | read, write = stdio_transport 70 | session = await self.exit_stack.enter_async_context(ClientSession(read, write)) 71 | await session.initialize() 72 | self.session = session 73 | except Exception as e: 74 | logging.error(f"初始化 MCP Server 实例时遇到错误 {self.name}: {e}") 75 | await self.cleanup() 76 | raise 77 | 78 | async def list_tools(self) -> list[Tool]: 79 | """ 80 | 从 MCP 服务器获得可用工具列表 81 | 82 | :return: 工具列表 83 | 84 | :raises RuntimeError: 如果服务器未启动 85 | """ 86 | if not self.session: 87 | raise RuntimeError(f"Server {self.name} not initialized") 88 | 89 | tools_response = await self.session.list_tools() 90 | tools: list[Tool] = [] 91 | 92 | for item in tools_response: 93 | if isinstance(item, tuple) and item[0] == "tools": 94 | tools.extend(Tool(tool.name, tool.description, tool.inputSchema) for tool in item[1]) 95 | 96 | return tools 97 | 98 | async def execute_tool( 99 | self, 100 | tool_name: str, 101 | arguments: Optional[dict[str, Any]] = None, 102 | retries: int = 2, 103 | delay: float = 1.0, 104 | ) -> Any: 105 | """ 106 | 执行一个 MCP 工具 107 | 108 | :param tool_name: 工具名称 109 | :param arguments: 工具参数 110 | :param retries: 重试次数 111 | :param delay: 重试间隔 112 | 113 | :return: 工具执行结果 114 | 115 | :raises RuntimeError: 如果服务器未初始化 116 | :raises Exception: 工具在所有重试中均失败 117 | """ 118 | if not self.session: 119 | raise RuntimeError(f"Server {self.name} not initialized") 120 | 121 | attempt = 0 122 | while attempt < retries: 123 | try: 124 | logging.info(f"Executing {tool_name}...") 125 | result = await self.session.call_tool(tool_name, arguments) 126 | 127 | return result 128 | 129 | except Exception as e: 130 | attempt += 1 131 | logging.warning(f"Error executing tool: {e}. Attempt {attempt} of {retries}.") 132 | if attempt < retries: 133 | logging.info(f"Retrying in {delay} seconds...") 134 | await asyncio.sleep(delay) 135 | else: 136 | logging.error("Max retries reached. Failing.") 137 | raise 138 | 139 | async def cleanup(self) -> None: 140 | """Clean up server resources.""" 141 | async with self._cleanup_lock: 142 | try: 143 | await self.exit_stack.aclose() 144 | self.session = None 145 | self.stdio_context = None 146 | except Exception as e: 147 | logging.error(f"Error during cleanup of server {self.name}: {e}") 148 | -------------------------------------------------------------------------------- /muicebot/plugin/models.py: -------------------------------------------------------------------------------- 1 | from types import ModuleType 2 | from typing import Any 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class PluginMetadata(BaseModel): 8 | """MuiceBot 插件元数据""" 9 | 10 | name: str 11 | """插件名""" 12 | description: str 13 | """插件描述""" 14 | usage: str 15 | """插件用法""" 16 | homepage: str | None = None 17 | """(可选) 插件主页,通常为开源存储库地址""" 18 | config: type[BaseModel] | None = None 19 | """插件配置项类,如无需配置可不填写""" 20 | extra: dict[Any, Any] | None = None 21 | """不知道干嘛的 extra 信息,我至今都没搞懂,喜欢的可以填""" 22 | 23 | 24 | class Plugin(BaseModel): 25 | """MuiceBot 插件对象""" 26 | 27 | name: str 28 | """插件名称""" 29 | module: ModuleType 30 | """插件模块对象""" 31 | package_name: str 32 | """模块包名""" 33 | 34 | def __hash__(self) -> int: 35 | return hash(self.package_name) 36 | 37 | def __eq__(self, other: Any) -> bool: 38 | return self.package_name == other.package_name if hasattr(other, "package_name") else False 39 | 40 | def __str__(self) -> str: 41 | return self.package_name 42 | 43 | class Config: 44 | arbitrary_types_allowed = True 45 | -------------------------------------------------------------------------------- /muicebot/plugin/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from pathlib import Path 3 | from typing import Any, Callable 4 | 5 | 6 | def path_to_module_name(module_path: Path, base_path: Path) -> str: 7 | """ 8 | 动态计算模块名,基于明确的基准路径 9 | """ 10 | try: 11 | rel_path = module_path.resolve().relative_to(base_path.resolve()) 12 | except ValueError: 13 | # 处理绝对路径与相对路径的兼容性问题 14 | rel_path = module_path.resolve() 15 | 16 | if rel_path.stem == "__init__": 17 | parts = rel_path.parts[:-1] 18 | else: 19 | parts = rel_path.parts 20 | 21 | # 过滤空字符串和无效部分 22 | module_names = [p for p in parts if p not in ("", ".", "..")] 23 | return ".".join(module_names) 24 | 25 | 26 | def is_coroutine_callable(call: Callable[..., Any]) -> bool: 27 | """ 28 | 检查 call 是否是一个 callable 协程函数 29 | """ 30 | if inspect.isroutine(call): 31 | return inspect.iscoroutinefunction(call) 32 | if inspect.isclass(call): 33 | return False 34 | func_ = getattr(call, "__call__", None) 35 | return inspect.iscoroutinefunction(func_) 36 | -------------------------------------------------------------------------------- /muicebot/scheduler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from apscheduler.schedulers.asyncio import AsyncIOScheduler 4 | from apscheduler.triggers.cron import CronTrigger 5 | from apscheduler.triggers.interval import IntervalTrigger 6 | from nonebot import get_bot, logger 7 | from nonebot_plugin_alconna.uniseg import Target, UniMessage 8 | 9 | from .config import get_schedule_configs 10 | from .models import Message 11 | from .muice import Muice 12 | 13 | 14 | async def send_message(target_id: str, message: str, probability: float = 1): 15 | """ 16 | 定时任务:发送信息 17 | 18 | :param target_id: 目标id;若为群聊则为 group_id 或者 channel_id,若为私聊则为 user_id 19 | :param message: 要发送的消息 20 | :param probability: 发送几率 21 | """ 22 | if not (random.random() < probability): 23 | return 24 | 25 | logger.info(f"定时任务: send_message: {message}") 26 | 27 | target = Target(target_id) 28 | await UniMessage(message).send(target=target, bot=get_bot()) 29 | 30 | 31 | async def model_ask(muice_app: Muice, target_id: str, prompt: str, probability: float = 1): 32 | """ 33 | 定时任务:向模型发送消息 34 | 35 | :param muice_app: 沐雪核心类,用于与大语言模型交互 36 | :param target_id: 目标id;若为群聊则为 group_id 或者 channel_id,若为私聊则为 user_id 37 | :param prompt: 模型提示词 38 | :param probability: 发送几率 39 | """ 40 | if not (random.random() < probability): 41 | return 42 | 43 | logger.info(f"定时任务: model_ask: {prompt}") 44 | 45 | if muice_app.model and muice_app.model.is_running: 46 | message = Message(message=prompt, userid=f"(bot_ask){target_id}") 47 | response = await muice_app.ask(message, enable_history=False, enable_plugins=False) 48 | 49 | target = Target(target_id) 50 | await UniMessage(response.text).send(target=target, bot=get_bot()) 51 | 52 | 53 | def setup_scheduler(muice: Muice) -> AsyncIOScheduler: 54 | """ 55 | 设置任务调度器 56 | 57 | :param muice: 沐雪核心类,用于与大语言模型交互 58 | """ 59 | jobs = get_schedule_configs() 60 | scheduler = AsyncIOScheduler() 61 | 62 | for job in jobs: 63 | job_id = job.id 64 | job_type = "send_message" if job.say else "model_ask" 65 | trigger_type = job.trigger 66 | trigger_args = job.args 67 | 68 | # 解析触发器 69 | if trigger_type == "cron": 70 | trigger = CronTrigger(**trigger_args) 71 | 72 | elif trigger_type == "interval": 73 | trigger = IntervalTrigger(**trigger_args) 74 | 75 | else: 76 | logger.error(f"未知的触发器类型: {trigger_type}") 77 | continue 78 | 79 | # 添加任务 80 | if job_type == "send_message": 81 | scheduler.add_job( 82 | send_message, 83 | trigger, 84 | id=job_id, 85 | replace_existing=True, 86 | args=[job.target, job.say, job.probability], 87 | ) 88 | else: 89 | scheduler.add_job( 90 | model_ask, 91 | trigger, 92 | id=job_id, 93 | replace_existing=True, 94 | args=[muice, job.target, job.ask, job.probability], 95 | ) 96 | 97 | logger.success(f"已注册定时任务: {job_id}") 98 | 99 | if jobs: 100 | scheduler.start() 101 | return scheduler 102 | -------------------------------------------------------------------------------- /muicebot/templates/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import generate_prompt_from_template 2 | 3 | __all__ = ["generate_prompt_from_template"] 4 | -------------------------------------------------------------------------------- /muicebot/templates/loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | from jinja2 import Environment, FileSystemLoader 5 | from jinja2.exceptions import TemplateNotFound 6 | from nonebot import logger 7 | 8 | from .model import PromptTemplatesConfig, PromptTemplatesData 9 | 10 | SEARCH_PATH = ["./templates", Path(__file__).parent.parent / "builtin_templates"] 11 | 12 | TEMPLATES_CONFIG_PATH = "./configs/templates.yml" 13 | 14 | 15 | def load_templates_config() -> dict: 16 | """ 17 | 获取模板配置 18 | """ 19 | try: 20 | with open(TEMPLATES_CONFIG_PATH, "r", encoding="utf-8") as f: 21 | return yaml.safe_load(f) or {} 22 | except (FileNotFoundError, yaml.YAMLError): 23 | return {} 24 | 25 | 26 | def load_templates_data(userid: str, is_private: bool = False) -> PromptTemplatesData: 27 | """ 28 | 获取模板数据 29 | """ 30 | config = load_templates_config() 31 | templates_config = PromptTemplatesConfig(**config) 32 | return PromptTemplatesData.from_config(templates_config, userid=userid, is_private=is_private) 33 | 34 | 35 | def generate_prompt_from_template(template_name: str, userid: str, is_private: bool = False) -> str: 36 | """ 37 | 获取提示词 38 | """ 39 | env = Environment(loader=FileSystemLoader(SEARCH_PATH)) 40 | 41 | if not template_name.endswith((".j2", ".jinja2")): 42 | template_name += ".jinja2" 43 | try: 44 | template = env.get_template(template_name) 45 | except TemplateNotFound: 46 | logger.error(f"模板文件 {template_name} 未找到!") 47 | return "" 48 | 49 | templates_data = load_templates_data(userid, is_private) 50 | prompt = template.render(templates_data.model_dump()) 51 | 52 | return prompt 53 | -------------------------------------------------------------------------------- /muicebot/templates/model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel, ConfigDict, Field 4 | 5 | 6 | class Userinfo(BaseModel): 7 | """用户信息配置""" 8 | 9 | name: str 10 | """用户名称""" 11 | id: str 12 | """用户 Nonebot ID""" 13 | info: str 14 | """用户信息""" 15 | 16 | 17 | class PromptTemplatesConfig(BaseModel): 18 | """提示词模板配置""" 19 | 20 | ai_nickname: str = "沐雪" 21 | """AI 昵称""" 22 | master_nickname: str = "沐沐(Muika)" 23 | """AI 开发者昵称""" 24 | 25 | userinfos: List[Userinfo] = Field(default=[]) 26 | """用户信息列表""" 27 | 28 | model_config = ConfigDict(extra="allow") 29 | """允许其他模板参数传入""" 30 | 31 | 32 | class PromptTemplatesData(BaseModel): 33 | """提示词模板数据""" 34 | 35 | ai_nickname: str = "沐雪" 36 | """AI 昵称""" 37 | master_nickname: str = "沐沐(Muika)" 38 | """AI 开发者昵称""" 39 | 40 | private: bool = False 41 | """当前对话是否为私聊""" 42 | user_name: str = "" 43 | """目标用户名""" 44 | user_info: str = "" 45 | """目标用户信息""" 46 | 47 | model_config = ConfigDict(extra="allow") 48 | """允许其他模板参数传入""" 49 | 50 | @classmethod 51 | def from_config(cls, templates_config: PromptTemplatesConfig, userid: str, is_private: bool = False): 52 | base = templates_config.model_dump() 53 | data = cls(**base) 54 | 55 | user = next((u for u in templates_config.userinfos if u.id == userid), None) 56 | if user: 57 | data.user_name = user.name 58 | data.user_info = user.info 59 | 60 | data.private = is_private 61 | 62 | return data 63 | -------------------------------------------------------------------------------- /muicebot/utils/SessionManager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Dict, List, Optional 3 | 4 | from nonebot import logger 5 | from nonebot.adapters import Event 6 | from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg 7 | 8 | from ..config import plugin_config 9 | 10 | 11 | class SessionManager: 12 | def __init__(self) -> None: 13 | self.sessions: Dict[str, List[UniMsg]] = {} 14 | self._lock: asyncio.Lock = asyncio.Lock() 15 | self._timeout = plugin_config.input_timeout 16 | 17 | async def _put(self, sid: str, msg: UniMsg) -> None: 18 | async with self._lock: 19 | if sid not in self.sessions: 20 | self.sessions[sid] = [] 21 | self.sessions[sid].append(msg) 22 | 23 | async def _get_messages_length(self, sid: str) -> int: 24 | async with self._lock: 25 | return len(self.sessions.get(sid, [])) 26 | 27 | def merge_messages(self, sid: str) -> UniMessage: 28 | merged_message = UniMessage() 29 | 30 | for message in self.sessions.pop(sid, []): 31 | merged_message += message 32 | 33 | return merged_message 34 | 35 | async def put_and_wait(self, event: Event, message: UniMsg) -> Optional[UniMessage]: 36 | sid = event.get_session_id() 37 | await self._put(sid, message) 38 | 39 | old_length = await self._get_messages_length(sid) 40 | logger.debug(f"开始等待后续消息 ({self._timeout}s): 会话 {sid}, 当前消息数 {old_length}") 41 | await asyncio.sleep(self._timeout) 42 | new_length = await self._get_messages_length(sid) 43 | 44 | if new_length != old_length: 45 | logger.debug(f"发现新消息插入,当前处理器退出,会话 {sid} 交由后续处理器处理") 46 | return None 47 | 48 | logger.debug(f"无新消息,当前处理器接管会话 {sid}") 49 | return self.merge_messages(sid) 50 | -------------------------------------------------------------------------------- /muicebot/utils/adapters.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from functools import lru_cache 3 | 4 | ADAPTER_CLASSES = {} 5 | """动态适配器注册表""" 6 | 7 | 8 | @lru_cache() 9 | def safe_import(path: str): 10 | """ 11 | 安全导入:即使导入出现问题也不会出现报错 12 | """ 13 | try: 14 | module_path, class_name = path.rsplit(".", 1) 15 | module = importlib.import_module(module_path) 16 | return getattr(module, class_name) 17 | except ImportError: 18 | return None 19 | 20 | 21 | ADAPTER_CLASSES = { 22 | "onebot_v12": safe_import("nonebot.adapters.onebot.v12.Bot"), 23 | "UnsupportedParam": safe_import("nonebot.adapters.onebot.v12.exception.UnsupportedParam"), 24 | "onebot_v11": safe_import("nonebot.adapters.onebot.v11.Bot"), 25 | "telegram_event": safe_import("nonebot.adapters.telegram.Event"), 26 | "telegram_file": safe_import("nonebot.adapters.telegram.message.File"), 27 | "qq_event": safe_import("nonebot.adapters.telegram.qq.Event"), 28 | } 29 | -------------------------------------------------------------------------------- /muicebot/utils/migrations.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import shutil 5 | from pathlib import Path 6 | from typing import TYPE_CHECKING 7 | 8 | import aiosqlite 9 | from nonebot import logger 10 | 11 | if TYPE_CHECKING: 12 | from ..database import Database 13 | 14 | 15 | class MigrationManager: 16 | """数据库迁移管理器""" 17 | 18 | def __init__(self, db: Database) -> None: 19 | self.db = db 20 | self.path: Path = db.DB_PATH 21 | 22 | self.migrations = {0: self._migrate_v0_to_v1, 1: self._migrate_v1_to_v2} 23 | 24 | self.latest_version = max(self.migrations.keys()) + 1 25 | 26 | async def migrate_if_needed(self): 27 | """ 28 | 检查数据库更新并迁移 29 | """ 30 | await self.__init_version_table() 31 | current_version = await self.__get_version() 32 | 33 | while current_version in self.migrations: 34 | logger.info(f"检测到数据库更新,当前版本 v{current_version}...") 35 | backup_path = self.path.with_suffix(f".backup.v{current_version}.db") 36 | shutil.copyfile(self.path, backup_path) 37 | 38 | try: 39 | async with self.db.connect() as conn: 40 | await conn.execute("BEGIN") 41 | await self.migrations[current_version](conn) 42 | await conn.commit() 43 | current_version += 1 44 | await self.__set_version(current_version) 45 | logger.success(f"数据库已成功迁移到 v{current_version} ⭐") 46 | except Exception as e: 47 | logger.error(f"迁移至 v{current_version + 1} 失败,错误:{e}") 48 | shutil.copyfile(backup_path, self.path) 49 | logger.info("已回退到迁移前的状态") 50 | break 51 | 52 | async def __init_version_table(self): 53 | await self.db.execute( 54 | """ 55 | CREATE TABLE IF NOT EXISTS schema_version ( 56 | version INTEGER NOT NULL 57 | ); 58 | """ 59 | ) 60 | result = await self.db.execute("SELECT COUNT(*) FROM schema_version", fetchone=True) 61 | if not result or result[0] == 0: # 只有 v0.2.x 的数据库没有版本号 62 | await self.db.execute("INSERT INTO schema_version (version) VALUES (0)") 63 | logger.info("数据库无版本记录,可能是v0版本的数据库,设为v0") 64 | 65 | async def __get_version(self) -> int: 66 | """ 67 | 获取数据库版本号,默认值为0 68 | """ 69 | result = await self.db.execute("SELECT version FROM schema_version", fetchone=True) 70 | return result[0] if result else 0 71 | 72 | async def __set_version(self, version: int): 73 | await self.db.execute("UPDATE schema_version SET version = ?", (version,)) 74 | 75 | async def _migrate_v0_to_v1(self, conn: aiosqlite.Connection): 76 | logger.info("v1 更新内容: 添加 TOTALTOKENS 与 GROUPID 字段") 77 | await conn.execute("ALTER TABLE MSG ADD COLUMN TOTALTOKENS INTEGER DEFAULT -1;") 78 | await conn.execute("ALTER TABLE MSG ADD COLUMN GROUPID TEXT DEFAULT '-1';") 79 | 80 | async def _migrate_v1_to_v2(self, conn: aiosqlite.Connection): 81 | logger.info("v2 更新内容: total_token 变更为 usage, images 列表优化为 resources") 82 | 83 | # 创建临时表 84 | await conn.execute( 85 | """ 86 | CREATE TABLE MSG_NEW( 87 | ID INTEGER PRIMARY KEY AUTOINCREMENT, 88 | TIME TEXT NOT NULL, 89 | USERID TEXT NOT NULL, 90 | GROUPID TEXT NOT NULL DEFAULT (-1), 91 | MESSAGE TEXT NOT NULL, 92 | RESPOND TEXT NOT NULL, 93 | HISTORY INTEGER NOT NULL DEFAULT (1), 94 | RESOURCES TEXT NOT NULL DEFAULT "[]", 95 | USAGE INTEGER NOT NULL DEFAULT (-1) 96 | ); 97 | """ 98 | ) 99 | 100 | cursor = await conn.execute( 101 | "SELECT ID, TIME, USERID, GROUPID, MESSAGE, RESPOND, HISTORY, IMAGES, TOTALTOKENS FROM MSG" 102 | ) 103 | rows = await cursor.fetchall() 104 | 105 | for row in rows: 106 | id_, time_, userid, groupid, message, respond, history, images_json, totaltokens = row 107 | 108 | # 转换 images -> resources 109 | try: 110 | images = json.loads(images_json) if images_json else [] 111 | except Exception: 112 | images = [] 113 | 114 | resources = [] 115 | for url in images: 116 | resources.append({"type": "image", "path": url}) 117 | resources_json = json.dumps(resources, ensure_ascii=False) 118 | 119 | # 插入到新表 120 | await conn.execute( 121 | """ 122 | INSERT INTO MSG_NEW (ID, TIME, USERID, GROUPID, MESSAGE, RESPOND, HISTORY, RESOURCES, USAGE) 123 | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) 124 | """, 125 | (id_, time_, userid, groupid, message, respond, history, resources_json, totaltokens), 126 | ) 127 | 128 | await conn.execute("DROP TABLE MSG") 129 | await conn.execute("ALTER TABLE MSG_NEW RENAME TO MSG") 130 | -------------------------------------------------------------------------------- /muicebot/utils/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import mimetypes 3 | import os 4 | import ssl 5 | import sys 6 | import time 7 | from importlib.metadata import PackageNotFoundError, version 8 | from typing import Optional 9 | 10 | import fleep 11 | import httpx 12 | import nonebot_plugin_localstore as store 13 | from nonebot import get_bot, logger 14 | from nonebot.adapters import Event, MessageSegment 15 | from nonebot.log import default_filter, logger_id 16 | from nonebot_plugin_userinfo import get_user_info 17 | 18 | from ..config import plugin_config 19 | from ..models import Resource 20 | from ..plugin.context import get_event 21 | from .adapters import ADAPTER_CLASSES 22 | 23 | IMG_DIR = store.get_plugin_data_dir() / ".cache" / "images" 24 | IMG_DIR.mkdir(parents=True, exist_ok=True) 25 | 26 | User_Agent = ( 27 | "Mozilla/5.0 (Windows NT 10.0; Win64; x64)" 28 | "AppleWebKit/537.36 (KHTML, like Gecko)" 29 | "Chrome/134.0.0.0 Safari/537.36 Edg/134.0.0.0" 30 | ) 31 | 32 | 33 | async def download_file(file_url: str, file_name: Optional[str] = None, proxy: Optional[str] = None) -> str: 34 | """ 35 | 保存文件至本地目录 36 | 37 | :param file_url: 图片在线地址 38 | :param file_name: 要保存的文件名 39 | :param proxy: 代理地址 40 | 41 | :return: 保存后的本地目录 42 | """ 43 | ssl_context = ssl.create_default_context() 44 | ssl_context.set_ciphers("DEFAULT") 45 | file_name = file_name if file_name else str(time.time_ns()) + ".jpg" 46 | 47 | async with httpx.AsyncClient(proxy=proxy, verify=ssl_context) as client: 48 | r = await client.get(file_url, headers={"User-Agent": User_Agent}) 49 | local_path = (IMG_DIR / file_name).resolve() 50 | with open(local_path, "wb") as file: 51 | file.write(r.content) 52 | return str(local_path) 53 | 54 | 55 | async def save_image_as_base64(image_url: str, proxy: Optional[str] = None) -> str: 56 | """ 57 | 从在线 url 获取图像 Base64 58 | 59 | :image_url: 图片在线地址 60 | :return: 本地地址 61 | """ 62 | ssl_context = ssl.create_default_context() 63 | ssl_context.set_ciphers("DEFAULT") 64 | 65 | async with httpx.AsyncClient(proxy=proxy, verify=ssl_context) as client: 66 | r = await client.get(image_url, headers={"User-Agent": User_Agent}) 67 | image_base64 = base64.b64encode(r.content) 68 | return image_base64.decode("utf-8") 69 | 70 | 71 | async def get_file_via_adapter(message: MessageSegment, event: Event) -> Optional[str]: 72 | """ 73 | 通过适配器自有方式获取文件地址并保存到本地 74 | 75 | :return: 本地地址 76 | """ 77 | bot = get_bot() 78 | 79 | Onebotv12Bot = ADAPTER_CLASSES["onebot_v12"] 80 | UnsupportedParam = ADAPTER_CLASSES["UnsupportedParam"] 81 | Onebotv11Bot = ADAPTER_CLASSES["onebot_v11"] 82 | TelegramEvent = ADAPTER_CLASSES["telegram_event"] 83 | TelegramFile = ADAPTER_CLASSES["telegram_file"] 84 | 85 | if Onebotv12Bot and UnsupportedParam and isinstance(bot, Onebotv12Bot): 86 | # if message.type != "image": 87 | # return None 88 | 89 | try: 90 | file_path = await bot.get_file(type="url", file_id=message.data["file_id"]) 91 | except UnsupportedParam as e: 92 | logger.error(f"Onebot 实现不支持获取文件 URL,文件获取操作失败:{e}") 93 | return None 94 | 95 | return str(file_path) 96 | 97 | elif Onebotv11Bot and isinstance(bot, Onebotv11Bot): 98 | if "url" in message.data and "file" in message.data: 99 | return await download_file(message.data["url"], message.data["file"]) 100 | 101 | elif TelegramEvent and TelegramFile and isinstance(event, TelegramEvent): 102 | if not isinstance(message, TelegramFile): 103 | return None 104 | 105 | file_id = message.data["file"] 106 | file = await bot.get_file(file_id=file_id) 107 | if not file.file_path: 108 | return None 109 | 110 | url = f"https://api.telegram.org/file/bot{bot.bot_config.token}/{file.file_path}" # type: ignore 111 | # filename = file.file_path.split("/")[1] 112 | return await download_file(url, proxy=plugin_config.telegram_proxy) 113 | 114 | return None 115 | 116 | 117 | def init_logger(): 118 | console_handler_level = plugin_config.log_level 119 | 120 | log_dir = "logs" 121 | if not os.path.exists(log_dir): 122 | os.mkdir(log_dir) 123 | 124 | log_file_path = f"{log_dir}/{time.strftime('%Y-%m-%d')}.log" 125 | 126 | # 移除 NoneBot 默认的日志处理器 127 | logger.remove(logger_id) 128 | # 添加新的日志处理器 129 | logger.add( 130 | sys.stdout, 131 | level=console_handler_level, 132 | diagnose=True, 133 | format="[{level}] {function}: {message}", 134 | filter=default_filter, 135 | colorize=True, 136 | ) 137 | 138 | logger.add( 139 | log_file_path, 140 | level="DEBUG", 141 | format="[{time:YYYY-MM-DD HH:mm:ss}] [{level}] {function}: {message}", 142 | encoding="utf-8", 143 | rotation="1 day", 144 | retention="7 days", 145 | ) 146 | 147 | 148 | def get_version() -> str: 149 | """ 150 | 获取当前版本号 151 | 152 | 优先尝试从已安装包中获取版本号, 否则从 `pyproject.toml` 读取 153 | """ 154 | package_name = "muicebot" 155 | 156 | try: 157 | return version(package_name) 158 | except PackageNotFoundError: 159 | pass 160 | 161 | toml_path = os.path.join(os.path.dirname(__file__), "../pyproject.toml") 162 | 163 | if not os.path.isfile(toml_path): 164 | return "Unknown" 165 | 166 | try: 167 | if sys.version_info >= (3, 11): 168 | import tomllib 169 | 170 | with open(toml_path, "rb") as f: 171 | pyproject_data = tomllib.load(f) 172 | 173 | else: 174 | import toml 175 | 176 | with open(toml_path, "r", encoding="utf-8") as f: 177 | pyproject_data = toml.load(f) 178 | 179 | # 返回版本号 180 | return pyproject_data["tool"]["pdm"]["version"] 181 | 182 | except (FileNotFoundError, KeyError, ModuleNotFoundError): 183 | return "Unknown" 184 | 185 | 186 | async def get_username(user_id: Optional[str] = None) -> str: 187 | """ 188 | 获取当前对话的用户名,如果失败就返回用户id 189 | """ 190 | bot = get_bot() 191 | event = get_event() 192 | user_id = user_id if user_id else event.get_user_id() 193 | user_info = await get_user_info(bot, event, user_id) 194 | return user_info.user_name if user_info else user_id 195 | 196 | 197 | def guess_mimetype(resource: Resource) -> Optional[str]: 198 | """ 199 | 尝试获取 minetype 类型 200 | """ 201 | # raw 不落库,因此无法从 raw 判断 202 | if resource.path and os.path.exists(resource.path): 203 | try: 204 | with open(resource.path, "rb") as file: 205 | header = file.read(128) 206 | info = fleep.get(header) 207 | if info.mime: 208 | return info.mime[0] # type:ignore 209 | else: 210 | return mimetypes.guess_type(resource.path)[0] 211 | except Exception: 212 | return mimetypes.guess_type(resource.path)[0] 213 | elif resource.url: 214 | return mimetypes.guess_type(resource.url)[0] 215 | return None 216 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "MuiceBot" 3 | dynamic = ["version"] 4 | description = "一个可以用来调用多种模型,且被用来练手的Nonebot项目" 5 | readme = "README.md" 6 | requires-python = ">=3.10, <=3.12" 7 | dependencies = [ 8 | "aiosqlite>=0.17.0", 9 | "APScheduler>=3.11.0", 10 | "fleep>=1.0.1", 11 | "jinja2>=3.1.6", 12 | "nonebot2>=2.4.1", 13 | "nonebot-adapter-onebot>=2.4.6", 14 | "nonebot_plugin_alconna>=0.54.2", 15 | "nonebot_plugin_apscheduler>=0.5.0", 16 | "nonebot_plugin_localstore>=0.7.3", 17 | "nonebot_plugin_session>=0.3.2", 18 | "nonebot_plugin_userinfo>=0.2.6", 19 | "openai>=1.64.0", 20 | "pydantic>=2.10.5", 21 | "httpx>=0.27.0", 22 | "ruamel.yaml>=0.18.10", 23 | "SQLAlchemy>=2.0.38", 24 | "toml>=0.10.2; python_version < '3.11'", 25 | "websocket_client>=1.8.0", 26 | "watchdog>=6.0.0", 27 | "mcp[cli]>=1.9.0" 28 | ] 29 | authors = [ 30 | { name = "Moemu", email = "i@snowy.moe" }, 31 | ] 32 | 33 | [project.optional-dependencies] 34 | standard = [ 35 | "azure-ai-inference>=1.0.0b7", 36 | "dashscope>=1.22.1", 37 | "google-genai==1.8.0", 38 | "numpy>=1.26.4", 39 | "ollama>=0.4.7", 40 | "soundfile>=0.13.1" 41 | ] 42 | dev = [ 43 | "pre-commit>=4.1.0", 44 | "mypy>=1.15.0", 45 | "black>=25.1.0", 46 | "types-PyYAML", 47 | ] 48 | 49 | [tool.nonebot] 50 | adapters = [ 51 | { name = "OneBot V12", module_name = "nonebot.adapters.onebot.v12" }, 52 | { name = "OneBot V11", module_name = "nonebot.adapters.onebot.v11" } 53 | ] 54 | plugins = ["nonebot_plugin_alconna", "nonebot_plugin_localstore", "nonebot_plugin_apscheduler", "nonebot_plugin_session", "nonebot_plugin_userinfo"] 55 | builtin_plugins = [] 56 | 57 | [tool.black] 58 | line-length = 120 59 | 60 | [tool.isort] 61 | profile = "black" 62 | 63 | [tool.pdm] 64 | distribution = true 65 | 66 | [tool.pdm.version] 67 | source = "scm" 68 | tag_filter = "v*" 69 | tag_regex = '^v(?:\D*)?(?P([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|c|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$)$' 70 | fallback_version = "0.2.0" 71 | 72 | [tool.pdm.build] 73 | includes = [] 74 | 75 | [build-system] 76 | requires = ["pdm-backend"] 77 | build-backend = "pdm.backend" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | APScheduler==3.11.0 2 | fleep>=1.0.1 3 | httpx>=0.27.0 4 | jinja2>=3.1.6 5 | nonebot2==2.4.1 6 | nonebot_plugin_alconna==0.54.2 7 | nonebot_plugin_apscheduler>=0.5.0 8 | nonebot_plugin_localstore==0.7.3 9 | nonebot_plugin_session>=0.3.2 10 | nonebot_plugin_userinfo>=0.2.6 11 | openai==1.64.0 12 | pydantic==2.10.5 13 | ruamel.yaml==0.18.10 14 | SQLAlchemy==2.0.38 15 | toml>=0.10.2; python_version < '3.11' 16 | websocket_client==1.8.0 17 | watchdog==6.0.0 18 | aiosqlite>=0.17.0 19 | mcp[cli]>=1.9.0 --------------------------------------------------------------------------------