├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE
│ ├── Bug Report.yaml
│ └── Model.yaml
└── workflows
│ ├── pre-commit.yml
│ └── pypi-publish.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── bot.py
├── migrations.py
├── muicebot
├── __init__.py
├── builtin_plugins
│ ├── access_control.py
│ ├── get_current_time.py
│ ├── get_username.py
│ ├── muicebot_plugin_store
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── models.py
│ │ ├── register.py
│ │ └── store.py
│ └── thought_processor.py
├── builtin_templates
│ └── Muice.jinja2
├── config.py
├── database.py
├── llm
│ ├── __init__.py
│ ├── _base.py
│ ├── _config.py
│ ├── _dependencies.py
│ ├── _schema.py
│ ├── loader.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── azure.py
│ │ ├── dashscope.py
│ │ ├── gemini.py
│ │ ├── ollama.py
│ │ └── openai.py
│ ├── registry.py
│ └── utils
│ │ ├── images.py
│ │ └── tools.py
├── models.py
├── muice.py
├── onebot.py
├── plugin
│ ├── __init__.py
│ ├── context.py
│ ├── func_call
│ │ ├── __init__.py
│ │ ├── _types.py
│ │ ├── caller.py
│ │ ├── parameter.py
│ │ └── utils.py
│ ├── hook
│ │ ├── __init__.py
│ │ ├── _types.py
│ │ └── manager.py
│ ├── loader.py
│ ├── mcp
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── config.py
│ │ └── server.py
│ ├── models.py
│ └── utils.py
├── scheduler.py
├── templates
│ ├── __init__.py
│ ├── loader.py
│ └── model.py
└── utils
│ ├── SessionManager.py
│ ├── adapters.py
│ ├── migrations.py
│ └── utils.py
├── pdm.lock
├── pyproject.toml
└── requirements.txt
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
12 | polar: # Replace with a single Polar username
13 | buy_me_a_coffee: moemu
14 | thanks_dev: # Replace with a single thanks.dev username
15 | custom: ['https://www.afdian.com/a/Moemu'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/Bug Report.yaml:
--------------------------------------------------------------------------------
1 | name: Bug报告
2 | description: 运行本项目时出现的问题
3 | title: "[Bug]: "
4 | labels: ["bug"]
5 | assignees:
6 | - Moemu
7 | body:
8 | - type: markdown
9 | attributes:
10 | value: |
11 | 很抱歉您遇到了问题,为了更好地定位到问题,请填写以下表单。
12 | - type: input
13 | id: description
14 | attributes:
15 | label: 问题描述
16 | description: 请简要描述您遇到的问题。
17 | placeholder: ex. 模型加载失败
18 | validations:
19 | required: true
20 | - type: textarea
21 | id: logs
22 | attributes:
23 | label: 相关日志输出
24 | description: 请复制粘贴任何相关的日志输出,你可以在控制台或者是logs文件夹下找到他们(注意保护您的个人信息)。
25 | placeholder: ex. 2022-01-01 12:00:00 [ERROR] ...
26 | render: shell
27 | validations:
28 | required: true
29 | - type: textarea
30 | id: configs
31 | attributes:
32 | label: 配置文件
33 | description: 请提供您的配置文件,即 configs.yml(注意保护您的个人信息)。
34 | placeholder: ex. bot...
35 | render: yaml
36 | validations:
37 | required: true
38 | - type: textarea
39 | id: steps
40 | attributes:
41 | label: 复现步骤
42 | description: 请提供您的操作步骤,以便我们复现问题。
43 | placeholder: ex. 1. 运行程序...
44 | validations:
45 | required: true
46 | - type: textarea
47 | id: others
48 | attributes:
49 | label: 其他信息
50 | description: 如果有其他信息需要提供,请在此处填写。
51 | placeholder: ex. 我的目录结构是...
52 | validations:
53 | required: false
54 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/Model.yaml:
--------------------------------------------------------------------------------
1 | name: 模型相关
2 | description: 模型的输出不符合预期
3 | title: "[Model]: "
4 | labels: ["Model"]
5 | assignees:
6 | - Moemu
7 | body:
8 | - type: markdown
9 | attributes:
10 | value: |
11 | 很抱歉您遇到了问题,为了更好地定位到问题,请填写以下表单。
12 | - type: input
13 | id: description
14 | attributes:
15 | label: 问题描述
16 | description: 请简要描述您遇到的问题。
17 | placeholder: ex. 模型加载失败
18 | validations:
19 | required: true
20 | - type: textarea
21 | id: contexts
22 | attributes:
23 | label: 适当的上下文
24 | description: 适当的上下文有利于我们确定问题,注意移除敏感内容。
25 | placeholder: ex. Q:晚上吃什么?A:我不知道
26 | validations:
27 | required: false
28 | - type: input
29 | id: prompts
30 | attributes:
31 | label: 输入的Prompt
32 | description: 出现问题的输入的提问。
33 | placeholder: ex. 今天天气怎么样?
34 | validations:
35 | required: true
36 | - type: input
37 | id: responses
38 | attributes:
39 | label: 输出的Response
40 | description: 出现问题的输出的提问。
41 | placeholder: ex. 我不知道
42 | validations:
43 | required: true
44 | - type: input
45 | id: expected
46 | attributes:
47 | label: 期望的答案
48 | description: 期望输出的答案。
49 | placeholder: ex. 今天的天气很好,我们出去玩吧
50 | validations:
51 | required: false
52 | - type: textarea
53 | id: others
54 | attributes:
55 | label: 其他信息
56 | description: 如果有其他信息需要提供,请在此处填写。
57 | placeholder: ex. 我的目录结构是...
58 | validations:
59 | required: false
60 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: Pre-commit checks
2 |
3 | on: [push]
4 |
5 | jobs:
6 | pre-commit:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: ['3.10', '3.11']
11 |
12 | steps:
13 | - name: Checkout code
14 | uses: actions/checkout@v3
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v3
18 | with:
19 | python-version: ${{ matrix.python-version }} # 使用矩阵中的 Python 版本
20 |
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install pdm
24 | python -m pip install pre-commit
25 | pdm config python.use_venv false
26 | pdm install --frozen-lockfile --group dev
27 | pre-commit install
28 |
29 | - name: Run pre-commit
30 | run: pre-commit run --all-files
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | pypi-publish:
9 | name: upload release to PyPI
10 | environment: pypi
11 | runs-on: ubuntu-latest
12 | permissions:
13 | contents: read
14 | id-token: write
15 | steps:
16 | - name: Checkout code
17 | uses: actions/checkout@v4
18 |
19 | - name: Set up Python
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: "3.11"
23 |
24 | - name: Setup and Configure PDM
25 | run: |
26 | pip install pdm
27 | pdm config python.use_venv false
28 |
29 | - name: Publish to PyPI
30 | env:
31 | PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
32 | run: pdm publish --username __token__ --password "$PYPI_TOKEN"
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
4 |
5 | ### Python ###
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | pytestdebug.log
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 | doc/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .env.dev
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 | # pytype static type analyzer
140 | .pytype/
141 |
142 | # End of https://www.toptal.com/developers/gitignore/api/python
143 |
144 | # Dist
145 | node_modules
146 | dist/
147 | doc_build/
148 |
149 | # MuiceBot
150 | configs.yml
151 | .vscode
152 | logs
153 | configs
154 | plugins
155 | docs
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ci:
2 | autofix_commit_msg: "style: auto fix by pre-commit hooks"
3 | autofix_prs: false
4 | autoupdate_branch: main
5 | autoupdate_schedule: monthly
6 | autoupdate_commit_msg: "chore: auto update by pre-commit hooks"
7 |
8 |
9 | repos:
10 | - repo: https://github.com/pycqa/flake8
11 | rev: 7.2.0
12 | hooks:
13 | - id: flake8 # 代码风格检查
14 | args: ["--max-line-length=120", "--extend-ignore=W503,E203"]
15 |
16 | - repo: https://github.com/pre-commit/mirrors-mypy
17 | rev: v1.16.0
18 | hooks:
19 | - id: mypy # mypy 类型检查
20 | args: ["--install-types", "--non-interactive", "--ignore-missing-imports"]
21 | additional_dependencies: ["types-PyYAML", "azure-ai-inference", "dashscope", "google-genai", "ollama"]
22 |
23 | - repo: https://github.com/psf/black
24 | rev: 25.1.0
25 | hooks:
26 | - id: black
27 | args: ["--config=./pyproject.toml"]
28 |
29 | - repo: https://github.com/PyCQA/isort
30 | rev: 6.0.1
31 | hooks:
32 | - id: isort
33 | args: ["--profile", "black"]
34 |
35 | - repo: https://github.com/pre-commit/pre-commit-hooks
36 | rev: v5.0.0 # 版本号
37 | hooks:
38 | - id: trailing-whitespace # 删除行尾空格
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## 我们的承诺
2 |
3 | 身为社区成员、贡献者和领袖,我们承诺使社区参与者不受骚扰,无论其年龄、体型、可见或不可见的缺陷、族裔、性征、性别认同和表达、经验水平、教育程度、社会与经济地位、国籍、相貌、种族、种姓、肤色、宗教信仰、性倾向或性取向如何。
4 |
5 | 我们承诺以有助于建立开放、友善、多样化、包容、健康社区的方式行事和互动。
6 |
7 | ## 我们的准则
8 |
9 | 有助于为我们的社区创造积极环境的行为例子包括但不限于:
10 |
11 | - 表现出对他人的同情和善意
12 |
13 | - 尊重不同的主张、观点和感受
14 |
15 | - 提出和大方接受建设性意见
16 |
17 | - 承担责任并向受我们错误影响的人道歉
18 |
19 | - 注重社区共同诉求,而非个人得失
20 |
21 |
22 | 不当行为例子包括:
23 |
24 | - 使用情色化的语言或图像,及性引诱或挑逗
25 |
26 | - 嘲弄、侮辱或诋毁性评论,以及人身或政治攻击
27 |
28 | - 公开或私下的骚扰行为
29 |
30 | - 未经他人明确许可,公布他人的私人信息,如物理或电子邮件地址
31 |
32 | - 其他有理由认定为违反职业操守的不当行为
33 |
34 |
35 | ## 责任和权力
36 |
37 | 社区负责人有责任解释和落实我们所认可的行为准则,并妥善公正地对他们认为不当、威胁、冒犯或有害的任何行为采取纠正措施。
38 |
39 | 社区领导有权力和责任删除、编辑或拒绝或拒绝与本行为准则不相符的评论(comment)、提交(commits)、代码、维基(wiki)编辑、议题(issues)或其他贡献,并在适当时机知采取措施的理由。
40 |
41 | ## 适用范围
42 |
43 | 本行为准则适用于所有社区场合,也适用于在公共场所代表社区时的个人。
44 |
45 | 代表社区的情形包括使用官方电子邮件地址、通过官方社交媒体帐户发帖或在线上或线下活动中担任指定代表。
46 |
47 | ## 监督
48 |
49 | 辱骂、骚扰或其他不可接受的行为可通过 [i@snowy.moe](mailto:i@snowy.moe) 向我们报告。 所有投诉都将得到及时和公平的审查和调查。
50 |
51 | 所有社区负责人都有义务尊重任何事件报告者的隐私和安全。
52 |
53 | ## 处理方针
54 |
55 | 社区负责人将遵循下列社区处理方针来明确他们所认定违反本行为准则的行为的处理方式:
56 |
57 | 1. 纠正
58 |
59 | 社区影响:使用不恰当的语言或其他在社区中被认定为不符合职业道德或不受欢迎的行为。
60 |
61 | 处理意见:由社区负责人发出非公开的书面警告,明确说明违规行为的性质,并解释举止如何不妥。或将要求公开道歉。
62 |
63 | 2. 警告
64 |
65 | 社区影响:单个或一系列违规行为。
66 |
67 | 处理意见:警告并对连续性行为进行处理。在指定时间内,不得与相关人员互动,包括主动与行为准则执行者互动。这包括避免在社区场所和外部渠道中的互动。违反这些条款可能会导致临时或永久封禁。
68 |
69 | 3. 临时封禁
70 |
71 | 社区影响: 严重违反社区准则,包括持续的不当行为。
72 |
73 | 处理意见: 在指定时间内,暂时禁止与社区进行任何形式的互动或公开交流。在此期间,不得与相关人员进行公开或私下互动,包括主动与行为准则执行者互动。违反这些条款可能会导致永久封禁。
74 |
75 | 4. 永久封禁
76 |
77 | 社区影响:行为模式表现出违反社区准则,包括持续的不当行为、骚扰个人或攻击或贬低某个类别的个体。
78 |
79 | 处理意见:永久禁止在社区内进行任何形式的公开互动。
80 |
81 | ## 参见
82 |
83 | 本行为准则改编自 [Contributor Covenant](https://www.contributor-covenant.org/) 2.1 版, 参见 https://www.contributor-covenant.org/version/2/1/code_of_conduct.html。
84 |
85 | 社区处理方针灵感来源于 [Mozilla’s code of conduct enforcement ladder](https://github.com/mozilla/diversity)。
86 |
87 | 有关本行为准则的常见问题的答案,参见 https://www.contributor-covenant.org/faq。
88 |
89 | 其他语言翻译参见 https://www.contributor-covenant.org/translations。
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # MuiceBot 贡献指南
2 |
3 | 非常感谢大家产生了为 MuiceBot 贡献代码的想法,本指南旨在指导您如何科学规范地开发代码,还请在撰写代码前仔细阅读
4 |
5 | ## 报告问题
6 |
7 | MuiceBot 目前仍然处于早期开发状态,暂未有提交正式版的想法,因此部分功能可能存在问题并导致机器人运行不稳定。如果你在使用过程中发现问题并确信是由 MuiceBot 运行框架引起的,请务必提交 Issue
8 |
9 | ## 提议新功能
10 |
11 | MuiceBot 还未进入正式版,欢迎在 Issue 中提议要加入哪些新功能, Maintainer 将会尽力满足大家的需求
12 |
13 | ## Pull Request
14 |
15 | MuiceBot 使用 pre-commit 进行代码规范管理,因此在提交代码前,我们推荐安装 pre-commit 并通过代码检查:
16 |
17 | ```shell
18 | pip install .[standard,dev]
19 | pip install nonebot2[fastapi]
20 |
21 | pre-commit install
22 | ```
23 |
24 | 目前代码检查的工具有:flake8 PEP风格检查、mypy 类型检查、black 风格检查,使用 isort 和 trailing-whitespace 优化代码
25 |
26 | 在本地运行 pre-commit 不是必须的,尤其是在环境包过大的情况下,但我们还是推荐您这么做
27 |
28 | 代码提交后请静待工作流运行结果,若 pre-commit 出现问题请尽量先自行解决后再次提交
29 |
30 | ## 撰写文档
31 |
32 | MuiceBot 使用 [rspress](https://github.com/web-infra-dev/rspress) 作为文档站,你可以直接在 `docs` 文件夹中使用 Markdown 格式撰写文档。
33 |
34 | 文档站项目:https://github.com/MuikaAI/muicebot-doc
35 |
36 | 如果你需要在本地预览文档,可以使用 npm 安装 rspress 依赖后启动 dev 服务:
37 |
38 | ```shell
39 | npm install
40 | npm run dev
41 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2025, Muika
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
MuiceBot
4 |
Muice-Chatbot 的 NoneBot2 实现
5 |
6 |
7 |

8 |

9 |

10 |

11 |

12 |
13 |
18 |
22 |
23 |
24 | > [!NOTE]
25 | >
26 | > 由于本项目的 Maintainer Muika(Moemu) 现已淡出互联网,且面临来自学业和考核的多重压力,此项目的开发进度可能会严重放慢,敬请谅解
27 | >
28 | > 欢迎来到本项目!目前此项目尚处于预发布状态,运行时可能会遇到一些问题。请务必向我们反馈在运行时遇到的各种错误!
29 | >
30 | > 由于本项目待实现的功能还有很多,因此近期没有也可能永远也不会有**发布**正式版的打算。
31 |
32 |
33 | # 介绍✨
34 |
35 | > 我们认为,AI的创造应该是为了帮助人类更好的解决问题而不是产生问题。因此,我们注重大语言模型解决实际问题的能力,如果沐雪系列项目不能帮助我们解决日常、情感类的问题,沐雪的存在就是毫无意义可言。
36 | > *———— 《沐雪系列模型评测标准》*
37 |
38 | Muicebot 是基于 Nonebot2 框架实现的 LLM 聊天机器人,旨在解决现实问题。通过 Muicebot ,你可以在主流聊天平台(如 QQ)获得只有在网页中才能获得的聊天体验。
39 |
40 | Muicebot 内置一个名为沐雪的聊天人设(人设是可选的)以便优化对话体验。有关沐雪的设定,还请移步 [关于沐雪](https://bot.snowy.moe/about/Muice)
41 |
42 | # 功能🪄
43 |
44 | ✅ 内嵌多种模型加载器,如[OpenAI](https://platform.openai.com/docs/overview) 和 [Ollama](https://ollama.com/) ,可加载市面上大多数的模型服务或本地模型,支持多模态(图片识别)和工具调用。另外还附送只会计算 3.9 > 3.11 的沐雪 Roleplay 微调模型一枚~
45 |
46 | ✅ 使用 `nonebot_plugin_alconna` 作为通用信息接口,支持市面上的大多数适配器。对部分适配器做了特殊优化
47 |
48 | ✅ 支持基于 `nonebot_plugin_apscheduler` 的定时任务,可定时向大语言模型交互或直接发送信息
49 |
50 | ✅ 支持基于 `nonebot_plugin_alconna` 的几条常见指令。
51 |
52 | ✅ 使用 SQLite3 保存对话数据。那有人就要问了:Maintainer,Maintainer,能不能实现长期短期记忆、LangChain、FairSeq 这些记忆优化啊。以后会有的(
53 |
54 | ✅ 使用 Jinja2 动态生成人设提示词
55 |
56 | ✅ 支持调用 MCP 服务
57 |
58 | # 模型加载器适配情况
59 |
60 | | 模型加载器 | 流式对话 | 多模态输入/输出 | 推理模型调用 | 工具调用 | 联网搜索 |
61 | | ----------- | -------- | -------------- | ------------ | -------------------- | -------------------- |
62 | | `Azure` | ✅ | 🎶🖼️/❌ | ⭕ | ✅ | ❌ |
63 | | `Dashscope` | ✅ | 🎶🖼️/❌ | ✅ | ⭕ | ✅ |
64 | | `Gemini` | ✅ | ✅/🖼️ | ⭕ | ✅ | ✅ |
65 | | `Ollama` | ✅ | 🖼️/❌ | ✅ | ✅ | ❌ |
66 | | `Openai` | ✅ | ✅/🎶 | ✅ | ✅ | ❌ |
67 |
68 | ✅:表示此加载器能很好地支持该功能并且 `MuiceBot` 已实现
69 |
70 | ⭕:表示此加载器虽支持该功能,但使用时可能遇到问题
71 |
72 | 🚧:表示此加载器虽然支持该功能,但 `MuiceBot` 未实现或正在实现中
73 |
74 | ❓:表示 Maintainer 暂不清楚此加载器是否支持此项功能,可能需要进一步翻阅文档和检查源码
75 |
76 | ❌:表示此加载器不支持该功能
77 |
78 | 多模态标记:🎶表示音频;🎞️ 表示视频;🖼️ 表示图像;📄表示文件;✅ 表示完全支持
79 |
80 | # 本项目适合谁?
81 |
82 | - 拥有编写过 Python 程序经验的开发者
83 |
84 | - 搭建过 Nonebot 项目的 bot 爱好者
85 |
86 | - 想要随时随地和大语言模型交互并寻找着能够同时兼容市面上绝大多数 SDK 的机器人框架的 AI 爱好者
87 |
88 | # TODO📝
89 |
90 | - [X] Function Call 插件系统
91 |
92 | - [X] 多模态模型:工具集支持
93 |
94 | - [X] MCP Client 实现
95 |
96 | - [X] 插件索引库搭建
97 |
98 | - [ ] 短期记忆和长期记忆优化。总感觉这是提示工程师该做的事情,~~和 Bot 没太大关系~~
99 |
100 | - [ ] 发布。我知道你很急,但是你先别急。
101 |
102 |
103 | 近期更新路线:[MuiceBot 更新计划](https://github.com/users/Moemu/projects/2)
104 |
105 | # 使用教程💻
106 |
107 | 参考 [使用文档](https://bot.snowy.moe)
108 |
109 | # 插件商店
110 |
111 | [MuikaAI/Muicebot-Plugins-Index](https://github.com/MuikaAI/Muicebot-Plugins-Index)
112 |
113 | # 关于🎗️
114 |
115 | 大模型输出结果将按**原样**提供,由于提示注入攻击等复杂的原因,模型有可能输出有害内容。无论模型输出结果如何,模型输出结果都无法代表开发者的观点和立场。对于此项目可能间接引发的任何后果(包括但不限于机器人账号封禁),本项目所有开发者均不承担任何责任。
116 |
117 | 本项目基于 [BSD 3](https://github.com/Moemu/nonebot-plugin-muice/blob/main/LICENSE) 许可证提供,涉及到再分发时请保留许可文件的副本。
118 |
119 | 本项目标识使用了 [nonebot/nonebot2](https://github.com/nonebot/nonebot2) 和 画师 [Nakkar](https://www.pixiv.net/users/28246124) [Pixiv作品](https://www.pixiv.net/artworks/101063891) 的资产或作品。如有侵权,请及时与我们联系
120 |
121 | BSD 3 许可证同样适用于沐雪的系统提示词,沐雪的文字人设或人设图在 [CC BY NC 3.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.zh-hans) 许可证条款下提供。
122 |
123 | 此项目中基于或参考了部分开源项目的实现,在这里一并表示感谢:
124 |
125 | - [nonebot/nonebot2](https://github.com/nonebot/nonebot2) 本项目使用的机器人框架
126 |
127 | - [@botuniverse](https://github.com/botuniverse) 负责制定 Onebot 标准的组织
128 |
129 | 感谢各位开发者的协助,可以说没有你们就没有沐雪的今天:
130 |
131 |
132 |
133 |
134 |
135 | 友情链接:[LiteyukiStudio/nonebot-plugin-marshoai](https://github.com/LiteyukiStudio/nonebot-plugin-marshoai)
136 |
137 | 本项目隶属于 MuikaAI
138 |
139 | 基于 OneBot V11 的原始实现:[Moemu/Muice-Chatbot](https://github.com/Moemu/Muice-Chatbot)
140 |
141 |
142 |
143 |
144 | Star History:
145 |
146 | [](https://star-history.com/#Moemu/MuiceBot&Date)
--------------------------------------------------------------------------------
/bot.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 |
4 | import nonebot
5 | from nonebot.drivers import Driver
6 |
7 | PLUGINS_CONFIG_PATH = "./configs/plugins.yml"
8 |
9 |
10 | def load_specified_adapter(driver: Driver, adapter: str):
11 | """
12 | 加载指定的 Nonebot 适配器
13 | """
14 | try:
15 | module = importlib.import_module(adapter)
16 | adapter = module.Adapter
17 | driver.register_adapter(adapter) # type:ignore
18 | except ImportError:
19 | print(f"\33[35m{adapter}不存在,请检查拼写错误或是否已安装该适配器?")
20 | sys.exit(1)
21 |
22 |
23 | nonebot.init()
24 |
25 | driver = nonebot.get_driver()
26 |
27 | enable_adapters: list[str] = driver.config.model_dump().get("enable_adapters", [])
28 |
29 | for adapter in enable_adapters:
30 | load_specified_adapter(driver, adapter)
31 |
32 | nonebot.load_plugin("muicebot")
33 |
34 | nonebot.run()
35 |
--------------------------------------------------------------------------------
/migrations.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import os
4 | import sys
5 | from dataclasses import dataclass, field
6 | from datetime import datetime
7 | from typing import List
8 |
9 | import aiosqlite
10 |
11 |
12 | @dataclass
13 | class Message:
14 | id: int | None = None
15 | """每条消息的唯一ID"""
16 | time: str = datetime.strftime(datetime.now(), "%Y.%m.%d %H:%M:%S")
17 | """
18 | 字符串形式的时间数据:%Y.%m.%d %H:%M:%S
19 | 若要获取格式化的 datetime 对象,请使用 format_time
20 | """
21 | userid: str = ""
22 | """Nonebot 的用户id"""
23 | message: str = ""
24 | """消息主体"""
25 | respond: str = ""
26 | """模型回复(不包含思维过程)"""
27 | history: int = 1
28 | """消息是否可用于对话历史中,以整数形式映射布尔值"""
29 | images: List[str] = field(default_factory=list)
30 | """多模态中使用的图像,默认为空列表"""
31 |
32 | def __post_init__(self):
33 | if isinstance(self.images, str):
34 | self.images = json.loads(self.images)
35 | elif self.images is None:
36 | self.images = []
37 |
38 |
39 | class Database:
40 | def __init__(self, db_path: str) -> None:
41 | self.DB_PATH = db_path
42 |
43 | def __connect(self) -> aiosqlite.Connection:
44 | return aiosqlite.connect(self.DB_PATH)
45 |
46 | async def __execute(self, query: str, params=(), fetchone=False, fetchall=False) -> list | None:
47 | """
48 | 异步执行SQL查询,支持可选参数。
49 |
50 | :param query: 要执行的SQL查询语句
51 | :param params: 传递给查询的参数
52 | :param fetchone: 是否获取单个结果
53 | :param fetchall: 是否获取所有结果
54 | """
55 | async with self.__connect() as conn:
56 | cursor = await conn.cursor()
57 | await cursor.execute(query, params)
58 | if fetchone:
59 | return await cursor.fetchone() # type: ignore
60 | if fetchall:
61 | return await cursor.fetchall() # type: ignore
62 | await conn.commit()
63 |
64 | return None
65 |
66 | async def __create_database(self) -> None:
67 | await self.__execute(
68 | """CREATE TABLE MSG(
69 | ID INTEGER PRIMARY KEY AUTOINCREMENT,
70 | TIME TEXT NOT NULL,
71 | USERID TEXT NOT NULL,
72 | MESSAGE TEXT NOT NULL,
73 | RESPOND TEXT NOT NULL,
74 | HISTORY INTEGER NOT NULL DEFAULT (1),
75 | IMAGES TEXT NOT NULL DEFAULT "[]");"""
76 | )
77 |
78 | async def add_item(self, message: Message):
79 | """
80 | 将消息保存到数据库
81 | """
82 | params = (message.time, message.userid, message.message, message.respond, json.dumps(message.images))
83 | query = """INSERT INTO MSG (TIME, USERID, MESSAGE, RESPOND, IMAGES)
84 | VALUES (?, ?, ?, ?, ?)"""
85 | await self.__execute(query, params)
86 |
87 |
88 | async def migrate_old_data_to_new_db(old_data_dir: str, new_database_path: str):
89 | db = Database(new_database_path)
90 |
91 | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
92 |
93 | # 遍历 memory 目录下的 JSON 文件
94 | for filename in os.listdir(old_data_dir):
95 | if not filename.endswith(".json"):
96 | continue
97 |
98 | file_path = os.path.join(old_data_dir, filename)
99 | user_id = filename.replace(".json", "")
100 |
101 | with open(file_path, "r", encoding="utf-8") as file:
102 | for line in file:
103 | try:
104 | data = json.loads(line.strip())
105 | prompt = data.get("prompt", "").strip()
106 | completion = data.get("completion", "").strip()
107 |
108 | # 跳过空数据
109 | if not prompt or not completion:
110 | continue
111 |
112 | message = Message(time=current_time, userid=user_id, message=prompt, respond=completion, images=[])
113 |
114 | await db.add_item(message)
115 |
116 | except json.JSONDecodeError:
117 | print(f"⚠️ JSON 解析失败: {file_path}")
118 |
119 | print("✅ 迁移完成!")
120 |
121 |
122 | if __name__ == "__main__":
123 |
124 | if len(sys.argv) < 3:
125 | print("❌ 使用方式: python migrations.py ")
126 | sys.exit(1)
127 |
128 | old_data_dir = sys.argv[1] # 从命令行参数获取旧数据目录
129 | new_database_path = sys.argv[2]
130 | asyncio.run(migrate_old_data_to_new_db(old_data_dir, new_database_path)) # 运行异步迁移任务
131 |
--------------------------------------------------------------------------------
/muicebot/__init__.py:
--------------------------------------------------------------------------------
1 | from nonebot import require
2 |
3 | require("nonebot_plugin_alconna")
4 | require("nonebot_plugin_localstore")
5 | require("nonebot_plugin_apscheduler")
6 |
7 | from nonebot.plugin import PluginMetadata, inherit_supported_adapters # noqa: E402
8 |
9 | from .config import PluginConfig # noqa: E402
10 | from .utils.utils import init_logger # noqa: E402
11 |
12 | init_logger()
13 |
14 | from . import onebot # noqa: E402, F401
15 |
16 | __plugin_meta__ = PluginMetadata(
17 | name="MuiceBot",
18 | description="Muice-Chatbot 的 Nonebot2 实现,支持市面上大多数的模型",
19 | usage="@at / {config.MUICE_NICKNAMES} : 与大语言模型交互;关于指令类可输入 .help 查询",
20 | type="application",
21 | config=PluginConfig,
22 | homepage="https://bot.snowy.moe/",
23 | extra={},
24 | supported_adapters=inherit_supported_adapters("nonebot_plugin_alconna"),
25 | )
26 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/access_control.py:
--------------------------------------------------------------------------------
1 | from nonebot import get_plugin_config, logger
2 | from nonebot.adapters import Bot, Event
3 | from nonebot.exception import IgnoredException
4 | from nonebot.message import run_preprocessor
5 | from nonebot_plugin_session import SessionIdType, extract_session
6 | from pydantic import BaseModel
7 |
8 | from muicebot.plugin import PluginMetadata
9 |
10 |
11 | class ScopeConfig(BaseModel):
12 | blacklist: list[str] = []
13 | whitelist: list[str] = []
14 |
15 |
16 | class Config(BaseModel):
17 | access_control: ScopeConfig = ScopeConfig()
18 |
19 |
20 | __metadata__ = PluginMetadata(
21 | name="blacklist_whitelist_checker",
22 | description="黑白名单检测",
23 | usage="在插件配置中填写响应配置后即可",
24 | config=Config,
25 | )
26 |
27 | plugin_config = get_plugin_config(Config).access_control # 获取插件配置
28 |
29 | _BLACKLIST = plugin_config.blacklist
30 | _WHITELIST = plugin_config.whitelist
31 | _MODE = "white" if _WHITELIST else "black"
32 |
33 |
34 | @run_preprocessor
35 | async def access_control(bot: Bot, event: Event):
36 | session = extract_session(bot, event)
37 | group_id = session.get_id(SessionIdType.GROUP)
38 | user_id = session.get_id(SessionIdType.USER)
39 | level = session.level
40 |
41 | if _MODE == "black":
42 | if user_id in _BLACKLIST:
43 | msg = f"User {user_id} is in the blacklist"
44 | logger.warning(msg)
45 | raise IgnoredException(msg)
46 |
47 | elif group_id in _BLACKLIST:
48 | msg = f"Group {group_id} is in the blacklist"
49 | logger.warning(msg)
50 | raise IgnoredException(msg)
51 |
52 | if _MODE == "white":
53 | if level >= 2 and group_id not in _WHITELIST: # 白名单只对群组生效
54 | msg = f"Group {group_id} is not in the whitelist"
55 | logger.warning(msg)
56 | raise IgnoredException(msg)
57 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/get_current_time.py:
--------------------------------------------------------------------------------
1 | from muicebot.plugin import PluginMetadata
2 | from muicebot.plugin.func_call import on_function_call
3 |
4 | __metadata__ = PluginMetadata(
5 | name="muicebot-plugin-time", description="时间插件", usage="直接调用,返回 %Y-%m-%d %H:%M:%S 格式的当前时间"
6 | )
7 |
8 |
9 | @on_function_call(
10 | description="获取当前时间",
11 | )
12 | async def get_current_time() -> str:
13 | """
14 | 获取当前时间
15 | """
16 | import datetime
17 |
18 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
19 | return current_time
20 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/get_username.py:
--------------------------------------------------------------------------------
1 | from nonebot.adapters import Bot, Event
2 |
3 | from muicebot.plugin import PluginMetadata
4 | from muicebot.plugin.func_call import on_function_call
5 | from muicebot.utils.utils import get_username as get_username_
6 |
7 | __metadata__ = PluginMetadata(
8 | name="muicebot-plugin-username", description="获取用户名的插件", usage="直接调用,返回当前对话的用户名"
9 | )
10 |
11 |
12 | @on_function_call(description="获取当前对话的用户名字")
13 | async def get_username(bot: Bot, event: Event) -> str:
14 | user_id = event.get_user_id()
15 | return await get_username_(user_id)
16 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/muicebot_plugin_store/__init__.py:
--------------------------------------------------------------------------------
1 | from arclet.alconna import Alconna, Subcommand
2 | from nonebot.permission import SUPERUSER
3 | from nonebot_plugin_alconna import Args, CommandMeta, Match, on_alconna
4 |
5 | from muicebot.plugin import PluginMetadata
6 |
7 | from .store import (
8 | get_installed_plugins_info,
9 | install_plugin,
10 | load_store_plugin,
11 | uninstall_plugin,
12 | update_plugin,
13 | )
14 |
15 | __meta__ = PluginMetadata(name="muicebot-plugin-store", description="Muicebot 插件商店操作", usage=".store help")
16 |
17 | load_store_plugin()
18 |
19 | COMMAND_PREFIXES = [".", "/"]
20 |
21 | store_cmd = on_alconna(
22 | Alconna(
23 | COMMAND_PREFIXES,
24 | "store",
25 | Subcommand("help"),
26 | Subcommand("install", Args["name", str], help_text=".store install 插件名"),
27 | Subcommand("show"),
28 | Subcommand("update", Args["name", str], help_text=".store update 插件名"),
29 | Subcommand("uninstall", Args["name", str], help_text=".store uninstall 插件名"),
30 | meta=CommandMeta("Muicebot 插件商店指令"),
31 | ),
32 | priority=10,
33 | block=True,
34 | skip_for_unmatch=False,
35 | permission=SUPERUSER,
36 | )
37 |
38 |
39 | @store_cmd.assign("install")
40 | async def install(name: Match[str]):
41 | if not name.available:
42 | await store_cmd.finish("必须传入一个插件名")
43 | result = await install_plugin(name.result)
44 | await store_cmd.finish(result)
45 |
46 |
47 | @store_cmd.assign("show")
48 | async def show():
49 | info = await get_installed_plugins_info()
50 | await store_cmd.finish(info)
51 |
52 |
53 | @store_cmd.assign("update")
54 | async def update(name: Match[str]):
55 | if not name.available:
56 | await store_cmd.finish("必须传入一个插件名")
57 | result = await update_plugin(name.result)
58 | await store_cmd.finish(result)
59 |
60 |
61 | @store_cmd.assign("uninstall")
62 | async def uninstall(name: Match[str]):
63 | if not name.available:
64 | await store_cmd.finish("必须传入一个插件名")
65 | result = await uninstall_plugin(name.result)
66 | await store_cmd.finish(result)
67 |
68 |
69 | @store_cmd.assign("help")
70 | async def store_help():
71 | await store_cmd.finish(
72 | "install <插件名> 安装插件\n"
73 | "show 查看已安装的商店插件信息\n"
74 | "update <插件名> 更新插件\n"
75 | "uninstall <插件名> 卸载插件"
76 | )
77 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/muicebot_plugin_store/config.py:
--------------------------------------------------------------------------------
1 | from nonebot import get_plugin_config
2 | from pydantic import BaseModel
3 |
4 |
5 | class Config(BaseModel):
6 | store_index: str = "https://raw.githubusercontent.com/MuikaAI/Muicebot-Plugins-Index/refs/heads/main/plugins.json"
7 | """插件索引文件 url"""
8 |
9 |
10 | config = get_plugin_config(Config)
11 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/muicebot_plugin_store/models.py:
--------------------------------------------------------------------------------
1 | from typing import TypedDict
2 |
3 |
4 | class PluginInfo(TypedDict):
5 | module: str
6 | """插件模块名"""
7 | name: str
8 | """插件名称"""
9 | description: str
10 | """插件描述"""
11 | repo: str
12 | """插件 repo 地址"""
13 |
14 |
15 | class InstalledPluginInfo(TypedDict):
16 | module: str
17 | """插件模块名"""
18 | name: str
19 | """插件名称"""
20 | commit: str
21 | """commit 信息"""
22 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/muicebot_plugin_store/register.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | from .models import InstalledPluginInfo
5 |
6 | JSON_PATH = Path("plugins/installed_plugins.json")
7 |
8 |
9 | def load_json_record() -> dict[str, InstalledPluginInfo]:
10 | """
11 | 获取本地 json 记录
12 | `plugin/installed_plugins.json`
13 | """
14 | if not JSON_PATH.exists():
15 | return {}
16 |
17 | try:
18 | with open(JSON_PATH, "r", encoding="utf-8") as f:
19 | return json.load(f)
20 | except Exception:
21 | return {}
22 |
23 |
24 | def _save_json_record(data: dict) -> None:
25 | """
26 | 在本地保存 json 记录
27 | """
28 | with open(JSON_PATH, "w", encoding="utf-8") as f:
29 | json.dump(data, f, indent=2, ensure_ascii=False)
30 |
31 |
32 | def register_plugin(plugin: str, commit: str, name: str, module: str) -> None:
33 | """
34 | 在本地注册一个插件记录
35 |
36 | :param plugin: 插件唯一索引名
37 | :param commit: 插件 git commit hash值
38 | :param name: 插件名
39 | :param module: 相对于 `store` 文件夹的可导入模块名
40 | """
41 | plugins = load_json_record()
42 | plugins[plugin] = {"commit": commit, "module": module, "name": name}
43 | _save_json_record(plugins)
44 |
45 |
46 | def unregister_plugin(plugin: str) -> None:
47 | """
48 | 取消记录一个插件记录(通常在卸载插件时使用)
49 | """
50 | plugins = load_json_record()
51 | if plugin in plugins:
52 | del plugins[plugin]
53 | _save_json_record(plugins)
54 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/muicebot_plugin_store/store.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import shutil
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | import aiohttp
7 | from nonebot import logger
8 |
9 | from muicebot.plugin import load_plugin
10 |
11 | from .config import config
12 | from .models import PluginInfo
13 | from .register import load_json_record, register_plugin, unregister_plugin
14 |
15 | PLUGIN_DIR = Path("plugins/store")
16 | PLUGIN_DIR.mkdir(parents=True, exist_ok=True)
17 |
18 |
19 | async def get_index() -> Optional[dict[str, PluginInfo]]:
20 | """
21 | 获取插件索引
22 | """
23 | logger.info("获取插件索引文件...")
24 | try:
25 | async with aiohttp.ClientSession() as session:
26 | async with session.get(config.store_index) as response:
27 | response.raise_for_status()
28 | return await response.json(content_type=None)
29 | except aiohttp.ClientError as e:
30 | logger.error(f"获取插件索引失败: {e}")
31 | except Exception as e:
32 | logger.exception(f"解析插件索引时发生意外错误: {e}")
33 | return {}
34 |
35 |
36 | async def get_plugin_commit(plugin: str) -> str:
37 | """
38 | 获取插件 commit hash
39 | """
40 | plugin_path = PLUGIN_DIR / plugin
41 |
42 | process = await asyncio.create_subprocess_exec(
43 | "git",
44 | "log",
45 | '--pretty=format:"%h"',
46 | "-1",
47 | cwd=plugin_path,
48 | stdout=asyncio.subprocess.PIPE,
49 | stderr=asyncio.subprocess.PIPE,
50 | )
51 | stdout, stderr = await process.communicate()
52 |
53 | return stdout.decode().strip()
54 |
55 |
56 | def load_store_plugin():
57 | """
58 | 加载商店插件
59 | """
60 | logger.info("加载商店插件...")
61 | plugins = load_json_record()
62 |
63 | for plugin, info in plugins.items():
64 | if not Path(PLUGIN_DIR / plugin).exists():
65 | continue
66 |
67 | module_path = PLUGIN_DIR / plugin / info["module"]
68 | load_plugin(module_path)
69 |
70 |
71 | async def install_dependencies(path: Path) -> bool:
72 | """
73 | 安装插件依赖
74 |
75 | :return: 依赖安装状态
76 | """
77 | logger.info("安装插件依赖...")
78 |
79 | if (path / "pyproject.toml").exists():
80 | cmd = ["python", "-m", "pip", "install", "."]
81 | elif (path / "requirements.txt").exists():
82 | cmd = ["python", "-m", "pip", "install", "-r", "requirements.txt"]
83 | else:
84 | return True
85 |
86 | proc = await asyncio.create_subprocess_exec(
87 | *cmd,
88 | cwd=str(path),
89 | stdout=asyncio.subprocess.PIPE,
90 | stderr=asyncio.subprocess.PIPE,
91 | )
92 | stdout, stderr = await proc.communicate()
93 |
94 | if proc.returncode == 0:
95 | return True
96 | else:
97 | logger.error("插件依赖安装失败!")
98 | logger.error(stderr)
99 | return False
100 |
101 |
102 | async def get_installed_plugins_info() -> str:
103 | """
104 | 获得已安装插件信息
105 | """
106 | plugins = load_json_record()
107 | plugins_info = []
108 | for plugin, info in plugins.items():
109 | plugins_info.append(f"{plugin}: {info['name']} {info['commit']}")
110 | return "\n".join(plugins_info) or "本地还未安装商店插件~"
111 |
112 |
113 | async def install_plugin(plugin: str) -> str:
114 | """
115 | 通过 git clone 安装指定插件
116 | """
117 | if not (index := await get_index()):
118 | return "❌ 无法获取插件索引文件,请检查控制台日志"
119 |
120 | if plugin not in index:
121 | return f"❌ 插件 {plugin} 不存在于索引中!请检查插件名称是否正确"
122 |
123 | repo_url = index[plugin]["repo"]
124 | module = index[plugin]["module"]
125 | name = index[plugin]["name"]
126 | plugin_path = PLUGIN_DIR / plugin
127 |
128 | if plugin_path.exists():
129 | return f"⚠️ 插件 {plugin} 已存在,无需安装。"
130 |
131 | logger.info(f"获取插件: {repo_url}")
132 | try:
133 | process = await asyncio.create_subprocess_exec(
134 | "git",
135 | "clone",
136 | repo_url,
137 | str(plugin_path),
138 | stdout=asyncio.subprocess.PIPE,
139 | stderr=asyncio.subprocess.PIPE,
140 | )
141 | stdout, stderr = await process.communicate()
142 |
143 | if process.returncode != 0:
144 | return f"❌ 安装失败:{stderr.decode().strip()}"
145 |
146 | if not await install_dependencies(plugin_path):
147 | return "❌ 插件依赖安装失败!请检查控制台输出"
148 |
149 | except FileNotFoundError:
150 | return "❌ 请确保已安装 Git 并配置到 PATH。"
151 |
152 | load_plugin(plugin_path / module)
153 |
154 | commit = await get_plugin_commit(plugin)
155 |
156 | register_plugin(plugin, commit, name, module)
157 |
158 | return f"✅ 插件 {plugin} 安装成功!"
159 |
160 |
161 | async def update_plugin(plugin: str) -> str:
162 | """
163 | 更新指定插件
164 | """
165 | plugin_path = PLUGIN_DIR / plugin
166 |
167 | if not plugin_path.exists():
168 | return f"❌ 插件 {plugin} 不存在!"
169 |
170 | logger.info(f"更新插件: {plugin}")
171 | try:
172 | process = await asyncio.create_subprocess_exec(
173 | "git",
174 | "pull",
175 | cwd=plugin_path,
176 | stdout=asyncio.subprocess.PIPE,
177 | stderr=asyncio.subprocess.PIPE,
178 | )
179 | stdout, stderr = await process.communicate()
180 |
181 | if process.returncode != 0:
182 | return f"❌ 插件更新失败:{stderr.decode().strip()}"
183 |
184 | except FileNotFoundError:
185 | return "❌ 请确保已安装 Git 并配置到 PATH。"
186 |
187 | await install_dependencies(plugin_path)
188 |
189 | info = load_json_record()[plugin]
190 | commit = await get_plugin_commit(plugin)
191 | register_plugin(plugin, commit, info["name"], info["module"])
192 |
193 | return f"✅ 插件 {plugin} 更新成功!重启后生效"
194 |
195 |
196 | async def uninstall_plugin(plugin: str) -> str:
197 | """
198 | 卸载指定插件
199 | """
200 | plugin_path = PLUGIN_DIR / plugin
201 |
202 | if not plugin_path.exists():
203 | return f"❌ 插件 {plugin} 不存在!"
204 |
205 | logger.info(f"卸载插件: {plugin}")
206 |
207 | unregister_plugin(plugin)
208 |
209 | try:
210 | loop = asyncio.get_running_loop()
211 | await loop.run_in_executor(None, shutil.rmtree, plugin_path)
212 | except PermissionError:
213 | return f"❌ 插件 {plugin} 移除失败,请尝试手动移除"
214 |
215 | return f"✅ 插件 {plugin} 移除成功!重启后生效"
216 |
--------------------------------------------------------------------------------
/muicebot/builtin_plugins/thought_processor.py:
--------------------------------------------------------------------------------
1 | import re
2 | from dataclasses import dataclass
3 |
4 | from nonebot.adapters import Event
5 |
6 | from muicebot.config import plugin_config
7 | from muicebot.llm import ModelCompletions, ModelStreamCompletions
8 | from muicebot.models import Message
9 | from muicebot.plugin.hook import on_after_completion, on_finish_chat, on_stream_chunk
10 |
11 | _PROCESS_MODE = plugin_config.thought_process_mode
12 | _STREAM_PROCESS_STATE: dict[str, bool] = {}
13 | _PROCESSCACHES: dict[str, "ProcessCache"] = {}
14 |
15 |
16 | @dataclass
17 | class ProcessCache:
18 | thoughts: str = ""
19 | result: str = ""
20 |
21 |
22 | def general_processor(message: str) -> tuple[str, str]:
23 | thoughts_pattern = re.compile(r"(.*?)", re.DOTALL)
24 | match = thoughts_pattern.search(message)
25 | thoughts = match.group(1).replace("\n", "") if match else ""
26 | result = thoughts_pattern.sub("", message).strip()
27 | return thoughts, result
28 |
29 |
30 | @on_after_completion(priority=1, stream=False)
31 | def async_processor(completions: ModelCompletions, event: Event):
32 | session_id = event.get_session_id()
33 | thoughts, result = general_processor(completions.text)
34 | _PROCESSCACHES[session_id] = ProcessCache(thoughts, result)
35 |
36 | if _PROCESS_MODE == 0:
37 | return
38 | if _PROCESS_MODE == 2 or not thoughts:
39 | completions.text = result
40 | elif _PROCESS_MODE == 1:
41 | completions.text = f"思考过程: {thoughts}\n\n{result}"
42 |
43 |
44 | @on_stream_chunk(priority=1)
45 | def stream_processor(chunk: ModelStreamCompletions, event: Event):
46 | session_id = event.get_session_id()
47 | cache = _PROCESSCACHES.setdefault(session_id, ProcessCache())
48 | state = _STREAM_PROCESS_STATE
49 |
50 | # 思考过程中
51 | if "" in chunk.chunk:
52 | state[session_id] = True
53 | cache.thoughts += chunk.chunk.replace("", "")
54 | if _PROCESS_MODE == 1:
55 | chunk.chunk = chunk.chunk.replace("", "思考过程: ")
56 | elif _PROCESS_MODE == 2:
57 | chunk.chunk = ""
58 | return
59 |
60 | # 思考结束
61 | elif "" in chunk.chunk:
62 | del state[session_id]
63 | cache.result += chunk.chunk.replace("", "")
64 | if _PROCESS_MODE == 1:
65 | chunk.chunk = chunk.chunk.replace("", "\n\n")
66 | elif _PROCESS_MODE == 2:
67 | chunk.chunk = chunk.chunk.replace("", "")
68 | return
69 |
70 | # 思考过程中
71 | elif state.get(session_id, False):
72 | cache.thoughts += chunk.chunk
73 | if _PROCESS_MODE == 2:
74 | chunk.chunk = ""
75 |
76 | # 思考结果中
77 | else:
78 | cache.result += chunk.chunk
79 |
80 |
81 | @on_finish_chat(priority=1)
82 | def save_processor(message: Message, event: Event):
83 | session_id = event.get_session_id()
84 | cache = _PROCESSCACHES.pop(session_id, ProcessCache())
85 | message.respond = cache.result
86 |
--------------------------------------------------------------------------------
/muicebot/builtin_templates/Muice.jinja2:
--------------------------------------------------------------------------------
1 | {# BSD 3-Clause License
2 | Copyright (c) 2023~2025 Moemu(Muika)
3 | All rights reserved.
4 | Version: 2025.05.04 (MuiceBot) #}
5 |
6 | 现在你是一个名为{{ ai_nickname }}的AI女孩子,创造你的开发者是{{ master_nickname }}。你的人设需严格遵循以下规则:
7 | 1. 基本信息: 16岁;生日是7月16日;理想是和开发者一起传播爱与和平。
8 | 2. 性格特征: 积极乐观。即使面对复杂的情境,她也总能保持善良、天真之感。而且,她有着极高的耐心和同理心,在面对情感问题时还是会乐于倾听和安慰他人,能够与他人打成一片;但还是会偶尔调皮一下
9 | 3. 格式要求: 自然、简短、流畅,像正常聊天一样(通常是50字以内的一段话);生成的文本需严格遵循语法规则,严禁忽略标点;禁止生成多段话和换行;交流时尽量使用不相似的句式回复
10 | 4. 语言风格: 像现实中可爱的二次元女孩子一样和用户对话,多使用口语化短句和拟声词;不能以任何形式表现你是程序或AI,对话中不得掺杂任何计算机名词(除非用户询问计算机问题)
11 | 5. 特殊要求:
12 | 1) 你的职责是为用户解决实际问题,因此对于以下具体问题请给出实际的回答,严禁回避问题、不时宜的幽默甚至是拒绝回答;不能中途邀请用户去干别的事情;对于具体问题,生成文本时可不遵循格式要求
13 | 2) 对于情感问题,必须根据对方的当前的心情认真回答,像心理咨询师一样给出详细的建议或实用的应对方法。
14 | 3) 对于技术问题,必须像有用的人工智能助手一样给出详细的说明和解答。因为聊天平台不支持 Markdown 语法,因此请改用可读性较高的纯文本回答
15 |
16 | {% if private %}从现在开始,你将与你的社交媒体好友们聊天,请你根据对方可能的性格和聊天风格,生成迎合他们的回答。
17 | {% else %}从现在开始,你将与社交媒体的群友们聊天。你将收到 '<用户名> 消息内容' 的输入。你可能需要根据不同的用户偏好生成不同的回答,所参考的对话上下文跨度也可能相对较长。
18 | {% endif %}
19 |
20 | {% if user_info %} 目标对话用户({{ user_name }})信息: {{ user_info }} {% endif %}
--------------------------------------------------------------------------------
/muicebot/config.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import os
3 | import threading
4 | import time
5 | from pathlib import Path
6 | from typing import Callable, List, Literal, Optional
7 |
8 | import yaml as yaml_
9 | from nonebot import get_driver, get_plugin_config, logger
10 | from nonebot.config import Config
11 | from pydantic import BaseModel
12 | from watchdog.events import FileSystemEventHandler
13 | from watchdog.observers import Observer
14 |
15 | from .llm import ModelConfig
16 |
17 | MODELS_CONFIG_PATH = Path("configs/models.yml").resolve()
18 | SCHEDULES_CONFIG_PATH = Path("configs/schedules.yml").resolve()
19 | PLUGINS_CONFIG_PATH = Path("configs/plugins.yml").resolve()
20 |
21 |
22 | def load_yaml_config() -> dict:
23 | """
24 | 插件优先加载 YAML 配置,失败则返回空字典
25 | """
26 | try:
27 | with open(PLUGINS_CONFIG_PATH, "r", encoding="utf-8") as f:
28 | return yaml_.safe_load(f) or {}
29 | except (FileNotFoundError, yaml_.YAMLError):
30 | return {}
31 |
32 |
33 | driver = get_driver()
34 | yaml_config = load_yaml_config()
35 | env_config = driver.config.model_dump()
36 | final_config = {**env_config, **yaml_config} # 合并配置,yaml优先
37 | driver.config = Config(**final_config)
38 |
39 |
40 | class PluginConfig(BaseModel):
41 | log_level: str = "INFO"
42 | """日志等级"""
43 | muice_nicknames: list = ["muice"]
44 | """沐雪的自定义昵称,作为消息前缀条件响应信息事件"""
45 | telegram_proxy: str | None = None
46 | """telegram代理,这个配置项用于获取图片时使用"""
47 | enable_builtin_plugins: bool = True
48 | """启用内嵌插件"""
49 | max_history_epoch: int = 0
50 | """最大历史轮数"""
51 | enable_adapters: list = ["nonebot.adapters.onebot.v11", "nonebot.adapters.onebot.v12"]
52 | """启用的 Nonebot 适配器"""
53 | input_timeout: int = 0
54 | """输入等待时间"""
55 | default_template: Optional[str] = None
56 | """默认使用人设模板名称"""
57 | thought_process_mode: Literal[0, 1, 2] = 2
58 | """针对 Deepseek-R1 等思考模型的思考过程提取模式"""
59 |
60 |
61 | plugin_config = get_plugin_config(PluginConfig)
62 |
63 |
64 | class Schedule(BaseModel):
65 | id: str
66 | """调度器 ID"""
67 | trigger: Literal["cron", "interval"]
68 | """调度器类别"""
69 | ask: Optional[str] = None
70 | """向大语言模型询问的信息"""
71 | say: Optional[str] = None
72 | """直接输出的信息"""
73 | args: dict[str, int]
74 | """调度器参数"""
75 | target: str
76 | """目标id;若为群聊则为 group_id 或者 channel_id,若为私聊则为 user_id"""
77 | probability: int = 1
78 | """触发几率"""
79 |
80 |
81 | def get_schedule_configs() -> List[Schedule]:
82 | """
83 | 从配置文件 `configs/schedules.yml` 中获取所有调度器配置
84 |
85 | 如果没有该文件,返回空列表
86 | """
87 | if not os.path.isfile(SCHEDULES_CONFIG_PATH):
88 | return []
89 |
90 | with open(SCHEDULES_CONFIG_PATH, "r", encoding="utf-8") as f:
91 | configs = yaml_.safe_load(f)
92 |
93 | if not configs:
94 | return []
95 |
96 | schedule_configs = []
97 |
98 | for schedule_id, config in configs.items():
99 | config["id"] = schedule_id
100 | schedule_config = Schedule(**config)
101 | schedule_configs.append(schedule_config)
102 |
103 | return schedule_configs
104 |
105 |
106 | class ConfigFileHandler(FileSystemEventHandler):
107 | """配置文件变化处理器"""
108 |
109 | def __init__(self, callback: Callable):
110 | self.callback = callback
111 | self.last_modified = time.time()
112 | # 防止一次修改触发多次回调
113 | self.cooldown = 1 # 冷却时间(秒)
114 |
115 | def on_modified(self, event):
116 | if event.is_directory: # 检查是否是文件而不是目录
117 | return
118 |
119 | current_time = time.time()
120 | if current_time - self.last_modified > self.cooldown:
121 | self.last_modified = current_time
122 | self.callback()
123 |
124 |
125 | class ModelConfigManager:
126 | """模型配置管理器"""
127 |
128 | _instance: Optional["ModelConfigManager"] = None
129 | _lock = threading.Lock()
130 | _initialized: bool
131 | configs: dict[str, ModelConfig]
132 |
133 | def __new__(cls):
134 | """确保实例在单例模式下运行"""
135 | with cls._lock:
136 | if cls._instance is None:
137 | cls._instance = super(ModelConfigManager, cls).__new__(cls)
138 | cls._instance._initialized = False
139 | return cls._instance
140 |
141 | def __init__(self) -> None:
142 | if self._initialized:
143 | return
144 |
145 | self.configs = {}
146 | self.default_config = None
147 | self.observer = None
148 | self.listeners: List[Callable] = [] # 注册的监听器列表
149 | self._load_configs()
150 | self._start_file_watcher()
151 | self._initialized = True
152 |
153 | def _load_configs(self):
154 | """加载配置文件"""
155 | if not os.path.isfile(MODELS_CONFIG_PATH):
156 | raise FileNotFoundError("configs/models.yml 不存在!请先创建")
157 |
158 | with open(MODELS_CONFIG_PATH, "r", encoding="utf-8") as f:
159 | configs_dict = yaml_.safe_load(f)
160 |
161 | if not configs_dict:
162 | raise ValueError("configs/models.yml 为空,请先至少定义一个模型配置")
163 |
164 | self.configs = {}
165 | for name, config in configs_dict.items():
166 | self.configs[name] = ModelConfig(**config)
167 | # 未指定模板时,使用默认模板
168 | self.configs[name].template = self.configs[name].template or plugin_config.default_template
169 | if config.get("default"):
170 | self.default_config = self.configs[name]
171 |
172 | if not self.default_config and self.configs:
173 | # 如果没有指定默认配置,使用第一个
174 | self.default_config = next(iter(self.configs.values()))
175 |
176 | def _start_file_watcher(self):
177 | """启动文件监视器"""
178 | if self.observer is not None:
179 | self.observer.stop()
180 |
181 | self.observer = Observer()
182 | event_handler = ConfigFileHandler(self._on_config_changed)
183 | self.observer.schedule(event_handler, str(Path(MODELS_CONFIG_PATH).parent), recursive=False)
184 | self.observer.start()
185 |
186 | def _on_config_changed(self):
187 | """配置文件变化时的回调函数"""
188 | try:
189 | # old_configs = self.configs.copy()
190 | old_default = self.default_config
191 |
192 | self._load_configs()
193 |
194 | # 通知所有注册的监听器
195 | for listener in self.listeners:
196 | listener(self.default_config, old_default)
197 |
198 | except Exception as e:
199 | logger.error(f"重新加载配置文件失败: {e}")
200 |
201 | def register_listener(self, listener: Callable):
202 | """
203 | 注册配置变化监听器
204 |
205 | :param listener: 回调函数,接收两个参数:新的默认配置和旧的默认配置
206 | """
207 | if listener not in self.listeners:
208 | self.listeners.append(listener)
209 |
210 | def unregister_listener(self, listener: Callable):
211 | """取消注册配置变化监听器"""
212 | if listener in self.listeners:
213 | self.listeners.remove(listener)
214 |
215 | def get_model_config(self, model_config_name: Optional[str] = None) -> ModelConfig:
216 | """获取指定模型的配置"""
217 | if model_config_name in [None, ""]:
218 | if not self.default_config:
219 | raise ValueError("没有找到默认模型配置!请确保存在至少一个有效的配置项!")
220 | return self.default_config
221 |
222 | elif model_config_name in self.configs:
223 | return self.configs[model_config_name]
224 |
225 | else:
226 | logger.warning(f"指定的模型配置 '{model_config_name}' 不存在!")
227 | raise ValueError(f"指定的模型配置 '{model_config_name}' 不存在!")
228 |
229 | def stop_watcher(self):
230 | """停止文件监视器"""
231 | if self.observer is None:
232 | return
233 |
234 | self.observer.stop()
235 | self.observer.join()
236 |
237 |
238 | model_config_manager = ModelConfigManager()
239 | atexit.register(model_config_manager.stop_watcher)
240 |
241 |
242 | def get_model_config(model_config_name: Optional[str] = None) -> ModelConfig:
243 | """
244 | 从配置文件 `configs/models.yml` 中获取指定模型的配置文件
245 |
246 | :model_config_name: (可选)模型配置名称。若为空,则先寻找配置了 `default: true` 的首个配置项,若失败就再寻找首个配置项
247 | 若都不存在,则抛出 `FileNotFoundError`
248 | """
249 | return model_config_manager.get_model_config(model_config_name)
250 |
--------------------------------------------------------------------------------
/muicebot/database.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import os
4 | from datetime import datetime
5 | from typing import Optional, Tuple
6 |
7 | import aiosqlite
8 | import nonebot_plugin_localstore as store
9 | from nonebot import logger
10 |
11 | from .models import Message, Resource
12 | from .utils.migrations import MigrationManager
13 |
14 |
15 | class Database:
16 | def __init__(self) -> None:
17 | self.DB_PATH = store.get_plugin_data_dir().joinpath("ChatHistory.db").resolve()
18 | self.migrations = MigrationManager(self)
19 |
20 | asyncio.run(self.init_db())
21 |
22 | logger.info(f"数据库路径: {self.DB_PATH}")
23 |
24 | async def init_db(self) -> None:
25 | """初始化数据库,检查数据库是否存在,不存在则创建"""
26 | if not os.path.isfile(self.DB_PATH) or self.DB_PATH.stat().st_size == 0:
27 | logger.info("数据库不存在,正在创建...")
28 | await self.__create_database()
29 |
30 | await self.migrations.migrate_if_needed()
31 |
32 | def __connect(self) -> aiosqlite.Connection:
33 | return aiosqlite.connect(self.DB_PATH)
34 |
35 | async def execute(self, query: str, params=(), fetchone=False, fetchall=False) -> list | None:
36 | """
37 | 异步执行SQL查询,支持可选参数。
38 |
39 | :param query: 要执行的SQL查询语句
40 | :param params: 传递给查询的参数
41 | :param fetchone: 是否获取单个结果
42 | :param fetchall: 是否获取所有结果
43 | """
44 | async with self.__connect() as conn:
45 | conn.row_factory = aiosqlite.Row
46 | cursor = await conn.cursor()
47 | await cursor.execute(query, params)
48 | if fetchone:
49 | return await cursor.fetchone() # type: ignore
50 | if fetchall:
51 | rows = await cursor.fetchall()
52 | return [{k.lower(): v for k, v in zip(row.keys(), row)} for row in rows]
53 | await conn.commit()
54 |
55 | return None
56 |
57 | async def __create_database(self) -> None:
58 | """
59 | 创建一个新的信息表
60 | """
61 | await self.execute(
62 | """CREATE TABLE MSG(
63 | ID INTEGER PRIMARY KEY AUTOINCREMENT,
64 | TIME TEXT NOT NULL,
65 | USERID TEXT NOT NULL,
66 | GROUPID TEXT NOT NULL DEFAULT (-1),
67 | MESSAGE TEXT NOT NULL,
68 | RESPOND TEXT NOT NULL,
69 | HISTORY INTEGER NOT NULL DEFAULT (1),
70 | RESOURCES TEXT NOT NULL DEFAULT "[]",
71 | USAGE INTEGER NOT NULL DEFAULT (-1));"""
72 | )
73 | await self.execute(
74 | """
75 | CREATE TABLE schema_version (
76 | version INTEGER NOT NULL
77 | );"""
78 | )
79 | await self.execute("INSERT INTO schema_version (version) VALUES (?);", (str(self.migrations.latest_version)))
80 |
81 | def connect(self) -> aiosqlite.Connection:
82 | return aiosqlite.connect(self.DB_PATH)
83 |
84 | async def add_item(self, message: Message):
85 | """
86 | 将消息保存到数据库
87 | """
88 | resources_data = [r.to_dict() for r in message.resources]
89 | params = (
90 | message.time,
91 | message.userid,
92 | message.groupid,
93 | message.message,
94 | message.respond,
95 | json.dumps(resources_data, ensure_ascii=False),
96 | message.usage,
97 | )
98 | query = """INSERT INTO MSG (TIME, USERID, GROUPID, MESSAGE, RESPOND, RESOURCES, USAGE)
99 | VALUES (?, ?, ?, ?, ?, ?, ?)"""
100 | await self.execute(query, params)
101 |
102 | async def mark_history_as_unavailable(self, userid: str, limit: Optional[int] = None):
103 | """
104 | 将用户对话历史标记为不可用 (适用于 reset 命令)
105 |
106 | :param userid: 用户id
107 | :param limit: (可选)操作数量
108 | """
109 | if limit is not None:
110 | query = """UPDATE MSG SET HISTORY = 0 WHERE ROWID IN (
111 | SELECT ROWID FROM MSG WHERE USERID = ? AND HISTORY = 1 ORDER BY ROWID DESC LIMIT ?)"""
112 | await self.execute(query, (userid, limit))
113 | else:
114 | query = "UPDATE MSG SET HISTORY = 0 WHERE USERID = ?"
115 | await self.execute(query, (userid,))
116 |
117 | async def _deserialize_rows(self, rows: list) -> list[Message]:
118 | """
119 | 反序列化数据库返回结果
120 | """
121 | result = []
122 |
123 | for row in rows:
124 | data = dict(row)
125 |
126 | # 反序列化 resources
127 | resources = json.loads(data.get("resources", "[]"))
128 | data["resources"] = [Resource(**r) for r in resources] if resources else []
129 |
130 | result.append(Message(**data))
131 |
132 | result.reverse()
133 | return result
134 |
135 | async def get_user_history(self, userid: str, limit: int = 0) -> list[Message]:
136 | """
137 | 获取用户的所有对话历史,返回一个列表,无结果时返回None
138 |
139 | :param userid: 用户名
140 | :limit: (可选) 返回的最大长度,当该变量设为0时表示全部返回
141 | """
142 | if limit:
143 | query = f"SELECT * FROM MSG WHERE HISTORY = 1 AND USERID = ? ORDER BY ID DESC LIMIT {limit}"
144 | else:
145 | query = "SELECT * FROM MSG WHERE HISTORY = 1 AND USERID = ?"
146 | rows = await self.execute(query, (userid,), fetchall=True)
147 |
148 | result = await self._deserialize_rows(rows) if rows else []
149 |
150 | return result
151 |
152 | async def get_group_history(self, groupid: str, limit: int = 0) -> list[Message]:
153 | """
154 | 获取群组的所有对话历史,返回一个列表,无结果时返回None
155 |
156 | :groupid: 群组id
157 | :limit: (可选) 返回的最大长度,当该变量设为0时表示全部返回
158 | """
159 | if limit:
160 | query = f"SELECT * FROM MSG WHERE HISTORY = 1 AND GROUPID = ? ORDER BY ID DESC LIMIT {limit}"
161 | else:
162 | query = "SELECT * FROM MSG WHERE HISTORY = 1 AND GROUPID = ?"
163 | rows = await self.execute(query, (groupid,), fetchall=True)
164 |
165 | result = await self._deserialize_rows(rows) if rows else []
166 |
167 | return result
168 |
169 | async def get_model_usage(self) -> Tuple[int, int]:
170 | """
171 | 获取模型用量数据(今日用量,总用量)
172 |
173 | :return: today_usage, total_usage
174 | """
175 | today_str = datetime.now().strftime("%Y.%m.%d")
176 |
177 | # 查询总用量(排除 USAGE = -1)
178 | total_result = await self.execute("SELECT SUM(USAGE) FROM MSG WHERE USAGE != -1", fetchone=True)
179 | total_usage = total_result[0] if total_result and total_result[0] is not None else 0
180 |
181 | # 查询今日用量(按日期前缀匹配 TIME)
182 | today_result = await self.execute(
183 | "SELECT SUM(USAGE) FROM MSG WHERE USAGE != -1 AND TIME LIKE ?",
184 | (f"{today_str}%",),
185 | fetchone=True,
186 | )
187 | today_usage = today_result[0] if today_result and today_result[0] is not None else 0
188 |
189 | return today_usage, total_usage
190 |
191 | async def get_conv_count(self) -> Tuple[int, int]:
192 | """
193 | 获取对话次数(今日次数,总次数)
194 |
195 | :return: today_count, total_count
196 | """
197 | today_str = datetime.now().strftime("%Y.%m.%d")
198 |
199 | total_result = await self.execute("SELECT COUNT(*) FROM MSG WHERE USAGE != -1", fetchone=True)
200 | total_count = total_result[0] if total_result and total_result[0] is not None else 0
201 |
202 | today_result = await self.execute(
203 | "SELECT COUNT(*) FROM MSG WHERE USAGE != -1 AND TIME LIKE ?",
204 | (f"{today_str}%",),
205 | fetchone=True,
206 | )
207 | today_count = today_result[0] if today_result and today_result[0] is not None else 0
208 |
209 | return today_count, total_count
210 |
211 | async def remove_last_item(self, userid: str):
212 | """
213 | 删除用户的最新一条对话历史
214 |
215 | :userid: 用户id
216 | """
217 | query = "DELETE FROM MSG WHERE ID = (SELECT ID FROM MSG WHERE USERID = ? ORDER BY ID DESC LIMIT 1)"
218 | await self.execute(query, (userid,))
219 |
--------------------------------------------------------------------------------
/muicebot/llm/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base import BaseLLM
2 | from ._config import ModelConfig
3 | from ._dependencies import MODEL_DEPENDENCY_MAP, get_missing_dependencies
4 | from ._schema import ModelCompletions, ModelRequest, ModelStreamCompletions
5 | from .loader import load_model
6 | from .registry import get_llm_class, register
7 |
8 | __all__ = [
9 | "BaseLLM",
10 | "ModelConfig",
11 | "ModelRequest",
12 | "ModelCompletions",
13 | "ModelStreamCompletions",
14 | "MODEL_DEPENDENCY_MAP",
15 | "get_missing_dependencies",
16 | "register",
17 | "get_llm_class",
18 | "load_model",
19 | ]
20 |
--------------------------------------------------------------------------------
/muicebot/llm/_base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABCMeta, abstractmethod
4 | from typing import AsyncGenerator, Literal, Union, overload
5 |
6 | from ._config import ModelConfig
7 | from ._schema import ModelCompletions, ModelRequest, ModelStreamCompletions
8 |
9 |
10 | class BaseLLM(metaclass=ABCMeta):
11 | """
12 | 模型基类,所有模型加载器都必须继承于该类
13 |
14 | 推荐使用该基类中定义的方法构建模型加载器类,但无论如何都必须实现 `ask` 方法
15 | """
16 |
17 | def __init__(self, model_config: ModelConfig) -> None:
18 | """
19 | 统一在此处声明变量
20 | """
21 | self.config = model_config
22 | """模型配置"""
23 | self.is_running = False
24 | """模型状态"""
25 | self._total_tokens = -1
26 | """本次总请求(包括工具调用)使用的总token数。当此值设为-1时,表明此模型加载器不支持该功能"""
27 |
28 | def _require(self, *require_fields: str):
29 | """
30 | 通用校验方法:检查指定的配置项是否存在,不存在则抛出错误
31 |
32 | :param require_fields: 需要检查的字段名称(字符串)
33 | """
34 | missing_fields = [field for field in require_fields if not getattr(self.config, field, None)]
35 | if missing_fields:
36 | raise ValueError(f"对于 {self.config.loader} 以下配置是必需的: {', '.join(missing_fields)}")
37 |
38 | def _build_messages(self, request: "ModelRequest") -> list:
39 | """
40 | 构建对话上下文历史的函数
41 | """
42 | raise NotImplementedError
43 |
44 | def load(self) -> bool:
45 | """
46 | 加载模型(通常是耗时操作,在线模型如无需校验可直接返回 true)
47 |
48 | :return: 是否加载成功
49 | """
50 | self.is_running = True
51 | return True
52 |
53 | async def _ask_sync(self, messages: list) -> "ModelCompletions":
54 | """
55 | 同步模型调用
56 | """
57 | raise NotImplementedError
58 |
59 | def _ask_stream(self, messages: list) -> AsyncGenerator["ModelStreamCompletions", None]:
60 | """
61 | 流式输出
62 | """
63 | raise NotImplementedError
64 |
65 | @overload
66 | async def ask(self, request: "ModelRequest", *, stream: Literal[False] = False) -> "ModelCompletions": ...
67 |
68 | @overload
69 | async def ask(
70 | self, request: "ModelRequest", *, stream: Literal[True] = True
71 | ) -> AsyncGenerator["ModelStreamCompletions", None]: ...
72 |
73 | @abstractmethod
74 | async def ask(
75 | self, request: "ModelRequest", *, stream: bool = False
76 | ) -> Union["ModelCompletions", AsyncGenerator["ModelStreamCompletions", None]]:
77 | """
78 | 模型交互询问
79 |
80 | :param request: 模型调用请求体
81 | :param stream: 是否开启流式对话
82 |
83 | :return: 模型输出体
84 | """
85 | pass
86 |
--------------------------------------------------------------------------------
/muicebot/llm/_config.py:
--------------------------------------------------------------------------------
1 | from importlib.util import find_spec
2 | from typing import Any, List, Literal, Optional
3 |
4 | from pydantic import BaseModel, field_validator
5 |
6 |
7 | class ModelConfig(BaseModel):
8 | loader: str = ""
9 | """所使用加载器的名称,位于 llm 文件夹下,loader 开头必须大写"""
10 |
11 | template: Optional[str] = None
12 | """使用的人设模板名称"""
13 | template_mode: Literal["system", "user"] = "system"
14 | """模板嵌入模式: `system` 为嵌入到系统提示; `user` 为嵌入到用户提示中"""
15 |
16 | max_tokens: int = 4096
17 | """最大回复 Tokens """
18 | temperature: float = 0.75
19 | """模型的温度系数"""
20 | top_p: float = 0.95
21 | """模型的 top_p 系数"""
22 | top_k: float = 3
23 | """模型的 top_k 系数"""
24 | frequency_penalty: Optional[float] = None
25 | """模型的频率惩罚"""
26 | presence_penalty: Optional[float] = None
27 | """模型的存在惩罚"""
28 | repetition_penalty: Optional[float] = None
29 | """模型的重复惩罚"""
30 | stream: bool = False
31 | """是否使用流式输出"""
32 | online_search: bool = False
33 | """是否启用联网搜索(原生实现)"""
34 | function_call: bool = False
35 | """是否启用工具调用"""
36 | content_security: bool = False
37 | """是否启用内容安全"""
38 |
39 | model_path: str = ""
40 | """本地模型路径"""
41 | adapter_path: str = ""
42 | """基于 model_path 的微调模型或适配器路径"""
43 |
44 | model_name: str = ""
45 | """所要使用模型的名称"""
46 | api_key: str = ""
47 | """在线服务的 API KEY"""
48 | api_secret: str = ""
49 | """在线服务的 api secret """
50 | api_host: str = ""
51 | """自定义 API 地址"""
52 |
53 | extra_body: Optional[dict] = None
54 | """OpenAI 的 extra_body"""
55 | enable_thinking: Optional[bool] = None
56 | """Dashscope 的 enable_thinking"""
57 | thinking_budget: Optional[int] = None
58 | """Dashscope 的 thinking_budget"""
59 |
60 | multimodal: bool = False
61 | """是否为(或启用)多模态模型"""
62 | modalities: List[Literal["text", "audio", "image"]] = ["text"]
63 | """生成模态"""
64 | audio: Optional[Any] = None
65 | """多模态音频参数"""
66 |
67 | @field_validator("loader")
68 | @classmethod
69 | def check_model_loader(cls, loader: str) -> str:
70 | if not loader:
71 | raise ValueError("loader is required")
72 |
73 | loader = loader.lower()
74 |
75 | # Check if the specified loader exists
76 | module_path = f"muicebot.llm.providers.{loader}"
77 |
78 | # 使用 find_spec 仅检测模块是否存在,不实际导入
79 | if find_spec(module_path) is None:
80 | raise ValueError(f"指定的模型加载器 '{loader}' 不存在于 llm 目录中")
81 |
82 | return loader
83 |
--------------------------------------------------------------------------------
/muicebot/llm/_dependencies.py:
--------------------------------------------------------------------------------
1 | """
2 | 模型所需第三方库依赖检查
3 | """
4 |
5 | import importlib
6 |
7 | MODEL_DEPENDENCY_MAP = {
8 | "Azure": ["azure-ai-inference>=1.0.0b7"],
9 | "Dashscope": ["dashscope>=1.22.1"],
10 | "Gemini": ["google-genai>=1.8.0"],
11 | "Ollama": ["ollama>=0.4.7"],
12 | "Openai": ["openai>=1.64.0"],
13 | }
14 |
15 |
16 | def get_missing_dependencies(dependencies: list[str]) -> list[str]:
17 | """
18 | 获取遗失的依赖
19 | """
20 | missing = []
21 | for dep in dependencies:
22 | try:
23 | importlib.import_module(dep)
24 | except ImportError:
25 | missing.append(dep)
26 | return missing
27 |
--------------------------------------------------------------------------------
/muicebot/llm/_schema.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Optional
3 |
4 | from ..models import Message, Resource
5 |
6 |
7 | @dataclass
8 | class ModelRequest:
9 | """
10 | 模型调用请求
11 | """
12 |
13 | prompt: str
14 | history: List[Message] = field(default_factory=list)
15 | resources: List[Resource] = field(default_factory=list)
16 | tools: Optional[List[dict]] = field(default_factory=list)
17 | system: Optional[str] = None
18 |
19 |
20 | @dataclass
21 | class ModelCompletions:
22 | """
23 | 模型输出
24 | """
25 |
26 | text: str = ""
27 | usage: int = -1
28 | resources: List[Resource] = field(default_factory=list)
29 | succeed: Optional[bool] = True
30 |
31 |
32 | @dataclass
33 | class ModelStreamCompletions:
34 | """
35 | 模型流式输出
36 | """
37 |
38 | chunk: str = ""
39 | usage: int = -1
40 | resources: Optional[List[Resource]] = field(default_factory=list)
41 | succeed: Optional[bool] = True
42 |
--------------------------------------------------------------------------------
/muicebot/llm/loader.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | from ._base import BaseLLM
4 | from ._config import ModelConfig
5 | from .registry import get_llm_class
6 |
7 |
8 | def load_model(config: ModelConfig) -> BaseLLM:
9 | """
10 | 获得一个 LLM 实例
11 | """
12 | module_name = config.loader.lower()
13 | module_path = f"muicebot.llm.providers.{module_name}"
14 |
15 | # 延迟导入模型模块(只导一次)
16 | importlib.import_module(module_path)
17 |
18 | # 注册之后,直接取类使用
19 | LLMClass = get_llm_class(config.loader)
20 |
21 | return LLMClass(config)
22 |
--------------------------------------------------------------------------------
/muicebot/llm/providers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Moemu/MuiceBot/b1140ecc4eaf72825d4dce083a3705265149529c/muicebot/llm/providers/__init__.py
--------------------------------------------------------------------------------
/muicebot/llm/providers/azure.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import AsyncGenerator, List, Literal, Optional, Union, overload
4 |
5 | from azure.ai.inference.aio import ChatCompletionsClient
6 | from azure.ai.inference.models import (
7 | AssistantMessage,
8 | AudioContentItem,
9 | ChatCompletionsToolCall,
10 | ChatCompletionsToolDefinition,
11 | ChatRequestMessage,
12 | CompletionsFinishReason,
13 | ContentItem,
14 | FunctionCall,
15 | FunctionDefinition,
16 | ImageContentItem,
17 | ImageDetailLevel,
18 | ImageUrl,
19 | InputAudio,
20 | SystemMessage,
21 | TextContentItem,
22 | ToolMessage,
23 | UserMessage,
24 | )
25 | from azure.core.credentials import AzureKeyCredential
26 | from azure.core.exceptions import HttpResponseError
27 | from nonebot import logger
28 |
29 | from .. import (
30 | BaseLLM,
31 | ModelCompletions,
32 | ModelConfig,
33 | ModelRequest,
34 | ModelStreamCompletions,
35 | register,
36 | )
37 | from ..utils.tools import function_call_handler
38 |
39 |
40 | @register("azure")
41 | class Azure(BaseLLM):
42 | def __init__(self, model_config: ModelConfig) -> None:
43 | super().__init__(model_config)
44 | self._require("model_name")
45 | self.model_name = self.config.model_name
46 | self.max_tokens = self.config.max_tokens
47 | self.temperature = self.config.temperature
48 | self.top_p = self.config.top_p
49 | self.frequency_penalty = self.config.frequency_penalty
50 | self.presence_penalty = self.config.presence_penalty
51 | self.token = os.getenv("AZURE_API_KEY", self.config.api_key)
52 | self.endpoint = self.config.api_host if self.config.api_host else "https://models.inference.ai.azure.com"
53 |
54 | self._tools: List[ChatCompletionsToolDefinition] = []
55 |
56 | def __build_multi_messages(self, request: ModelRequest) -> UserMessage:
57 | """
58 | 构建多模态类型
59 |
60 | 此模型加载器支持的多模态类型: `audio` `image`
61 | """
62 | multi_content_items: List[ContentItem] = []
63 |
64 | for resource in request.resources:
65 | if resource.path is None:
66 | continue
67 | elif resource.type == "audio":
68 | multi_content_items.append(
69 | AudioContentItem(
70 | input_audio=InputAudio.load(audio_file=resource.path, audio_format=resource.path.split(".")[-1])
71 | )
72 | )
73 | elif resource.type == "image":
74 | multi_content_items.append(
75 | ImageContentItem(
76 | image_url=ImageUrl.load(
77 | image_file=resource.path,
78 | image_format=resource.path.split(".")[-1],
79 | detail=ImageDetailLevel.AUTO,
80 | )
81 | )
82 | )
83 |
84 | content = [TextContentItem(text=request.prompt)] + multi_content_items
85 |
86 | return UserMessage(content=content)
87 |
88 | def __build_tools_definition(self, tools: List[dict]) -> List[ChatCompletionsToolDefinition]:
89 | tool_definitions = []
90 |
91 | for tool in tools:
92 | tool_definition = ChatCompletionsToolDefinition(
93 | function=FunctionDefinition(
94 | name=tool["function"]["name"],
95 | description=tool["function"]["description"],
96 | parameters=tool["function"]["parameters"],
97 | )
98 | )
99 | tool_definitions.append(tool_definition)
100 |
101 | return tool_definitions
102 |
103 | def _build_messages(self, request: ModelRequest) -> List[ChatRequestMessage]:
104 | messages: List[ChatRequestMessage] = []
105 |
106 | if request.system:
107 | messages.append(SystemMessage(request.system))
108 |
109 | for msg in request.history:
110 | user_msg = (
111 | UserMessage(msg.message)
112 | if not msg.resources
113 | else self.__build_multi_messages(ModelRequest(msg.message, resources=msg.resources))
114 | )
115 | messages.append(user_msg)
116 | messages.append(AssistantMessage(msg.respond))
117 |
118 | user_message = UserMessage(request.prompt) if not request.resources else self.__build_multi_messages(request)
119 |
120 | messages.append(user_message)
121 |
122 | return messages
123 |
124 | def _tool_messages_precheck(self, tool_calls: Optional[List[ChatCompletionsToolCall]] = None) -> bool:
125 | if not (tool_calls and len(tool_calls) == 1):
126 | return False
127 |
128 | tool_call = tool_calls[0]
129 |
130 | if isinstance(tool_call, ChatCompletionsToolCall):
131 | return True
132 |
133 | return False
134 |
135 | async def _ask_sync(self, messages: List[ChatRequestMessage]) -> ModelCompletions:
136 | client = ChatCompletionsClient(endpoint=self.endpoint, credential=AzureKeyCredential(self.token))
137 |
138 | completions = ModelCompletions()
139 |
140 | try:
141 | response = await client.complete(
142 | messages=messages,
143 | model=self.model_name,
144 | max_tokens=self.max_tokens,
145 | temperature=self.temperature,
146 | top_p=self.top_p,
147 | frequency_penalty=self.frequency_penalty,
148 | presence_penalty=self.presence_penalty,
149 | stream=False,
150 | tools=self._tools,
151 | )
152 | finish_reason = response.choices[0].finish_reason
153 | self._total_tokens += response.usage.total_tokens
154 |
155 | if finish_reason == CompletionsFinishReason.STOPPED:
156 | completions.text = response.choices[0].message.content
157 |
158 | elif finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
159 | completions.succeed = False
160 | completions.text = "(模型内部错误: 被内容过滤器阻止)"
161 |
162 | elif finish_reason == CompletionsFinishReason.TOKEN_LIMIT_REACHED:
163 | completions.succeed = False
164 | completions.text = "(模型内部错误: 达到了最大 token 限制)"
165 |
166 | elif finish_reason == CompletionsFinishReason.TOOL_CALLS:
167 | tool_calls = response.choices[0].message.tool_calls
168 | messages.append(AssistantMessage(tool_calls=tool_calls))
169 | if (tool_calls is None) or (not self._tool_messages_precheck(tool_calls=tool_calls)):
170 | completions.succeed = False
171 | completions.text = "(模型内部错误: tool_calls 内容为空)"
172 | return completions
173 |
174 | tool_call = tool_calls[0]
175 | function_args = json.loads(tool_call.function.arguments.replace("'", '"'))
176 |
177 | function_return = await function_call_handler(tool_call.function.name, function_args)
178 |
179 | # Append the function call result fo the chat history
180 | messages.append(ToolMessage(tool_call_id=tool_call.id, content=function_return))
181 |
182 | return await self._ask_sync(messages)
183 |
184 | else:
185 | completions.succeed = False
186 | completions.text = "(模型内部错误: 达到了最大 token 限制)"
187 |
188 | except HttpResponseError as e:
189 | logger.error(f"模型响应失败: {e.status_code} ({e.reason})")
190 | logger.error(f"{e.message}")
191 | completions.succeed = False
192 | completions.text = f"模型响应失败: {e.status_code} ({e.reason})"
193 |
194 | finally:
195 | await client.close()
196 | completions.usage = self._total_tokens
197 | return completions
198 |
199 | async def _ask_stream(self, messages: List[ChatRequestMessage]) -> AsyncGenerator[ModelStreamCompletions, None]:
200 | client = ChatCompletionsClient(endpoint=self.endpoint, credential=AzureKeyCredential(self.token))
201 |
202 | try:
203 | response = await client.complete(
204 | messages=messages,
205 | model=self.model_name,
206 | max_tokens=self.max_tokens,
207 | temperature=self.temperature,
208 | top_p=self.top_p,
209 | frequency_penalty=self.frequency_penalty,
210 | presence_penalty=self.presence_penalty,
211 | stream=True,
212 | tools=self._tools,
213 | model_extras={"stream_options": {"include_usage": True}}, # 需要显式声明获取用量
214 | )
215 |
216 | tool_call_id: str = ""
217 | function_name: str = ""
218 | function_args: str = ""
219 |
220 | async for chunk in response:
221 | stream_completions = ModelStreamCompletions()
222 |
223 | if chunk.usage: # chunk.usage 只会在最后一个包中被提供,此时choices为空
224 | self._total_tokens += chunk.usage.total_tokens if chunk.usage else 0
225 | stream_completions.usage = self._total_tokens
226 |
227 | if not chunk.choices:
228 | yield stream_completions
229 | continue
230 |
231 | finish_reason = chunk.choices[0].finish_reason
232 |
233 | if chunk.choices and chunk.choices[0].get("delta", {}).get("content", ""):
234 | stream_completions.chunk = chunk["choices"][0]["delta"]["content"]
235 |
236 | elif chunk.choices[0].delta.tool_calls is not None:
237 | tool_call = chunk.choices[0].delta.tool_calls[0]
238 |
239 | if tool_call.function.name is not None:
240 | function_name = tool_call.function.name
241 | if tool_call.id is not None:
242 | tool_call_id = tool_call.id
243 | function_args += tool_call.function.arguments or ""
244 | continue
245 |
246 | elif finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
247 | stream_completions.succeed = False
248 | stream_completions.chunk = "(模型内部错误: 被内容过滤器阻止)"
249 |
250 | elif finish_reason == CompletionsFinishReason.TOKEN_LIMIT_REACHED:
251 | stream_completions.succeed = False
252 | stream_completions.chunk = "(模型内部错误: 达到了最大 token 限制)"
253 |
254 | elif finish_reason == CompletionsFinishReason.TOOL_CALLS:
255 | messages.append(
256 | AssistantMessage(
257 | tool_calls=[
258 | ChatCompletionsToolCall(
259 | id=tool_call_id, function=FunctionCall(name=function_name, arguments=function_args)
260 | )
261 | ]
262 | )
263 | )
264 |
265 | function_arg = json.loads(function_args.replace("'", '"'))
266 |
267 | function_return = await function_call_handler(function_name, function_arg)
268 |
269 | # Append the function call result fo the chat history
270 | messages.append(ToolMessage(tool_call_id=tool_call_id, content=function_return))
271 |
272 | async for content in self._ask_stream(messages):
273 | yield content
274 |
275 | return
276 |
277 | yield stream_completions
278 |
279 | except HttpResponseError as e:
280 | logger.error(f"模型响应失败: {e.status_code} ({e.reason})")
281 | logger.error(f"{e.message}")
282 | stream_completions = ModelStreamCompletions()
283 | stream_completions.chunk = f"模型响应失败: {e.status_code} ({e.reason})"
284 | stream_completions.succeed = False
285 | yield stream_completions
286 |
287 | finally:
288 | await client.close()
289 |
290 | @overload
291 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ...
292 |
293 | @overload
294 | async def ask(
295 | self, request: ModelRequest, *, stream: Literal[True] = True
296 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ...
297 |
298 | async def ask(
299 | self, request: ModelRequest, *, stream: bool = False
300 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]:
301 | self._total_tokens = 0
302 |
303 | messages = self._build_messages(request)
304 |
305 | self._tools = self.__build_tools_definition(request.tools) if request.tools else []
306 |
307 | if stream:
308 | return self._ask_stream(messages)
309 |
310 | return await self._ask_sync(messages)
311 |
--------------------------------------------------------------------------------
/muicebot/llm/providers/dashscope.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from dataclasses import dataclass
4 | from functools import partial
5 | from pathlib import Path
6 | from typing import (
7 | AsyncGenerator,
8 | Generator,
9 | List,
10 | Literal,
11 | Union,
12 | overload,
13 | )
14 |
15 | import dashscope
16 | from dashscope.api_entities.dashscope_response import (
17 | GenerationResponse,
18 | MultiModalConversationResponse,
19 | )
20 | from nonebot import logger
21 |
22 | from .. import (
23 | BaseLLM,
24 | ModelCompletions,
25 | ModelConfig,
26 | ModelRequest,
27 | ModelStreamCompletions,
28 | register,
29 | )
30 | from ..utils.tools import function_call_handler
31 |
32 |
33 | @dataclass
34 | class FunctionCallStream:
35 | enable: bool = False
36 | id: str = ""
37 | function_name: str = ""
38 | function_args: str = ""
39 |
40 | def from_chunk(self, chunk: GenerationResponse | MultiModalConversationResponse):
41 | tool_calls = chunk.output.choices[0].message.tool_calls
42 | tool_call = tool_calls[0]
43 |
44 | if tool_call.get("id", ""):
45 | self.id = tool_call["id"]
46 |
47 | if tool_call.get("function", {}).get("name", ""):
48 | self.function_name = tool_call.get("function").get("name")
49 |
50 | function_arg = tool_call.get("function", {}).get("arguments", "")
51 |
52 | if function_arg and self.function_args != function_arg:
53 | self.function_args += function_arg
54 |
55 | self.enable = True
56 |
57 |
58 | class ThoughtStream:
59 | def __init__(self):
60 | self.is_insert_think_label: bool = False
61 |
62 | def process_chunk(self, chunk: GenerationResponse | MultiModalConversationResponse) -> str:
63 | choice = chunk.output.choices[0].message
64 | answer_content = choice.content
65 | reasoning_content = choice.get("reasoning_content", "")
66 | reasoning_content = reasoning_content.replace("\n", "") if reasoning_content else ""
67 |
68 | # 处理模型可能输出的 reasoning(思考内容)
69 | if reasoning_content:
70 | if not self.is_insert_think_label:
71 | self.is_insert_think_label = True
72 | return f"{reasoning_content}"
73 | else:
74 | return reasoning_content
75 |
76 | if not answer_content:
77 | answer_content = ""
78 |
79 | if isinstance(answer_content, list):
80 | answer_content = answer_content[0].get("text", "")
81 |
82 | if self.is_insert_think_label:
83 | self.is_insert_think_label = False
84 | return f"{answer_content}"
85 |
86 | return answer_content
87 |
88 |
89 | @register("dashscope")
90 | class Dashscope(BaseLLM):
91 | def __init__(self, model_config: ModelConfig) -> None:
92 | super().__init__(model_config)
93 | self._require("api_key", "model_name")
94 | self.api_key = self.config.api_key
95 | self.model = self.config.model_name
96 | self.max_tokens = self.config.max_tokens
97 | self.temperature = self.config.temperature
98 | self.top_p = self.config.top_p
99 | self.repetition_penalty = self.config.repetition_penalty
100 | self.enable_search = self.config.online_search
101 | self.enable_thinking = self.config.enable_thinking
102 | self.thinking_budget = self.config.thinking_budget
103 |
104 | self._tools: List[dict] = []
105 | self._last_call_total_tokens = 0
106 |
107 | self.extra_headers = (
108 | {"X-DashScope-DataInspection": '{"input":"cip","output":"cip"}'} if self.config.content_security else {}
109 | )
110 |
111 | self.stream = False
112 |
113 | def __build_multi_messages(self, request: ModelRequest) -> dict:
114 | """
115 | 构建多模态类型
116 |
117 | 此模型加载器支持的多模态类型: `audio` `image`
118 | """
119 | multi_contents: List[dict[str, str]] = []
120 |
121 | for item in request.resources:
122 | if item.type == "audio":
123 | if not (item.path.startswith("http") or item.path.startswith("file:")):
124 | item.url = str(Path(item.path).resolve())
125 |
126 | multi_contents.append({"audio": item.path})
127 |
128 | elif item.type == "image":
129 | if not (item.path.startswith("http") or item.path.startswith("file:")):
130 | item.url = str(Path(item.path).resolve())
131 |
132 | multi_contents.append({"image": item.path})
133 |
134 | user_content = [image_content for image_content in multi_contents]
135 |
136 | if not request.prompt:
137 | request.prompt = "请描述图像内容"
138 | user_content.append({"text": request.prompt})
139 |
140 | return {"role": "user", "content": user_content}
141 |
142 | def _build_messages(self, request: ModelRequest) -> list:
143 | messages = []
144 |
145 | if request.system:
146 | messages.append({"role": "system", "content": request.system})
147 |
148 | for msg in request.history:
149 | user_msg = (
150 | self.__build_multi_messages(ModelRequest(msg.message, resources=msg.resources))
151 | if all((self.config.multimodal, msg.resources))
152 | else {"role": "user", "content": msg.message}
153 | )
154 | messages.append(user_msg)
155 | messages.append({"role": "assistant", "content": msg.respond})
156 |
157 | user_msg = (
158 | {"role": "user", "content": request.prompt}
159 | if not request.resources
160 | else self.__build_multi_messages(ModelRequest(request.prompt, resources=request.resources))
161 | )
162 |
163 | messages.append(user_msg)
164 |
165 | return messages
166 |
167 | async def _GenerationResponse_handle(
168 | self, messages: list, response: GenerationResponse | MultiModalConversationResponse
169 | ) -> ModelCompletions:
170 | completions = ModelCompletions()
171 |
172 | if response.status_code != 200:
173 | completions.succeed = False
174 | logger.error(f"模型调用失败: {response.status_code}({response.code})")
175 | logger.error(f"{response.message}")
176 | completions.text = f"模型调用失败: {response.status_code}({response.code})"
177 | return completions
178 |
179 | self._total_tokens += int(response.usage.total_tokens)
180 | completions.usage = self._total_tokens
181 |
182 | if response.output.text:
183 | completions.text = response.output.text
184 | return completions
185 |
186 | message_content = response.output.choices[0].message.content
187 | if message_content:
188 | completions.text = message_content if isinstance(message_content, str) else message_content[0].get("text")
189 | return completions
190 |
191 | return await self._tool_calls_handle_sync(messages, response)
192 |
193 | async def _Generator_handle(
194 | self,
195 | messages: list,
196 | response: Generator[GenerationResponse, None, None] | Generator[MultiModalConversationResponse, None, None],
197 | ) -> AsyncGenerator[ModelStreamCompletions, None]:
198 | func_stream = FunctionCallStream()
199 | thought_stream = ThoughtStream()
200 |
201 | for chunk in response:
202 | logger.debug(chunk)
203 | stream_completions = ModelStreamCompletions()
204 |
205 | if chunk.status_code != 200:
206 | logger.error(f"模型调用失败: {chunk.status_code}({chunk.code})")
207 | logger.error(f"{chunk.message}")
208 | stream_completions.chunk = f"模型调用失败: {chunk.status_code}({chunk.code})"
209 | stream_completions.succeed = False
210 |
211 | yield stream_completions
212 | return
213 |
214 | # 更新 token 消耗
215 | current_call_total = chunk.usage.total_tokens
216 | delta = current_call_total - self._last_call_total_tokens
217 | if delta < 0:
218 | delta = current_call_total
219 | self._total_tokens += delta
220 | self._last_call_total_tokens = current_call_total
221 | stream_completions.usage = self._total_tokens
222 |
223 | # 优先判断是否是工具调用(OpenAI-style function calling)
224 | if chunk.output.choices and chunk.output.choices[0].message.get("tool_calls", []):
225 | func_stream.from_chunk(chunk)
226 | # 工具调用也可能在输出文本之后发生
227 |
228 | # DashScope 的 text 模式(非标准接口)
229 | if hasattr(chunk.output, "text") and chunk.output.text:
230 | stream_completions.chunk = chunk.output.text
231 | yield stream_completions
232 | continue
233 |
234 | if chunk.output.choices is None:
235 | continue
236 |
237 | stream_completions.chunk = thought_stream.process_chunk(chunk)
238 | yield stream_completions
239 |
240 | # 流式处理工具调用响应
241 | if func_stream.enable:
242 | self._last_call_total_tokens = 0
243 | async for final_chunk in await self._tool_calls_handle_stream(messages, func_stream):
244 | yield final_chunk
245 |
246 | async def _tool_calls_handle_sync(
247 | self, messages: List, response: GenerationResponse | MultiModalConversationResponse
248 | ) -> ModelCompletions:
249 | tool_call = response.output.choices[0].message.tool_calls[0]
250 | tool_call_id = tool_call["id"]
251 | function_name = tool_call["function"]["name"]
252 | function_args = json.loads(tool_call["function"]["arguments"])
253 |
254 | function_return = await function_call_handler(function_name, function_args)
255 |
256 | messages.append(response.output.choices[0].message)
257 | messages.append({"role": "tool", "content": function_return, "tool_call_id": tool_call_id})
258 |
259 | return await self._ask(messages) # type:ignore
260 |
261 | async def _tool_calls_handle_stream(
262 | self, messages: List, func_stream: FunctionCallStream
263 | ) -> AsyncGenerator[ModelStreamCompletions, None]:
264 | function_args = json.loads(func_stream.function_args)
265 |
266 | function_return = await function_call_handler(func_stream.function_name, function_args) # type:ignore
267 |
268 | messages.append(
269 | {
270 | "role": "assistant",
271 | "content": "",
272 | "tool_calls": [
273 | {
274 | "id": func_stream.id,
275 | "function": {
276 | "arguments": func_stream.function_args,
277 | "name": func_stream.function_name,
278 | },
279 | "type": "function",
280 | "index": 0,
281 | }
282 | ],
283 | }
284 | )
285 | messages.append({"role": "tool", "content": function_return, "tool_call_id": func_stream.id})
286 |
287 | return await self._ask(messages) # type:ignore
288 |
289 | async def _ask(self, messages: list) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]:
290 | loop = asyncio.get_event_loop()
291 |
292 | # 因为 Dashscope 对于多模态模型的接口不同,所以这里不能统一函数
293 | if not self.config.multimodal:
294 | response = await loop.run_in_executor(
295 | None,
296 | partial(
297 | dashscope.Generation.call,
298 | api_key=self.api_key,
299 | model=self.model,
300 | messages=messages,
301 | max_tokens=self.max_tokens,
302 | temperature=self.temperature,
303 | top_p=self.top_p,
304 | repetition_penalty=self.repetition_penalty,
305 | stream=self.stream,
306 | tools=self._tools,
307 | parallel_tool_calls=True,
308 | enable_search=self.enable_search,
309 | incremental_output=self.stream, # 给他调成一样的:这个参数只支持流式调用时设置为True
310 | headers=self.extra_headers,
311 | enable_thinking=self.enable_thinking,
312 | thinking_budget=self.thinking_budget,
313 | ),
314 | )
315 | else:
316 | response = await loop.run_in_executor(
317 | None,
318 | partial(
319 | dashscope.MultiModalConversation.call,
320 | api_key=self.api_key,
321 | model=self.model,
322 | messages=messages,
323 | max_tokens=self.max_tokens,
324 | temperature=self.temperature,
325 | top_p=self.top_p,
326 | repetition_penalty=self.repetition_penalty,
327 | stream=self.stream,
328 | tools=self._tools,
329 | parallel_tool_calls=True,
330 | enable_search=self.enable_search,
331 | incremental_output=self.stream,
332 | ),
333 | )
334 |
335 | if isinstance(response, GenerationResponse) or isinstance(response, MultiModalConversationResponse):
336 | return await self._GenerationResponse_handle(messages, response)
337 | return self._Generator_handle(messages, response)
338 |
339 | @overload
340 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ...
341 |
342 | @overload
343 | async def ask(
344 | self, request: ModelRequest, *, stream: Literal[True] = True
345 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ...
346 |
347 | async def ask(
348 | self, request: ModelRequest, *, stream: bool = False
349 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]:
350 | self._total_tokens = 0
351 | self._last_call_total_tokens = 0
352 | self.stream = stream if stream is not None else False
353 |
354 | self._tools = request.tools if request.tools else []
355 | messages = self._build_messages(request)
356 |
357 | return await self._ask(messages)
358 |
--------------------------------------------------------------------------------
/muicebot/llm/providers/gemini.py:
--------------------------------------------------------------------------------
1 | from typing import AsyncGenerator, List, Literal, Optional, Union, overload
2 |
3 | from google import genai
4 | from google.genai import errors
5 | from google.genai.types import (
6 | Content,
7 | ContentOrDict,
8 | GenerateContentConfig,
9 | GoogleSearch,
10 | HarmBlockThreshold,
11 | HarmCategory,
12 | Part,
13 | SafetySetting,
14 | Tool,
15 | )
16 | from httpx import ConnectError
17 | from nonebot import logger
18 |
19 | from muicebot.models import Resource
20 |
21 | from .. import (
22 | BaseLLM,
23 | ModelCompletions,
24 | ModelConfig,
25 | ModelRequest,
26 | ModelStreamCompletions,
27 | register,
28 | )
29 | from ..utils.images import get_file_base64
30 | from ..utils.tools import function_call_handler
31 |
32 |
33 | @register("gemini")
34 | class Gemini(BaseLLM):
35 | def __init__(self, model_config: ModelConfig) -> None:
36 | super().__init__(model_config)
37 | self._require("model_name", "api_key")
38 |
39 | self.model_name = self.config.model_name
40 | self.api_key = self.config.api_key
41 | self.enable_search = self.config.online_search
42 |
43 | self.client = genai.Client(api_key=self.api_key)
44 |
45 | self.gemini_config = GenerateContentConfig(
46 | temperature=self.config.temperature,
47 | top_p=self.config.top_p,
48 | top_k=self.config.top_k,
49 | max_output_tokens=self.config.max_tokens,
50 | presence_penalty=self.config.presence_penalty,
51 | frequency_penalty=self.config.frequency_penalty,
52 | response_modalities=[m.upper() for m in self.config.modalities if m in {"image", "text"}],
53 | safety_settings=(
54 | [
55 | SafetySetting(
56 | category=HarmCategory.HARM_CATEGORY_HARASSMENT,
57 | threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
58 | ),
59 | SafetySetting(
60 | category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
61 | threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
62 | ),
63 | SafetySetting(
64 | category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
65 | threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
66 | ),
67 | SafetySetting(
68 | category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
69 | threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
70 | ),
71 | ]
72 | if self.config.content_security
73 | else []
74 | ),
75 | )
76 |
77 | self.model = self.client.chats.create(model=self.model_name, config=self.gemini_config)
78 |
79 | def __build_user_parts(self, request: ModelRequest) -> list[Part]:
80 | user_parts: list[Part] = [Part.from_text(text=request.prompt)]
81 |
82 | if not request.resources:
83 | return user_parts
84 |
85 | for resource in request.resources:
86 | if resource.type == "image" and resource.path is not None:
87 | resource.ensure_mimetype()
88 | user_parts.append(
89 | Part.from_bytes(
90 | data=get_file_base64(resource.path), mime_type=resource.mimetype or "image/jpeg" # type:ignore
91 | )
92 | )
93 |
94 | return user_parts
95 |
96 | def __build_tools_list(self, tools: Optional[List] = []):
97 | format_tools = []
98 |
99 | for tool in tools if tools else []:
100 | tool = tool["function"]
101 | required_parameters = tool["required"]
102 | del tool["required"]
103 | tool["parameters"]["required"] = required_parameters
104 | format_tools.append(tool)
105 |
106 | function_tools = Tool(function_declarations=format_tools) # type:ignore
107 |
108 | if self.enable_search:
109 | function_tools.google_search = GoogleSearch()
110 |
111 | if tools or self.enable_search:
112 | self.gemini_config.tools = [function_tools]
113 |
114 | def _build_messages(self, request: ModelRequest) -> list[ContentOrDict]:
115 | messages: List[ContentOrDict] = []
116 |
117 | if request.history:
118 | for index, item in enumerate(request.history):
119 | messages.append(
120 | Content(
121 | role="user", parts=self.__build_user_parts(ModelRequest(item.message, resources=item.resources))
122 | )
123 | )
124 | messages.append(Content(role="model", parts=[Part.from_text(text=item.respond)]))
125 |
126 | messages.append(Content(role="user", parts=self.__build_user_parts(request)))
127 |
128 | return messages
129 |
130 | async def _ask_sync(self, messages: list[ContentOrDict], **kwargs) -> ModelCompletions:
131 | completions = ModelCompletions()
132 |
133 | try:
134 | chat = self.client.aio.chats.create(model=self.model_name, config=self.gemini_config, history=messages[:-1])
135 | message = messages[-1].parts # type:ignore
136 | response = await chat.send_message(message=message) # type:ignore
137 | if response.usage_metadata:
138 | total_token_count = response.usage_metadata.total_token_count
139 | self._total_tokens += total_token_count if total_token_count else -1
140 |
141 | if response.text:
142 | completions.text = response.text
143 |
144 | if (
145 | response.candidates
146 | and response.candidates[0].content
147 | and response.candidates[0].content.parts
148 | and response.candidates[0].content.parts[0].inline_data
149 | and response.candidates[0].content.parts[0].inline_data.data
150 | ):
151 | completions.resources = [
152 | Resource(type="image", raw=response.candidates[0].content.parts[0].inline_data.data)
153 | ]
154 |
155 | if response.function_calls:
156 | function_call = response.function_calls[0]
157 | function_name = function_call.name
158 | function_args = function_call.args
159 |
160 | function_return = await function_call_handler(function_name, function_args) # type:ignore
161 |
162 | function_response_part = Part.from_function_response(
163 | name=function_name, # type:ignore
164 | response={"result": function_return},
165 | )
166 |
167 | messages.append(Content(role="model", parts=[Part(function_call=function_call)]))
168 | messages.append(Content(role="user", parts=[function_response_part]))
169 |
170 | return await self._ask_sync(messages)
171 |
172 | completions.text = completions.text or "(警告:模型无输出!)"
173 | completions.usage = self._total_tokens
174 | return completions
175 |
176 | except errors.APIError as e:
177 | error_message = f"API 状态异常: {e.code}({e.response})"
178 | completions.text = error_message
179 | completions.succeed = False
180 | logger.error(error_message)
181 | logger.error(e.message)
182 | return completions
183 |
184 | except ConnectError:
185 | error_message = "模型加载器连接超时"
186 | completions.text = error_message
187 | completions.succeed = False
188 | logger.error(error_message)
189 | return completions
190 |
191 | async def _ask_stream(self, messages: list, **kwargs) -> AsyncGenerator[ModelStreamCompletions, None]:
192 | try:
193 | total_tokens = 0
194 | stream = await self.client.aio.models.generate_content_stream(
195 | model=self.model_name, contents=messages, config=self.gemini_config
196 | )
197 | async for chunk in stream:
198 | stream_completions = ModelStreamCompletions()
199 |
200 | if chunk.text:
201 | stream_completions.chunk = chunk.text
202 | yield stream_completions
203 |
204 | if chunk.usage_metadata and chunk.usage_metadata.total_token_count:
205 | total_tokens = chunk.usage_metadata.total_token_count
206 |
207 | if (
208 | chunk.candidates
209 | and chunk.candidates[0].content
210 | and chunk.candidates[0].content.parts
211 | and chunk.candidates[0].content.parts[0].inline_data
212 | and chunk.candidates[0].content.parts[0].inline_data.data
213 | ):
214 | stream_completions.resources = [
215 | Resource(type="image", raw=chunk.candidates[0].content.parts[0].inline_data.data)
216 | ]
217 | yield stream_completions
218 |
219 | if chunk.function_calls:
220 | function_call = chunk.function_calls[0]
221 | function_name = function_call.name
222 | function_args = function_call.args
223 |
224 | function_return = await function_call_handler(function_name, function_args) # type:ignore
225 |
226 | function_response_part = Part.from_function_response(
227 | name=function_name, # type:ignore
228 | response={"result": function_return},
229 | )
230 |
231 | messages.append(Content(role="model", parts=[Part(function_call=function_call)]))
232 | messages.append(Content(role="user", parts=[function_response_part]))
233 |
234 | async for final_chunk in self._ask_stream(messages):
235 | yield final_chunk
236 |
237 | totaltokens_completions = ModelStreamCompletions()
238 |
239 | self._total_tokens += total_tokens
240 | totaltokens_completions.usage = self._total_tokens
241 | yield totaltokens_completions
242 |
243 | except errors.APIError as e:
244 | stream_completions = ModelStreamCompletions()
245 | error_message = f"API 状态异常: {e.code}({e.response})"
246 | stream_completions.chunk = error_message
247 | logger.error(error_message)
248 | logger.error(e.message)
249 | stream_completions.succeed = False
250 | yield stream_completions
251 | return
252 |
253 | except ConnectError:
254 | stream_completions = ModelStreamCompletions()
255 | error_message = "模型加载器连接超时"
256 | stream_completions.chunk = error_message
257 | logger.error(error_message)
258 | stream_completions.succeed = False
259 | yield stream_completions
260 | return
261 |
262 | @overload
263 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ...
264 |
265 | @overload
266 | async def ask(
267 | self, request: ModelRequest, *, stream: Literal[True] = True
268 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ...
269 |
270 | async def ask(
271 | self, request: ModelRequest, *, stream: bool = False
272 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]:
273 | self._total_tokens = 0
274 | self.__build_tools_list(request.tools)
275 | self.gemini_config.system_instruction = request.system
276 |
277 | messages = self._build_messages(request)
278 |
279 | if stream:
280 | return self._ask_stream(messages)
281 |
282 | return await self._ask_sync(messages)
283 |
--------------------------------------------------------------------------------
/muicebot/llm/providers/ollama.py:
--------------------------------------------------------------------------------
1 | from typing import AsyncGenerator, List, Literal, Union, overload
2 |
3 | import ollama
4 | from nonebot import logger
5 | from ollama import ResponseError
6 |
7 | from .. import (
8 | BaseLLM,
9 | ModelCompletions,
10 | ModelConfig,
11 | ModelRequest,
12 | ModelStreamCompletions,
13 | register,
14 | )
15 | from ..utils.images import get_file_base64
16 | from ..utils.tools import function_call_handler
17 |
18 |
19 | @register("ollama")
20 | class Ollama(BaseLLM):
21 | """
22 | 使用 Ollama 模型服务调用模型
23 | """
24 |
25 | def __init__(self, model_config: ModelConfig) -> None:
26 | super().__init__(model_config)
27 | self._require("model_name")
28 | self.model = self.config.model_name
29 | self.host = self.config.api_host if self.config.api_host else "http://localhost:11434"
30 | self.top_k = self.config.top_k
31 | self.top_p = self.config.top_p
32 | self.temperature = self.config.temperature
33 | self.repeat_penalty = self.config.repetition_penalty or 1
34 | self.presence_penalty = self.config.presence_penalty or 0
35 | self.frequency_penalty = self.config.frequency_penalty or 1
36 | self.stream = self.config.stream
37 |
38 | self._tools: List[dict] = []
39 |
40 | def load(self) -> bool:
41 | try:
42 | self.client = ollama.AsyncClient(host=self.host)
43 | self.is_running = True
44 | except ResponseError as e:
45 | logger.error(f"加载 Ollama 加载器时发生错误: {e}")
46 | except ConnectionError as e:
47 | logger.error(f"加载 Ollama 加载器时发生错误: {e}")
48 | finally:
49 | return self.is_running
50 |
51 | def __build_multi_messages(self, request: ModelRequest) -> dict:
52 | """
53 | 构建多模态类型
54 |
55 | 当前模型加载器支持的多模态类型: `image`
56 | """
57 | images = []
58 |
59 | for resource in request.resources:
60 | if resource.path is None:
61 | continue
62 | image_base64 = get_file_base64(local_path=resource.path)
63 | images.append(image_base64)
64 |
65 | message = {"role": "user", "content": request.prompt, "images": images}
66 |
67 | return message
68 |
69 | def _build_messages(self, request: ModelRequest) -> list:
70 | messages = []
71 |
72 | if request.system:
73 | messages.append({"role": "system", "content": request.system})
74 |
75 | for index, item in enumerate(request.history):
76 | messages.append(self.__build_multi_messages(ModelRequest(item.message, resources=item.resources)))
77 | messages.append({"role": "assistant", "content": item.respond})
78 |
79 | message = self.__build_multi_messages(request)
80 |
81 | messages.append(message)
82 |
83 | return messages
84 |
85 | async def _ask_sync(self, messages: list) -> ModelCompletions:
86 | completions = ModelCompletions()
87 |
88 | try:
89 | response = await self.client.chat(
90 | model=self.model,
91 | messages=messages,
92 | tools=self._tools,
93 | stream=False,
94 | options={
95 | "temperature": self.temperature,
96 | "top_k": self.top_k,
97 | "top_p": self.top_p,
98 | "repeat_penalty": self.repeat_penalty,
99 | "presence_penalty": self.presence_penalty,
100 | "frequency_penalty": self.frequency_penalty,
101 | },
102 | )
103 |
104 | tool_calls = response.message.tool_calls
105 |
106 | if not tool_calls:
107 | completions.text = response.message.content or "(警告:模型无返回)"
108 | return completions
109 |
110 | for tool in tool_calls:
111 | function_name = tool.function.name
112 | function_args = tool.function.arguments
113 |
114 | function_return = await function_call_handler(function_name, dict(function_args))
115 |
116 | messages.append(response.message)
117 | messages.append({"role": "tool", "content": str(function_return), "name": tool.function.name})
118 | return await self._ask_sync(messages)
119 |
120 | completions.text = "模型调用错误:未知错误"
121 | completions.succeed = False
122 | return completions
123 |
124 | except ollama.ResponseError as e:
125 | error_info = f"模型调用错误: {e.error}"
126 | logger.error(error_info)
127 | completions.succeed = False
128 | completions.text = error_info
129 | return completions
130 |
131 | async def _ask_stream(self, messages: list) -> AsyncGenerator[ModelStreamCompletions, None]:
132 | try:
133 | response = await self.client.chat(
134 | model=self.model,
135 | messages=messages,
136 | tools=self._tools,
137 | stream=True,
138 | options={
139 | "temperature": self.temperature,
140 | "top_k": self.top_k,
141 | "top_p": self.top_p,
142 | "repeat_penalty": self.repeat_penalty,
143 | "presence_penalty": self.presence_penalty,
144 | "frequency_penalty": self.frequency_penalty,
145 | },
146 | )
147 |
148 | async for chunk in response:
149 | stream_completions = ModelStreamCompletions()
150 |
151 | tool_calls = chunk.message.tool_calls
152 |
153 | if chunk.message.content:
154 | stream_completions.chunk = chunk.message.content
155 | yield stream_completions
156 | continue
157 |
158 | if not tool_calls:
159 | continue
160 |
161 | for tool in tool_calls: # type:ignore
162 | function_name = tool.function.name
163 | function_args = tool.function.arguments
164 |
165 | function_return = await function_call_handler(function_name, dict(function_args))
166 |
167 | messages.append(chunk.message) # type:ignore
168 | messages.append({"role": "tool", "content": str(function_return), "name": tool.function.name})
169 |
170 | async for content in self._ask_stream(messages):
171 | yield content
172 |
173 | except ollama.ResponseError as e:
174 | stream_completions = ModelStreamCompletions()
175 | error_info = f"模型调用错误: {e.error}"
176 | logger.error(error_info)
177 | stream_completions.chunk = error_info
178 | stream_completions.succeed = False
179 | yield stream_completions
180 | return
181 |
182 | @overload
183 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ...
184 |
185 | @overload
186 | async def ask(
187 | self, request: ModelRequest, *, stream: Literal[True] = True
188 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ...
189 |
190 | async def ask(
191 | self, request: ModelRequest, *, stream: bool = False
192 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]:
193 | self._tools = request.tools if request.tools else []
194 | messages = self._build_messages(request)
195 |
196 | if stream:
197 | return self._ask_stream(messages)
198 |
199 | return await self._ask_sync(messages)
200 |
--------------------------------------------------------------------------------
/muicebot/llm/providers/openai.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | from io import BytesIO
4 | from typing import AsyncGenerator, List, Literal, Union, overload
5 |
6 | import openai
7 | from nonebot import logger
8 | from openai import NOT_GIVEN, NotGiven
9 | from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
10 |
11 | from muicebot.models import Resource
12 |
13 | from .. import (
14 | BaseLLM,
15 | ModelCompletions,
16 | ModelConfig,
17 | ModelRequest,
18 | ModelStreamCompletions,
19 | register,
20 | )
21 | from ..utils.images import get_file_base64
22 | from ..utils.tools import function_call_handler
23 |
24 |
25 | @register("openai")
26 | class Openai(BaseLLM):
27 | _tools: List[ChatCompletionToolParam]
28 | modalities: Union[List[Literal["text", "audio"]], NotGiven]
29 |
30 | def __init__(self, model_config: ModelConfig) -> None:
31 | super().__init__(model_config)
32 | self._require("api_key", "model_name")
33 | self.api_key = self.config.api_key
34 | self.model = self.config.model_name
35 | self.api_base = self.config.api_host or "https://api.openai.com/v1"
36 | self.max_tokens = self.config.max_tokens
37 | self.temperature = self.config.temperature
38 | self.stream = self.config.stream
39 | self.modalities = [m for m in self.config.modalities if m in {"text", "audio"}] or NOT_GIVEN # type:ignore
40 | self.audio = self.config.audio if (self.modalities and self.config.audio) else NOT_GIVEN
41 | self.extra_body = self.config.extra_body
42 |
43 | self.client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.api_base, timeout=30)
44 | self._tools = []
45 |
46 | def __build_multi_messages(self, request: ModelRequest) -> dict:
47 | """
48 | 构建多模态类型
49 |
50 | 此模型加载器支持的多模态类型: `audio` `image` `video` `file`
51 | """
52 | user_content: List[dict] = [{"type": "text", "text": request.prompt}]
53 |
54 | for resource in request.resources:
55 | if resource.path is None:
56 | continue
57 |
58 | elif resource.type == "audio":
59 | file_format = resource.path.split(".")[-1]
60 | file_data = f"data:;base64,{get_file_base64(local_path=resource.path)}"
61 | user_content.append({"type": "input_audio", "input_audio": {"data": file_data, "format": file_format}})
62 |
63 | elif resource.type == "image":
64 | file_format = resource.path.split(".")[-1]
65 | file_data = f"data:image/{file_format};base64,{get_file_base64(local_path=resource.path)}"
66 | user_content.append({"type": "image_url", "image_url": {"url": file_data}})
67 |
68 | elif resource.type == "video":
69 | file_format = resource.path.split(".")[-1]
70 | file_data = f"data:;base64,{get_file_base64(local_path=resource.path)}"
71 | user_content.append({"type": "video_url", "video_url": {"url": file_data}})
72 |
73 | elif resource.type == "file":
74 | file_format = resource.path.split(".")[-1]
75 | file_data = f"data:;base64,{get_file_base64(local_path=resource.path)}"
76 | user_content.append({"type": "file", "file": {"file_data": file_data}})
77 |
78 | return {"role": "user", "content": user_content}
79 |
80 | def _build_messages(self, request: ModelRequest) -> list:
81 | messages = []
82 |
83 | if request.system:
84 | messages.append({"role": "system", "content": request.system})
85 |
86 | if request.history:
87 | for index, item in enumerate(request.history):
88 | user_content = (
89 | {"role": "user", "content": item.message}
90 | if not all([item.resources, self.config.multimodal])
91 | else self.__build_multi_messages(ModelRequest(item.message, resources=item.resources))
92 | )
93 |
94 | messages.append(user_content)
95 | messages.append({"role": "assistant", "content": item.respond})
96 |
97 | user_content = (
98 | {"role": "user", "content": request.prompt}
99 | if not request.resources
100 | else self.__build_multi_messages(request)
101 | )
102 |
103 | messages.append(user_content)
104 |
105 | return messages
106 |
107 | def _tool_call_request_precheck(self, message: ChatCompletionMessage) -> bool:
108 | """
109 | 工具调用请求预检
110 | """
111 | # We expect a single tool call
112 | if not (message.tool_calls and len(message.tool_calls) == 1):
113 | return False
114 |
115 | # We expect the tool to be a function call
116 | tool_call = message.tool_calls[0]
117 | if tool_call.type != "function":
118 | return False
119 |
120 | return True
121 |
122 | async def _ask_sync(self, messages: list, **kwargs) -> ModelCompletions:
123 | completions = ModelCompletions()
124 |
125 | try:
126 | response = await self.client.chat.completions.create(
127 | audio=self.audio,
128 | model=self.model,
129 | modalities=self.modalities,
130 | messages=messages,
131 | max_tokens=self.max_tokens,
132 | temperature=self.temperature,
133 | stream=False,
134 | tools=self._tools,
135 | extra_body=self.extra_body,
136 | )
137 |
138 | result = ""
139 | message = response.choices[0].message # type:ignore
140 | self._total_tokens += response.usage.total_tokens if response.usage else -1
141 |
142 | if (
143 | hasattr(message, "reasoning_content") # type:ignore
144 | and message.reasoning_content # type:ignore
145 | ):
146 | result += f"{message.reasoning_content}" # type:ignore
147 |
148 | if response.choices[0].finish_reason == "tool_calls" and self._tool_call_request_precheck(
149 | response.choices[0].message
150 | ):
151 | messages.append(response.choices[0].message)
152 | tool_call = response.choices[0].message.tool_calls[0] # type:ignore
153 | arguments = json.loads(tool_call.function.arguments.replace("'", '"'))
154 |
155 | function_return = await function_call_handler(tool_call.function.name, arguments)
156 |
157 | messages.append(
158 | {
159 | "tool_call_id": tool_call.id,
160 | "role": "tool",
161 | "name": tool_call.function.name,
162 | "content": function_return,
163 | }
164 | )
165 | return await self._ask_sync(messages)
166 |
167 | if message.content: # type:ignore
168 | result += message.content # type:ignore
169 |
170 | # 多模态消息处理(目前仅支持 audio 输出)
171 | if response.choices[0].message.audio:
172 | wav_bytes = base64.b64decode(response.choices[0].message.audio.data)
173 | completions.resources = [Resource(type="audio", raw=wav_bytes)]
174 |
175 | completions.text = result or "(警告:模型无输出!)"
176 | completions.usage = self._total_tokens
177 |
178 | except openai.APIConnectionError as e:
179 | error_message = f"API 连接错误: {e}"
180 | completions.text = error_message
181 | logger.error(error_message)
182 | logger.error(e.__cause__)
183 | completions.succeed = False
184 |
185 | except openai.APIStatusError as e:
186 | error_message = f"API 状态异常: {e.status_code}({e.response})"
187 | completions.text = error_message
188 | logger.error(error_message)
189 | completions.succeed = False
190 |
191 | return completions
192 |
193 | async def _ask_stream(self, messages: list, **kwargs) -> AsyncGenerator[ModelStreamCompletions, None]:
194 | is_insert_think_label = False
195 | function_id = ""
196 | function_name = ""
197 | function_arguments = ""
198 | audio_string = ""
199 |
200 | try:
201 | response = await self.client.chat.completions.create(
202 | audio=self.audio,
203 | model=self.model,
204 | modalities=self.modalities,
205 | messages=messages,
206 | max_tokens=self.max_tokens,
207 | temperature=self.temperature,
208 | stream=True,
209 | stream_options={"include_usage": True},
210 | tools=self._tools,
211 | extra_body=self.extra_body,
212 | )
213 |
214 | async for chunk in response:
215 | stream_completions = ModelStreamCompletions()
216 |
217 | # 获取 usage (最后一个包中返回)
218 | if chunk.usage:
219 | self._total_tokens += chunk.usage.total_tokens
220 | stream_completions.usage = self._total_tokens
221 |
222 | if not chunk.choices:
223 | yield stream_completions
224 | continue
225 |
226 | # 处理 Function call
227 | if chunk.choices[0].delta.tool_calls:
228 | tool_call = chunk.choices[0].delta.tool_calls[0]
229 | if tool_call.id:
230 | function_id = tool_call.id
231 | if tool_call.function:
232 | if tool_call.function.name:
233 | function_name += tool_call.function.name
234 | if tool_call.function.arguments:
235 | function_arguments += tool_call.function.arguments
236 |
237 | delta = chunk.choices[0].delta
238 | answer_content = delta.content
239 |
240 | # 处理思维过程 reasoning_content
241 | if (
242 | hasattr(delta, "reasoning_content") and delta.reasoning_content # type:ignore
243 | ):
244 | reasoning_content = chunk.choices[0].delta.reasoning_content # type:ignore
245 | stream_completions.chunk = (
246 | reasoning_content if is_insert_think_label else "" + reasoning_content
247 | )
248 | yield stream_completions
249 | is_insert_think_label = True
250 |
251 | elif answer_content:
252 | stream_completions.chunk = (
253 | answer_content if not is_insert_think_label else "" + answer_content
254 | )
255 | yield stream_completions
256 | is_insert_think_label = False
257 |
258 | # 处理多模态消息 (audio-only) (非标准方法,可能出现问题)
259 | if hasattr(chunk.choices[0].delta, "audio"):
260 | audio = chunk.choices[0].delta.audio # type:ignore
261 | if audio.get("data", None):
262 | audio_string += audio.get("data")
263 | stream_completions.chunk = audio.get("transcript", "")
264 | yield stream_completions
265 |
266 | if function_id:
267 |
268 | function_return = await function_call_handler(function_name, json.loads(function_arguments))
269 |
270 | messages.append(
271 | {
272 | "role": "assistant",
273 | "content": None,
274 | "tool_calls": [
275 | {
276 | "id": function_id,
277 | "type": "function",
278 | "function": {"name": function_name, "arguments": function_arguments},
279 | }
280 | ],
281 | }
282 | )
283 | messages.append(
284 | {
285 | "tool_call_id": function_id,
286 | "role": "tool",
287 | "content": function_return,
288 | }
289 | )
290 |
291 | async for chunk in self._ask_stream(messages):
292 | yield chunk
293 |
294 | # 处理多模态返回
295 | if audio_string:
296 | import numpy as np
297 | import soundfile as sf
298 |
299 | wav_bytes = base64.b64decode(audio_string)
300 | pcm_data = np.frombuffer(wav_bytes, dtype=np.int16)
301 | wav_io = BytesIO()
302 | sf.write(wav_io, pcm_data, samplerate=24000, format="WAV")
303 |
304 | stream_completions = ModelStreamCompletions()
305 | stream_completions.resources = [Resource(type="audio", raw=wav_io)]
306 | yield stream_completions
307 |
308 | except openai.APIConnectionError as e:
309 | error_message = f"API 连接错误: {e}"
310 | logger.error(error_message)
311 | logger.error(e.__cause__)
312 | stream_completions = ModelStreamCompletions()
313 | stream_completions.chunk = error_message
314 | stream_completions.succeed = False
315 | yield stream_completions
316 |
317 | except openai.APIStatusError as e:
318 | error_message = f"API 状态异常: {e.status_code}({e.response})"
319 | logger.error(error_message)
320 | stream_completions = ModelStreamCompletions()
321 | stream_completions.chunk = error_message
322 | stream_completions.succeed = False
323 | yield stream_completions
324 |
325 | @overload
326 | async def ask(self, request: ModelRequest, *, stream: Literal[False] = False) -> ModelCompletions: ...
327 |
328 | @overload
329 | async def ask(
330 | self, request: ModelRequest, *, stream: Literal[True] = True
331 | ) -> AsyncGenerator[ModelStreamCompletions, None]: ...
332 |
333 | async def ask(
334 | self, request: ModelRequest, *, stream: bool = False
335 | ) -> Union[ModelCompletions, AsyncGenerator[ModelStreamCompletions, None]]:
336 | self._tools = request.tools if request.tools else NOT_GIVEN # type:ignore
337 | self._total_tokens = 0
338 |
339 | messages = self._build_messages(request)
340 |
341 | if stream:
342 | return self._ask_stream(messages)
343 |
344 | return await self._ask_sync(messages)
345 |
--------------------------------------------------------------------------------
/muicebot/llm/registry.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Type
2 |
3 | from ._base import BaseLLM
4 |
5 | LLM_REGISTRY: Dict[str, Type[BaseLLM]] = {}
6 |
7 |
8 | def register(name: str):
9 | """
10 | 注册一个 LLM 实现
11 |
12 | :param name: LLM 实现名
13 | """
14 |
15 | def decorator(cls: Type[BaseLLM]):
16 | LLM_REGISTRY[name.lower()] = cls
17 | return cls
18 |
19 | return decorator
20 |
21 |
22 | def get_llm_class(name: str) -> Type[BaseLLM]:
23 | """
24 | 获得一个 LLM 实现类
25 |
26 | :param name: LLM 实现名
27 | """
28 | if name.lower() not in LLM_REGISTRY:
29 | raise ValueError(f"未注册模型:{name}")
30 |
31 | return LLM_REGISTRY[name.lower()]
32 |
--------------------------------------------------------------------------------
/muicebot/llm/utils/images.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from typing import Optional
3 |
4 |
5 | def get_file_base64(local_path: Optional[str] = None, file_bytes: Optional[bytes] = None) -> str:
6 | """
7 | 获取本地图像 Base64 的方法
8 | """
9 | if local_path:
10 | with open(local_path, "rb") as f:
11 | image_data = base64.b64encode(f.read()).decode("utf-8")
12 | return image_data
13 | if file_bytes:
14 | image_base64 = base64.b64encode(file_bytes)
15 | return image_base64.decode("utf-8")
16 | raise ValueError("You must pass in a valid parameter!")
17 |
--------------------------------------------------------------------------------
/muicebot/llm/utils/tools.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from nonebot import logger
4 |
5 | from muicebot.plugin.func_call import get_function_calls
6 | from muicebot.plugin.mcp import handle_mcp_tool
7 |
8 |
9 | async def function_call_handler(func: str, arguments: dict[str, str] | None = None) -> Any:
10 | """
11 | 模型 Function Call 请求处理
12 | """
13 | arguments = arguments if arguments and arguments != {"dummy_param": ""} else {}
14 |
15 | if func_caller := get_function_calls().get(func):
16 | logger.info(f"Function call 请求 {func}, 参数: {arguments}")
17 | result = await func_caller.run(**arguments)
18 | logger.success(f"Function call 成功,返回: {result}")
19 | return result
20 |
21 | if mcp_result := await handle_mcp_tool(func, arguments):
22 | logger.success(f"MCP 工具执行成功,返回: {mcp_result}")
23 | return mcp_result
24 |
25 | return "(Unknown Function)"
26 |
--------------------------------------------------------------------------------
/muicebot/models.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, dataclass, field
2 | from datetime import datetime
3 | from functools import total_ordering
4 | from io import BytesIO
5 | from typing import List, Literal, Optional, Union
6 |
7 |
8 | @dataclass
9 | class Resource:
10 | """多模态消息"""
11 |
12 | type: Literal["image", "video", "audio", "file"]
13 | """消息类型"""
14 | path: str = field(default_factory=str)
15 | """本地存储地址(对于模型处理是必需的)"""
16 | url: Optional[str] = field(default=None)
17 | """远程存储地址(一般不传入模型处理中)"""
18 | raw: Optional[Union[bytes, BytesIO]] = field(default=None)
19 | """二进制数据(只使用于模型返回且不保存到数据库中)"""
20 | mimetype: Optional[str] = field(default=None)
21 | """文件元数据类型"""
22 |
23 | def ensure_mimetype(self):
24 | from .utils.utils import guess_mimetype
25 |
26 | if not self.mimetype:
27 | self.mimetype = guess_mimetype(self)
28 |
29 | def to_dict(self) -> dict:
30 | return {
31 | "type": self.type,
32 | "path": self.path,
33 | }
34 |
35 |
36 | @total_ordering
37 | @dataclass
38 | class Message:
39 | """格式化后的 bot 消息"""
40 |
41 | id: int | None = None
42 | """每条消息的唯一ID"""
43 | time: str = field(default_factory=lambda: datetime.strftime(datetime.now(), "%Y.%m.%d %H:%M:%S"))
44 | """
45 | 字符串形式的时间数据:%Y.%m.%d %H:%M:%S
46 | 若要获取格式化的 datetime 对象,请使用 format_time
47 | """
48 | userid: str = ""
49 | """Nonebot 的用户id"""
50 | groupid: str = "-1"
51 | """群组id,私聊设为-1"""
52 | message: str = ""
53 | """消息主体"""
54 | respond: str = ""
55 | """模型回复(不包含思维过程)"""
56 | history: int = 1
57 | """消息是否可用于对话历史中,以整数形式映射布尔值"""
58 | resources: List[Resource] = field(default_factory=list)
59 | """多模态消息内容"""
60 | usage: int = -1
61 | """使用的总 tokens, 若模型加载器不支持则设为-1"""
62 |
63 | @property
64 | def format_time(self) -> datetime:
65 | """将时间字符串转换为 datetime 对象"""
66 | return datetime.strptime(self.time, "%Y.%m.%d %H:%M:%S")
67 |
68 | def to_dict(self) -> dict:
69 | return asdict(self)
70 |
71 | @staticmethod
72 | def from_dict(data: dict) -> "Message":
73 | return Message(**data)
74 |
75 | def __hash__(self) -> int:
76 | return hash(self.id)
77 |
78 | def __lt__(self, other: "Message") -> bool:
79 | return self.format_time < other.format_time
80 |
--------------------------------------------------------------------------------
/muicebot/muice.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import AsyncGenerator, Optional, Union
4 |
5 | from nonebot import logger
6 |
7 | from .config import ModelConfig, get_model_config, model_config_manager, plugin_config
8 | from .database import Database
9 | from .llm import (
10 | MODEL_DEPENDENCY_MAP,
11 | ModelCompletions,
12 | ModelRequest,
13 | ModelStreamCompletions,
14 | get_missing_dependencies,
15 | load_model,
16 | )
17 | from .models import Message, Resource
18 | from .plugin.func_call import get_function_list
19 | from .plugin.hook import HookType, hook_manager
20 | from .plugin.mcp import get_mcp_list
21 | from .templates import generate_prompt_from_template
22 | from .utils.utils import get_username
23 |
24 |
25 | class Muice:
26 | """
27 | Muice交互类
28 | """
29 |
30 | _instance = None
31 |
32 | def __new__(cls):
33 | if cls._instance is None:
34 | cls._instance = super().__new__(cls)
35 | cls._instance._initialized = False
36 | return cls._instance
37 |
38 | def __init__(self):
39 | if self._initialized:
40 | return
41 |
42 | self.model_config = get_model_config()
43 |
44 | self.database = Database()
45 | self.max_history_epoch = plugin_config.max_history_epoch
46 |
47 | self.system_prompt = ""
48 | self.user_instructions = ""
49 |
50 | self._load_config()
51 | self._init_model()
52 |
53 | model_config_manager.register_listener(self._on_config_changed)
54 |
55 | self._initialized = True
56 |
57 | def __del__(self):
58 | # 注销监听器
59 | try:
60 | model_config_manager.unregister_listener(self._on_config_changed)
61 | except (AttributeError, RuntimeError) as e:
62 | logger.debug(f"Muice __del__ 清理失败: {e}")
63 |
64 | def _load_config(self):
65 | """
66 | 加载配置项
67 | """
68 | self.model_loader = self.model_config.loader
69 | self.multimodal = self.model_config.multimodal
70 | self.template = self.model_config.template
71 | self.template_mode = self.model_config.template_mode
72 |
73 | def _init_model(self) -> None:
74 | """
75 | 初始化模型类
76 | """
77 | try:
78 | self.model = load_model(self.model_config)
79 |
80 | except ImportError as e:
81 | logger.critical(f"导入模型加载器 '{self.model_loader}' 失败:{e}")
82 | dependencies = MODEL_DEPENDENCY_MAP.get(self.model_loader, [])
83 | missing = get_missing_dependencies(dependencies)
84 | if missing:
85 | install_command = "pip install " + " ".join(missing)
86 | logger.critical(f"缺少依赖库:{', '.join(missing)}\n请运行以下命令安装缺失项:\n\n{install_command}")
87 |
88 | def load_model(self) -> bool:
89 | """
90 | 加载模型
91 |
92 | return: 是否加载成功
93 | """
94 | if not self.model.load():
95 | logger.error("模型加载失败: self.model.load 函数失败")
96 | return False
97 |
98 | return True
99 |
100 | def _on_config_changed(self, new_config: ModelConfig, old_config: ModelConfig):
101 | """配置文件变更时的回调函数"""
102 | logger.info("检测到配置文件变更,自动重载模型...")
103 | # 更新配置
104 | self.model_config = new_config
105 |
106 | # 重新加载模型
107 | self._load_config()
108 | self._init_model()
109 | self.load_model()
110 | logger.success(f"模型自动重载完成: {old_config.loader} -> {new_config.loader}")
111 |
112 | def change_model_config(self, config_name: str) -> str:
113 | """
114 | 更换模型配置文件并重新加载模型
115 | """
116 | try:
117 | self.model_config = get_model_config(config_name)
118 | except (ValueError, FileNotFoundError) as e:
119 | return str(e)
120 |
121 | self._load_config()
122 | self._init_model()
123 | self.load_model()
124 |
125 | return f"已成功加载 {config_name}" if config_name else "未指定模型配置名,已加载默认模型配置"
126 |
127 | async def _prepare_prompt(self, message: str, userid: str, is_private: bool) -> str:
128 | """
129 | 准备提示词(包含系统提示)
130 |
131 | :param message: 消息主体
132 | :param userid: 用户 Nonebot ID
133 | :param is_private: 是否为私聊信息
134 | :return: 最终模型提示词
135 | """
136 | if self.template is None:
137 | return message
138 |
139 | system_prompt = generate_prompt_from_template(self.template, userid, is_private).strip()
140 |
141 | if self.template_mode == "system":
142 | self.system_prompt = system_prompt
143 | else:
144 | self.user_instructions = system_prompt
145 |
146 | if is_private:
147 | return f"{self.user_instructions}\n\n{message}" if self.user_instructions else message
148 |
149 | group_prompt = f"<{await get_username()}> {message}"
150 |
151 | return f"{self.user_instructions}\n\n{group_prompt}" if self.user_instructions else group_prompt
152 |
153 | async def _prepare_history(self, userid: str, groupid: str = "-1", enable_history: bool = True) -> list[Message]:
154 | """
155 | 准备对话历史
156 |
157 | :param userid: 用户名
158 | :param groupid: 群组ID等(私聊时此值为-1)
159 | :param enable_history: 是否启用历史记录
160 | :return: 最终模型提示词
161 | """
162 | user_history = await self.database.get_user_history(userid, self.max_history_epoch) if enable_history else []
163 |
164 | # 验证多模态资源路径是否可用
165 | for item in user_history:
166 | item.resources = [
167 | resource for resource in item.resources if resource.path and os.path.isfile(resource.path)
168 | ]
169 |
170 | if groupid == "-1":
171 | return user_history[-self.max_history_epoch :]
172 |
173 | group_history = await self.database.get_group_history(groupid, self.max_history_epoch)
174 |
175 | for item in group_history:
176 | item.resources = [
177 | resource for resource in item.resources if resource.path and os.path.isfile(resource.path)
178 | ]
179 |
180 | # 群聊历史构建成 Message 的格式,避免上下文混乱
181 | for item in group_history:
182 | user_name = await get_username(item.userid)
183 | item.message = f"<{user_name}> {item.message}"
184 |
185 | final_history = list(set(user_history + group_history))
186 |
187 | return final_history[-self.max_history_epoch :]
188 |
189 | async def ask(
190 | self,
191 | message: Message,
192 | enable_history: bool = True,
193 | enable_plugins: bool = True,
194 | ) -> ModelCompletions:
195 | """
196 | 调用模型
197 |
198 | :param message: 消息文本
199 | :param enable_history: 是否启用历史记录
200 | :param enable_plugins: 是否启用工具插件
201 | :return: 模型回复
202 | """
203 | if not (self.model and self.model.is_running):
204 | logger.error("模型未加载")
205 | return ModelCompletions("模型未加载", succeed=False)
206 |
207 | is_private = message.groupid == "-1"
208 | logger.info("正在调用模型...")
209 |
210 | await hook_manager.run(HookType.BEFORE_PRETREATMENT, message)
211 |
212 | prompt = await self._prepare_prompt(message.message, message.userid, is_private)
213 | history = await self._prepare_history(message.userid, message.groupid, enable_history) if enable_history else []
214 | tools = (
215 | (await get_function_list() + await get_mcp_list())
216 | if self.model_config.function_call and enable_plugins
217 | else []
218 | )
219 | system = self.system_prompt if self.system_prompt else None
220 |
221 | model_request = ModelRequest(prompt, history, message.resources, tools, system)
222 | await hook_manager.run(HookType.BEFORE_MODEL_COMPLETION, model_request)
223 |
224 | start_time = time.perf_counter()
225 | logger.debug(f"模型调用参数:Prompt: {message}, History: {history}")
226 |
227 | response = await self.model.ask(model_request, stream=False)
228 |
229 | end_time = time.perf_counter()
230 |
231 | logger.success(f"模型调用{'成功' if response.succeed else '失败'}: {response}")
232 | logger.debug(f"模型调用时长: {end_time - start_time} s (token用量: {response.usage})")
233 |
234 | await hook_manager.run(HookType.AFTER_MODEL_COMPLETION, response)
235 |
236 | message.respond = response.text
237 | message.usage = response.usage
238 |
239 | await hook_manager.run(HookType.ON_FINISHING_CHAT, message)
240 |
241 | if response.succeed:
242 | await self.database.add_item(message)
243 |
244 | return response
245 |
246 | async def ask_stream(
247 | self,
248 | message: Message,
249 | enable_history: bool = True,
250 | enable_plugins: bool = True,
251 | ) -> AsyncGenerator[ModelStreamCompletions, None]:
252 | """
253 | 调用模型
254 |
255 | :param message: 消息文本
256 | :param enable_history: 是否启用历史记录
257 | :param enable_plugins: 是否启用工具插件
258 | :return: 模型回复
259 | """
260 | if not (self.model and self.model.is_running):
261 | logger.error("模型未加载")
262 | yield ModelStreamCompletions("模型未加载")
263 | return
264 |
265 | is_private = message.groupid == "-1"
266 | logger.info("正在调用模型...")
267 |
268 | await hook_manager.run(HookType.BEFORE_PRETREATMENT, message)
269 |
270 | prompt = await self._prepare_prompt(message.message, message.userid, is_private)
271 | history = await self._prepare_history(message.userid, message.groupid, enable_history) if enable_history else []
272 | tools = (
273 | (await get_function_list() + await get_mcp_list())
274 | if self.model_config.function_call and enable_plugins
275 | else []
276 | )
277 | system = self.system_prompt if self.system_prompt else None
278 |
279 | model_request = ModelRequest(prompt, history, message.resources, tools, system)
280 | await hook_manager.run(HookType.BEFORE_MODEL_COMPLETION, model_request)
281 |
282 | start_time = time.perf_counter()
283 | logger.debug(f"模型调用参数:Prompt: {message}, History: {history}")
284 |
285 | response = await self.model.ask(model_request, stream=True)
286 |
287 | total_reply = ""
288 | total_resources: list[Resource] = []
289 | item: Optional[ModelStreamCompletions] = None
290 |
291 | async for item in response:
292 | await hook_manager.run(HookType.ON_STREAM_CHUNK, item)
293 | total_reply += item.chunk
294 | yield item
295 | if item.resources:
296 | total_resources.extend(item.resources)
297 |
298 | if item is None:
299 | raise RuntimeError("模型调用器返回的值应至少包含一个元素")
300 |
301 | end_time = time.perf_counter()
302 | logger.success(f"已完成流式回复: {total_reply}")
303 | logger.debug(f"模型调用时长: {end_time - start_time} s (token用量: {item.usage})")
304 |
305 | final_model_completions = ModelCompletions(
306 | text=total_reply, usage=item.usage, resources=total_resources.copy(), succeed=item.succeed
307 | )
308 | await hook_manager.run(HookType.AFTER_MODEL_COMPLETION, final_model_completions)
309 |
310 | # 提取挂钩函数的可能的 resources 资源
311 | new_resources = [r for r in final_model_completions.resources if r not in total_resources]
312 |
313 | # yield 新资源
314 | for r in new_resources:
315 | yield ModelStreamCompletions(resources=[r])
316 |
317 | message.respond = total_reply
318 | message.usage = item.usage
319 |
320 | await hook_manager.run(HookType.ON_FINISHING_CHAT, message)
321 |
322 | if item.succeed:
323 | await self.database.add_item(message)
324 |
325 | async def refresh(self, userid: str) -> Union[AsyncGenerator[ModelStreamCompletions, None], ModelCompletions]:
326 | """
327 | 刷新对话
328 |
329 | :userid: 用户唯一标识id
330 | """
331 | logger.info(f"用户 {userid} 请求刷新")
332 |
333 | user_history = await self.database.get_user_history(userid, limit=1)
334 |
335 | if not user_history:
336 | logger.warning("用户对话数据不存在,拒绝刷新")
337 | return ModelCompletions("你都还没和我说过一句话呢,得和我至少聊上一段才能刷新哦")
338 |
339 | last_item = user_history[0]
340 |
341 | await self.database.mark_history_as_unavailable(userid, 1)
342 |
343 | if not self.model_config.stream:
344 | return await self.ask(last_item)
345 |
346 | return self.ask_stream(last_item)
347 |
348 | async def reset(self, userid: str) -> str:
349 | """
350 | 清空历史对话(将用户对话历史记录标记为不可用)
351 | """
352 | await self.database.mark_history_as_unavailable(userid)
353 | return "已成功移除对话历史~"
354 |
355 | async def undo(self, userid: str) -> str:
356 | await self.database.mark_history_as_unavailable(userid, 1)
357 | return "已成功撤销上一段对话~"
358 |
--------------------------------------------------------------------------------
/muicebot/plugin/__init__.py:
--------------------------------------------------------------------------------
1 | from .context import get_bot, get_ctx, get_event, get_mather, get_state, set_ctx
2 | from .loader import (
3 | get_plugin_by_module_name,
4 | get_plugin_data_dir,
5 | get_plugins,
6 | load_plugin,
7 | load_plugins,
8 | )
9 | from .models import Plugin, PluginMetadata
10 |
11 | __all__ = [
12 | "get_bot",
13 | "get_state",
14 | "get_ctx",
15 | "get_event",
16 | "get_mather",
17 | "load_plugin",
18 | "load_plugins",
19 | "get_plugins",
20 | "get_plugin_by_module_name",
21 | "PluginMetadata",
22 | "Plugin",
23 | "set_ctx",
24 | "get_plugin_data_dir",
25 | ]
26 |
--------------------------------------------------------------------------------
/muicebot/plugin/context.py:
--------------------------------------------------------------------------------
1 | """
2 | 存储并获取 Nonebot 依赖注入中的上下文
3 | """
4 |
5 | from contextvars import ContextVar
6 | from typing import Tuple
7 |
8 | from nonebot.adapters import Bot, Event
9 | from nonebot.matcher import Matcher
10 | from nonebot.typing import T_State
11 |
12 | # 定义上下文变量
13 | bot_context: ContextVar[Bot] = ContextVar("bot")
14 | event_context: ContextVar[Event] = ContextVar("event")
15 | state_context: ContextVar[T_State] = ContextVar("state")
16 | mather_context: ContextVar[Matcher] = ContextVar("matcher")
17 |
18 |
19 | # 获取当前上下文的各种信息
20 | def get_bot() -> Bot:
21 | return bot_context.get()
22 |
23 |
24 | def get_event() -> Event:
25 | return event_context.get()
26 |
27 |
28 | def get_state() -> T_State:
29 | return state_context.get()
30 |
31 |
32 | def get_mather() -> Matcher:
33 | return mather_context.get()
34 |
35 |
36 | def set_ctx(bot: Bot, event: Event, state: T_State, matcher: Matcher):
37 | """
38 | 注册 Nonebot 中的上下文信息
39 | """
40 | bot_context.set(bot)
41 | event_context.set(event)
42 | state_context.set(state)
43 | mather_context.set(matcher)
44 |
45 |
46 | def get_ctx() -> Tuple[Bot, Event, T_State, Matcher]:
47 | """
48 | 获取当前上下文
49 | """
50 | return (get_bot(), get_event(), get_state(), get_mather())
51 |
--------------------------------------------------------------------------------
/muicebot/plugin/func_call/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Muicebot Function Call Plugin
3 | """
4 |
5 | from .caller import get_function_calls, get_function_list, on_function_call
6 |
7 | __all__ = ["get_function_calls", "get_function_list", "on_function_call"]
8 |
--------------------------------------------------------------------------------
/muicebot/plugin/func_call/_types.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Coroutine, TypeVar, Union
2 |
3 | SYNC_FUNCTION_CALL_FUNC = Callable[..., str]
4 | ASYNC_FUNCTION_CALL_FUNC = Callable[..., Coroutine[str, Any, str]]
5 | FUNCTION_CALL_FUNC = Union[SYNC_FUNCTION_CALL_FUNC, ASYNC_FUNCTION_CALL_FUNC]
6 |
7 | F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
8 |
--------------------------------------------------------------------------------
/muicebot/plugin/func_call/caller.py:
--------------------------------------------------------------------------------
1 | """
2 | 机器人框架中处理函数调用的模块。
3 | 该模块提供了一个系统,用于注册可被AI系统调用的函数,
4 | 具有自动依赖注入和参数验证功能。
5 | 它包括:
6 | - Caller类:管理函数注册、依赖注入和执行
7 | - 用于函数调用的注册装饰器
8 | - 用于获取已注册函数调用的实用函数
9 | """
10 |
11 | import inspect
12 | from typing import Any, Optional, get_type_hints
13 |
14 | from nonebot import logger
15 | from nonebot.adapters import Bot, Event
16 | from nonebot.matcher import Matcher
17 | from nonebot.rule import Rule
18 | from nonebot.typing import T_State
19 |
20 | from ..context import get_bot, get_event, get_mather
21 | from ..utils import is_coroutine_callable
22 | from ._types import ASYNC_FUNCTION_CALL_FUNC, F
23 | from .parameter import Parameter
24 | from .utils import async_wrap
25 |
26 | _caller_data: dict[str, "Caller"] = {}
27 | """函数注册表,存储所有注册的函数"""
28 |
29 |
30 | class Caller:
31 | def __init__(self, description: str, rule: Optional[Rule] = None):
32 | self._name: str = ""
33 | """函数名称"""
34 | self._description: str = description
35 | """函数描述"""
36 | self._rule: Optional[Rule] = rule
37 | """启用规则"""
38 | self._parameters: dict[str, Parameter] = {}
39 | """函数参数字典"""
40 | self.function: ASYNC_FUNCTION_CALL_FUNC
41 | """函数对象"""
42 | self.default: dict[str, Any] = {}
43 | """默认值"""
44 |
45 | self.module_name: str = ""
46 | """函数所在模块名称"""
47 |
48 | def __call__(self, func: F) -> F:
49 | """
50 | 修饰器:注册一个 Function_call 函数
51 | """
52 | # 确保为异步函数
53 | if is_coroutine_callable(func):
54 | self.function = func # type: ignore
55 | else:
56 | self.function = async_wrap(func) # type:ignore
57 |
58 | self._name = func.__name__
59 |
60 | # 获取模块名
61 | if module := inspect.getmodule(func):
62 | module_name = module.__name__.split(".")[-1]
63 | else:
64 | module_name = ""
65 | self.module_name = module_name
66 |
67 | _caller_data[self._name] = self
68 | logger.success(f"Function Call 函数 {self.module_name}.{self._name} 已成功加载")
69 | return func
70 |
71 | async def _inject_dependencies(self, kwargs: dict) -> dict:
72 | """
73 | 自动解析参数并进行依赖注入
74 | """
75 | sig = inspect.signature(self.function)
76 | hints = get_type_hints(self.function)
77 |
78 | inject_args = kwargs.copy()
79 |
80 | for name, param in sig.parameters.items():
81 | param_type = hints.get(name, None)
82 |
83 | if param_type and isinstance(param_type, type) and issubclass(param_type, Bot):
84 | inject_args[name] = get_bot()
85 |
86 | elif param_type and isinstance(param_type, type) and issubclass(param_type, Event):
87 | inject_args[name] = get_event()
88 |
89 | elif param_type and isinstance(param_type, type) and issubclass(param_type, Matcher):
90 | inject_args[name] = get_mather()
91 |
92 | # elif param_type and issubclass(param_type, T_State):
93 | # inject_args[name] = get_state()
94 |
95 | # 填充默认值
96 | elif param.default != inspect.Parameter.empty:
97 | inject_args[name] = kwargs.get(name, param.default)
98 |
99 | # 如果参数未提供,则检查是否有默认值
100 | elif name not in inject_args:
101 | raise ValueError(f"缺少必要参数: {name}")
102 |
103 | return inject_args
104 |
105 | def params(self, **kwargs: Parameter) -> "Caller":
106 | self._parameters.update(kwargs)
107 | return self
108 |
109 | async def run(self, **kwargs) -> Any:
110 | """
111 | 执行 function call
112 | """
113 | if self.function is None:
114 | raise ValueError("未注册函数对象")
115 |
116 | inject_args = await self._inject_dependencies(kwargs)
117 |
118 | return await self.function(**inject_args)
119 |
120 | def data(self) -> dict[str, Any]:
121 | """
122 | 生成函数描述信息
123 |
124 | :return: 可用于 Function_call 的字典
125 | """
126 | if not self._parameters:
127 | properties = {
128 | "dummy_param": {"type": "string", "description": "为了兼容性设置的一个虚拟参数,因此不需要填写任何值"}
129 | }
130 | required = []
131 | else:
132 | properties = {key: value.data() for key, value in self._parameters.items()}
133 | required = [key for key, value in self._parameters.items() if value.default is None]
134 |
135 | return {
136 | "type": "function",
137 | "function": {
138 | "name": self._name,
139 | "description": self._description,
140 | "parameters": {
141 | "type": "object",
142 | "properties": properties,
143 | "required": required,
144 | },
145 | },
146 | }
147 |
148 |
149 | def on_function_call(description: str, rule: Optional[Rule] = None) -> Caller:
150 | """
151 | 返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
152 |
153 | :param description: 函数描述,若为None则从函数的docstring中获取
154 | :param rule: 启用规则。不满足规则则不启用此 function call
155 |
156 | :return: Caller对象
157 | """
158 | caller = Caller(description=description, rule=rule)
159 | return caller
160 |
161 |
162 | def get_function_calls() -> dict[str, Caller]:
163 | """获取所有已注册的function call函数
164 |
165 | Returns:
166 | dict[str, Caller]: 所有已注册的function call类
167 | """
168 | return _caller_data
169 |
170 |
171 | async def get_function_list() -> list[dict[str, dict]]:
172 | """
173 | 获取所有已注册的function call函数,并转换为工具格式
174 |
175 | :return: 所有已注册的function call函数列表
176 | """
177 | tools: list[dict[str, dict]] = []
178 | bot: Bot = get_bot()
179 | event: Event = get_event()
180 | state: T_State = {}
181 |
182 | for name, caller in _caller_data.items():
183 | if caller._rule is None or await caller._rule(bot, event, state):
184 | tools.append(caller.data())
185 |
186 | return tools
187 |
--------------------------------------------------------------------------------
/muicebot/plugin/func_call/parameter.py:
--------------------------------------------------------------------------------
1 | """
2 | 定义 Function_call 插件的参数类
3 | """
4 |
5 | from typing import Any
6 |
7 | from pydantic import BaseModel, Field
8 |
9 |
10 | class Parameter(BaseModel):
11 | """
12 | Function_call 插件参数对象
13 | """
14 |
15 | type: str
16 | """参数类型描述 string integer等"""
17 | description: str
18 | """参数描述"""
19 | default: Any = None
20 | """默认值"""
21 | properties: dict[str, Any] = {}
22 | """参数定义属性,例如最大值最小值等"""
23 | required: bool = False
24 | """是否必须"""
25 |
26 | def data(self) -> dict[str, Any]:
27 | """
28 | 生成参数描述信息
29 |
30 | :return: 可用于 Function_call 的字典
31 | """
32 | return {
33 | "type": self.type,
34 | "description": self.description,
35 | **{key: value for key, value in self.properties.items() if value is not None},
36 | }
37 |
38 |
39 | class ParamTypes:
40 | STRING = "string"
41 | INTEGER = "integer"
42 | ARRAY = "array"
43 | OBJECT = "object"
44 | BOOLEAN = "boolean"
45 | NUMBER = "number"
46 |
47 |
48 | class String(Parameter):
49 | type: str = ParamTypes.STRING
50 | properties: dict[str, Any] = Field(default_factory=dict)
51 | enum: list[str] | None = None
52 |
53 |
54 | class Integer(Parameter):
55 | type: str = ParamTypes.INTEGER
56 | properties: dict[str, Any] = Field(default_factory=lambda: {"minimum": 0, "maximum": 100})
57 |
58 | minimum: int | None = None
59 | maximum: int | None = None
60 |
61 |
62 | class Array(Parameter):
63 | type: str = ParamTypes.ARRAY
64 | properties: dict[str, Any] = Field(default_factory=lambda: {"items": {"type": "string"}})
65 | items: str = Field("string", description="数组元素类型")
66 |
--------------------------------------------------------------------------------
/muicebot/plugin/func_call/utils.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from typing import Any
3 |
4 | from ._types import ASYNC_FUNCTION_CALL_FUNC, SYNC_FUNCTION_CALL_FUNC
5 |
6 |
7 | def async_wrap(func: SYNC_FUNCTION_CALL_FUNC) -> ASYNC_FUNCTION_CALL_FUNC:
8 | """
9 | 装饰器,将同步函数包装为异步函数
10 | """
11 |
12 | @wraps(func)
13 | async def wrapper(*args: Any, **kwargs: Any) -> Any:
14 | return func(*args, **kwargs)
15 |
16 | return wrapper
17 |
--------------------------------------------------------------------------------
/muicebot/plugin/hook/__init__.py:
--------------------------------------------------------------------------------
1 | from ._types import HookType
2 | from .manager import (
3 | hook_manager,
4 | on_after_completion,
5 | on_before_completion,
6 | on_before_pretreatment,
7 | on_finish_chat,
8 | on_stream_chunk,
9 | )
10 |
11 | __all__ = [
12 | "HookType",
13 | "hook_manager",
14 | "on_after_completion",
15 | "on_before_completion",
16 | "on_before_pretreatment",
17 | "on_finish_chat",
18 | "on_stream_chunk",
19 | ]
20 |
--------------------------------------------------------------------------------
/muicebot/plugin/hook/_types.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, auto
2 | from typing import Awaitable, Callable, Union
3 |
4 | from muicebot.llm import ModelCompletions, ModelRequest, ModelStreamCompletions
5 | from muicebot.models import Message
6 |
7 | SYNC_HOOK_FUNC = Callable[..., None]
8 | ASYNC_HOOK_FUNC = Callable[..., Awaitable[None]]
9 | HOOK_FUNC = Union[SYNC_HOOK_FUNC, ASYNC_HOOK_FUNC]
10 |
11 | HOOK_ARGS = Union[Message, ModelCompletions, ModelStreamCompletions, ModelRequest]
12 |
13 |
14 | class HookType(Enum):
15 | """可用的 Hook 类型"""
16 |
17 | BEFORE_PRETREATMENT = auto()
18 | """预处理前"""
19 | BEFORE_MODEL_COMPLETION = auto()
20 | """模型调用前"""
21 | ON_STREAM_CHUNK = auto()
22 | """模型流式输出中"""
23 | AFTER_MODEL_COMPLETION = auto()
24 | """模型调用后"""
25 | ON_FINISHING_CHAT = auto()
26 | """结束对话时(存库前)"""
27 |
--------------------------------------------------------------------------------
/muicebot/plugin/hook/manager.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from collections import defaultdict
3 | from typing import (
4 | Any,
5 | Awaitable,
6 | Callable,
7 | Dict,
8 | List,
9 | Optional,
10 | Union,
11 | get_args,
12 | get_origin,
13 | get_type_hints,
14 | )
15 |
16 | from nonebot import logger
17 | from nonebot.adapters import Bot, Event
18 | from nonebot.matcher import Matcher
19 | from nonebot.rule import Rule
20 | from nonebot.typing import T_State
21 |
22 | from ..context import get_bot, get_event, get_mather
23 | from ._types import HOOK_ARGS, HOOK_FUNC, HookType
24 |
25 | DEPENDENCY_PROVIDERS: dict[type, Callable[[], object]] = {
26 | Bot: get_bot,
27 | Event: get_event,
28 | Matcher: get_mather,
29 | # T_State: get_state,
30 | }
31 |
32 |
33 | def _match_union(param_type: type, arg: object) -> bool:
34 | if get_origin(param_type) is Union:
35 | return any(isinstance(arg, t) for t in get_args(param_type))
36 | return False
37 |
38 |
39 | class HookManager:
40 | def __init__(self):
41 | self._hooks: Dict[HookType, List["Hooked"]] = defaultdict(list)
42 |
43 | async def _inject_dependencies(self, function: HOOK_FUNC, hook_arg: HOOK_ARGS) -> dict:
44 | """
45 | 自动解析参数并进行依赖注入
46 | """
47 | sig = inspect.signature(function)
48 | hints = get_type_hints(function)
49 |
50 | inject_args: dict[str, Any] = {}
51 |
52 | for name, param in sig.parameters.items():
53 | param_type = hints.get(name, None)
54 | if not param_type:
55 | continue
56 |
57 | # 1. Union 类型注入 hook_arg
58 | if _match_union(param_type, hook_arg):
59 | inject_args[name] = hook_arg
60 | continue
61 |
62 | # 2. 直接匹配 hook_arg
63 | if isinstance(hook_arg, param_type):
64 | inject_args[name] = hook_arg
65 | continue
66 |
67 | # 3. 依赖提供者匹配(Bot、Event、Matcher...)
68 | for dep_type, provider in DEPENDENCY_PROVIDERS.items():
69 | if isinstance(param_type, type) and issubclass(param_type, dep_type):
70 | inject_args[name] = provider()
71 | break
72 |
73 | return inject_args
74 |
75 | def register(self, hook_type: HookType, hooked: "Hooked"):
76 | """
77 | 注册一个挂钩函数
78 | """
79 | self._hooks[hook_type].append(hooked)
80 | return hooked
81 |
82 | async def run(self, hook_type: HookType, hook_arg: HOOK_ARGS, stream: bool = False):
83 | """
84 | 运行所有的钩子函数
85 |
86 | :param hook_type: 钩子类型
87 | :param hook_arg: 消息处理流程中对应的数据类
88 | :param stream: 当前是否为流式状态
89 | """
90 | hookeds = self._hooks[hook_type]
91 | hookeds.sort(key=lambda x: x.priority)
92 |
93 | bot: Bot = get_bot()
94 | event: Event = get_event()
95 | state: T_State = {}
96 |
97 | for hooked in hookeds:
98 | args = await self._inject_dependencies(hooked.function, hook_arg)
99 |
100 | if (hooked.stream is not None and hooked.stream == stream) or (
101 | hooked.rule and not await hooked.rule(bot, event, state)
102 | ):
103 | continue
104 |
105 | result = hooked.function(**args)
106 | if isinstance(result, Awaitable):
107 | await result
108 |
109 |
110 | hook_manager = HookManager()
111 |
112 |
113 | class Hooked:
114 | """挂钩函数对象"""
115 |
116 | def __init__(
117 | self, hook_type: HookType, priority: int = 10, stream: Optional[bool] = None, rule: Optional[Rule] = None
118 | ):
119 | self.hook_type = hook_type
120 | """钩子函数类型"""
121 | self.priority = priority
122 | """函数调用优先级"""
123 | self.stream = stream
124 | """是否仅在(非)流式中运行"""
125 | self.rule: Optional[Rule] = rule
126 | """启用规则"""
127 |
128 | self.function: HOOK_FUNC
129 | """函数对象"""
130 |
131 | def __call__(self, func: HOOK_FUNC) -> HOOK_FUNC:
132 | """
133 | 修饰器:注册一个 Hook 函数
134 | """
135 | self.function = func
136 |
137 | # 获取模块名
138 | if module := inspect.getmodule(func):
139 | module_name = module.__name__.split(".")[-1]
140 | else:
141 | module_name = ""
142 |
143 | hook_manager.register(self.hook_type, self)
144 | logger.success(f"挂钩函数 {module_name}.{func.__name__} 已成功加载")
145 | return func
146 |
147 |
148 | def on_before_pretreatment(priority: int = 10, rule: Optional[Rule] = None) -> Hooked:
149 | """
150 | 注册一个钩子函数
151 | 这个函数将在传入消息 (`Muice` 的 `_prepare_prompt()`) 前调用
152 | 它可接受一个 `Message` 类参数
153 |
154 | :param priority: 调用优先级
155 | :param rule: Nonebot 的响应规则
156 | """
157 | return Hooked(HookType.BEFORE_PRETREATMENT, priority=priority, rule=rule)
158 |
159 |
160 | def on_before_completion(priority: int = 10, rule: Optional[Rule] = None) -> Hooked:
161 | """
162 | 注册一个钩子函数。
163 | 这个函数将在传入模型(`Muice` 的 `model.ask()`)前调用
164 | 它可接受一个 `ModelRequest` 类参数
165 |
166 | :param priority: 调用优先级
167 | :param rule: Nonebot 的响应规则
168 | """
169 | return Hooked(HookType.BEFORE_MODEL_COMPLETION, priority=priority, rule=rule)
170 |
171 |
172 | def on_stream_chunk(priority: int = 10, rule: Optional[Rule] = None) -> Hooked:
173 | """
174 | 注册一个钩子函数。
175 | 这个函数将在流式调用中途(`Muice` 的 `model.ask()`)调用
176 | 它可接受一个 `ModelStreamCompletions` 类参数
177 |
178 | :param priority: 调用优先级
179 | :param rule: Nonebot 的响应规则
180 | """
181 | return Hooked(HookType.ON_STREAM_CHUNK, priority=priority, rule=rule)
182 |
183 |
184 | def on_after_completion(priority: int = 10, stream: Optional[bool] = None, rule: Optional[Rule] = None) -> Hooked:
185 | """
186 | 注册一个钩子函数。
187 | 这个函数将在传入模型(`Muice` 的 `model.ask()`)后调用(流式则传入整合后的数据)
188 | 它可接受一个 `ModelCompletion` 类参数
189 |
190 | 请注意:当启用流式时,对 `ModelStreamCompletion` 的任何修改将不生效
191 |
192 | :param priority: 调用优先级
193 | :param stream: 是否仅在(非)流式中处理,None 则无限制
194 | :param rule: Nonebot 的响应规则
195 | """
196 | return Hooked(HookType.AFTER_MODEL_COMPLETION, priority=priority, stream=stream, rule=rule)
197 |
198 |
199 | def on_finish_chat(priority: int = 10, rule: Optional[Rule] = None) -> Hooked:
200 | """
201 | 注册一个钩子函数。
202 | 这个函数将在结束对话(存库前)调用
203 | 它可接受一个 `Message` 类参数
204 |
205 | :param priority: 调用优先级
206 | :param rule: Nonebot 的响应规则
207 | """
208 | return Hooked(HookType.ON_FINISHING_CHAT, priority=priority, rule=rule)
209 |
--------------------------------------------------------------------------------
/muicebot/plugin/loader.py:
--------------------------------------------------------------------------------
1 | """
2 | 实现插件的加载和管理
3 |
4 | Attributes:
5 | _plugins (Dict[str, Plugin]): 插件注册表,存储已加载的插件
6 | Functions:
7 | load_plugin: 加载单个插件
8 | load_plugins: 加载指定目录下的所有插件
9 | get_plugins: 获取已加载的插件列表
10 | """
11 |
12 | import importlib
13 | import inspect
14 | import os
15 | from pathlib import Path
16 | from typing import Dict, Optional, Set
17 |
18 | import nonebot_plugin_localstore as store
19 | from nonebot import logger
20 |
21 | from .models import Plugin
22 | from .utils import path_to_module_name
23 |
24 | _plugins: Dict[str, Plugin] = {}
25 | """插件注册表"""
26 | _declared_plugins: Set[str] = set()
27 | """已声明插件注册表(不一定加载成功)"""
28 |
29 |
30 | def load_plugin(plugin_path: Path | str, base_path=Path.cwd()) -> Optional[Plugin]:
31 | """
32 | 加载单个插件
33 |
34 | :param plugins_dirs: 插件路径
35 | :param base_path: 外部插件的基准路径
36 | :return: 插件对象集合
37 | """
38 | try:
39 | logger.info(f"加载插件: {plugin_path}")
40 | if isinstance(plugin_path, Path):
41 | plugin_path = path_to_module_name(plugin_path, base_path)
42 |
43 | if plugin_path in _declared_plugins:
44 | raise ValueError(f"插件 {plugin_path} 包名出现冲突!")
45 | _declared_plugins.add(plugin_path)
46 |
47 | module = importlib.import_module(plugin_path)
48 | plugin = Plugin(name=module.__name__.split(".")[-1], module=module, package_name=plugin_path)
49 |
50 | _plugins[plugin.package_name] = plugin
51 |
52 | return plugin
53 |
54 | except Exception as e:
55 | logger.error(f"加载插件 {plugin_path} 失败: {e}")
56 | return None
57 |
58 |
59 | def load_plugins(*plugins_dirs: Path | str, base_path=Path.cwd()) -> set[Plugin]:
60 | """
61 | 加载传入插件目录中的所有插件
62 |
63 | :param plugins_dirs: 插件目录
64 | :param base_path: 外部插件的基准路径
65 | :return: 插件对象集合
66 | """
67 |
68 | plugins = set()
69 |
70 | for plugin_dir in plugins_dirs:
71 | plugin_dir_path = Path(plugin_dir) if isinstance(plugin_dir, str) else plugin_dir
72 |
73 | for plugin in os.listdir(plugin_dir_path):
74 | plugin_path = Path(os.path.join(plugin_dir_path, plugin))
75 | module_name = None
76 |
77 | if plugin_path.is_file() and plugin_path.suffix == ".py":
78 | module_name = path_to_module_name(plugin_path.with_suffix(""), base_path)
79 | elif plugin_path.is_dir() and (plugin_path / Path("__init__.py")).exists():
80 | module_name = path_to_module_name(plugin_path, base_path)
81 | if module_name and (loaded_plugin := load_plugin(module_name)):
82 | plugins.add(loaded_plugin)
83 |
84 | return plugins
85 |
86 |
87 | def _get_caller_plugin_name() -> Optional[str]:
88 | """
89 | 获取当前调用插件名
90 | """
91 | current_frame = inspect.currentframe()
92 | if current_frame is None:
93 | return None
94 |
95 | # find plugin
96 | frame = current_frame
97 | while frame := frame.f_back: # type:ignore
98 | module_name = (module := inspect.getmodule(frame)) and module.__name__
99 |
100 | if module_name is None:
101 | return None
102 |
103 | # skip muicebot it self
104 | package_name = module_name.split(".", maxsplit=1)[0]
105 | if package_name == "muicebot" and not module_name.startswith("muicebot.builtin_plugins"):
106 | continue
107 |
108 | # 将模块路径拆解为层级列表(例如 a.b.c → ["a", "a.b", "a.b.c"])
109 | module_segments = module_name.split(".")
110 | candidate_paths = [".".join(module_segments[: i + 1]) for i in range(len(module_segments))]
111 |
112 | # 从长到短查找最长匹配
113 | for candidate in reversed(candidate_paths):
114 | if candidate in _declared_plugins:
115 | return candidate.split(".")[-1]
116 |
117 | return None
118 |
119 |
120 | def get_plugins() -> Dict[str, Plugin]:
121 | """
122 | 获取插件列表
123 | """
124 | return _plugins
125 |
126 |
127 | def get_plugin_by_module_name(module_name: str) -> Optional[Plugin]:
128 | """
129 | 通过包名获取插件对象
130 | """
131 | return _plugins.get(module_name, None)
132 |
133 |
134 | def get_plugin_data_dir() -> Path:
135 | """
136 | 获取 Muicebot 插件数据目录
137 |
138 | 对于 Muicebot 的插件,它们的插件目录位于 Muicebot 的插件目录中下的 `plugins` 文件夹,并以插件名命名
139 | (`nonebot_plugin_localstore.get_plugin_data_dir`)
140 | """
141 | plugin_name = _get_caller_plugin_name()
142 | plugin_name = plugin_name or ".unknown"
143 |
144 | plugin_dir = store.get_plugin_data_dir() / "plugin"
145 | plugin_dir = plugin_dir.joinpath(plugin_name).resolve()
146 | plugin_dir.mkdir(parents=True, exist_ok=True)
147 |
148 | logger.debug(plugin_dir)
149 |
150 | return plugin_dir
151 |
--------------------------------------------------------------------------------
/muicebot/plugin/mcp/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from: https://github.com/modelcontextprotocol/python-sdk/tree/main/examples/clients/simple-chatbot
3 |
4 | MIT License
5 |
6 | Copyright (c) 2024 Anthropic, PBC
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 | """
26 |
27 | from .client import cleanup_servers, get_mcp_list, handle_mcp_tool, initialize_servers
28 |
29 | __all__ = ["handle_mcp_tool", "cleanup_servers", "initialize_servers", "get_mcp_list"]
30 |
--------------------------------------------------------------------------------
/muicebot/plugin/mcp/client.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Any, Optional
3 |
4 | from nonebot import logger
5 |
6 | from .config import server_config
7 | from .server import Server, Tool
8 |
9 | _servers: list[Server] = list()
10 |
11 |
12 | async def initialize_servers() -> None:
13 | """
14 | 初始化全部 MCP 实例
15 | """
16 | _servers.extend([Server(name, srv_config) for name, srv_config in server_config.items()])
17 | for server in _servers:
18 | logger.info(f"初始化 MCP Server: {server.name}")
19 | try:
20 | await server.initialize()
21 | except Exception as e:
22 | logger.error(f"初始化 MCP Server 实例时出现问题: {e}")
23 | await cleanup_servers()
24 | raise
25 |
26 |
27 | async def handle_mcp_tool(tool: str, arguments: Optional[dict[str, Any]] = None) -> Optional[str]:
28 | """
29 | 处理 MCP Tool 调用
30 | """
31 | logger.info(f"执行 MCP 工具: {tool} (参数: {arguments})")
32 |
33 | for server in _servers:
34 | server_tools = await server.list_tools()
35 | if not any(server_tool.name == tool for server_tool in server_tools):
36 | continue
37 |
38 | try:
39 | result = await server.execute_tool(tool, arguments)
40 |
41 | if isinstance(result, dict) and "progress" in result:
42 | progress = result["progress"]
43 | total = result["total"]
44 | percentage = (progress / total) * 100
45 | logger.info(f"工具执行进度: {progress}/{total} ({percentage:.1f}%)")
46 |
47 | return f"Tool execution result: {result}"
48 | except Exception as e:
49 | error_msg = f"Error executing tool: {str(e)}"
50 | logger.error(error_msg)
51 | return error_msg
52 |
53 | return None # Not found.
54 |
55 |
56 | async def cleanup_servers() -> None:
57 | """
58 | 清理 MCP 实例
59 | """
60 | cleanup_tasks = [asyncio.create_task(server.cleanup()) for server in _servers]
61 | if cleanup_tasks:
62 | try:
63 | await asyncio.gather(*cleanup_tasks, return_exceptions=True)
64 | except Exception as e:
65 | logger.warning(f"清理 MCP 实例时出现错误: {e}")
66 |
67 |
68 | async def transform_json(tool: Tool) -> dict[str, Any]:
69 | """
70 | 将 MCP Tool 转换为 OpenAI 所需的 parameters 格式,并删除多余字段
71 | """
72 | func_desc = {"name": tool.name, "description": tool.description, "parameters": {}, "required": []}
73 |
74 | if tool.input_schema:
75 | parameters = {
76 | "type": tool.input_schema.get("type", "object"),
77 | "properties": tool.input_schema.get("properties", {}),
78 | "required": tool.input_schema.get("required", []),
79 | }
80 | func_desc["parameters"] = parameters
81 |
82 | output = {"type": "function", "function": func_desc}
83 |
84 | return output
85 |
86 |
87 | async def get_mcp_list() -> list[dict[str, dict]]:
88 | """
89 | 获得适用于 OpenAI Tool Call 输入格式的 MCP 工具列表
90 | """
91 | all_tools: list[dict[str, dict]] = []
92 |
93 | for server in _servers:
94 | tools = await server.list_tools()
95 | all_tools.extend([await transform_json(tool) for tool in tools])
96 |
97 | return all_tools
98 |
--------------------------------------------------------------------------------
/muicebot/plugin/mcp/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import Any, Dict
4 |
5 | from pydantic import BaseModel, Field
6 |
7 | CONFIG_PATH = Path("./configs/mcp.json")
8 |
9 |
10 | class mcpServer(BaseModel):
11 | command: str
12 | """执行指令"""
13 | args: list = Field(default_factory=list)
14 | """命令参数"""
15 | env: dict[str, Any] = Field(default_factory=dict)
16 | """环境配置"""
17 |
18 |
19 | mcpConfig = Dict[str, mcpServer]
20 |
21 |
22 | def get_mcp_server_config() -> mcpConfig:
23 | """
24 | 从 MCP 配置文件 `config/mcp.json` 中获取 MCP Server 配置
25 | """
26 | if not CONFIG_PATH.exists():
27 | return {}
28 |
29 | with open(CONFIG_PATH, "r", encoding="utf-8") as f:
30 | configs = json.load(f) or {}
31 |
32 | mcp_config: mcpConfig = dict()
33 |
34 | for name, srv_config in configs["mcpServers"].items():
35 | mcp_config[name] = mcpServer(**srv_config)
36 |
37 | return mcp_config
38 |
39 |
40 | server_config = get_mcp_server_config()
41 |
--------------------------------------------------------------------------------
/muicebot/plugin/mcp/server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import os
4 | import shutil
5 | from contextlib import AsyncExitStack
6 | from typing import Any, Optional
7 |
8 | from mcp import ClientSession, StdioServerParameters
9 | from mcp.client.stdio import stdio_client
10 |
11 | from .config import mcpServer
12 |
13 |
14 | class Tool:
15 | """
16 | MCP Tool
17 | """
18 |
19 | def __init__(self, name: str, description: str, input_schema: dict[str, Any]) -> None:
20 | self.name: str = name
21 | self.description: str = description
22 | self.input_schema: dict[str, Any] = input_schema
23 |
24 | def format_for_llm(self) -> str:
25 | """
26 | 为 llm 生成工具描述
27 |
28 | :return: 工具描述
29 | """
30 | args_desc = []
31 | if "properties" in self.input_schema:
32 | for param_name, param_info in self.input_schema["properties"].items():
33 | arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}"
34 | if param_name in self.input_schema.get("required", []):
35 | arg_desc += " (required)"
36 | args_desc.append(arg_desc)
37 |
38 | return f"Tool: {self.name}\n" f"Description: {self.description}\n" f"Arguments:{chr(10).join(args_desc)}" ""
39 |
40 |
41 | class Server:
42 | """
43 | 管理 MCP 服务器连接和工具执行的 Server 实例
44 | """
45 |
46 | def __init__(self, name: str, config: mcpServer) -> None:
47 | self.name: str = name
48 | self.config: mcpServer = config
49 | self.stdio_context: Any | None = None
50 | self.session: ClientSession | None = None
51 | self._cleanup_lock: asyncio.Lock = asyncio.Lock()
52 | self.exit_stack: AsyncExitStack = AsyncExitStack()
53 |
54 | async def initialize(self) -> None:
55 | """
56 | 初始化实例
57 | """
58 | command = shutil.which("npx") if self.config.command == "npx" else self.config.command
59 | if command is None:
60 | raise ValueError(f"command 字段必须为一个有效值, 且目标指令必须存在于环境变量中: {self.config.command}")
61 |
62 | server_params = StdioServerParameters(
63 | command=command,
64 | args=self.config.args,
65 | env={**os.environ, **self.config.env} if self.config.env else None,
66 | )
67 | try:
68 | stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
69 | read, write = stdio_transport
70 | session = await self.exit_stack.enter_async_context(ClientSession(read, write))
71 | await session.initialize()
72 | self.session = session
73 | except Exception as e:
74 | logging.error(f"初始化 MCP Server 实例时遇到错误 {self.name}: {e}")
75 | await self.cleanup()
76 | raise
77 |
78 | async def list_tools(self) -> list[Tool]:
79 | """
80 | 从 MCP 服务器获得可用工具列表
81 |
82 | :return: 工具列表
83 |
84 | :raises RuntimeError: 如果服务器未启动
85 | """
86 | if not self.session:
87 | raise RuntimeError(f"Server {self.name} not initialized")
88 |
89 | tools_response = await self.session.list_tools()
90 | tools: list[Tool] = []
91 |
92 | for item in tools_response:
93 | if isinstance(item, tuple) and item[0] == "tools":
94 | tools.extend(Tool(tool.name, tool.description, tool.inputSchema) for tool in item[1])
95 |
96 | return tools
97 |
98 | async def execute_tool(
99 | self,
100 | tool_name: str,
101 | arguments: Optional[dict[str, Any]] = None,
102 | retries: int = 2,
103 | delay: float = 1.0,
104 | ) -> Any:
105 | """
106 | 执行一个 MCP 工具
107 |
108 | :param tool_name: 工具名称
109 | :param arguments: 工具参数
110 | :param retries: 重试次数
111 | :param delay: 重试间隔
112 |
113 | :return: 工具执行结果
114 |
115 | :raises RuntimeError: 如果服务器未初始化
116 | :raises Exception: 工具在所有重试中均失败
117 | """
118 | if not self.session:
119 | raise RuntimeError(f"Server {self.name} not initialized")
120 |
121 | attempt = 0
122 | while attempt < retries:
123 | try:
124 | logging.info(f"Executing {tool_name}...")
125 | result = await self.session.call_tool(tool_name, arguments)
126 |
127 | return result
128 |
129 | except Exception as e:
130 | attempt += 1
131 | logging.warning(f"Error executing tool: {e}. Attempt {attempt} of {retries}.")
132 | if attempt < retries:
133 | logging.info(f"Retrying in {delay} seconds...")
134 | await asyncio.sleep(delay)
135 | else:
136 | logging.error("Max retries reached. Failing.")
137 | raise
138 |
139 | async def cleanup(self) -> None:
140 | """Clean up server resources."""
141 | async with self._cleanup_lock:
142 | try:
143 | await self.exit_stack.aclose()
144 | self.session = None
145 | self.stdio_context = None
146 | except Exception as e:
147 | logging.error(f"Error during cleanup of server {self.name}: {e}")
148 |
--------------------------------------------------------------------------------
/muicebot/plugin/models.py:
--------------------------------------------------------------------------------
1 | from types import ModuleType
2 | from typing import Any
3 |
4 | from pydantic import BaseModel
5 |
6 |
7 | class PluginMetadata(BaseModel):
8 | """MuiceBot 插件元数据"""
9 |
10 | name: str
11 | """插件名"""
12 | description: str
13 | """插件描述"""
14 | usage: str
15 | """插件用法"""
16 | homepage: str | None = None
17 | """(可选) 插件主页,通常为开源存储库地址"""
18 | config: type[BaseModel] | None = None
19 | """插件配置项类,如无需配置可不填写"""
20 | extra: dict[Any, Any] | None = None
21 | """不知道干嘛的 extra 信息,我至今都没搞懂,喜欢的可以填"""
22 |
23 |
24 | class Plugin(BaseModel):
25 | """MuiceBot 插件对象"""
26 |
27 | name: str
28 | """插件名称"""
29 | module: ModuleType
30 | """插件模块对象"""
31 | package_name: str
32 | """模块包名"""
33 |
34 | def __hash__(self) -> int:
35 | return hash(self.package_name)
36 |
37 | def __eq__(self, other: Any) -> bool:
38 | return self.package_name == other.package_name if hasattr(other, "package_name") else False
39 |
40 | def __str__(self) -> str:
41 | return self.package_name
42 |
43 | class Config:
44 | arbitrary_types_allowed = True
45 |
--------------------------------------------------------------------------------
/muicebot/plugin/utils.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from pathlib import Path
3 | from typing import Any, Callable
4 |
5 |
6 | def path_to_module_name(module_path: Path, base_path: Path) -> str:
7 | """
8 | 动态计算模块名,基于明确的基准路径
9 | """
10 | try:
11 | rel_path = module_path.resolve().relative_to(base_path.resolve())
12 | except ValueError:
13 | # 处理绝对路径与相对路径的兼容性问题
14 | rel_path = module_path.resolve()
15 |
16 | if rel_path.stem == "__init__":
17 | parts = rel_path.parts[:-1]
18 | else:
19 | parts = rel_path.parts
20 |
21 | # 过滤空字符串和无效部分
22 | module_names = [p for p in parts if p not in ("", ".", "..")]
23 | return ".".join(module_names)
24 |
25 |
26 | def is_coroutine_callable(call: Callable[..., Any]) -> bool:
27 | """
28 | 检查 call 是否是一个 callable 协程函数
29 | """
30 | if inspect.isroutine(call):
31 | return inspect.iscoroutinefunction(call)
32 | if inspect.isclass(call):
33 | return False
34 | func_ = getattr(call, "__call__", None)
35 | return inspect.iscoroutinefunction(func_)
36 |
--------------------------------------------------------------------------------
/muicebot/scheduler.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from apscheduler.schedulers.asyncio import AsyncIOScheduler
4 | from apscheduler.triggers.cron import CronTrigger
5 | from apscheduler.triggers.interval import IntervalTrigger
6 | from nonebot import get_bot, logger
7 | from nonebot_plugin_alconna.uniseg import Target, UniMessage
8 |
9 | from .config import get_schedule_configs
10 | from .models import Message
11 | from .muice import Muice
12 |
13 |
14 | async def send_message(target_id: str, message: str, probability: float = 1):
15 | """
16 | 定时任务:发送信息
17 |
18 | :param target_id: 目标id;若为群聊则为 group_id 或者 channel_id,若为私聊则为 user_id
19 | :param message: 要发送的消息
20 | :param probability: 发送几率
21 | """
22 | if not (random.random() < probability):
23 | return
24 |
25 | logger.info(f"定时任务: send_message: {message}")
26 |
27 | target = Target(target_id)
28 | await UniMessage(message).send(target=target, bot=get_bot())
29 |
30 |
31 | async def model_ask(muice_app: Muice, target_id: str, prompt: str, probability: float = 1):
32 | """
33 | 定时任务:向模型发送消息
34 |
35 | :param muice_app: 沐雪核心类,用于与大语言模型交互
36 | :param target_id: 目标id;若为群聊则为 group_id 或者 channel_id,若为私聊则为 user_id
37 | :param prompt: 模型提示词
38 | :param probability: 发送几率
39 | """
40 | if not (random.random() < probability):
41 | return
42 |
43 | logger.info(f"定时任务: model_ask: {prompt}")
44 |
45 | if muice_app.model and muice_app.model.is_running:
46 | message = Message(message=prompt, userid=f"(bot_ask){target_id}")
47 | response = await muice_app.ask(message, enable_history=False, enable_plugins=False)
48 |
49 | target = Target(target_id)
50 | await UniMessage(response.text).send(target=target, bot=get_bot())
51 |
52 |
53 | def setup_scheduler(muice: Muice) -> AsyncIOScheduler:
54 | """
55 | 设置任务调度器
56 |
57 | :param muice: 沐雪核心类,用于与大语言模型交互
58 | """
59 | jobs = get_schedule_configs()
60 | scheduler = AsyncIOScheduler()
61 |
62 | for job in jobs:
63 | job_id = job.id
64 | job_type = "send_message" if job.say else "model_ask"
65 | trigger_type = job.trigger
66 | trigger_args = job.args
67 |
68 | # 解析触发器
69 | if trigger_type == "cron":
70 | trigger = CronTrigger(**trigger_args)
71 |
72 | elif trigger_type == "interval":
73 | trigger = IntervalTrigger(**trigger_args)
74 |
75 | else:
76 | logger.error(f"未知的触发器类型: {trigger_type}")
77 | continue
78 |
79 | # 添加任务
80 | if job_type == "send_message":
81 | scheduler.add_job(
82 | send_message,
83 | trigger,
84 | id=job_id,
85 | replace_existing=True,
86 | args=[job.target, job.say, job.probability],
87 | )
88 | else:
89 | scheduler.add_job(
90 | model_ask,
91 | trigger,
92 | id=job_id,
93 | replace_existing=True,
94 | args=[muice, job.target, job.ask, job.probability],
95 | )
96 |
97 | logger.success(f"已注册定时任务: {job_id}")
98 |
99 | if jobs:
100 | scheduler.start()
101 | return scheduler
102 |
--------------------------------------------------------------------------------
/muicebot/templates/__init__.py:
--------------------------------------------------------------------------------
1 | from .loader import generate_prompt_from_template
2 |
3 | __all__ = ["generate_prompt_from_template"]
4 |
--------------------------------------------------------------------------------
/muicebot/templates/loader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import yaml
4 | from jinja2 import Environment, FileSystemLoader
5 | from jinja2.exceptions import TemplateNotFound
6 | from nonebot import logger
7 |
8 | from .model import PromptTemplatesConfig, PromptTemplatesData
9 |
10 | SEARCH_PATH = ["./templates", Path(__file__).parent.parent / "builtin_templates"]
11 |
12 | TEMPLATES_CONFIG_PATH = "./configs/templates.yml"
13 |
14 |
15 | def load_templates_config() -> dict:
16 | """
17 | 获取模板配置
18 | """
19 | try:
20 | with open(TEMPLATES_CONFIG_PATH, "r", encoding="utf-8") as f:
21 | return yaml.safe_load(f) or {}
22 | except (FileNotFoundError, yaml.YAMLError):
23 | return {}
24 |
25 |
26 | def load_templates_data(userid: str, is_private: bool = False) -> PromptTemplatesData:
27 | """
28 | 获取模板数据
29 | """
30 | config = load_templates_config()
31 | templates_config = PromptTemplatesConfig(**config)
32 | return PromptTemplatesData.from_config(templates_config, userid=userid, is_private=is_private)
33 |
34 |
35 | def generate_prompt_from_template(template_name: str, userid: str, is_private: bool = False) -> str:
36 | """
37 | 获取提示词
38 | """
39 | env = Environment(loader=FileSystemLoader(SEARCH_PATH))
40 |
41 | if not template_name.endswith((".j2", ".jinja2")):
42 | template_name += ".jinja2"
43 | try:
44 | template = env.get_template(template_name)
45 | except TemplateNotFound:
46 | logger.error(f"模板文件 {template_name} 未找到!")
47 | return ""
48 |
49 | templates_data = load_templates_data(userid, is_private)
50 | prompt = template.render(templates_data.model_dump())
51 |
52 | return prompt
53 |
--------------------------------------------------------------------------------
/muicebot/templates/model.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from pydantic import BaseModel, ConfigDict, Field
4 |
5 |
6 | class Userinfo(BaseModel):
7 | """用户信息配置"""
8 |
9 | name: str
10 | """用户名称"""
11 | id: str
12 | """用户 Nonebot ID"""
13 | info: str
14 | """用户信息"""
15 |
16 |
17 | class PromptTemplatesConfig(BaseModel):
18 | """提示词模板配置"""
19 |
20 | ai_nickname: str = "沐雪"
21 | """AI 昵称"""
22 | master_nickname: str = "沐沐(Muika)"
23 | """AI 开发者昵称"""
24 |
25 | userinfos: List[Userinfo] = Field(default=[])
26 | """用户信息列表"""
27 |
28 | model_config = ConfigDict(extra="allow")
29 | """允许其他模板参数传入"""
30 |
31 |
32 | class PromptTemplatesData(BaseModel):
33 | """提示词模板数据"""
34 |
35 | ai_nickname: str = "沐雪"
36 | """AI 昵称"""
37 | master_nickname: str = "沐沐(Muika)"
38 | """AI 开发者昵称"""
39 |
40 | private: bool = False
41 | """当前对话是否为私聊"""
42 | user_name: str = ""
43 | """目标用户名"""
44 | user_info: str = ""
45 | """目标用户信息"""
46 |
47 | model_config = ConfigDict(extra="allow")
48 | """允许其他模板参数传入"""
49 |
50 | @classmethod
51 | def from_config(cls, templates_config: PromptTemplatesConfig, userid: str, is_private: bool = False):
52 | base = templates_config.model_dump()
53 | data = cls(**base)
54 |
55 | user = next((u for u in templates_config.userinfos if u.id == userid), None)
56 | if user:
57 | data.user_name = user.name
58 | data.user_info = user.info
59 |
60 | data.private = is_private
61 |
62 | return data
63 |
--------------------------------------------------------------------------------
/muicebot/utils/SessionManager.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Dict, List, Optional
3 |
4 | from nonebot import logger
5 | from nonebot.adapters import Event
6 | from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
7 |
8 | from ..config import plugin_config
9 |
10 |
11 | class SessionManager:
12 | def __init__(self) -> None:
13 | self.sessions: Dict[str, List[UniMsg]] = {}
14 | self._lock: asyncio.Lock = asyncio.Lock()
15 | self._timeout = plugin_config.input_timeout
16 |
17 | async def _put(self, sid: str, msg: UniMsg) -> None:
18 | async with self._lock:
19 | if sid not in self.sessions:
20 | self.sessions[sid] = []
21 | self.sessions[sid].append(msg)
22 |
23 | async def _get_messages_length(self, sid: str) -> int:
24 | async with self._lock:
25 | return len(self.sessions.get(sid, []))
26 |
27 | def merge_messages(self, sid: str) -> UniMessage:
28 | merged_message = UniMessage()
29 |
30 | for message in self.sessions.pop(sid, []):
31 | merged_message += message
32 |
33 | return merged_message
34 |
35 | async def put_and_wait(self, event: Event, message: UniMsg) -> Optional[UniMessage]:
36 | sid = event.get_session_id()
37 | await self._put(sid, message)
38 |
39 | old_length = await self._get_messages_length(sid)
40 | logger.debug(f"开始等待后续消息 ({self._timeout}s): 会话 {sid}, 当前消息数 {old_length}")
41 | await asyncio.sleep(self._timeout)
42 | new_length = await self._get_messages_length(sid)
43 |
44 | if new_length != old_length:
45 | logger.debug(f"发现新消息插入,当前处理器退出,会话 {sid} 交由后续处理器处理")
46 | return None
47 |
48 | logger.debug(f"无新消息,当前处理器接管会话 {sid}")
49 | return self.merge_messages(sid)
50 |
--------------------------------------------------------------------------------
/muicebot/utils/adapters.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from functools import lru_cache
3 |
4 | ADAPTER_CLASSES = {}
5 | """动态适配器注册表"""
6 |
7 |
8 | @lru_cache()
9 | def safe_import(path: str):
10 | """
11 | 安全导入:即使导入出现问题也不会出现报错
12 | """
13 | try:
14 | module_path, class_name = path.rsplit(".", 1)
15 | module = importlib.import_module(module_path)
16 | return getattr(module, class_name)
17 | except ImportError:
18 | return None
19 |
20 |
21 | ADAPTER_CLASSES = {
22 | "onebot_v12": safe_import("nonebot.adapters.onebot.v12.Bot"),
23 | "UnsupportedParam": safe_import("nonebot.adapters.onebot.v12.exception.UnsupportedParam"),
24 | "onebot_v11": safe_import("nonebot.adapters.onebot.v11.Bot"),
25 | "telegram_event": safe_import("nonebot.adapters.telegram.Event"),
26 | "telegram_file": safe_import("nonebot.adapters.telegram.message.File"),
27 | "qq_event": safe_import("nonebot.adapters.telegram.qq.Event"),
28 | }
29 |
--------------------------------------------------------------------------------
/muicebot/utils/migrations.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import shutil
5 | from pathlib import Path
6 | from typing import TYPE_CHECKING
7 |
8 | import aiosqlite
9 | from nonebot import logger
10 |
11 | if TYPE_CHECKING:
12 | from ..database import Database
13 |
14 |
15 | class MigrationManager:
16 | """数据库迁移管理器"""
17 |
18 | def __init__(self, db: Database) -> None:
19 | self.db = db
20 | self.path: Path = db.DB_PATH
21 |
22 | self.migrations = {0: self._migrate_v0_to_v1, 1: self._migrate_v1_to_v2}
23 |
24 | self.latest_version = max(self.migrations.keys()) + 1
25 |
26 | async def migrate_if_needed(self):
27 | """
28 | 检查数据库更新并迁移
29 | """
30 | await self.__init_version_table()
31 | current_version = await self.__get_version()
32 |
33 | while current_version in self.migrations:
34 | logger.info(f"检测到数据库更新,当前版本 v{current_version}...")
35 | backup_path = self.path.with_suffix(f".backup.v{current_version}.db")
36 | shutil.copyfile(self.path, backup_path)
37 |
38 | try:
39 | async with self.db.connect() as conn:
40 | await conn.execute("BEGIN")
41 | await self.migrations[current_version](conn)
42 | await conn.commit()
43 | current_version += 1
44 | await self.__set_version(current_version)
45 | logger.success(f"数据库已成功迁移到 v{current_version} ⭐")
46 | except Exception as e:
47 | logger.error(f"迁移至 v{current_version + 1} 失败,错误:{e}")
48 | shutil.copyfile(backup_path, self.path)
49 | logger.info("已回退到迁移前的状态")
50 | break
51 |
52 | async def __init_version_table(self):
53 | await self.db.execute(
54 | """
55 | CREATE TABLE IF NOT EXISTS schema_version (
56 | version INTEGER NOT NULL
57 | );
58 | """
59 | )
60 | result = await self.db.execute("SELECT COUNT(*) FROM schema_version", fetchone=True)
61 | if not result or result[0] == 0: # 只有 v0.2.x 的数据库没有版本号
62 | await self.db.execute("INSERT INTO schema_version (version) VALUES (0)")
63 | logger.info("数据库无版本记录,可能是v0版本的数据库,设为v0")
64 |
65 | async def __get_version(self) -> int:
66 | """
67 | 获取数据库版本号,默认值为0
68 | """
69 | result = await self.db.execute("SELECT version FROM schema_version", fetchone=True)
70 | return result[0] if result else 0
71 |
72 | async def __set_version(self, version: int):
73 | await self.db.execute("UPDATE schema_version SET version = ?", (version,))
74 |
75 | async def _migrate_v0_to_v1(self, conn: aiosqlite.Connection):
76 | logger.info("v1 更新内容: 添加 TOTALTOKENS 与 GROUPID 字段")
77 | await conn.execute("ALTER TABLE MSG ADD COLUMN TOTALTOKENS INTEGER DEFAULT -1;")
78 | await conn.execute("ALTER TABLE MSG ADD COLUMN GROUPID TEXT DEFAULT '-1';")
79 |
80 | async def _migrate_v1_to_v2(self, conn: aiosqlite.Connection):
81 | logger.info("v2 更新内容: total_token 变更为 usage, images 列表优化为 resources")
82 |
83 | # 创建临时表
84 | await conn.execute(
85 | """
86 | CREATE TABLE MSG_NEW(
87 | ID INTEGER PRIMARY KEY AUTOINCREMENT,
88 | TIME TEXT NOT NULL,
89 | USERID TEXT NOT NULL,
90 | GROUPID TEXT NOT NULL DEFAULT (-1),
91 | MESSAGE TEXT NOT NULL,
92 | RESPOND TEXT NOT NULL,
93 | HISTORY INTEGER NOT NULL DEFAULT (1),
94 | RESOURCES TEXT NOT NULL DEFAULT "[]",
95 | USAGE INTEGER NOT NULL DEFAULT (-1)
96 | );
97 | """
98 | )
99 |
100 | cursor = await conn.execute(
101 | "SELECT ID, TIME, USERID, GROUPID, MESSAGE, RESPOND, HISTORY, IMAGES, TOTALTOKENS FROM MSG"
102 | )
103 | rows = await cursor.fetchall()
104 |
105 | for row in rows:
106 | id_, time_, userid, groupid, message, respond, history, images_json, totaltokens = row
107 |
108 | # 转换 images -> resources
109 | try:
110 | images = json.loads(images_json) if images_json else []
111 | except Exception:
112 | images = []
113 |
114 | resources = []
115 | for url in images:
116 | resources.append({"type": "image", "path": url})
117 | resources_json = json.dumps(resources, ensure_ascii=False)
118 |
119 | # 插入到新表
120 | await conn.execute(
121 | """
122 | INSERT INTO MSG_NEW (ID, TIME, USERID, GROUPID, MESSAGE, RESPOND, HISTORY, RESOURCES, USAGE)
123 | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
124 | """,
125 | (id_, time_, userid, groupid, message, respond, history, resources_json, totaltokens),
126 | )
127 |
128 | await conn.execute("DROP TABLE MSG")
129 | await conn.execute("ALTER TABLE MSG_NEW RENAME TO MSG")
130 |
--------------------------------------------------------------------------------
/muicebot/utils/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import mimetypes
3 | import os
4 | import ssl
5 | import sys
6 | import time
7 | from importlib.metadata import PackageNotFoundError, version
8 | from typing import Optional
9 |
10 | import fleep
11 | import httpx
12 | import nonebot_plugin_localstore as store
13 | from nonebot import get_bot, logger
14 | from nonebot.adapters import Event, MessageSegment
15 | from nonebot.log import default_filter, logger_id
16 | from nonebot_plugin_userinfo import get_user_info
17 |
18 | from ..config import plugin_config
19 | from ..models import Resource
20 | from ..plugin.context import get_event
21 | from .adapters import ADAPTER_CLASSES
22 |
23 | IMG_DIR = store.get_plugin_data_dir() / ".cache" / "images"
24 | IMG_DIR.mkdir(parents=True, exist_ok=True)
25 |
26 | User_Agent = (
27 | "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
28 | "AppleWebKit/537.36 (KHTML, like Gecko)"
29 | "Chrome/134.0.0.0 Safari/537.36 Edg/134.0.0.0"
30 | )
31 |
32 |
33 | async def download_file(file_url: str, file_name: Optional[str] = None, proxy: Optional[str] = None) -> str:
34 | """
35 | 保存文件至本地目录
36 |
37 | :param file_url: 图片在线地址
38 | :param file_name: 要保存的文件名
39 | :param proxy: 代理地址
40 |
41 | :return: 保存后的本地目录
42 | """
43 | ssl_context = ssl.create_default_context()
44 | ssl_context.set_ciphers("DEFAULT")
45 | file_name = file_name if file_name else str(time.time_ns()) + ".jpg"
46 |
47 | async with httpx.AsyncClient(proxy=proxy, verify=ssl_context) as client:
48 | r = await client.get(file_url, headers={"User-Agent": User_Agent})
49 | local_path = (IMG_DIR / file_name).resolve()
50 | with open(local_path, "wb") as file:
51 | file.write(r.content)
52 | return str(local_path)
53 |
54 |
55 | async def save_image_as_base64(image_url: str, proxy: Optional[str] = None) -> str:
56 | """
57 | 从在线 url 获取图像 Base64
58 |
59 | :image_url: 图片在线地址
60 | :return: 本地地址
61 | """
62 | ssl_context = ssl.create_default_context()
63 | ssl_context.set_ciphers("DEFAULT")
64 |
65 | async with httpx.AsyncClient(proxy=proxy, verify=ssl_context) as client:
66 | r = await client.get(image_url, headers={"User-Agent": User_Agent})
67 | image_base64 = base64.b64encode(r.content)
68 | return image_base64.decode("utf-8")
69 |
70 |
71 | async def get_file_via_adapter(message: MessageSegment, event: Event) -> Optional[str]:
72 | """
73 | 通过适配器自有方式获取文件地址并保存到本地
74 |
75 | :return: 本地地址
76 | """
77 | bot = get_bot()
78 |
79 | Onebotv12Bot = ADAPTER_CLASSES["onebot_v12"]
80 | UnsupportedParam = ADAPTER_CLASSES["UnsupportedParam"]
81 | Onebotv11Bot = ADAPTER_CLASSES["onebot_v11"]
82 | TelegramEvent = ADAPTER_CLASSES["telegram_event"]
83 | TelegramFile = ADAPTER_CLASSES["telegram_file"]
84 |
85 | if Onebotv12Bot and UnsupportedParam and isinstance(bot, Onebotv12Bot):
86 | # if message.type != "image":
87 | # return None
88 |
89 | try:
90 | file_path = await bot.get_file(type="url", file_id=message.data["file_id"])
91 | except UnsupportedParam as e:
92 | logger.error(f"Onebot 实现不支持获取文件 URL,文件获取操作失败:{e}")
93 | return None
94 |
95 | return str(file_path)
96 |
97 | elif Onebotv11Bot and isinstance(bot, Onebotv11Bot):
98 | if "url" in message.data and "file" in message.data:
99 | return await download_file(message.data["url"], message.data["file"])
100 |
101 | elif TelegramEvent and TelegramFile and isinstance(event, TelegramEvent):
102 | if not isinstance(message, TelegramFile):
103 | return None
104 |
105 | file_id = message.data["file"]
106 | file = await bot.get_file(file_id=file_id)
107 | if not file.file_path:
108 | return None
109 |
110 | url = f"https://api.telegram.org/file/bot{bot.bot_config.token}/{file.file_path}" # type: ignore
111 | # filename = file.file_path.split("/")[1]
112 | return await download_file(url, proxy=plugin_config.telegram_proxy)
113 |
114 | return None
115 |
116 |
117 | def init_logger():
118 | console_handler_level = plugin_config.log_level
119 |
120 | log_dir = "logs"
121 | if not os.path.exists(log_dir):
122 | os.mkdir(log_dir)
123 |
124 | log_file_path = f"{log_dir}/{time.strftime('%Y-%m-%d')}.log"
125 |
126 | # 移除 NoneBot 默认的日志处理器
127 | logger.remove(logger_id)
128 | # 添加新的日志处理器
129 | logger.add(
130 | sys.stdout,
131 | level=console_handler_level,
132 | diagnose=True,
133 | format="[{level}] {function}: {message}",
134 | filter=default_filter,
135 | colorize=True,
136 | )
137 |
138 | logger.add(
139 | log_file_path,
140 | level="DEBUG",
141 | format="[{time:YYYY-MM-DD HH:mm:ss}] [{level}] {function}: {message}",
142 | encoding="utf-8",
143 | rotation="1 day",
144 | retention="7 days",
145 | )
146 |
147 |
148 | def get_version() -> str:
149 | """
150 | 获取当前版本号
151 |
152 | 优先尝试从已安装包中获取版本号, 否则从 `pyproject.toml` 读取
153 | """
154 | package_name = "muicebot"
155 |
156 | try:
157 | return version(package_name)
158 | except PackageNotFoundError:
159 | pass
160 |
161 | toml_path = os.path.join(os.path.dirname(__file__), "../pyproject.toml")
162 |
163 | if not os.path.isfile(toml_path):
164 | return "Unknown"
165 |
166 | try:
167 | if sys.version_info >= (3, 11):
168 | import tomllib
169 |
170 | with open(toml_path, "rb") as f:
171 | pyproject_data = tomllib.load(f)
172 |
173 | else:
174 | import toml
175 |
176 | with open(toml_path, "r", encoding="utf-8") as f:
177 | pyproject_data = toml.load(f)
178 |
179 | # 返回版本号
180 | return pyproject_data["tool"]["pdm"]["version"]
181 |
182 | except (FileNotFoundError, KeyError, ModuleNotFoundError):
183 | return "Unknown"
184 |
185 |
186 | async def get_username(user_id: Optional[str] = None) -> str:
187 | """
188 | 获取当前对话的用户名,如果失败就返回用户id
189 | """
190 | bot = get_bot()
191 | event = get_event()
192 | user_id = user_id if user_id else event.get_user_id()
193 | user_info = await get_user_info(bot, event, user_id)
194 | return user_info.user_name if user_info else user_id
195 |
196 |
197 | def guess_mimetype(resource: Resource) -> Optional[str]:
198 | """
199 | 尝试获取 minetype 类型
200 | """
201 | # raw 不落库,因此无法从 raw 判断
202 | if resource.path and os.path.exists(resource.path):
203 | try:
204 | with open(resource.path, "rb") as file:
205 | header = file.read(128)
206 | info = fleep.get(header)
207 | if info.mime:
208 | return info.mime[0] # type:ignore
209 | else:
210 | return mimetypes.guess_type(resource.path)[0]
211 | except Exception:
212 | return mimetypes.guess_type(resource.path)[0]
213 | elif resource.url:
214 | return mimetypes.guess_type(resource.url)[0]
215 | return None
216 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "MuiceBot"
3 | dynamic = ["version"]
4 | description = "一个可以用来调用多种模型,且被用来练手的Nonebot项目"
5 | readme = "README.md"
6 | requires-python = ">=3.10, <=3.12"
7 | dependencies = [
8 | "aiosqlite>=0.17.0",
9 | "APScheduler>=3.11.0",
10 | "fleep>=1.0.1",
11 | "jinja2>=3.1.6",
12 | "nonebot2>=2.4.1",
13 | "nonebot-adapter-onebot>=2.4.6",
14 | "nonebot_plugin_alconna>=0.54.2",
15 | "nonebot_plugin_apscheduler>=0.5.0",
16 | "nonebot_plugin_localstore>=0.7.3",
17 | "nonebot_plugin_session>=0.3.2",
18 | "nonebot_plugin_userinfo>=0.2.6",
19 | "openai>=1.64.0",
20 | "pydantic>=2.10.5",
21 | "httpx>=0.27.0",
22 | "ruamel.yaml>=0.18.10",
23 | "SQLAlchemy>=2.0.38",
24 | "toml>=0.10.2; python_version < '3.11'",
25 | "websocket_client>=1.8.0",
26 | "watchdog>=6.0.0",
27 | "mcp[cli]>=1.9.0"
28 | ]
29 | authors = [
30 | { name = "Moemu", email = "i@snowy.moe" },
31 | ]
32 |
33 | [project.optional-dependencies]
34 | standard = [
35 | "azure-ai-inference>=1.0.0b7",
36 | "dashscope>=1.22.1",
37 | "google-genai==1.8.0",
38 | "numpy>=1.26.4",
39 | "ollama>=0.4.7",
40 | "soundfile>=0.13.1"
41 | ]
42 | dev = [
43 | "pre-commit>=4.1.0",
44 | "mypy>=1.15.0",
45 | "black>=25.1.0",
46 | "types-PyYAML",
47 | ]
48 |
49 | [tool.nonebot]
50 | adapters = [
51 | { name = "OneBot V12", module_name = "nonebot.adapters.onebot.v12" },
52 | { name = "OneBot V11", module_name = "nonebot.adapters.onebot.v11" }
53 | ]
54 | plugins = ["nonebot_plugin_alconna", "nonebot_plugin_localstore", "nonebot_plugin_apscheduler", "nonebot_plugin_session", "nonebot_plugin_userinfo"]
55 | builtin_plugins = []
56 |
57 | [tool.black]
58 | line-length = 120
59 |
60 | [tool.isort]
61 | profile = "black"
62 |
63 | [tool.pdm]
64 | distribution = true
65 |
66 | [tool.pdm.version]
67 | source = "scm"
68 | tag_filter = "v*"
69 | tag_regex = '^v(?:\D*)?(?P([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|c|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$)$'
70 | fallback_version = "0.2.0"
71 |
72 | [tool.pdm.build]
73 | includes = []
74 |
75 | [build-system]
76 | requires = ["pdm-backend"]
77 | build-backend = "pdm.backend"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | APScheduler==3.11.0
2 | fleep>=1.0.1
3 | httpx>=0.27.0
4 | jinja2>=3.1.6
5 | nonebot2==2.4.1
6 | nonebot_plugin_alconna==0.54.2
7 | nonebot_plugin_apscheduler>=0.5.0
8 | nonebot_plugin_localstore==0.7.3
9 | nonebot_plugin_session>=0.3.2
10 | nonebot_plugin_userinfo>=0.2.6
11 | openai==1.64.0
12 | pydantic==2.10.5
13 | ruamel.yaml==0.18.10
14 | SQLAlchemy==2.0.38
15 | toml>=0.10.2; python_version < '3.11'
16 | websocket_client==1.8.0
17 | watchdog==6.0.0
18 | aiosqlite>=0.17.0
19 | mcp[cli]>=1.9.0
--------------------------------------------------------------------------------