在 run_bot 中捕获到未处理的异常: {e}", exc_info=True)
126 | handle_critical_error(sys.exc_info())
127 |
128 | async def main():
129 | print("")
130 | print(" _ _ _ _ ")
131 | print(" / \\ _ __ ___ _ _ / \\ | |_ __ ___ ___ _ __ __| |")
132 | print(" / _ \\ | '_ ` _ \\| | | | / _ \\ | | '_ ` _ \\ / _ \\| '_ \\ / _` |")
133 | print(" / ___ \\| | | | | | |_| |/ ___ \\| | | | | | | (_) | | | | (_| |")
134 | print(" /_/ \\_|_| |_| |_|\\__, /_/ \\_|_|_| |_| |_|\\___/|_| |_|\\__,_|")
135 | print(" |___/ ")
136 | print("")
137 |
138 | logger.info(">>> SYSTEM INITIATING...")
139 |
140 | # 并行启动 Uvicorn 服务器和机器人客户端
141 | uvicorn_task = asyncio.create_task(start_uvicorn())
142 |
143 | # 运行机器人客户端(同步)
144 | bot_task = asyncio.to_thread(run_bot)
145 |
146 | await asyncio.gather(uvicorn_task, bot_task)
147 |
148 | if __name__ == "__main__":
149 | # 在主线程中创建事件循环
150 | asyncio.run(main())
151 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # AmyAlmond Chatbot
4 |
5 | [](hhttps://opensource.org/license/mpl-2-0)
6 | [](https://www.python.org/downloads/)
7 | [](https://github.com/shuakami/amyalmond_bot/stargazers)
8 | [](https://github.com/shuakami/amyalmond_bot)
9 | [-yellow.svg)](https://github.com/shuakami/amyalmond_bot/releases)
10 |
11 | [English](README_en.md) | 简体中文
12 |
13 | ⭐ Your go-to chatbot for supercharging group chats ⭐
14 |
15 | [Features](#功能特性) • [Screenshots](#先看效果) • [Docs](#安装部署开发) • [Contribute](#开发与贡献) • [License](#许可证)
16 |
17 |
18 | ## Features
19 |
20 | AmyAlmond is an LLM API-powered smart chatbot designed to seamlessly integrate into QQ groups and channels.
21 |
22 | By leveraging LLM API, AmyAlmond offers context-aware intelligent responses, enhancing user interaction and supporting long-term memory management. Whether it’s automating replies or boosting user engagement, she handles complex conversations like a breeze.
23 |
24 | - 🌈 She uses the **LLM API** to generate human-like responses based on conversation context, with customizable prompts.
25 | - 💗 Integrated with QQ’s official Python SDK, so you don’t have to worry about being blocked.
26 | - 🔥 Automatically recognizes and remembers user names, providing a personalized interaction experience.
27 | - 🧠 Equipped with **long-term and short-term memory**, she can record and recall important information, ensuring continuity in conversations.
28 | - 🐳 Administrators can control her behavior with specific commands.
29 | - ⭐ **Full configuration hot-reloading** reduces restart times, boosting efficiency.
30 | - 🪝 Detailed logs and code comments make debugging and monitoring a breeze.
31 |
32 | ## Curious about the results?
33 |
34 | 
35 | 
36 |
37 | ## Installation/Deployment/Development
38 |
39 |
40 |
41 |
42 |
43 |
44 | Click the image to jump in
45 |
46 |
47 | ## Contributing
48 |
49 | We'd love to have you on board! Whether it’s adding new features, fixing bugs, or improving documentation, your contributions are welcome!
50 |
51 | ### Branch Strategy
52 |
53 | We follow the Git Flow branching model:
54 |
55 | - **main**: The stable branch, always ready for production.
56 | - **develop**: The development branch, where all new features are integrated.
57 | - **feature/**: Feature branches, created from `develop`, merged back once the feature is complete.
58 | - **hotfix/**: Hotfix branches, used to quickly patch bugs, merged back into `main` and `develop`.
59 |
60 | ### How to Contribute
61 |
62 | 1. **Fork this repo**
63 | Fork the project to your GitHub account.
64 |
65 | 2. **Create a branch**
66 | Create a new feature branch for your changes:
67 | ```bash
68 | git checkout -b feature/AmazingFeature
69 | ```
70 |
71 | 3. **Commit your changes**
72 | Commit your code with clear and concise messages:
73 | ```bash
74 | git commit -m 'Add some AmazingFeature'
75 | ```
76 |
77 | 4. **Push to GitHub**
78 | Push your branch to GitHub:
79 | ```bash
80 | git push origin feature/AmazingFeature
81 | ```
82 |
83 | 5. **Create a Pull Request**
84 | Create a Pull Request on GitHub, describing your changes and their impact.
85 |
86 | ## License
87 |
88 | [](https://opensource.org/licenses/MPL-2.0)
89 |
90 | AmyAlmond is licensed under the [MPL 2.0 License](LICENSE). You are free to use, modify, and distribute this project, but you must open source any modified versions and retain the original author's copyright notice.
91 |
92 | ## Disclaimer
93 |
94 | This project is for learning and research purposes only. The developers are not responsible for any consequences resulting from the use of this project. Please ensure compliance with relevant laws and respect others' intellectual property rights when using this project.
95 |
96 | ## Roadmap
97 |
98 | Check out our [Project Board](https://github.com/users/shuakami/projects/1) for the latest updates!
99 |
100 | q(≧▽≦q) You've read this far—how about dropping us a ⭐️?
101 |
102 |
103 |
109 |
115 |
119 |
--------------------------------------------------------------------------------
/core/ace/secure.py:
--------------------------------------------------------------------------------
1 | # core/ace/secure.py
2 | import json
3 | import os
4 | import time
5 | import random
6 | import string
7 | from threading import Thread
8 |
9 |
10 | class VerificationCode:
11 | def __init__(self, data=None):
12 | if data:
13 | self.code = data.get("code")
14 | self.generated_time = data.get("generated_time")
15 | self.used_codes = set(data.get("used_codes", []))
16 | self.last_verified_time = data.get("last_verified_time", 0)
17 | self.last_rejected_time = data.get("last_rejected_time", 0)
18 | else:
19 | self.code = None
20 | self.generated_time = None
21 | self.used_codes = set()
22 | self.last_verified_time = 0
23 | self.last_rejected_time = 0
24 |
25 | def to_dict(self):
26 | return {
27 | "code": self.code,
28 | "generated_time": self.generated_time,
29 | "used_codes": list(self.used_codes),
30 | "last_verified_time": self.last_verified_time,
31 | "last_rejected_time": self.last_rejected_time,
32 | }
33 |
34 | def generate_code(self):
35 | # 生成7天内不重复的验证码
36 | while True:
37 | new_code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
38 | if new_code not in self.used_codes:
39 | self.code = new_code
40 | self.generated_time = time.time()
41 | self.used_codes.add(new_code)
42 | # 定期清理 used_codes,防止无限增长
43 | if len(self.used_codes) > 1000:
44 | self.used_codes = set(list(self.used_codes)[-500:]) # 保留最近500个验证码
45 | break
46 |
47 | def is_valid(self):
48 | if self.code is None:
49 | return False
50 | return time.time() - self.generated_time < 300 # 5分钟有效期
51 |
52 | def is_rejected_recently(self):
53 | return time.time() - getattr(self, 'last_rejected_time', 0) < 3600 # 1小时内拒绝过
54 |
55 | def mark_rejected(self):
56 | self.last_rejected_time = time.time()
57 |
58 | def mark_verified(self):
59 | self.last_verified_time = time.time()
60 |
61 | def is_verified_recently(self):
62 | return time.time() - self.last_verified_time < 604800 # 7天内验证过
63 |
64 | class SecureInterface:
65 | def __init__(self):
66 | self.secure_file = "configs/secure.json"
67 | self.verification_code = self._load_verification_code()
68 |
69 | def _load_verification_code(self):
70 | if os.path.exists(self.secure_file):
71 | try:
72 | with open(self.secure_file, "r") as f:
73 | data = json.load(f)
74 | return VerificationCode(data)
75 | except json.JSONDecodeError:
76 | # 文件为空或格式错误,返回默认的 VerificationCode 对象
77 | return VerificationCode()
78 | else:
79 | return VerificationCode()
80 |
81 | def _save_verification_code(self):
82 | # 获取文件所在的目录路径
83 | directory = os.path.dirname(self.secure_file)
84 |
85 | # 如果目录不存在,则创建目录
86 | if not os.path.exists(directory):
87 | os.makedirs(directory)
88 |
89 | # 打开文件并写入数据
90 | with open(self.secure_file, "w") as f:
91 | json.dump(self.verification_code.to_dict(), f)
92 |
93 | def _show_verification_dialog(self):
94 | if self.verification_code.is_valid():
95 | code = self.verification_code.code
96 | else:
97 | self.verification_code.generate_code()
98 | code = self.verification_code.code
99 |
100 | print(f" 验证码: {code}")
101 | print(f"您的关键API在被请求,触发了ACE模块拦截。您可以:")
102 | print("1. 输入验证码以允许本次请求")
103 | print("2. 拒绝本次请求")
104 | print("3. 强制拒绝本次及后续1小时内的所有请求")
105 |
106 | while True:
107 | user_input = input("请选择 (1/2/3): ")
108 | if user_input in ["1", "2", "3"]:
109 | break
110 | print("无效的选择,请重新输入")
111 |
112 | if user_input == "1":
113 | verification_code_input = input("请输入验证码: ")
114 | if verification_code_input == code:
115 | self.verification_code.mark_verified() # 标记为已验证
116 | return True
117 | else:
118 | print("验证码错误")
119 | return False
120 | elif user_input == "2":
121 | return False
122 | elif user_input == "3":
123 | self.last_rejected_time = time.time()
124 | return False
125 |
126 | def verify_request(self):
127 | if time.time() - self.verification_code.last_rejected_time < 3600:
128 | return False # 1小时内强制拒绝过,直接拒绝
129 |
130 | if self.verification_code.is_verified_recently():
131 | return True # 7天内验证过,直接允许
132 |
133 | # 在新线程中打开对话框,避免阻塞主线程
134 | verification_result = False
135 |
136 | def verification_thread():
137 | nonlocal verification_result
138 | verification_result = self._show_verification_dialog()
139 |
140 | thread = Thread(target=verification_thread)
141 | thread.start()
142 | thread.join()
143 |
144 | if verification_result:
145 | self.verification_code.code = None # 验证成功后清空验证码,7天内不再重复验证
146 | self._save_verification_code() # 保存数据到文件
147 | return True
148 | else:
149 | self._save_verification_code() # 保存数据到文件
150 | return False
--------------------------------------------------------------------------------
/core/api/controllers/es_controller.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, HTTPException, Body, Depends
2 | from core.ace.secure import SecureInterface
3 | from pydantic import BaseModel
4 | from core.utils.logger import get_logger
5 | from core.db.elasticsearch_index_manager import ElasticsearchIndexManager
6 |
7 | logger = get_logger()
8 | router = APIRouter()
9 |
10 | # 依赖注入 ElasticsearchIndexManager 实例
11 | async def get_es():
12 | es = ElasticsearchIndexManager()
13 | try:
14 | yield es
15 | finally:
16 | pass # Elasticsearch 连接不需要手动关闭
17 |
18 | class UpdateDocumentModel(BaseModel):
19 | update: dict
20 |
21 | @router.get("/indices")
22 | async def get_all_indices(es: ElasticsearchIndexManager = Depends(get_es)):
23 | """获取所有索引名称"""
24 | secure_interface = SecureInterface()
25 | if not secure_interface.verify_request():
26 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
27 |
28 | try:
29 | indices = es.get_all_indices()
30 | return {"status": "success", "indices": indices}
31 | except Exception as e:
32 | logger.error(f"获取索引列表时出错: {e}")
33 | raise HTTPException(status_code=500, detail=str(e))
34 |
35 | @router.get("/mapping/{index_name}")
36 | async def get_index_mapping(index_name: str, es: ElasticsearchIndexManager = Depends(get_es)):
37 | """获取指定索引的映射"""
38 | secure_interface = SecureInterface()
39 | if not secure_interface.verify_request():
40 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
41 |
42 | try:
43 | mapping = es.get_index_mapping(index_name)
44 | if mapping:
45 | return {"status": "success", "mapping": mapping}
46 | else:
47 | return {"status": "not_found", "message": f"索引 '{index_name}' 不存在"}
48 | except Exception as e:
49 | logger.error(f"获取索引映射时出错: {e}")
50 | raise HTTPException(status_code=500, detail=str(e))
51 |
52 | @router.get("/documents/{index_name}")
53 | async def get_all_documents(index_name: str, es: ElasticsearchIndexManager = Depends(get_es)):
54 | """获取指定索引中的所有文档"""
55 | secure_interface = SecureInterface()
56 | if not secure_interface.verify_request():
57 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
58 |
59 | try:
60 | # 使用 match_all 查询获取所有文档
61 | query = {"query": {"match_all": {}}}
62 | documents = es.search(index_name, query)
63 | return {"status": "success", "documents": documents}
64 | except Exception as e:
65 | logger.error(f"获取文档列表时出错: {e}")
66 | raise HTTPException(status_code=500, detail=str(e))
67 |
68 | @router.post("/documents/{index_name}")
69 | async def insert_document(index_name: str, document: dict = Body(...), es: ElasticsearchIndexManager = Depends(get_es)):
70 | """向指定索引中插入文档"""
71 | secure_interface = SecureInterface()
72 | if not secure_interface.verify_request():
73 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
74 |
75 | try:
76 | # 使用 Elasticsearch 的 index API 插入文档
77 | result = es.es.index(index=index_name, body=document) # 直接使用 es.es
78 | return {"status": "success", "inserted_id": result["_id"]}
79 | except Exception as e:
80 | logger.error(f"插入文档时出错: {e}")
81 | raise HTTPException(status_code=500, detail=str(e))
82 |
83 | @router.get("/documents/{index_name}/find")
84 | async def find_document(index_name: str, query: dict = Body(...), es: ElasticsearchIndexManager = Depends(get_es)):
85 | """根据查询条件查找文档"""
86 | secure_interface = SecureInterface()
87 | if not secure_interface.verify_request():
88 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
89 |
90 | try:
91 | documents = es.search(index_name, query)
92 | if documents:
93 | return {"status": "success", "documents": documents}
94 | else:
95 | return {"status": "not_found", "message": "未找到符合条件的文档"}
96 | except Exception as e:
97 | logger.error(f"查找文档时出错: {e}")
98 | raise HTTPException(status_code=500, detail=str(e))
99 |
100 | @router.put("/documents/{index_name}/update/{document_id}")
101 | async def update_document(index_name: str, document_id: str, update_data: UpdateDocumentModel = Body(...), es: ElasticsearchIndexManager = Depends(get_es)):
102 | """根据文档ID更新文档"""
103 | secure_interface = SecureInterface()
104 | if not secure_interface.verify_request():
105 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
106 |
107 | try:
108 | # 使用 Elasticsearch 的 update API 更新文档
109 | result = es.es.update(index=index_name, id=document_id, body=update_data.update) # 直接使用 es.es
110 | return {"status": "success", "result": result}
111 | except Exception as e:
112 | logger.error(f"更新文档时出错: {e}")
113 | raise HTTPException(status_code=500, detail=str(e))
114 |
115 | @router.delete("/documents/{index_name}/delete/{document_id}")
116 | async def delete_document(index_name: str, document_id: str, es: ElasticsearchIndexManager = Depends(get_es)):
117 | """根据文档ID删除文档"""
118 | secure_interface = SecureInterface()
119 | if not secure_interface.verify_request():
120 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
121 |
122 | try:
123 | success = es.delete_document(index_name, document_id)
124 | if success:
125 | return {"status": "success", "message": f"文档 '{document_id}' 已删除"}
126 | else:
127 | return {"status": "not_found", "message": f"索引 '{index_name}' 或文档 '{document_id}' 不存在"}
128 | except Exception as e:
129 | logger.error(f"删除文档时出错: {e}")
130 | raise HTTPException(status_code=500, detail=str(e))
--------------------------------------------------------------------------------
/tools/db_tools.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import subprocess
4 |
5 |
6 | from game import play_game
7 |
8 |
9 | def run_external_script(script_relative_path):
10 | """分离运行外部脚本"""
11 | script_dir = os.path.dirname(os.path.abspath(__file__))
12 | script_path = os.path.abspath(os.path.join(script_dir, script_relative_path))
13 | try:
14 | subprocess.run([sys.executable, script_path], check=True)
15 | except subprocess.CalledProcessError as e:
16 | print(f"执行 {script_path} 时出错: {e}")
17 | return False
18 | return True
19 |
20 |
21 | def manage_mongodb(action):
22 | """管理MongoDB的安装、启动和配置"""
23 | if action == '1': # 安装
24 | print("开始安装MongoDB...")
25 | if not run_external_script("install/mongodb/mongodb_install.py"):
26 | print("MongoDB安装失败。")
27 | return False
28 | elif action == '2': # 启动
29 | print("开始启动MongoDB...")
30 | if not run_external_script("setup/mongodb/mongodb_setup.py"):
31 | print("MongoDB启动失败。")
32 | return False
33 | elif action == '3': # 配置
34 | if not run_external_script("setup/mongodb/mongodb_setup_configs.py"):
35 | print("MongoDB启动失败。")
36 | return False
37 | return True
38 |
39 |
40 | def print_help():
41 | """打印帮助信息,详细解释各个操作及其使用方法"""
42 | help_text = """
43 | +----------------------------------------------------------------------------------+
44 | | 帮助文档 |
45 | +----------------------------------------------------------------------------------+
46 | | 欢迎来到数据库管理工具的帮助页面! |
47 | | 在这里,我们将详细介绍每一个功能的作用,以及如何使用这个工具来高效地管理你的数据库。|
48 | +----------------------------------------------------------------------------------+
49 |
50 | 1. 安装 (帮助你安装数据库)
51 | - 这个选项将引导你通过一个简单的步骤来安装MongoDB或Elasticsearch数据库。
52 | - 安装过程是自动化的,你只需要选择数据库类型,然后工具将自动运行对应的安装脚本。
53 | - 适用于初次使用该数据库或在新环境中重新搭建数据库的用户。
54 |
55 | 2. 启动 (帮助你启动数据库)
56 | - 如果你已经安装了数据库,并且需要启动它们,那么这个选项适合你。
57 | - 启动操作将执行预设的启动脚本,确保数据库正确启动,并且能够接受连接。
58 | - 通常用于在服务器重新启动后,或者需要手动启动数据库的场景。
59 |
60 | 3. 配置 (配置数据库账号密码)
61 | - 此选项用于配置数据库的基本安全设置,比如账号和密码。
62 | - 对于MongoDB和Elasticsearch,我们提供了专用的配置脚本来设置数据库的账号和密码。
63 | - 强烈建议在生产环境中进行适当的配置,以确保数据库的安全性。
64 |
65 | 4. 升级原数据库(数据迁移)
66 | - 当你需要将现有的数据库升级到新版本或迁移数据时,选择这个选项。
67 | - 该操作将自动调用升级脚本,确保数据在升级过程中不会丢失。
68 | - 请在操作之前备份你的数据,以防万一。
69 |
70 | +----------------------------------------------------------------------------------+
71 | | 注意:每一个操作都有它特定的目的,请根据你的需求选择相应的功能。
72 | | 如果你在使用过程中遇到了问题,建议参考对应的日志文件,以获取更详细的信息。
73 | | 我们的目标是让你以最少的操作成本,完成对数据库的管理工作。
74 | +----------------------------------------------------------------------------------+
75 |
76 |
77 |
78 |
79 | egg you cai dan o
80 | """
81 | print(help_text)
82 |
83 | def manage_elasticsearch(action):
84 | """管理Elasticsearch的安装、启动和配置"""
85 | if action == '1': # 安装
86 | print("开始安装Elasticsearch...")
87 | if not run_external_script("install/elasticsearch/elasticsearch_install.py"):
88 | print("Elasticsearch安装失败。")
89 | return False
90 | elif action == '2': # 启动
91 | print("开始启动Elasticsearch...")
92 | if not run_external_script("setup/elasticsearch/elasticsearch_setup.py"):
93 | print("Elasticsearch启动失败。")
94 | return False
95 | elif action == '3': # 配置
96 | print("开始配置Elasticsearch...")
97 | if not run_external_script("setup/elasticsearch/elasticsearch_configs.py"):
98 | print("Elasticsearch启动失败。")
99 | return False
100 | return True
101 |
102 |
103 | if __name__ == "__main__":
104 | if len(sys.argv) > 1 and sys.argv[1].lower() == 'egg':
105 | play_game()
106 | sys.exit(0)
107 |
108 | print("+----------------------------------------+")
109 | print("| 欢迎使用数据库管理工具 |")
110 | print("+----------------------------------------+")
111 | print("| 请选择操作: |")
112 | print("| 1. 安装 (帮助你安装数据库) |")
113 | print("| 2. 启动 (帮助你启动数据库) |")
114 | print("| 3. 配置 (配置数据库账号密码) |")
115 | print("| 4. 升级原数据库(数据迁移) |")
116 | print("| 需要帮助请按在脚本后缀加h(就是h,不是-h) |")
117 | print("+----------------------------------------+")
118 |
119 | if len(sys.argv) > 1 and sys.argv[1] == 'h':
120 | print_help()
121 | sys.exit(0)
122 |
123 | choice = input("请输入数字选择操作: ")
124 |
125 | if choice == '4':
126 | print("开始升级原数据库...")
127 | if not run_external_script("upgrade/db_upgrade.py"):
128 | print("数据库升级失败。")
129 | sys.exit(1)
130 |
131 | if choice not in ['1', '2', '3']:
132 | print("无效的选择,程序将退出。")
133 | sys.exit(1)
134 |
135 | print("+----------------------------------------+")
136 | print("| 请选择数据库: |")
137 | print("| 1. MongoDB |")
138 | print("| 2. Elasticsearch |")
139 | print("+----------------------------------------+")
140 |
141 | db_choice = input("请输入数字选择数据库: ")
142 |
143 | if db_choice == '1':
144 | print(f"您选择了MongoDB,执行 {choice} 操作。")
145 | if not manage_mongodb(choice):
146 | print("MongoDB操作失败。")
147 | sys.exit(1)
148 | elif db_choice == '2':
149 | print(f"您选择了Elasticsearch,执行 {choice} 操作。")
150 | if not manage_elasticsearch(choice):
151 | print("Elasticsearch操作失败。")
152 | sys.exit(1)
153 | else:
154 | print("无效的选择,程序将退出。")
155 | sys.exit(1)
156 |
157 | print("操作完成。")
158 | sys.exit(0)
159 |
--------------------------------------------------------------------------------
/tools/setup/mongodb/mongodb_setup_configs.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import getpass
3 | import sys
4 | from pathlib import Path
5 | from pymongo import MongoClient, errors
6 |
7 | # 配置文件路径
8 | MONGO_CONFIG_PATH = Path(__file__).parent.parent.parent / "configs/mongodb.yaml"
9 | PROJECT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "configs/config.yaml"
10 |
11 |
12 | def prompt_user_for_mongo_credentials():
13 | print("+-----------------------------------------------------+")
14 | print("| MongoDB 配置尚未设置。请输入用户名和密码: |")
15 | print("+-----------------------------------------------------+")
16 |
17 | username = input("请输入 MongoDB 用户名: ").strip()
18 | while not username:
19 | print("用户名不能为空,请重新输入。")
20 | username = input("请输入 MongoDB 用户名: ").strip()
21 |
22 | password = getpass.getpass("请输入 MongoDB 密码: ").strip()
23 | while not password:
24 | print("密码不能为空,请重新输入。")
25 | password = getpass.getpass("请输入 MongoDB 密码: ").strip()
26 |
27 | return username, password
28 |
29 |
30 | def update_mongo_config(username, password):
31 | try:
32 | # 确保配置目录存在
33 | if not MONGO_CONFIG_PATH.parent.exists():
34 | MONGO_CONFIG_PATH.parent.mkdir(parents=True)
35 |
36 | # 加载现有的 mongodb.yaml 配置文件,如果存在
37 | config = {}
38 | if MONGO_CONFIG_PATH.exists():
39 | with open(MONGO_CONFIG_PATH, 'r', encoding='utf-8') as f:
40 | config = yaml.safe_load(f) or {}
41 |
42 | # 更新 mongodb.yaml 文件中的配置
43 | config['mongodb'] = {'username': username, 'password': password}
44 |
45 | # 保存更新后的配置到 mongodb.yaml 文件
46 | with open(MONGO_CONFIG_PATH, 'w', encoding='utf-8') as f:
47 | yaml.dump(config, f, allow_unicode=True)
48 |
49 | # 更新项目根目录的 config.yaml 文件
50 | if PROJECT_CONFIG_PATH.exists():
51 | with open(PROJECT_CONFIG_PATH, 'r', encoding='utf-8') as f:
52 | project_config_lines = f.readlines()
53 | else:
54 | project_config_lines = []
55 |
56 | # 移除旧的 MongoDB 配置
57 | new_config_lines = []
58 | inside_mongodb_block = False
59 | for line in project_config_lines:
60 | if line.strip() == '# ---------- MongoDB Configuration ----------':
61 | inside_mongodb_block = True
62 | continue
63 | if line.strip() == '# ---------- End MongoDB Configuration ------':
64 | inside_mongodb_block = False
65 | continue
66 | if not inside_mongodb_block:
67 | new_config_lines.append(line)
68 |
69 | # 将新的 MongoDB 配置添加到文件末尾
70 | new_config_lines.append('\n')
71 | new_config_lines.append('# ---------- MongoDB Configuration ----------\n')
72 | new_config_lines.append(f'mongodb_url: "mongodb://localhost:27017"\n')
73 | new_config_lines.append(f'mongodb_username: "{username}"\n')
74 | new_config_lines.append(f'mongodb_password: "{password}"\n')
75 | new_config_lines.append('# ---------- End MongoDB Configuration ------\n')
76 |
77 | # 保存更新后的内容到 config.yaml 文件
78 | with open(PROJECT_CONFIG_PATH, 'w', encoding='utf-8') as f:
79 | f.writelines(new_config_lines)
80 |
81 | print(f"> MongoDB 配置已保存至:{MONGO_CONFIG_PATH} 和 {PROJECT_CONFIG_PATH}")
82 | print(f"> -------------------------------------------------")
83 | print(f"> 请不要擅自修改已添加的配置内容及注释,否则可能导致配置系统无法正常工作。")
84 | print(f"> -------------------------------------------------")
85 |
86 | except Exception as e:
87 | print(f"! 保存 MongoDB 配置时出错:{e}")
88 | raise
89 |
90 |
91 | def apply_mongo_config(username, password):
92 | try:
93 | client = MongoClient("mongodb://localhost:27017/")
94 | db = client.admin
95 |
96 | # 创建管理员用户
97 | db.command("createUser", username, pwd=password, roles=[{"role": "root", "db": "admin"}])
98 |
99 | print("> MongoDB 配置已成功应用。")
100 |
101 | # 验证连接是否成功
102 | test_mongo_connection(username, password)
103 |
104 | except errors.OperationFailure as err:
105 | print(f"! 应用 MongoDB 配置时失败:{err}")
106 | # 如果报错already exists
107 | if "already exists" in str(err):
108 | print("> 用户已存在,跳过创建用户步骤。")
109 | raise
110 | except Exception as e:
111 | print(f"! 应用 MongoDB 配置时出错:{e}")
112 | # 如果报错already exists
113 | if "already exists" in str(e):
114 | print("> 用户已存在,跳过创建用户步骤。")
115 | # 检测一下链接
116 | if test_mongo_connection(username, password):
117 | print("> 使用新的用户名和密码连接 MongoDB 成功!")
118 | raise
119 |
120 |
121 | def test_mongo_connection(username, password):
122 | try:
123 | uri = f"mongodb://{username}:{password}@localhost:27017/"
124 | client = MongoClient(uri, serverSelectionTimeoutMS=5000)
125 | client.server_info()
126 | print("> 使用新的用户名和密码连接 MongoDB 成功!")
127 | return True
128 | except errors.ServerSelectionTimeoutError as err:
129 | print(f"! 使用新的用户名和密码无法连接到MongoDB服务器:{err}")
130 | return False
131 | except Exception as e:
132 | print(f"! 测试 MongoDB 连接时发生错误:{e}")
133 | return False
134 |
135 |
136 | def configure_mongodb():
137 | try:
138 | # 提示用户输入用户名和密码
139 | username, password = prompt_user_for_mongo_credentials()
140 |
141 | # 更新配置文件
142 | update_mongo_config(username, password)
143 |
144 | # 应用到 MongoDB 并验证
145 | apply_mongo_config(username, password)
146 |
147 | except Exception as e:
148 | print(f"! 配置 MongoDB 时发生错误:{e}")
149 | sys.exit(1)
150 |
151 |
152 | if __name__ == "__main__":
153 | print("> 开始MongoDB配置...")
154 | print(f"> MongoDB 配置文件路径:{MONGO_CONFIG_PATH}")
155 | print(f"> 项目配置文件路径:{PROJECT_CONFIG_PATH}")
156 | configure_mongodb()
157 | print("> MongoDB配置完成。")
158 | sys.exit(0)
159 |
--------------------------------------------------------------------------------
/core/api/controllers/db_controller.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, HTTPException, Body, Depends
2 | from core.ace.secure import SecureInterface
3 | from pydantic import BaseModel
4 | from core.utils.logger import get_logger
5 | from core.utils.mongodb_utils import MongoDBUtils
6 |
7 | logger = get_logger()
8 | router = APIRouter()
9 |
10 | # 依赖注入 MongoDBUtils 实例
11 | async def get_db():
12 | db = MongoDBUtils()
13 | try:
14 | yield db
15 | finally:
16 | db.close_connection()
17 |
18 | class UpdateDocumentModel(BaseModel):
19 | update: dict
20 |
21 | @router.get("/databases")
22 | async def get_all_databases(db: MongoDBUtils = Depends(get_db)):
23 | """获取所有数据库名称"""
24 | secure_interface = SecureInterface()
25 | if not secure_interface.verify_request():
26 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
27 |
28 | try:
29 | databases = db.get_all_database_names()
30 | return {"status": "success", "databases": databases}
31 | except Exception as e:
32 | logger.error(f"获取数据库列表时出错: {e}")
33 | raise HTTPException(status_code=500, detail=str(e))
34 |
35 | @router.get("/collections/{db_name}")
36 | async def get_all_collections(db_name: str, db: MongoDBUtils = Depends(get_db)):
37 | """获取指定数据库中的所有集合名称"""
38 | secure_interface = SecureInterface()
39 | if not secure_interface.verify_request():
40 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
41 |
42 | try:
43 | collections = db.get_all_collection_names(db_name)
44 | return {"status": "success", "collections": collections}
45 | except Exception as e:
46 | logger.error(f"获取数据库集合列表时出错: {e}")
47 | raise HTTPException(status_code=500, detail=str(e))
48 |
49 | @router.get("/documents/{db_name}/{collection_name}")
50 | async def get_all_documents(db_name: str, collection_name: str, db: MongoDBUtils = Depends(get_db)):
51 | """获取指定集合中的所有文档"""
52 | secure_interface = SecureInterface()
53 | if not secure_interface.verify_request():
54 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
55 |
56 | try:
57 | # 获取集合对象
58 | collection = db.client[db_name][collection_name]
59 | documents = list(collection.find({}))
60 |
61 | # 将 ObjectId 转换为字符串
62 | for document in documents:
63 | if "_id" in document:
64 | document["_id"] = str(document["_id"])
65 |
66 | return {"status": "success", "documents": documents}
67 | except Exception as e:
68 | logger.error(f"获取数据库文档列表时出错: {e}")
69 | raise HTTPException(status_code=500, detail=str(e))
70 |
71 | @router.post("/documents/{db_name}/{collection_name}")
72 | async def insert_document(db_name: str, collection_name: str, document: dict = Body(...), db: MongoDBUtils = Depends(get_db)):
73 | """向指定集合中插入文档"""
74 | secure_interface = SecureInterface()
75 | if not secure_interface.verify_request():
76 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
77 |
78 | try:
79 | # 获取集合对象
80 | collection = db.client[db_name][collection_name]
81 | result = collection.insert_one(document)
82 | return {"status": "success", "inserted_id": str(result.inserted_id)}
83 | except Exception as e:
84 | logger.error(f"插入数据库文档时出错: {e}")
85 | raise HTTPException(status_code=500, detail=str(e))
86 |
87 |
88 | @router.post("/documents/find/{db_name}/{collection_name}")
89 | async def find_document(db_name: str, collection_name: str, query: dict = Body(...),
90 | db: MongoDBUtils = Depends(get_db)):
91 | """根据查询条件查找文档"""
92 | secure_interface = SecureInterface()
93 | if not secure_interface.verify_request():
94 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
95 |
96 | try:
97 | # 获取集合对象
98 | collection = db.client[db_name][collection_name]
99 | # 查找文档
100 | cursor = collection.find(query)
101 | result = []
102 | for document in cursor:
103 | if "_id" in document:
104 | document["_id"] = str(document["_id"]) # 将 ObjectId 转换为字符串
105 | result.append(document)
106 |
107 | return {"status": "success", "documents": result}
108 |
109 | except Exception as e:
110 | logger.error(f"查找数据库文档时出错: {e}")
111 | raise HTTPException(status_code=500, detail=str(e))
112 |
113 |
114 | @router.put("/update/documents/{db_name}/{collection_name}")
115 | async def update_document(db_name: str, collection_name: str, query: dict = Body(...), update_data: UpdateDocumentModel = Body(...), db: MongoDBUtils = Depends(get_db)):
116 | """根据查询条件更新文档"""
117 | secure_interface = SecureInterface()
118 | if not secure_interface.verify_request():
119 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
120 |
121 | try:
122 | # 获取集合对象
123 | collection = db.client[db_name][collection_name]
124 | result = collection.update_one(query, update_data.update)
125 | return {"status": "success", "matched_count": result.matched_count, "modified_count": result.modified_count}
126 | except Exception as e:
127 | logger.error(f"更新数据库文档时出错: {e}")
128 | raise HTTPException(status_code=500, detail=str(e))
129 |
130 | @router.delete("/delete/documents/{db_name}/{collection_name}")
131 | async def delete_document(db_name: str, collection_name: str, query: dict = Body(...), db: MongoDBUtils = Depends(get_db)):
132 | """根据查询条件删除文档"""
133 | secure_interface = SecureInterface()
134 | if not secure_interface.verify_request():
135 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"}
136 |
137 | try:
138 | # 获取集合对象
139 | collection = db.client[db_name][collection_name]
140 | result = collection.delete_one(query)
141 | return {"status": "success", "deleted_count": result.deleted_count}
142 | except Exception as e:
143 | logger.error(f"删除数据库文档时出错: {e}")
144 | raise HTTPException(status_code=500, detail=str(e))
--------------------------------------------------------------------------------
/core/api/controllers/plugin_controller.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 |
3 | import requests
4 | from fastapi import APIRouter, UploadFile, File, HTTPException, Body
5 | from core.plugins.plugin_manager import PluginManager
6 | from core.plugins.tools.add_plugin import create_plugin
7 | from core.utils.logger import get_logger
8 | import tempfile
9 | import os
10 | import shutil
11 |
12 | logger = get_logger()
13 | router = APIRouter()
14 | plugin_manager = PluginManager(bot_client=None)
15 |
16 |
17 |
18 | @router.post("/install")
19 | async def install_plugin(file: UploadFile = File(...)):
20 | if not file:
21 | raise HTTPException(status_code=400, detail="上传文件不存在")
22 |
23 | try:
24 | # 使用 tempfile 来生成一个跨平台的临时文件
25 | with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
26 | shutil.copyfileobj(file.file, tmp_file)
27 | tmp_file_path = tmp_file.name
28 |
29 | logger.info(f"文件 {file.filename} 已上传到 {tmp_file_path}")
30 |
31 | # 获取插件的目标安装目录
32 | plugin_name = os.path.splitext(file.filename)[0]
33 | plugins_dir = os.path.join("core", "plugins", plugin_name)
34 |
35 | if os.path.exists(plugins_dir):
36 | logger.warning(f"插件目录 {plugins_dir} 已存在,正在删除...")
37 | shutil.rmtree(plugins_dir)
38 |
39 | os.makedirs(plugins_dir, exist_ok=True)
40 |
41 | # 解压到插件目录
42 | with zipfile.ZipFile(tmp_file_path, 'r') as zip_ref:
43 | zip_ref.extractall(plugins_dir)
44 |
45 | logger.info(f"插件 {file.filename} 已成功解压到 {plugins_dir}")
46 | plugin_manager.load_plugins() # 重新加载插件
47 |
48 | return {"status": "success", "message": f"插件 {file.filename} 已成功安装"}
49 | except Exception as e:
50 | logger.error(f"安装插件时出错: {e}")
51 | raise HTTPException(status_code=500, detail=str(e))
52 | finally:
53 | # 清理上传的文件
54 | if os.path.exists(tmp_file_path):
55 | os.remove(tmp_file_path)
56 |
57 | @router.post("/url_install")
58 | async def install_plugin_from_url(url: str):
59 | """
60 | 通过提供zip文件链接安装插件。
61 | """
62 | tmp_file_path = None # 初始化临时文件路径
63 | try:
64 | # 处理 zip 文件链接
65 | response = requests.get(url, stream=True)
66 | response.raise_for_status()
67 | with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
68 | for chunk in response.iter_content(chunk_size=8192):
69 | tmp_file.write(chunk)
70 | tmp_file_path = tmp_file.name
71 | logger.info(f"文件 {url} 已下载到 {tmp_file_path}")
72 | file_name = os.path.basename(url) # 从链接中提取文件名
73 |
74 | # 获取插件的目标安装目录
75 | plugin_name = os.path.splitext(file_name)[0]
76 | plugins_dir = os.path.join("core", "plugins", plugin_name)
77 |
78 | if os.path.exists(plugins_dir):
79 | logger.warning(f"插件目录 {plugins_dir} 已存在,正在删除...")
80 | shutil.rmtree(plugins_dir)
81 |
82 | os.makedirs(plugins_dir, exist_ok=True)
83 |
84 | # 解压到插件目录
85 | with zipfile.ZipFile(tmp_file_path, 'r') as zip_ref:
86 | zip_ref.extractall(plugins_dir)
87 |
88 | logger.info(f"插件 {file_name} 已成功解压到 {plugins_dir}")
89 | plugin_manager.load_plugins() # 重新加载插件
90 |
91 | return {"status": "success", "message": f"插件 {file_name} 已成功安装"}
92 |
93 | except Exception as e:
94 | logger.error(f"安装插件时出错: {e}")
95 | raise HTTPException(status_code=500, detail=str(e))
96 | finally:
97 | # 清理下载的临时文件
98 | if tmp_file_path and os.path.exists(tmp_file_path):
99 | os.remove(tmp_file_path)
100 |
101 |
102 | @router.post("/uninstall")
103 | async def uninstall_plugin(plugin_name: str):
104 | try:
105 | plugin_manager.uninstall_plugin(plugin_name)
106 | return {"status": "success", "message": f"插件 {plugin_name} 已成功卸载"}
107 | except Exception as e:
108 | logger.error(f"卸载插件时出错: {e}")
109 | raise HTTPException(status_code=500, detail=str(e))
110 |
111 |
112 | # @router.post("/enable")
113 | # async def enable_plugin(plugin_name: str):
114 | # try:
115 | # plugin_manager.enable_plugin(plugin_name)
116 | # return {"status": "success", "message": f"插件 {plugin_name} 已启用"}
117 | # except Exception as e:
118 | # logger.error(f"启用插件时出错: {e}")
119 | # raise HTTPException(status_code=500, detail=str(e))
120 | #
121 | #
122 | # @router.post("/disable")
123 | # async def disable_plugin(plugin_name: str):
124 | # try:
125 | # plugin_manager.disable_plugin(plugin_name)
126 | # return {"status": "success", "message": f"插件 {plugin_name} 已禁用"}
127 | # except Exception as e:
128 | # logger.error(f"禁用插件时出错: {e}")
129 | # raise HTTPException(status_code=500, detail=str(e))
130 |
131 |
132 | @router.get("/list")
133 | async def get_plugin_list():
134 | try:
135 | plugin_list = plugin_manager.get_plugin_list()
136 | return {"status": "success", "plugins": plugin_list}
137 | except Exception as e:
138 | logger.error(f"获取插件列表时出错: {e}")
139 | raise HTTPException(status_code=500, detail=str(e))
140 |
141 |
142 | @router.post("/reload")
143 | async def reload_plugins():
144 | """
145 | 热重载所有插件
146 | """
147 | try:
148 | plugin_manager.reload_plugins()
149 | return {"status": "success", "message": "插件已成功热重载"}
150 | except Exception as e:
151 | logger.error(f"热重载插件时出错: {e}")
152 | raise HTTPException(status_code=500, detail=str(e))
153 |
154 | @router.post("/add_plugin")
155 | async def add_plugin(
156 | system_prompt: str = Body(..., embed=True),
157 | user_input: str = Body(..., embed=True)
158 | ):
159 | """
160 | 使用LLM帮助用户创建插件的API接口
161 |
162 | 参数:
163 | system_prompt (str): 逗号分隔的系统提示词,将在服务端转换为列表
164 | user_input (str): 用户输入的插件需求
165 |
166 | 返回:
167 | dict: LLM生成的插件代码或相关信息
168 | """
169 | try:
170 | # 调用 create_plugin 函数来生成插件代码
171 | result = await create_plugin(system_prompt, user_input)
172 | return {"status": "success", "plugin_code": result}
173 | except Exception as e:
174 | raise HTTPException(status_code=500, detail=f"创建插件失败: {e}")
--------------------------------------------------------------------------------
/core/llm/plugins/openai_client.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 | import httpx
4 | from core.utils.logger import get_logger
5 | from core.llm.llm_client import LLMClient
6 | from config import REQUEST_TIMEOUT
7 |
8 | _log = get_logger()
9 |
10 |
11 | class OpenAIClient(LLMClient):
12 | """
13 | OpenAI API 客户端,实现了 LLMClient 接口。
14 | """
15 |
16 | async def on_message(self, message, reply_message):
17 | pass
18 |
19 | def __init__(self, openai_secret, openai_model, openai_api_url):
20 | self.openai_secret = openai_secret
21 | self.openai_model = openai_model
22 | self.openai_api_url = openai_api_url
23 |
24 | # 初始化 last_request_time 和 last_request_content
25 | self.last_request_time = 0
26 | self.last_request_content = None
27 |
28 | # 从配置文件中读取超时设置,默认为7秒
29 | self.timeout = REQUEST_TIMEOUT or 7
30 |
31 | async def get_response(self, context, user_input, system_prompt, retries=2):
32 | """
33 | 根据给定的上下文和用户输入,从 OpenAI 模型获取回复
34 |
35 | 参数:
36 | context (list): 对话上下文,包含之前的对话内容
37 | user_input (str): 用户的输入内容
38 | system_prompt (str): 系统提示
39 | retries (int): 出现错误时的最大重试次数,默认值为2次
40 |
41 | 返回:
42 | str: OpenAI 模型生成的回复内容
43 |
44 | 异常:
45 | httpx.HTTPStatusError: 当请求 OpenAI API 出现问题时引发
46 | """
47 | # 检查是否为重复请求
48 | if time.time() - self.last_request_time < 0.6 and user_input == self.last_request_content and "" not in user_input:
49 | _log.warning(" 检测到重复请求,已忽略:")
50 | _log.warning(f" ↳ 用户输入: {user_input}")
51 | return None
52 |
53 | payload = {
54 | "model": self.openai_model,
55 | "temperature": 0.85,
56 | "top_p": 1,
57 | "presence_penalty": 1,
58 | "max_tokens": 3450,
59 | "messages": [
60 | {"role": "system", "content": system_prompt}
61 | ] + context + [
62 | {"role": "user", "content": user_input}
63 | ]
64 | }
65 |
66 | headers = {
67 | "Content-Type": "application/json",
68 | "Authorization": f"Bearer {self.openai_secret}"
69 | }
70 |
71 | # 记录请求的 payload 和 headers
72 | _log.debug(" 请求参数:")
73 | _log.debug(f" ↳ Payload: {payload}")
74 | _log.debug(f" ↳ Headers: {headers}")
75 |
76 | for attempt in range(retries + 1):
77 | try:
78 | async with httpx.AsyncClient(timeout=self.timeout) as client:
79 | response = await client.post(self.openai_api_url, headers=headers, json=payload)
80 | response.raise_for_status()
81 | response_data = response.json()
82 |
83 | # 记录完整的响应数据
84 | _log.debug(" 完整响应数据:")
85 | _log.debug(f" ↳ {response_data}")
86 |
87 | reply = response_data['choices'][0]['message']['content'] if 'choices' in response_data and \
88 | response_data['choices'][0]['message'][
89 | 'content'] else None
90 |
91 | # 更新 last_request_time 和 last_request_content
92 | self.last_request_time = time.time()
93 | self.last_request_content = user_input
94 |
95 | if reply is None:
96 | _log.warning(" OpenAI 回复为空:")
97 | _log.warning(f" ↳ 用户输入: {user_input}")
98 | else:
99 | # 记录 OpenAI 的回复内容
100 | _log.info(" OpenAI 回复:")
101 | _log.info(f" ↳ 内容: {reply}")
102 |
103 | return reply
104 |
105 | except httpx.HTTPStatusError as e:
106 | _log.error(" 🚨请求错误:")
107 | _log.error(f" ↳ 状态码: {e.response.status_code}")
108 | _log.error(f" ↳ 错误详情: {e}")
109 | _log.error(f" ↳ 返回内容: {e.response.text}")
110 | if e.response.status_code in {503, 504, 500}: # 处理常见错误状态码
111 | _log.info(f"请求失败,状态码:{e.response.status_code}。正在尝试重试...({attempt + 1}/{retries})")
112 | if attempt < retries:
113 | await asyncio.sleep(2) # 等待2秒后重试
114 | continue
115 | return f"请求失败,状态码:{e.response.status_code}。请稍后再试。"
116 |
117 |
118 | except httpx.RequestError as e:
119 | _log.error(" 请求异常:")
120 | _log.error(f" ↳ 错误详情: {e}")
121 | _log.error(f" ↳ 错误类型: {type(e)}")
122 | if attempt < retries:
123 | _log.info(f"请求异常,正在尝试重试...({attempt + 1}/{retries})")
124 | await asyncio.sleep(2) # 等待2秒后重试
125 | continue
126 | return "请求超时或网络错误,请稍后再试。"
127 |
128 |
129 | except Exception as e:
130 |
131 | _log.error(" 未知错误:")
132 | _log.error(f" ↳ 错误详情: {e}")
133 | _log.error(f" ↳ 错误类型: {type(e)}")
134 | if attempt < retries:
135 | _log.info(f"发生未知错误,正在尝试重试...({attempt + 1}/{retries})")
136 | await asyncio.sleep(2) # 等待2秒后重试
137 | continue
138 |
139 | return "发生未知错误,请联系管理员。"
140 |
141 | return "请求失败,请稍后再试。"
142 |
143 | async def test(self):
144 | """
145 | 测试 OpenAIClient 类的方法
146 | """
147 | context = [
148 | {"role": "user", "content": "你好!"}
149 | ]
150 | user_input = "你还记得我之前说了多少个“你好”吗"
151 | system_prompt = "你是一个友好的助手。"
152 |
153 | response = await self.get_response(context, user_input, system_prompt)
154 | print("API Response:", response)
155 |
156 |
157 | # 使用方法
158 | if __name__ == "__main__":
159 | # 配置 OpenAIClient
160 | openai_secret = "sk-s2lDjPP1AdigpPBO53845f5d134a406d96CbE24aEeBe2d36"
161 | openai_model = "Meta-Llama-3.1-8B-Instruct"
162 | openai_api_url = "https://ngedlktfticp.cloud.sealos.io/v1/chat/completions"
163 | # 我可以把我的秘钥公开,因为额度很小,而且是用来测试的。但你一定不要像我一样把秘钥明文写在代码中。
164 |
165 | # 创建 OpenAIClient 实例
166 | client = OpenAIClient(openai_secret=openai_secret, openai_model=openai_model, openai_api_url=openai_api_url)
167 |
168 | # 运行测试
169 | asyncio.run(client.test())
170 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | {
17 | "lastFilter": {
18 | "state": "OPEN",
19 | "assignee": "shuakami"
20 | }
21 | }
22 | {
23 | "selectedUrlAndAccountId": {
24 | "url": "https://github.com/shuakami/amyalmond_bot.git",
25 | "accountId": "d941c301-caf6-4448-9c7f-a86b9568da98"
26 | }
27 | }
28 | {
29 | "associatedIndex": 5
30 | }
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 | 1727063987284
70 |
71 |
72 | 1727063987284
73 |
74 |
75 |
76 |
77 |
78 | 1727064360801
79 |
80 |
81 |
82 | 1727064360801
83 |
84 |
85 |
86 | 1727064672958
87 |
88 |
89 |
90 | 1727064672958
91 |
92 |
93 |
94 | 1727064750333
95 |
96 |
97 |
98 | 1727064750333
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
--------------------------------------------------------------------------------
/tools/upgrade/db_upgrade.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import shutil
4 | import datetime
5 | import sys
6 | import time
7 | from pymongo import MongoClient
8 | from elasticsearch import Elasticsearch, helpers
9 |
10 | # 手动指定项目根目录
11 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
12 | # 将项目根目录添加到 Python 的搜索路径中
13 | sys.path.append(project_root)
14 | from config import MONGODB_URI, MONGODB_USERNAME, MONGODB_PASSWORD, ELASTICSEARCH_URL, ELASTICSEARCH_USERNAME, \
15 | ELASTICSEARCH_PASSWORD
16 |
17 | # 路径配置
18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19 | PROJECT_ROOT = os.path.abspath(os.path.join(BASE_DIR, "../../"))
20 | DATA_DIR = os.path.join(PROJECT_ROOT, "data")
21 | BACKUP_DIR = os.path.join(BASE_DIR, "backup", "data")
22 |
23 | # 连接数据库
24 | mongo_client = MongoClient(MONGODB_URI, username=MONGODB_USERNAME, password=MONGODB_PASSWORD)
25 | mongo_db = mongo_client["amyalmond"]
26 | mongo_collection = mongo_db["conversations"]
27 |
28 | es_client = Elasticsearch(
29 | [ELASTICSEARCH_URL],
30 | basic_auth=(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
31 | )
32 |
33 |
34 | def backup_mongodb():
35 | """备份 MongoDB 数据"""
36 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
37 | backup_dir = os.path.join(BACKUP_DIR, "mongodb", timestamp)
38 |
39 | if not os.path.exists(backup_dir):
40 | os.makedirs(backup_dir)
41 |
42 | for document in mongo_collection.find():
43 | with open(os.path.join(backup_dir, f"{document['_id']}.json"), "w", encoding="utf-8") as file:
44 | json.dump(document, file, default=str)
45 |
46 | print(f"> MongoDB 数据备份完成,备份目录: {backup_dir}")
47 |
48 |
49 | def backup_elasticsearch():
50 | """备份 Elasticsearch 数据"""
51 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
52 | backup_dir = os.path.join(BACKUP_DIR, "elasticsearch", timestamp)
53 |
54 | if not os.path.exists(backup_dir):
55 | os.makedirs(backup_dir)
56 |
57 | query = {"query": {"match_all": {}}}
58 | results = helpers.scan(es_client, index="messages", query=query)
59 |
60 | for i, result in enumerate(results):
61 | with open(os.path.join(backup_dir, f"{i}.json"), "w", encoding="utf-8") as file:
62 | json.dump(result, file, default=str)
63 |
64 | print(f"> Elasticsearch 数据备份完成,备份目录: {backup_dir}")
65 |
66 |
67 | def backup_data():
68 | """备份旧数据"""
69 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
70 | backup_dir = os.path.join(BACKUP_DIR, timestamp)
71 |
72 | if not os.path.exists(backup_dir):
73 | os.makedirs(backup_dir)
74 |
75 | for filename in os.listdir(DATA_DIR):
76 | file_path = os.path.join(DATA_DIR, filename)
77 | if os.path.isfile(file_path):
78 | shutil.copy(file_path, backup_dir)
79 | print(f"> 备份文件: {filename} 到 {backup_dir}")
80 |
81 | print("> 文件数据备份完成.")
82 |
83 |
84 | def prompt_clear_databases():
85 | """提示用户是否清空数据库"""
86 | print("> 即将清空 MongoDB 和 Elasticsearch 数据库。")
87 | print("> 请确认是否要继续,10 秒后将自动清空数据库...")
88 | time.sleep(10)
89 |
90 | print("> 清空 MongoDB 'conversations' 集合...")
91 | mongo_collection.delete_many({})
92 | print("> MongoDB 清空完成.")
93 |
94 | print("> 清空 Elasticsearch 'messages' 索引...")
95 | if es_client.indices.exists(index="messages"):
96 | es_client.options(ignore_status=[400, 404]).indices.delete(index="messages")
97 | es_client.indices.create(index="messages")
98 | print("> Elasticsearch 清空完成.")
99 |
100 |
101 | def migrate_memory_json():
102 | """迁移 memory.json 数据到 MongoDB"""
103 | memory_json_path = os.path.join(DATA_DIR, "memory.json")
104 |
105 | if not os.path.exists(memory_json_path):
106 | print(f"! 未找到 memory.json 文件: {memory_json_path}")
107 | return
108 |
109 | with open(memory_json_path, "r", encoding="utf-8") as file:
110 | try:
111 | memory_data = json.load(file)
112 | except json.JSONDecodeError as e:
113 | print(f"! 解析 memory.json 文件时发生错误: {e}")
114 | return
115 |
116 | for group_id, conversations in memory_data.items():
117 | for conversation in conversations:
118 | try:
119 | document = {
120 | "group_id": group_id,
121 | "role": conversation.get("role"),
122 | "content": conversation.get("content"),
123 | "timestamp": datetime.datetime.now(datetime.timezone.utc)
124 | }
125 | if not document["role"] or not document["content"]:
126 | raise ValueError("无效数据,跳过")
127 |
128 | mongo_collection.insert_one(document)
129 | print(f"> 成功迁移对话记录到 MongoDB: group_id={group_id}, role={document['role']}")
130 |
131 | except Exception as e:
132 | print(f"! 迁移数据时发生错误,跳过: {e}, 数据: {conversation}")
133 |
134 |
135 | def migrate_long_term_memory():
136 | """迁移 long_term_memory_*.txt 数据到 Elasticsearch"""
137 | for filename in os.listdir(DATA_DIR):
138 | if filename.startswith("long_term_memory_") and filename.endswith(".txt"):
139 | group_id = filename.split("long_term_memory_")[-1].replace(".txt", "")
140 | long_term_memory_path = os.path.join(DATA_DIR, filename)
141 |
142 | with open(long_term_memory_path, "r", encoding="utf-8") as file:
143 | lines = file.readlines()
144 |
145 | actions = []
146 | for line in lines:
147 | content = line.strip()
148 | if not content:
149 | continue
150 |
151 | action = {
152 | "_index": "messages",
153 | "_source": {
154 | "group_id": group_id,
155 | "role": "system",
156 | "content": content
157 | }
158 | }
159 | actions.append(action)
160 |
161 | try:
162 | helpers.bulk(es_client, actions)
163 | print(f"> 成功迁移长时间记忆数据到 Elasticsearch: group_id={group_id}, 文件={filename}")
164 |
165 | except Exception as e:
166 | print(f"! 迁移 Elasticsearch 数据时发生错误: {e}, 文件: {filename}")
167 |
168 |
169 | def main():
170 | # 打印项目根目录
171 | print("+------------------------------+")
172 | print("| 开始数据库迁移... |")
173 | print("+------------------------------+")
174 |
175 | # 备份数据
176 | backup_data()
177 | backup_mongodb()
178 | backup_elasticsearch()
179 |
180 | # 提示用户是否清空数据库
181 | prompt_clear_databases()
182 |
183 | # 迁移 memory.json 数据到 MongoDB
184 | migrate_memory_json()
185 |
186 | # 迁移 long_term_memory_*.txt 数据到 Elasticsearch
187 | migrate_long_term_memory()
188 |
189 | print("+------------------------------+")
190 | print("| 数据库迁移完成 |")
191 | print("+------------------------------+")
192 |
193 |
194 | if __name__ == "__main__":
195 | main()
196 |
--------------------------------------------------------------------------------
/core/bot/bot_client.py:
--------------------------------------------------------------------------------
1 | """
2 | AmyAlmond Project - core/bot/bot_client.py
3 |
4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot
5 | Developer: Shuakami <3 LuoXiaoHei
6 |
7 | Copyright (c) 2024 Amyalmond_bot. All rights reserved.
8 | Version: 1.3.0 (Stable_923001)
9 |
10 | bot_client.py 包含 AmyAlmond 机器人的主要客户端类,链接其他模块进行处理。
11 | """
12 | import asyncio
13 | import random
14 | import subprocess
15 | import sys
16 | import watchdog.observers
17 | import botpy
18 | from botpy.message import GroupMessage
19 |
20 | from core.plugins.plugin_manager import PluginManager
21 | # user_management.py模块 - <用户管理模块化文件>
22 | from core.utils.user_management import load_user_names
23 | # utils.py模块 - <工具模块化文件>
24 | from core.utils.utils import load_system_prompt
25 | # config.py模块 - <配置管理模块化文件>
26 | from config import SYSTEM_PROMPT_FILE, test_config
27 | # file_handler.py模块 - <文件处理模块化文件>
28 | from core.utils.file_handler import ConfigFileHandler
29 | # logger.py模块 - <日志记录模块>
30 | from core.utils.logger import get_logger
31 | # message_handler.py模块 - <消息处理模块化文件>
32 | from core.bot.message_handler import MessageHandler
33 | # memory_manager.py模块 - <内存管理模块化文件>
34 | from core.memory.memory_manager import MemoryManager
35 | # keep_alive.py模块 -
36 | from core.keep_alive import keep_alive
37 | # llm_client.py模块 -
38 | from core.llm.llm_factory import LLMFactory
39 |
40 | _log = get_logger()
41 |
42 |
43 | class MyClient(botpy.Client):
44 | """
45 | AmyAlmond 项目的主要客户端类,继承自 botpy.Client
46 | 处理机器人的各种事件和请求
47 | """
48 |
49 | def __init__(self, *args, **kwargs):
50 | """
51 | 初始化客户端
52 |
53 | 初始化待处理用户列表、加载系统提示、设置内存管理器和消息处理器
54 | 读取配置并验证必要的配置项是否设置
55 | 初始化文件系统观察器以监听配置文件变化
56 | """
57 | super().__init__(*args, **kwargs)
58 | # 初始化插件管理器
59 | self.plugin_manager = PluginManager(self)
60 |
61 | # 初始化 LLM 客户端
62 | llm_factory = LLMFactory()
63 | self.llm_client = llm_factory.create_llm_client()
64 |
65 | # 加载插件
66 | self.plugin_manager.register_plugins()
67 |
68 | self.pending_users = {}
69 | self.system_prompt = load_system_prompt(SYSTEM_PROMPT_FILE)
70 | self.memory_manager = MemoryManager()
71 | self.message_handler = MessageHandler(self, self.memory_manager)
72 |
73 | # 读取配置
74 | self.openai_secret = test_config.get("openai_secret", "")
75 | self.openai_model = test_config.get("openai_model", "gpt-4o-mini")
76 | self.openai_api_url = test_config.get("openai_api_url", "https://api.openai-hk.com/v1/chat/completions")
77 | self.ADMIN_ID = test_config.get("admin_id", "")
78 |
79 | if not self.openai_secret:
80 | _log.critical(" OpenAI API 密钥缺失,请检查 config.yaml 文件")
81 | raise ValueError("OpenAI API key is missing in config.yaml")
82 | if not self.openai_model:
83 | _log.critical(" OpenAI 模型缺失,请检查 config.yaml 文件")
84 | raise ValueError("OpenAI model is missing in config.yaml")
85 | if not self.openai_api_url:
86 | _log.critical(" OpenAI API URL 缺失,请检查 config.yaml 文件")
87 | raise ValueError("OpenAI API URL is missing in config.yaml")
88 | if not self.ADMIN_ID:
89 | _log.critical(" 管理员 ID 缺失,请检查 config.yaml 文件")
90 | raise ValueError("Admin ID is missing in config.yaml")
91 |
92 |
93 | # 初始化 last_request_time 和 last_request_content
94 | self.last_request_time = 0
95 | self.last_request_content = None
96 |
97 | # 设置文件监视器
98 | self.observer = watchdog.observers.Observer()
99 | event_handler = ConfigFileHandler(self)
100 | self.observer.schedule(event_handler, path='.', recursive=False)
101 | self.observer.start()
102 |
103 | async def on_message(self, message: botpy.message):
104 | """
105 | 当收到消息时调用
106 |
107 | Args:
108 | message (botpy.Message): 收到的消息对象
109 | """
110 | # 通过事件总线发布 on_message 事件,让所有订阅的插件处理该消息
111 | await self.plugin_manager.event_bus.publish("on_message", message)
112 |
113 | def load_system_prompt(self):
114 | """
115 | 加载机器人SystemPrompt
116 | """
117 | self.system_prompt = load_system_prompt(SYSTEM_PROMPT_FILE)
118 | _log.info(f">>> SYSTEM PROMPT LOADED")
119 | _log.info(f" ↳ Prompt count: {len(self.system_prompt)}")
120 |
121 | def reload_system_prompt(self):
122 | """
123 | 重新加载机器人SystemPrompt
124 | """
125 | self.system_prompt = load_system_prompt(SYSTEM_PROMPT_FILE)
126 | _log.info(">>> SYSTEM PROMPT RELOADED")
127 |
128 | async def on_ready(self):
129 | """
130 | 当机器人准备好时调用
131 | """
132 | _log.info(f">>> ROBOT 「{self.robot.name}」 IS READY!")
133 | load_user_names()
134 |
135 | # 加载记忆
136 | _log.info(">>> MEMORY LOADING...")
137 | await self.memory_manager.load_memory()
138 | _log.info(" ↳ 记忆加载完成")
139 |
140 | # 启动 Keep-Alive 任务
141 | await asyncio.create_task(keep_alive(self.openai_api_url, self.openai_secret))
142 |
143 | # 通知插件准备就绪
144 | await self.plugin_manager.on_ready()
145 |
146 | async def on_group_at_message_create(self, message: GroupMessage):
147 | """
148 | 当接收到群组中提及机器人的消息时调用
149 |
150 | 参数:
151 | message (GroupMessage): 接收到的消息对象
152 | """
153 | await self.message_handler.handle_group_message(message)
154 |
155 | async def get_gpt_response(self, context, user_input):
156 | """
157 | 根据给定的上下文和用户输入,从 LLM 模型获取回复
158 | """
159 | return await self.llm_client.get_response(context, user_input, self.system_prompt)
160 |
161 | async def restart_bot(self, group_id, msg_id):
162 | """
163 | 重启机器人
164 |
165 | 参数:
166 | group_id (str): 羡组ID
167 | msg_id (str): 消息ID
168 | """
169 | await self.api.post_group_message(
170 | group_openid=group_id,
171 | content=f"子网重启,请稍后... ({random.randint(1000, 9999)})",
172 | msg_id=msg_id
173 | )
174 |
175 | _log.info(">>> RESTARTING BOT...")
176 |
177 | self.observer.stop()
178 | self.observer.join()
179 |
180 | _log.info(">>> BOT RESTART COMMAND RECEIVED, SHUTTING DOWN...")
181 |
182 | python = sys.executable
183 | subprocess.Popen([python] + sys.argv)
184 |
185 | sys.exit()
186 |
187 | async def hot_reload(self, group_id, msg_id):
188 | """
189 | 热重载系统
190 |
191 | 参数:
192 | group_id (str): 群组ID
193 | msg_id (str): 消息ID
194 | """
195 | _log.info(">>> HOT RELOAD INITIATED...")
196 | self.system_prompt = load_system_prompt(SYSTEM_PROMPT_FILE)
197 | load_user_names()
198 | _log.info(" ↳ 热重载完成,系统已更新")
199 | await self.api.post_group_message(
200 | group_openid=group_id,
201 | content="热重载完成,系统已更新。",
202 | msg_id=msg_id
203 | )
204 |
205 |
--------------------------------------------------------------------------------
/core/db/auto_tune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import psutil
4 | import time
5 | import yaml
6 | from pathlib import Path
7 | from concurrent.futures import ThreadPoolExecutor
8 | from tqdm import tqdm # 进度条库
9 |
10 | PROJECT_CONFIG_PATH = Path(__file__).resolve().parent.parent.parent / 'configs' / 'config.yaml'
11 |
12 | class AutoTuner:
13 | def __init__(self):
14 | self.cpu_count = psutil.cpu_count(logical=False) # 物理核心数
15 | self.total_memory = psutil.virtual_memory().total
16 | self.system = platform.system()
17 | self.config = {}
18 |
19 | def run_load_test(self):
20 | """
21 | 运行负载测试,测量系统在高负载下的性能
22 | """
23 | load_test_results = {
24 | 'cpu_usage': [],
25 | 'memory_usage': [],
26 | 'response_times': [],
27 | 'io_usage': [],
28 | 'network_usage': []
29 | }
30 |
31 | def simulate_load():
32 | start_time = time.time()
33 | cpu_usage = psutil.cpu_percent(interval=1)
34 | memory_usage = psutil.virtual_memory().percent
35 | io_counters = psutil.disk_io_counters()
36 | network_counters = psutil.net_io_counters()
37 | duration = time.time() - start_time
38 | return cpu_usage, memory_usage, io_counters, network_counters, duration
39 |
40 | with ThreadPoolExecutor(max_workers=self.cpu_count) as executor:
41 | futures = [executor.submit(simulate_load) for _ in range(min(50, self.cpu_count * 10))] # 动态调整并发任务数量
42 | for future in tqdm(futures, desc="Running load test", ncols=100): # 添加进度条显示
43 | cpu_usage, memory_usage, io_counters, network_counters, duration = future.result()
44 | load_test_results['cpu_usage'].append(cpu_usage)
45 | load_test_results['memory_usage'].append(memory_usage)
46 | load_test_results['io_usage'].append(io_counters.read_bytes + io_counters.write_bytes)
47 | load_test_results['network_usage'].append(network_counters.bytes_sent + network_counters.bytes_recv)
48 | load_test_results['response_times'].append(duration)
49 |
50 | # 计算平均值
51 | load_test_results['avg_cpu_usage'] = sum(load_test_results['cpu_usage']) / len(load_test_results['cpu_usage'])
52 | load_test_results['avg_memory_usage'] = sum(load_test_results['memory_usage']) / len(load_test_results['memory_usage'])
53 | load_test_results['avg_io_usage'] = sum(load_test_results['io_usage']) / len(load_test_results['io_usage'])
54 | load_test_results['avg_network_usage'] = sum(load_test_results['network_usage']) / len(load_test_results['network_usage'])
55 | load_test_results['avg_response_time'] = sum(load_test_results['response_times']) / len(load_test_results['response_times'])
56 | return load_test_results
57 |
58 | def determine_optimal_parameters(self, load_test_results):
59 | """
60 | 根据负载测试结果和系统资源,自动调整参数,兼顾高配和低配系统
61 | """
62 | # 动态调整基准值,分别针对高配和低配系统
63 | if self.total_memory <= 8 * 1024 ** 3: # 小于等于8GB内存
64 | memory_ratio = self.total_memory / (8 * 1024 ** 3) # 基准为8GB
65 | cpu_ratio = self.cpu_count / 4 # 基准为4核
66 | base_context_tokens = 1024
67 | base_query_terms = 10
68 | else:
69 | memory_ratio = self.total_memory / (16 * 1024 ** 3) # 基准为16GB
70 | cpu_ratio = self.cpu_count / 8 # 基准为8核
71 | base_context_tokens = 2048
72 | base_query_terms = 18
73 |
74 | # 采用非线性映射调整参数
75 | def adjust_based_on_ratio(base_value, ratio, load_threshold, response_time_threshold):
76 | if load_test_results['avg_memory_usage'] < load_threshold and load_test_results['avg_response_time'] < response_time_threshold:
77 | return int(base_value * ratio * 1.5)
78 | elif load_test_results['avg_memory_usage'] < 80 and load_test_results['avg_response_time'] < 2:
79 | return int(base_value * ratio)
80 | else:
81 | return int(base_value * ratio * 0.75)
82 |
83 | # 动态调整 max_context_tokens 基于内存、响应时间和 I/O 性能
84 | self.config['max_context_tokens'] = adjust_based_on_ratio(base_context_tokens, memory_ratio, 60, 1)
85 |
86 | # 动态调整 Elasticsearch 查询参数基于 CPU 使用率和响应时间
87 | self.config['elasticsearch_query_terms'] = adjust_based_on_ratio(base_query_terms, cpu_ratio, 50, 1)
88 |
89 | # 为低配系统设置最低值限制
90 | if self.total_memory <= 4 * 1024 ** 3: # 4GB以下内存
91 | self.config['max_context_tokens'] = max(self.config['max_context_tokens'], 512)
92 | self.config['elasticsearch_query_terms'] = max(self.config['elasticsearch_query_terms'], 4)
93 |
94 | def update_config_file(self):
95 | """
96 | 将调整后的参数保存到 config.yaml 文件
97 | """
98 | try:
99 | if PROJECT_CONFIG_PATH.exists():
100 | with open(PROJECT_CONFIG_PATH, 'r', encoding='utf-8') as f:
101 | project_config_lines = f.readlines()
102 | else:
103 | project_config_lines = []
104 |
105 | # 移除旧的配置
106 | new_config_lines = []
107 | inside_custom_block = False
108 | for line in project_config_lines:
109 | if line.strip() == '# ---------- Auto-tuned Configuration ----------':
110 | inside_custom_block = True
111 | continue
112 | if line.strip() == '# ---------- End Auto-tuned Configuration ------':
113 | inside_custom_block = False
114 | continue
115 | if not inside_custom_block:
116 | new_config_lines.append(line)
117 |
118 | # 将新的配置添加到文件末尾
119 | new_config_lines.append('\n')
120 | new_config_lines.append('# ---------- Auto-tuned Configuration ----------\n')
121 | for key, value in self.config.items():
122 | new_config_lines.append(f'{key}: {value}\n')
123 | new_config_lines.append('# ---------- End Auto-tuned Configuration ------\n')
124 |
125 | # 保存更新后的内容到 config.yaml 文件
126 | with open(PROJECT_CONFIG_PATH, 'w', encoding='utf-8') as f:
127 | f.writelines(new_config_lines)
128 |
129 | print(f"> 自动调整的配置已保存至:{PROJECT_CONFIG_PATH}")
130 | print(f"> -------------------------------------------------")
131 | print(f"> 请不要擅自修改已添加的配置内容及注释,否则可能导致配置系统无法正常工作。")
132 | print(f"> PLEASE DO NOT MODIFY THE ADDED CONFIGURATION CONTENTS AND COMMENTS,")
133 | print(f"> OR ELSE THE CONFIGURATION SYSTEM MAY NOT WORK PROPERLY.")
134 | print(f"> -------------------------------------------------")
135 |
136 | except Exception as e:
137 | print(f"! 保存自动调整的配置时出错:{e}")
138 | raise
139 |
140 | def tune(self):
141 | """
142 | 执行自动调优过程
143 | """
144 | load_test_results = self.run_load_test()
145 | self.determine_optimal_parameters(load_test_results)
146 | self.update_config_file()
147 |
148 |
149 | if __name__ == "__main__":
150 | tuner = AutoTuner()
151 | tuner.tune()
152 |
--------------------------------------------------------------------------------
/core/update_manager.py:
--------------------------------------------------------------------------------
1 | import aiohttp
2 | import json
3 | import os
4 | import asyncio
5 | from datetime import datetime, timedelta
6 | from core.utils.logger import get_logger
7 | from core.utils.version_utils import is_newer_version
8 | import subprocess
9 | import sys
10 | import urllib.parse
11 | import platform
12 |
13 | _log = get_logger()
14 |
15 | FETCH_RELEASE_URL = "https://bot.luoxiaohei.cn/api/fetchLatestRelease"
16 | AUTO_UPDATE_SCRIPT_URL = "https://bot.luoxiaohei.cn/auto_update.py"
17 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18 | CURRENT_VERSION = "1.1.5 (Alpha_829001)"
19 | CONFIG_PATH = os.path.join(ROOT_DIR, "configs", "update_config.json")
20 |
21 | async def fetch_latest_release():
22 | """
23 | 获取最新版本的发布信息。
24 | """
25 | async with aiohttp.ClientSession() as session:
26 | async with session.get(FETCH_RELEASE_URL) as response:
27 | if response.status == 200:
28 | return await response.json()
29 | else:
30 | _log.warning(f"获取最新版本信息失败,状态码: {response.status}")
31 | return None
32 |
33 | async def prompt_user_for_update(stable_version_info, dev_version_info):
34 | """
35 | 询问用户是否要更新到最新版本以及选择更新的版本类型。
36 | """
37 | stable_version = stable_version_info.get("latestVersion")
38 | dev_version = dev_version_info.get("latestVersion")
39 | print(f"检测到新版本: 稳定版: {stable_version}, 开发版: {dev_version}。您当前版本是: {CURRENT_VERSION}。")
40 | print("请选择更新选项:")
41 | print("1. 更新到最新稳定版")
42 | print("2. 更新到最新开发版")
43 | print("3. 不更新")
44 | print("4. 不更新,且7天内不再提示")
45 |
46 | user_choice = input("请输入您的选择 (1/2/3/4): ")
47 | if user_choice == "1":
48 | await handle_user_choice("stable", stable_version_info)
49 | elif user_choice == "2":
50 | await handle_user_choice("development", dev_version_info)
51 | elif user_choice == "3":
52 | _log.info("已选择不更新。")
53 | elif user_choice == "4":
54 | _log.info("已选择不更新,并在7天内不再提示。")
55 | # 保存配置以便7天内不再提示
56 | update_config = {
57 | "snooze_until": (datetime.now() + timedelta(days=7)).isoformat()
58 | }
59 | # 如果没有自动创建
60 | with open(CONFIG_PATH, 'w') as f:
61 | json.dump(update_config, f)
62 | else:
63 | print("无效选择,请重新输入。")
64 | await prompt_user_for_update(stable_version_info, dev_version_info)
65 |
66 |
67 | async def handle_user_choice(choice, version_info):
68 | """
69 | 处理用户的更新选择。
70 | """
71 | download_url = version_info.get("downloadUrl")
72 | full_download_url = f"https://bot.luoxiaohei.cn{download_url}" # 拼接完整的下载URL
73 |
74 | if choice in ["stable", "development"]:
75 | # 正确解析下载URL并获取文件名
76 | parsed_url = urllib.parse.urlparse(full_download_url)
77 | query_params = urllib.parse.parse_qs(parsed_url.query)
78 | actual_download_url = query_params.get('url', [None])[0]
79 | if actual_download_url:
80 | zip_file_name = os.path.basename(urllib.parse.unquote(actual_download_url))
81 | else:
82 | _log.error("无法解析下载的URL")
83 | return
84 |
85 | zip_download_path = os.path.join(ROOT_DIR, zip_file_name)
86 |
87 | await download_file_with_progress(full_download_url, zip_download_path)
88 | await download_file_with_progress(AUTO_UPDATE_SCRIPT_URL, os.path.join(ROOT_DIR, "auto_update.py"))
89 |
90 | _log.info("更新文件已下载,准备退出程序并执行更新。")
91 | await shutdown_and_update(zip_download_path)
92 |
93 | async def download_file_with_progress(url, dest_path):
94 | """
95 | 下载文件到指定路径,并显示下载进度条。
96 | """
97 | async with aiohttp.ClientSession() as session:
98 | async with session.get(url) as response:
99 | if response.status == 200:
100 | total_size = int(response.headers.get('content-length', 0))
101 | with open(dest_path, 'wb') as f:
102 | downloaded_size = 0
103 | async for data in response.content.iter_chunked(1024):
104 | f.write(data)
105 | downloaded_size += len(data)
106 | progress = (downloaded_size / total_size) * 100
107 | print(f'\r下载进度: [{progress:.2f}%]', end='')
108 | print() # 换行
109 | _log.info(f"文件已下载到 {dest_path}")
110 | else:
111 | _log.error(f"下载文件失败,状态码: {response.status}")
112 |
113 | async def handle_updates():
114 | """
115 | 检查是否有新版本可用,并根据版本类型(正式版或开发版)提醒用户更新。
116 | """
117 | version_info_list = await fetch_latest_release()
118 |
119 | if version_info_list:
120 | stable_info = version_info_list.get('stable', None)
121 | dev_info = version_info_list.get('development', None)
122 |
123 | stable_version = stable_info.get("latestVersion") if stable_info else None
124 | dev_version = dev_info.get("latestVersion") if dev_info else None
125 |
126 | if stable_version or dev_version:
127 | # 检查更新配置文件
128 | if os.path.exists(CONFIG_PATH):
129 | with open(CONFIG_PATH, 'r') as f:
130 | config = json.load(f)
131 | if config.get('snooze_until'):
132 | snooze_until = datetime.fromisoformat(config.get('snooze_until'))
133 | if snooze_until > datetime.now():
134 | _log.info("更新检查已被用户暂停,直到指定日期。")
135 | return
136 |
137 | # 如果有更高版本的更新,提示用户
138 | if (stable_info and is_newer_version(CURRENT_VERSION, stable_version)[0]) or (dev_info and is_newer_version(CURRENT_VERSION, dev_version)[0]):
139 | await prompt_user_for_update(stable_info, dev_info)
140 | else:
141 | _log.info("当前版本已是最新,无需更新。")
142 |
143 | else:
144 | _log.warning(" 无法检查更新")
145 |
146 | async def shutdown_and_update(zip_download_path):
147 | """
148 | 关闭所有进程并执行更新。
149 | """
150 | _log.info("正在退出所有进程以便进行更新...")
151 | await asyncio.sleep(1) # 等待其他任务完成
152 |
153 | # 获取当前操作系统类型
154 | current_os = platform.system()
155 |
156 | # 切换到项目根目录
157 | os.chdir(ROOT_DIR)
158 |
159 | # 启动更新脚本
160 | if current_os == "Windows":
161 | # Windows 下使用 PowerShell 启动新窗口运行 Python 脚本,并确保在根目录下
162 | subprocess.Popen(
163 | f'start powershell -Command "{sys.executable} {os.path.join(ROOT_DIR, "auto_update.py")} {zip_download_path}"',
164 | shell=True
165 | )
166 |
167 | elif current_os == "Linux":
168 | # Linux 下使用终端模拟器(如 gnome-terminal)确保在根目录下
169 | subprocess.Popen(
170 | f'gnome-terminal -- bash -c "cd {ROOT_DIR} && {sys.executable} auto_update.py {zip_download_path}"',
171 | shell=True
172 | )
173 |
174 | elif current_os == "Darwin": # macOS 的系统标识符
175 | # macOS 使用 open -a Terminal 启动脚本
176 | subprocess.Popen(
177 | f'open -a Terminal "{sys.executable} {os.path.join(ROOT_DIR, "auto_update.py")} {zip_download_path}"',
178 | shell=True
179 | )
180 |
181 | else:
182 | _log.error(f"不支持的操作系统: {current_os}")
183 | return
184 |
185 | _log.info("更新脚本已启动,主程序即将退出。")
186 |
187 | # 使用 os._exit() 退出当前进程
188 | os._exit(0)
189 |
190 | # 主测试
191 | async def test():
192 | await shutdown_and_update("v1.2.0-Stable_827001.zip")
193 |
194 | if __name__ == "__main__":
195 | asyncio.run(test())
196 |
--------------------------------------------------------------------------------
/tools/setup/elasticsearch/elasticsearch_setup.py:
--------------------------------------------------------------------------------
1 | import socket
2 | import sys
3 | import subprocess
4 | import time
5 | import yaml
6 | import psutil
7 | from pathlib import Path
8 |
9 | # 配置文件路径
10 | DEFAULT_WINDOWS_PATH = Path(r"C:\Elasticsearch\8.15.0\elasticsearch-8.15.0\bin")
11 | DEFAULT_LINUX_PATH = Path("/usr/share/elasticsearch/bin")
12 | ELASTIC_CONFIG_PATH = Path(__file__).parent.parent.parent / "configs/elasticsearch.yaml"
13 |
14 |
15 | def detect_os_and_version():
16 | if sys.platform.startswith('win'):
17 | return "Windows", sys.getwindowsversion().platform_version
18 | elif sys.platform.startswith('linux'):
19 | return "Linux", subprocess.getoutput('uname -r')
20 | else:
21 | return sys.platform, "Unknown"
22 |
23 |
24 | def is_port_open(host, port):
25 | """检查指定主机的端口是否开放"""
26 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27 | try:
28 | s.connect((host, port))
29 | s.shutdown(socket.SHUT_RDWR)
30 | return True
31 | except:
32 | return False
33 | finally:
34 | s.close()
35 |
36 |
37 | def check_elasticsearch_installed():
38 | os_name, os_version = detect_os_and_version()
39 |
40 | if os_name == "Windows":
41 | return check_elasticsearch_installed_windows()
42 | elif os_name == "Linux":
43 | return check_elasticsearch_installed_linux()
44 | else:
45 | print(f"! 暂不支持的操作系统:{os_name}")
46 | sys.exit(1)
47 |
48 |
49 | def check_elasticsearch_installed_windows():
50 | # 检查默认安装路径
51 | if DEFAULT_WINDOWS_PATH.exists():
52 | print(f"> 检测到Elasticsearch安装在默认路径:{DEFAULT_WINDOWS_PATH}")
53 | save_elasticsearch_config(DEFAULT_WINDOWS_PATH)
54 | return True
55 |
56 | # 提示用户手动输入路径
57 | print("! 请注意,如果你现在在执行main.py,而且你没有安装的话 下面的安装路径可以直接回车或者编一个哦~")
58 | user_path = input("无法自动检测到Elasticsearch安装路径,请手动输入:")
59 | user_path = Path(user_path)
60 | if user_path.exists():
61 | save_elasticsearch_config(user_path)
62 | return True
63 | else:
64 | print(f"! 输入的路径无效:{user_path}")
65 | sys.exit(1)
66 |
67 |
68 | def check_elasticsearch_installed_linux():
69 | # 检查默认路径
70 | possible_paths = [DEFAULT_LINUX_PATH, Path("/usr/local/elasticsearch/bin")]
71 | for path in possible_paths:
72 | if path.exists():
73 | print(f"> 检测到Elasticsearch安装在路径:{path}")
74 | save_elasticsearch_config(path)
75 | return True
76 |
77 | # 通过包管理器检测安装
78 | try:
79 | result = subprocess.run(["which", "elasticsearch"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
80 | if result.returncode == 0:
81 | install_path = Path(result.stdout.decode().strip())
82 | print(f"> 通过包管理器检测到Elasticsearch安装路径:{install_path}")
83 | save_elasticsearch_config(install_path)
84 | return True
85 | except Exception as e:
86 | print(f"! 无法通过包管理器检测到Elasticsearch安装路径:{e}")
87 |
88 | print("! 无法检测到Elasticsearch安装。请检查您的安装状态。")
89 | sys.exit(1)
90 |
91 |
92 | def save_elasticsearch_config(install_path):
93 | # 确保配置目录存在
94 | if not ELASTIC_CONFIG_PATH.parent.exists():
95 | ELASTIC_CONFIG_PATH.parent.mkdir(parents=True)
96 |
97 | config = {"elasticsearch": {"install_path": str(install_path)}}
98 | with open(ELASTIC_CONFIG_PATH, 'w', encoding='utf-8') as f:
99 | yaml.dump(config, f, allow_unicode=True)
100 | print(f"> Elasticsearch安装路径已保存到配置文件:{ELASTIC_CONFIG_PATH}")
101 |
102 |
103 | def is_elasticsearch_running():
104 | for proc in psutil.process_iter(['pid', 'name']):
105 | if proc.info['name'] == "elasticsearch":
106 | print("> Elasticsearch正在运行")
107 | return True
108 | print("> Elasticsearch未运行")
109 | return False
110 |
111 |
112 | def start_elasticsearch():
113 | os_name, os_version = detect_os_and_version()
114 |
115 | if os_name == "Windows":
116 | return start_elasticsearch_windows()
117 | elif os_name == "Linux":
118 | return start_elasticsearch_linux()
119 | else:
120 | print(f"! 暂不支持的操作系统:{os_name}")
121 | sys.exit(1)
122 |
123 |
124 | def start_elasticsearch_windows():
125 | try:
126 | with open(ELASTIC_CONFIG_PATH, 'r', encoding='utf-8') as f:
127 | config = yaml.safe_load(f)
128 | install_path = Path(config['elasticsearch']['install_path'])
129 |
130 | print("> 正在尝试启动Elasticsearch服务...")
131 |
132 | # 直接运行 elasticsearch.bat 并确保它在后台运行
133 | subprocess.Popen([str(install_path / "elasticsearch.bat")], creationflags=subprocess.CREATE_NEW_CONSOLE)
134 |
135 | # 等待Elasticsearch启动(检查端口是否开放)
136 | for _ in range(20): # 尝试20次
137 | if is_port_open("127.0.0.1", 9200):
138 | print("> Elasticsearch已成功启动并监听端口9200。")
139 | return True
140 | time.sleep(2)
141 |
142 | print("! Elasticsearch启动失败,未能在预期端口上监听。")
143 | return False
144 |
145 | except Exception as e:
146 | print(f"! 启动Elasticsearch服务失败:{e}")
147 | return False
148 |
149 |
150 | def start_elasticsearch_linux():
151 | try:
152 | print("> 正在尝试启动Elasticsearch服务...")
153 |
154 | # 检查 systemctl 是否可用
155 | if subprocess.run(["which", "systemctl"], stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0:
156 | subprocess.run(["sudo", "systemctl", "start", "elasticsearch"], check=True)
157 | else:
158 | subprocess.run(["sudo", "service", "elasticsearch", "start"], check=True)
159 |
160 | print("> Elasticsearch已成功启动。")
161 | except Exception as e:
162 | print(f"! 启动Elasticsearch服务失败:{e}")
163 | sys.exit(1)
164 |
165 |
166 | def stop_all_elasticsearch_processes():
167 | print("> 正在停止所有Elasticsearch相关进程...")
168 |
169 | # 定义要终止的进程名称列表
170 | elasticsearch_related_processes = ["elasticsearch", "controller.exe", "OpenJDK Platform binary", "java.exe"]
171 |
172 | for proc in psutil.process_iter(['pid', 'name']):
173 | try:
174 | if proc.info['name'] in elasticsearch_related_processes:
175 | print(f"> 正在终止进程:{proc.info['name']} (PID: {proc.info['pid']})")
176 | proc.terminate()
177 | except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
178 | print(f"! 无法终止进程:{proc.info['name']} (PID: {proc.info['pid']}) - 权限不足或进程已结束。")
179 | continue
180 |
181 | print("> 所有Elasticsearch相关进程终止命令已发送。")
182 |
183 |
184 | def check_elasticsearch_connection():
185 | try:
186 | print("> 正在测试Elasticsearch是否已启动...")
187 | # 通过ping本地端口检查服务是否启动
188 | if is_port_open("127.0.0.1", 9200):
189 | print("> Elasticsearch已经成功启动并在9200端口监听。")
190 | return True
191 | else:
192 | print("! Elasticsearch未能在9200端口启动。")
193 | return False
194 | except Exception as e:
195 | print(f"! 检查Elasticsearch连接时发生错误:{e}")
196 | return False
197 |
198 |
199 | if __name__ == "__main__":
200 | print("> 开始Elasticsearch启动检测...")
201 |
202 | # 添加连接测试
203 | if check_elasticsearch_connection():
204 | print("> Elasticsearch已经在运行,连接正常。")
205 | sys.exit(0)
206 |
207 | if not check_elasticsearch_installed():
208 | print("! Elasticsearch未安装或安装检测失败。")
209 | sys.exit(1)
210 |
211 | if not is_elasticsearch_running():
212 | print("> Elasticsearch未运行,尝试启动...")
213 | stop_all_elasticsearch_processes()
214 | start_elasticsearch()
215 |
216 | if not check_elasticsearch_connection():
217 | print("! Elasticsearch启动失败或无法连接,请检查安装和配置。")
218 | sys.exit(1)
219 |
220 | print("> Elasticsearch准备就绪。")
221 | sys.exit(0)
222 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | AmyAlmond Project - config.py
3 |
4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot
5 | Developer: Shuakami <3 LuoXiaoHei
6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved.
7 | Version: 1.3.0 (Stable_923001)
8 |
9 | config.py - 配置文件读取与验证
10 | """
11 | import os
12 | from botpy.ext.cog_yaml import read
13 | from core.utils.logger import get_logger
14 | import subprocess
15 | import time
16 | from ruamel.yaml import YAML
17 |
18 | # 获取 logger 对象
19 | logger = get_logger()
20 |
21 | # 定义目录结构
22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23 | CONFIG_DIR = os.path.join(BASE_DIR, "configs")
24 | LOG_DIR = os.path.join(BASE_DIR, "logs")
25 | DATA_DIR = os.path.join(BASE_DIR, "data")
26 |
27 | # 确保目录存在
28 | os.makedirs(CONFIG_DIR, exist_ok=True)
29 | os.makedirs(LOG_DIR, exist_ok=True)
30 | os.makedirs(DATA_DIR, exist_ok=True)
31 |
32 | # 配置文件路径
33 | CONFIG_FILE = os.path.join(CONFIG_DIR, "config.yaml")
34 | SYSTEM_PROMPT_FILE = os.path.join(CONFIG_DIR, "system-prompt.txt")
35 |
36 | # 日志文件路径
37 | LOG_FILE = os.path.join(LOG_DIR, "bot.log")
38 |
39 | # 数据文件路径
40 | MEMORY_FILE = os.path.join(DATA_DIR, "memory.json")
41 | LONG_TERM_MEMORY_FILE = os.path.join(DATA_DIR, "long_term_memory_{}.txt")
42 | USER_NAMES_FILE = os.path.join(DATA_DIR, "user_names.json")
43 | FAISS_INDEX_PATH = "./data/faiss_index.bin"
44 |
45 | # 读取配置文件
46 | test_config = {}
47 | logger.info("")
48 | logger.info(">>> CONFIG LOADING...")
49 | if os.path.exists(CONFIG_FILE):
50 | loaded_config = read(CONFIG_FILE)
51 | if loaded_config:
52 | test_config.update(loaded_config)
53 | logger.info(" ↳ 配置文件加载成功")
54 | else:
55 | logger.critical(" 配置文件为空")
56 | logger.critical(f" ↳ 路径: {CONFIG_FILE}")
57 | logger.critical(" ↳ 请检查配置文件是否正确填写,并确保其格式为 YAML")
58 | exit(1)
59 | else:
60 | logger.critical(" 找不到配置文件")
61 | logger.critical(f" ↳ 路径: {CONFIG_FILE}")
62 | logger.critical(f" ↳ 请确保在 {CONFIG_DIR} 目录下存在 config.yaml 文件")
63 | exit(1)
64 |
65 | # 配置参数
66 | MAX_CONTEXT_TOKENS = test_config.get("max_context_tokens", None)
67 | ELASTICSEARCH_QUERY_TERMS = test_config.get("elasticsearch_query_terms", None)
68 |
69 | # 检查是否需要自动调优
70 | if MAX_CONTEXT_TOKENS is None or ELASTICSEARCH_QUERY_TERMS is None:
71 | logger.warning(" 未找到必要的配置参数,正在调用自动调优程序...")
72 | try:
73 | start_time = time.time()
74 | # 调用 auto_tune.py 自动调优
75 | result = subprocess.run(["python", "core/db/auto_tune.py"], timeout=60)
76 | elapsed_time = time.time() - start_time
77 |
78 | if result.returncode == 0:
79 | logger.info(" 自动调优完成")
80 | logger.info(f" ↳ 耗时: {elapsed_time:.2f} 秒")
81 | # 重新读取配置文件
82 | if os.path.exists(CONFIG_FILE):
83 | loaded_config = read(CONFIG_FILE)
84 | if loaded_config:
85 | test_config.update(loaded_config)
86 | MAX_CONTEXT_TOKENS = test_config.get("max_context_tokens", 2400) # 默认值 2400
87 | ELASTICSEARCH_QUERY_TERMS = test_config.get("elasticsearch_query_terms", 16) # 默认值 16
88 | else:
89 | logger.critical(" 配置文件读取失败,使用默认值")
90 | MAX_CONTEXT_TOKENS = 2400
91 | ELASTICSEARCH_QUERY_TERMS = 8
92 | else:
93 | logger.error(" 自动调优程序执行失败,使用默认值")
94 | MAX_CONTEXT_TOKENS = 2400
95 | ELASTICSEARCH_QUERY_TERMS = 16
96 | except subprocess.TimeoutExpired:
97 | logger.error(" 自动调优超时,使用默认值")
98 | MAX_CONTEXT_TOKENS = 2400
99 | ELASTICSEARCH_QUERY_TERMS = 16
100 |
101 |
102 |
103 | # 其他配置
104 | REQUEST_LIMIT_TIME_FRAME = test_config.get("request_limit_time_frame", 10)
105 | REQUEST_LIMIT_COUNT = test_config.get("request_limit_count", 7)
106 | GLOBAL_RATE_LIMIT = test_config.get("global_rate_limit", 75)
107 |
108 | MEMORY_THRESHOLD = 150
109 | FORGET_THRESHOLD = 5
110 |
111 |
112 | MEMORY_BATCH_SIZE = test_config.get("memory_batch_size", 1)
113 | REQUEST_TIMEOUT= test_config.get("request_timeout", 7)
114 |
115 | MONGODB_URI = test_config.get("mongodb_url", "")
116 | MONGODB_USERNAME = test_config.get("mongodb_username", "")
117 | MONGODB_PASSWORD = test_config.get("mongodb_password", "")
118 |
119 | ELASTICSEARCH_URL = test_config.get("elasticsearch_url", "")
120 | ELASTICSEARCH_USERNAME = test_config.get("elasticsearch_username", "")
121 | ELASTICSEARCH_PASSWORD = test_config.get("elasticsearch_password", "")
122 |
123 | OPENAI_SECRET = test_config.get("openai_secret", "")
124 | OPENAI_MODEL = test_config.get("openai_model", "gpt-4o-mini")
125 | OPENAI_API_URL = test_config.get("openai_api_url", "https://api.openai-hk.com/v1/chat/completions")
126 |
127 | ADMIN_ID = test_config.get("admin_id", "")
128 |
129 | # KEEP_ALIVE 配置
130 | OPENAI_KEEP_ALIVE = test_config.get("openai_keep_alive", True)
131 | UPDATE_KEEP_ALIVE = test_config.get("update_keep_alive", True)
132 |
133 | # 日志配置
134 | LOG_LEVEL = test_config.get("log_level", "INFO").upper()
135 | DEBUG_MODE = test_config.get("debug", False)
136 |
137 | # 验证关键配置
138 | if not MONGODB_USERNAME:
139 | logger.warning(" MongoDB 用户名缺失")
140 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}")
141 | if not MONGODB_PASSWORD:
142 | logger.warning(" MongoDB 密码缺失")
143 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}")
144 | if not MONGODB_URI:
145 | logger.warning(" MongoDB URI 缺失")
146 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}")
147 | if not OPENAI_SECRET:
148 | logger.warning(" OpenAI API 密钥缺失")
149 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}")
150 | if not OPENAI_MODEL:
151 | logger.warning(" OpenAI 模型缺失")
152 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}")
153 | if not OPENAI_API_URL:
154 | logger.warning(" OpenAI API URL 缺失")
155 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}")
156 | if not ADMIN_ID:
157 | logger.warning(" 管理员 ID 缺失")
158 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}")
159 |
160 | def _write_config():
161 | """将配置写入 config.yaml 文件,保留原始格式"""
162 | yaml = YAML()
163 | yaml.indent(mapping=2, sequence=4, offset=2)
164 | yaml.preserve_quotes = True
165 |
166 | with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
167 | yaml_data = yaml.load(f)
168 |
169 | # 更新配置项的值
170 | for key, value in test_config.items():
171 | if key in yaml_data:
172 | yaml_data[key] = value
173 | else:
174 | yaml_data[key] = value
175 |
176 | with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
177 | yaml.dump(yaml_data, f)
178 |
179 | def get_all_config():
180 | """获取所有配置"""
181 | return test_config
182 |
183 | def add_config(key, value):
184 | """添加新的配置项"""
185 | if key in test_config:
186 | logger.warning(f" 配置项 '{key}' 已存在,无法添加")
187 | return False
188 | test_config[key] = value
189 | _write_config()
190 | logger.info(f" 配置项 '{key}' 添加成功")
191 | return True
192 |
193 | def update_config(key, value):
194 | """修改或添加配置项"""
195 | test_config[key] = value # 如果 key 不存在,则添加新的配置项
196 | _write_config()
197 | logger.info(f" 配置项 '{key}' 修改成功")
198 | return True
199 |
200 | def delete_config(key):
201 | """删除配置项"""
202 | if key not in test_config:
203 | logger.warning(f" 配置项 '{key}' 不存在,无法删除")
204 | return False
205 | del test_config[key]
206 | _write_config()
207 | logger.info(f" 配置项 '{key}' 删除成功")
208 | return True
209 |
210 |
211 |
212 | # DEBUG情况下
213 | if DEBUG_MODE:
214 | if OPENAI_SECRET and OPENAI_MODEL and OPENAI_API_URL and ADMIN_ID:
215 | masked_secret = '*' * (len(OPENAI_SECRET) - 4) + OPENAI_SECRET[-4:]
216 | masked_admin_id = '*' * (len(ADMIN_ID) - 4) + ADMIN_ID[-4:]
217 | logger.info("")
218 | logger.info(" OpenAI API Configuration")
219 | logger.info(f" ↳ API Key : {masked_secret}")
220 | logger.info(f" ↳ Model : {OPENAI_MODEL}")
221 | logger.info(f" ↳ API URL : {OPENAI_API_URL}")
222 | logger.info(f" ↳ Admin ID : {masked_admin_id}")
223 | logger.info(f" ↳ Log Level : {LOG_LEVEL}")
224 | logger.info(f" ↳ Debug Mode: {'Enabled' if DEBUG_MODE else 'Disabled'}")
225 |
226 |
--------------------------------------------------------------------------------
/tools/setup/mongodb/mongodb_setup.py:
--------------------------------------------------------------------------------
1 | import socket
2 | import sys
3 | import subprocess
4 | import time
5 | import yaml
6 | import psutil
7 | from pathlib import Path
8 | from pymongo import MongoClient, errors
9 |
10 | # 配置文件路径
11 | DEFAULT_WINDOWS_PATH = Path(r"C:\Program Files\MongoDB\Server\7.0\bin")
12 | DEFAULT_LINUX_PATH = Path("/usr/bin/mongod")
13 | DEFAULT_DB_PATH = Path(r"C:\data\db")
14 | MONGO_CONFIG_PATH = Path(__file__).parent.parent.parent / "configs/mongodb.yaml"
15 |
16 |
17 | def detect_os_and_version():
18 | if sys.platform.startswith('win'):
19 | return "Windows", sys.getwindowsversion().platform_version
20 | elif sys.platform.startswith('linux'):
21 | return "Linux", subprocess.getoutput('uname -r')
22 | else:
23 | return sys.platform, "Unknown"
24 |
25 |
26 | def is_port_open(host, port):
27 | """检查指定主机的端口是否开放"""
28 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
29 | try:
30 | s.connect((host, port))
31 | s.shutdown(socket.SHUT_RDWR)
32 | return True
33 | except:
34 | return False
35 | finally:
36 | s.close()
37 |
38 |
39 | def check_mongodb_installed():
40 | os_name, os_version = detect_os_and_version()
41 |
42 | if os_name == "Windows":
43 | return check_mongodb_installed_windows()
44 | elif os_name == "Linux":
45 | return check_mongodb_installed_linux()
46 | else:
47 | print(f"! 暂不支持的操作系统:{os_name}")
48 | sys.exit(1)
49 |
50 |
51 | def check_mongodb_installed_windows():
52 | # 检查默认安装路径
53 | if DEFAULT_WINDOWS_PATH.exists():
54 | print(f"> 检测到MongoDB安装在默认路径:{DEFAULT_WINDOWS_PATH}")
55 | save_mongodb_config(DEFAULT_WINDOWS_PATH)
56 | return True
57 |
58 | # 尝试通过注册表检测
59 | try:
60 | import winreg
61 | key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, r"SOFTWARE\MongoDB\Server")
62 | install_path, _ = winreg.QueryValueEx(key, "InstallPath")
63 | install_path = Path(install_path)
64 | print(f"> 通过注册表检测到MongoDB安装路径:{install_path}")
65 | save_mongodb_config(install_path)
66 | return True
67 | except Exception as e:
68 | print(f"! 无法通过注册表检测到MongoDB安装路径:{e}")
69 |
70 | # 如果前面的方法都失败,提示用户手动输入路径
71 | print("! 请注意,如果你现在在执行main.py,而且你没有安装的话,下面的安装路径可以直接回车或者编一个哦~")
72 | user_path = input("无法自动检测到MongoDB安装路径,请手动输入:")
73 | user_path = Path(user_path)
74 | if user_path.exists():
75 | save_mongodb_config(user_path)
76 | return True
77 | else:
78 | print(f"! 输入的路径无效:{user_path}")
79 | sys.exit(1)
80 |
81 |
82 | def check_mongodb_installed_linux():
83 | # 检查默认路径
84 | possible_paths = [DEFAULT_LINUX_PATH, Path("/usr/local/bin/mongod")]
85 | for path in possible_paths:
86 | if path.exists():
87 | print(f"> 检测到MongoDB安装在路径:{path}")
88 | save_mongodb_config(path)
89 | return True
90 |
91 | # 通过包管理器检测安装
92 | try:
93 | result = subprocess.run(["which", "mongod"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
94 | if result.returncode == 0:
95 | install_path = Path(result.stdout.decode().strip())
96 | print(f"> 通过包管理器检测到MongoDB安装路径:{install_path}")
97 | save_mongodb_config(install_path)
98 | return True
99 |
100 | # 针对Ubuntu/Debian
101 | result = subprocess.run(["dpkg", "-l", "|", "grep", "mongodb"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
102 | if result.returncode == 0:
103 | print("> 检测到MongoDB已通过dpkg安装")
104 | save_mongodb_config(Path("/usr/bin/mongod"))
105 | return True
106 |
107 | # 针对CentOS/RHEL
108 | result = subprocess.run(["rpm", "-qa", "|", "grep", "mongodb"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
109 | if result.returncode == 0:
110 | print("> 检测到MongoDB已通过rpm安装")
111 | save_mongodb_config(Path("/usr/bin/mongod"))
112 | return True
113 | except Exception as e:
114 | print(f"! 无法通过包管理器检测到MongoDB安装路径:{e}")
115 |
116 | print("! 无法检测到MongoDB安装。请检查您的安装状态。")
117 | sys.exit(1)
118 |
119 |
120 | def save_mongodb_config(install_path):
121 | # 确保配置目录存在
122 | if not MONGO_CONFIG_PATH.parent.exists():
123 | MONGO_CONFIG_PATH.parent.mkdir(parents=True)
124 |
125 | config = {"mongodb": {"install_path": str(install_path)}}
126 | with open(MONGO_CONFIG_PATH, 'w', encoding='utf-8') as f:
127 | yaml.dump(config, f, allow_unicode=True)
128 | print(f"> MongoDB安装路径已保存到配置文件:{MONGO_CONFIG_PATH}")
129 |
130 |
131 | def is_mongodb_running():
132 | for proc in psutil.process_iter(['pid', 'name']):
133 | if proc.info['name'] == "mongod.exe" or proc.info['name'] == "mongod":
134 | print("> MongoDB正在运行")
135 | return True
136 | print("> MongoDB未运行")
137 | return False
138 |
139 |
140 | def start_mongodb():
141 | os_name, os_version = detect_os_and_version()
142 |
143 | if os_name == "Windows":
144 | return start_mongodb_windows()
145 | elif os_name == "Linux":
146 | return start_mongodb_linux()
147 | else:
148 | print(f"! 暂不支持的操作系统:{os_name}")
149 | sys.exit(1)
150 |
151 |
152 | def start_mongodb_windows():
153 | try:
154 | with open(MONGO_CONFIG_PATH, 'r', encoding='utf-8') as f:
155 | config = yaml.safe_load(f)
156 | install_path = Path(config['mongodb']['install_path'])
157 |
158 | db_path = DEFAULT_DB_PATH
159 | if not db_path.exists():
160 | db_path.mkdir(parents=True) # 确保数据目录存在
161 |
162 | print("> 正在尝试启动MongoDB服务...")
163 |
164 | subprocess.Popen(
165 | [str(install_path / "mongod.exe"), "--dbpath", str(db_path), "--quiet"],
166 | creationflags=subprocess.DETACHED_PROCESS
167 | )
168 | # 等待MongoDB启动(检查端口是否开放)
169 | for _ in range(10): # 尝试10次
170 | if is_port_open("127.0.0.1", 27017):
171 | print("> MongoDB已成功启动并监听端口27017。")
172 | return True
173 | time.sleep(1)
174 |
175 | print("! MongoDB启动失败,未能在预期端口上监听。")
176 | return False
177 |
178 | except Exception as e:
179 | print(f"! 启动MongoDB服务失败:{e}")
180 | return False
181 |
182 |
183 | def start_mongodb_linux():
184 | try:
185 | print("> 正在尝试启动MongoDB服务...")
186 |
187 | # 检查 systemctl 是否可用
188 | if subprocess.run(["which", "systemctl"], stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0:
189 | subprocess.run(["sudo", "systemctl", "start", "mongod"], check=True)
190 | else:
191 | subprocess.run(["sudo", "service", "mongod", "start"], check=True)
192 |
193 | print("> MongoDB已成功启动。")
194 | except Exception as e:
195 | print(f"! 启动MongoDB服务失败:{e}")
196 | sys.exit(1)
197 |
198 |
199 | def stop_all_mongodb_processes():
200 | print("> 正在停止所有MongoDB进程...")
201 | for proc in psutil.process_iter(['pid', 'name']):
202 | if proc.info['name'] == "mongod.exe" or proc.info['name'] == "mongod":
203 | print(f"> 正在终止进程:{proc.info['pid']}")
204 | proc.terminate()
205 |
206 | print("> 所有MongoDB进程已停止。")
207 |
208 |
209 | def check_mongodb_connection():
210 | try:
211 | print("> 正在测试与MongoDB的连接...")
212 | client = MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=5000)
213 | # 尝试连接到MongoDB服务器
214 | client.server_info() # 发送一个ping以确认连接成功
215 | print("> MongoDB连接成功!")
216 | return True
217 | except errors.ServerSelectionTimeoutError as err:
218 | print(f"! 无法连接到MongoDB服务器:{err}")
219 | return False
220 | except Exception as e:
221 | print(f"! 连接MongoDB时发生错误:{e}")
222 | return False
223 |
224 |
225 | if __name__ == "__main__":
226 | print("> 开始MongoDB启动检测...")
227 | if not check_mongodb_installed():
228 | print("! MongoDB未安装或安装检测失败。")
229 | sys.exit(1)
230 |
231 | if not is_mongodb_running():
232 | print("> MongoDB未运行,尝试启动...")
233 |
234 | if not is_mongodb_running():
235 | print("MongoDB未运行,尝试启动...")
236 | stop_all_mongodb_processes()
237 | start_mongodb()
238 |
239 | if not check_mongodb_connection():
240 | print("MongoDB启动失败或无法连接,请检查安装和配置。")
241 | sys.exit(1)
242 |
243 | print("MongoDB已成功启动并连接!系统准备就绪。")
244 | # 退出此进程
245 | sys.exit(0)
246 |
--------------------------------------------------------------------------------