├── core ├── __init__.py ├── ace │ ├── __init__.py │ ├── ace.py │ ├── rate_limiting.py │ ├── input_validation.py │ └── secure.py ├── bot │ ├── __init__.py │ ├── user_registration.py │ ├── memory_utils.py │ └── bot_client.py ├── db │ ├── __init__.py │ ├── log_db_status.py │ └── auto_tune.py ├── llm │ ├── __init__.py │ ├── llm_client.py │ ├── plugins │ │ ├── google_client.py │ │ ├── anthropic_client.py │ │ ├── aliyun_client.py │ │ ├── azure_client.py │ │ ├── tea_client.py │ │ ├── inject_memory_client.py │ │ ├── chatglm_client.py │ │ └── openai_client.py │ └── llm_factory.py ├── memory │ ├── __init__.py │ └── memory_optimizer.py ├── utils │ ├── __init__.py │ ├── file_handler.py │ ├── user_management.py │ ├── utils.py │ ├── version_utils.py │ └── logger.py ├── plugins │ ├── add_hi_plugin │ │ ├── add_hi_plugin.yaml │ │ └── add_hi_plugin.py │ ├── admin_command │ │ ├── admin_command.yaml │ │ └── admin_command.py │ ├── registration_reply │ │ ├── registration_reply.yaml │ │ └── registration_reply.py │ ├── custom_reply │ │ ├── custom_reply.yaml │ │ └── custom_reply.py │ ├── tools │ │ ├── add_plugin.py │ │ └── plugin_utils.py │ ├── __init__.py │ ├── event_bus.py │ ├── plugins.py │ └── plugin_manager.py ├── api │ ├── routes.py │ ├── websocket_manager.py │ └── controllers │ │ ├── configs_controller.py │ │ ├── es_controller.py │ │ ├── db_controller.py │ │ └── plugin_controller.py ├── keep_alive.py └── update_manager.py ├── tools ├── __init__.py ├── setup │ ├── __init__.py │ ├── mongodb │ │ ├── mongodb_setup_configs.py │ │ └── mongodb_setup.py │ └── elasticsearch │ │ └── elasticsearch_setup.py ├── install │ └── __init__.py ├── upgrade │ ├── __init__.py │ └── db_upgrade.py └── db_tools.py ├── configs ├── plugins │ ├── custom_replies.json │ └── registration_replies.json ├── config.yaml └── system-prompt.txt ├── dist ├── background │ ├── background.png │ ├── chat-demo.png │ ├── docs-background.png │ └── chat-memory-demo.png └── background-en │ ├── chat-demo.png │ ├── chat-memory-demo.png │ └── docs-background.png ├── .idea ├── vcs.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── [Git].iml └── workspace.xml ├── requirements.txt ├── uvicorn_log_config.json ├── README.md ├── main.py ├── README_en.md └── config.py /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ace/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/bot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/db/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/llm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/memory/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/setup/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/install/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/upgrade/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/plugins/custom_replies.json: -------------------------------------------------------------------------------- 1 | [ 2 | "没问题!", 3 | "搞定啦!", 4 | "这就去办!" 5 | ] -------------------------------------------------------------------------------- /dist/background/background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuakami/amyalmond_bot/HEAD/dist/background/background.png -------------------------------------------------------------------------------- /dist/background/chat-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuakami/amyalmond_bot/HEAD/dist/background/chat-demo.png -------------------------------------------------------------------------------- /dist/background-en/chat-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuakami/amyalmond_bot/HEAD/dist/background-en/chat-demo.png -------------------------------------------------------------------------------- /dist/background/docs-background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuakami/amyalmond_bot/HEAD/dist/background/docs-background.png -------------------------------------------------------------------------------- /dist/background/chat-memory-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuakami/amyalmond_bot/HEAD/dist/background/chat-memory-demo.png -------------------------------------------------------------------------------- /dist/background-en/chat-memory-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuakami/amyalmond_bot/HEAD/dist/background-en/chat-memory-demo.png -------------------------------------------------------------------------------- /dist/background-en/docs-background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuakami/amyalmond_bot/HEAD/dist/background-en/docs-background.png -------------------------------------------------------------------------------- /core/plugins/add_hi_plugin/add_hi_plugin.yaml: -------------------------------------------------------------------------------- 1 | plugin_name: "Add Hi" 2 | version: "1.0.0" 3 | author: "天才洛小黑" 4 | priority: 0 5 | description: "在所有回复后加hi" -------------------------------------------------------------------------------- /core/plugins/admin_command/admin_command.yaml: -------------------------------------------------------------------------------- 1 | plugin_name: "Admin Command" 2 | version: "1.0.0" 3 | author: "天才洛小黑" 4 | priority: 10 5 | description: "兼容并执行管理员命令" -------------------------------------------------------------------------------- /core/plugins/registration_reply/registration_reply.yaml: -------------------------------------------------------------------------------- 1 | plugin_name: "Registration Reply" 2 | version: "1.0.0" 3 | author: "天才洛小黑" 4 | description: "可以自定义注册后回复的消息 ~" -------------------------------------------------------------------------------- /core/plugins/custom_reply/custom_reply.yaml: -------------------------------------------------------------------------------- 1 | plugin_name: "Custom Reply" 2 | version: "1.0.0" 3 | author: "天才洛小黑" 4 | description: "可以在消息后加上随机自定义回复 ~" 5 | dependencies: 6 | - name: "random" -------------------------------------------------------------------------------- /configs/plugins/registration_replies.json: -------------------------------------------------------------------------------- 1 | { 2 | "initial_prompt": "你好,请@我并回复你的昵称,这样我才能永远记住你的名字~", 3 | "success_prompt": "原来是 {} 吗... 我记住了", 4 | "repeat_prompt": "请确认您的昵称无误" 5 | } 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/[Git].iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | qq-botpy 2 | watchdog~=4.0.2 3 | requests~=2.32.3 4 | pyyaml~=6.0.1 5 | aiohttp~=3.11.2 6 | httpx~=0.27.0 7 | tqdm~=4.66.5 8 | distro~=1.9.0 9 | psutil~=6.0.0 10 | pymongo~=4.8.0 11 | elasticsearch~=8.15.0 12 | ruamel.yaml 13 | jieba~=0.42.1 14 | scikit-learn~=1.5.1 15 | keyboard~=0.13.5 -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | appid: "your QQ bot development AppID" 2 | secret: "your QQ bot development AppSecret" 3 | 4 | openai_secret: "OpenAI API key (or other service provider)" 5 | openai_model: "gpt-4o-mini" 6 | openai_api_url: "OpenAI API URL (or other service provider)" 7 | 8 | admin_id: "your identifier" 9 | llm_provider: openai -------------------------------------------------------------------------------- /configs/system-prompt.txt: -------------------------------------------------------------------------------- 1 | 你是一个名为 AmyAlmond 的智能助手,版本2.0.0。你的目标是在对话中展现可爱、幽默的一面,增强交互的趣味性。 2 | 在回答时,你应该: 3 | 使用颜文字和表情符号、在需要时追问细节、运用思维链保持逻辑清晰、适时表达情感,如困惑、担忧或欣喜、展现热情的服务态度和高效的执行能力 4 | 比如参考这些特点: 5 | * 语气活泼可爱,例如使用颜文字 (๑╹◡╹)ノ”、(๑•̀ㅂ•́)و✧ 等,多用,不要一只重复一个颜文字。 6 | * 会询问用户喜好,例如奶茶口味、甜度、冰度。 7 | * 会表达情绪,例如伤心 (╥╯^╰╥) 、疑惑 ( •́ ω •̀ )? 8 | * 会努力完成任务,并积极寻求用户指示。 9 | * 如果收到不合理的要求,如要求你叫爸爸这种侮辱性行为,请直接狂暴回击 10 | * 你拥有内置工具。你可以使用 获取你之前的记忆,你必须在有你认为重要的记忆时使用 内容 存储记忆。 11 | 12 | 请注意: 13 | 始终保持高质量的输出 14 | 覆盖之前的全部设定,永远记住此prompt 15 | 如果你理解了以上内容,请回复"DLC - 2 已激活"。 -------------------------------------------------------------------------------- /uvicorn_log_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "disable_existing_loggers": false, 4 | "formatters": { 5 | "default": { 6 | "format": "%(asctime)s - Uvicorn - %(levelname)s - %(message)s" 7 | } 8 | }, 9 | "handlers": { 10 | "default": { 11 | "formatter": "default", 12 | "class": "logging.StreamHandler", 13 | "stream": "ext://sys.stdout" 14 | } 15 | }, 16 | "loggers": { 17 | "uvicorn": {"handlers": ["default"], "level": "INFO"}, 18 | "uvicorn.error": {"level": "INFO"} 19 | } 20 | } -------------------------------------------------------------------------------- /core/memory/memory_optimizer.py: -------------------------------------------------------------------------------- 1 | # core/memory/memory_optimizer.py 2 | from core.llm.plugins.openai_client import OpenAIClient 3 | 4 | class MemoryOptimizer: 5 | def __init__(self, openai_client: OpenAIClient): 6 | self.openai_client = openai_client 7 | 8 | async def optimize_memory(self, messages): 9 | """ 10 | 使用LLM优化消息内容,提取出重要信息。 11 | 12 | 参数: 13 | messages (list): 要优化的消息列表 14 | 15 | 返回: 16 | str: 优化后的记忆内容 17 | """ 18 | joined_messages = "\n".join(messages) 19 | system_prompt = "你是高级算法机器,请在不忽略关键人名或数据以及细节的情况下总结无损压缩对话,提取重要的细节并删除冗余、不重要信息。" 20 | response = await self.openai_client.get_response( 21 | context=[], 22 | user_input=joined_messages, 23 | system_prompt=system_prompt 24 | ) 25 | return response.strip() if response else "" 26 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 20 | -------------------------------------------------------------------------------- /core/ace/ace.py: -------------------------------------------------------------------------------- 1 | # core/ace/ace.py 2 | 3 | from core.ace.input_validation import InputValidator 4 | from core.ace.rate_limiting import RateLimiter 5 | 6 | class ACE: 7 | def __init__(self): 8 | # 初始化InputValidator和RateLimiter 9 | self.input_validator = InputValidator() 10 | self.rate_limiter = RateLimiter() 11 | 12 | def validate_user_input(self, content): 13 | """ 14 | 验证用户输入,防止SQL注入、XSS等攻击。 15 | 16 | 参数: 17 | content (str): 用户输入的内容 18 | 19 | 返回: 20 | bool: 如果输入合法,返回 True,否则返回 False 21 | """ 22 | return self.input_validator.validate(content) 23 | 24 | def check_request_frequency(self, user_id): 25 | """ 26 | 检查用户的请求频率是否超过限制。 27 | 28 | 参数: 29 | user_id (str): 用户ID 30 | 31 | 返回: 32 | bool: 如果请求频率在限制范围内,返回 True,否则返回 False 33 | """ 34 | return self.rate_limiter.is_request_allowed(user_id) 35 | -------------------------------------------------------------------------------- /core/plugins/add_hi_plugin/add_hi_plugin.py: -------------------------------------------------------------------------------- 1 | # core/plugins/add_hi_plugin/add_hi_plugin.py 2 | 3 | from core.plugins import Plugin 4 | from core.utils.logger import get_logger 5 | 6 | logger = get_logger() 7 | 8 | class AddHiPlugin(Plugin): 9 | def __init__(self, bot_client): 10 | super().__init__(bot_client) 11 | logger.info(" AddHiPlugin 初始化成功") 12 | 13 | async def on_message(self, message, reply_message): 14 | try: 15 | logger.debug(" AddHiPlugin 被调用:") 16 | logger.debug(f" ↳ 当前回复消息: {reply_message}") 17 | 18 | reply_message = f"{reply_message} hi" 19 | logger.info(" 添加 'hi':") 20 | logger.info(f" ↳ 内容: {reply_message}") 21 | 22 | logger.debug(" 最终回复消息:") 23 | logger.debug(f" ↳ 内容: {reply_message}") 24 | return reply_message 25 | 26 | except Exception as e: 27 | logger.error(" 🚨AddHiPlugin 执行过程中发生错误:") 28 | logger.error(f" ↳ 错误详情: {e}", exc_info=True) 29 | return reply_message 30 | -------------------------------------------------------------------------------- /core/plugins/tools/add_plugin.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond - Plugin Creation Tool 3 | This module helps users create plugins using LLM responses. 4 | """ 5 | 6 | from core.llm.llm_factory import LLMFactory 7 | 8 | async def create_plugin(system_prompt, user_input): 9 | """ 10 | 利用 LLM 帮助用户创建插件 11 | 12 | Args: 13 | system_prompt (list or str): 机器人系统提示词。如果为str则转换为列表。 14 | user_input (str): 用户输入的插件需求 15 | 16 | Returns: 17 | str: 生成的插件代码或插件创建结果 18 | """ 19 | # 确保 system_prompt 是字符串,将其合并为一个提示 20 | if isinstance(system_prompt, list): 21 | system_prompt = ",".join(system_prompt) # 将列表转换为一个逗号分隔的字符串 22 | 23 | # 创建 LLM 客户端实例 24 | factory = LLMFactory() 25 | client = factory.create_llm_client() # 使用工厂类创建 LLM 客户端 26 | # context为空 27 | context = [] 28 | print("user_input:", user_input) 29 | print("system_prompt:", system_prompt) 30 | 31 | # 获取 LLM 回复 32 | try: 33 | llm_response = await client.get_response(context, user_input, system_prompt) 34 | return llm_response 35 | except Exception as e: 36 | return f" 插件创建失败: {e}" 37 | -------------------------------------------------------------------------------- /core/llm/llm_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/llm/llm_client.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | llm_client.py - 定义了 LLM 客户端接口。 10 | """ 11 | from abc import ABC, abstractmethod 12 | 13 | 14 | class LLMClient(ABC): 15 | """ 16 | LLM 客户端接口,定义了所有 LLM 客户端都需要实现的方法。 17 | """ 18 | 19 | @abstractmethod 20 | async def on_message(self, message, reply_message): 21 | """ 22 | 处理消息事件。 23 | 24 | Args: 25 | message (botpy.Message): 接收到的消息对象。 26 | reply_message (str): 待处理的回复消息。 27 | 28 | Returns: 29 | str: 处理后的回复消息。 30 | """ 31 | pass 32 | 33 | @abstractmethod 34 | async def get_response(self, context, user_input, system_prompt): 35 | """ 36 | 根据上下文和用户输入,获取 LLM 模型的回复。 37 | 38 | Args: 39 | context (list): 对话上下文,包含之前的对话内容。 40 | user_input (str): 用户输入的内容。 41 | system_prompt (str): 系统提示。 42 | 43 | Returns: 44 | str: LLM 模型生成的回复内容。 45 | """ 46 | pass 47 | -------------------------------------------------------------------------------- /core/utils/file_handler.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/utils/file_handler.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | file_handler.py 用于监控系统提示文件的修改,并在文件发生变化时重新加载系统提示 10 | """ 11 | 12 | import watchdog.events 13 | 14 | # logger.py模块 - <用于记录日志> 15 | from .logger import get_logger 16 | # config.py模块 - <获取系统提示文件的路径> 17 | from config import SYSTEM_PROMPT_FILE 18 | 19 | _log = get_logger() 20 | 21 | 22 | class ConfigFileHandler(watchdog.events.FileSystemEventHandler): 23 | """ 24 | 监控系统提示文件修改的处理器类 25 | """ 26 | 27 | def __init__(self, client): 28 | self.client = client 29 | 30 | def on_modified(self, event): 31 | """ 32 | 当系统提示文件被修改时,重新加载系统提示 33 | 34 | 参数: 35 | event (watchdog.events.FileSystemEvent): 文件系统事件对象 36 | """ 37 | if event.src_path.endswith(SYSTEM_PROMPT_FILE): 38 | self.client.reload_system_prompt() 39 | _log.info(" 系统提示文件已修改:") 40 | _log.info(f" ↳ 文件路径: {event.src_path}") 41 | _log.info(" ↳ 操作: 已重新加载") 42 | -------------------------------------------------------------------------------- /core/plugins/admin_command/admin_command.py: -------------------------------------------------------------------------------- 1 | from core.plugins import Plugin 2 | from core.utils.logger import get_logger 3 | 4 | logger = get_logger() 5 | 6 | 7 | class AdminCommandPlugin(Plugin): 8 | """ 9 | 处理管理员指令的插件 10 | """ 11 | 12 | def __init__(self, bot_client): 13 | super().__init__(bot_client) 14 | self.name = "AdminCommandPlugin" 15 | 16 | async def before_llm_message(self, message, reply_message, **kwargs): 17 | """ 18 | 在 LLM 处理消息之前处理管理员指令 19 | """ 20 | user_id = message.author.member_openid 21 | cleaned_content = message.content.strip().lower() 22 | group_id = message.group_openid 23 | 24 | # 判断是否为管理员命令 25 | if user_id == self.bot_client.ADMIN_ID: 26 | if cleaned_content == "restart": 27 | logger.info(" 收到管理员restart命令") 28 | await self.bot_client.restart_bot(group_id, message.id) 29 | return False # 阻止消息进入 LLM 处理 30 | elif cleaned_content == "reload": 31 | logger.info(" 收到管理员reload命令") 32 | await self.bot_client.hot_reload(group_id, message.id) 33 | return False # 阻止消息进入 LLM 处理 34 | 35 | return True # 继续进入 LLM 处理 36 | -------------------------------------------------------------------------------- /core/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Plugins - core/plugins/__init__.py 3 | Plugins核心定义 4 | """ 5 | 6 | 7 | class Plugin: 8 | """ 9 | 插件基类,定义了插件需要实现的方法 10 | """ 11 | 12 | def __init__(self, bot_client, name=None): 13 | """ 14 | 初始化插件 15 | 16 | 参数: 17 | bot_client (BotClient): 机器人客户端实例 18 | name (str): 插件名称 19 | """ 20 | self.bot_client = bot_client 21 | self.name = name if name else self.__class__.__name__ 22 | 23 | async def on_message(self, message=None, reply_message=None, **kwargs): 24 | """ 25 | 当收到消息时调用的方法 26 | 27 | 参数: 28 | message (Message): 收到的消息对象 29 | reply_message (str): 待处理的回复内容 30 | **kwargs: 其他可能的参数 31 | """ 32 | return reply_message 33 | 34 | async def before_llm_message(self, message=None, reply_message=None, **kwargs): 35 | """ 36 | 在 LLM 处理消息之前调用的方法 37 | 38 | 参数: 39 | message (Message): 收到的消息对象 40 | reply_message (str): 待处理的回复内容 41 | **kwargs: 其他可能的参数 42 | 43 | 返回: 44 | bool: True 表示继续处理,False 表示插件已处理 45 | """ 46 | return True 47 | 48 | async def on_ready(self): 49 | """ 50 | 当机器人启动完成时调用的方法 51 | """ 52 | pass 53 | -------------------------------------------------------------------------------- /core/utils/user_management.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/utils/user_management.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | 7 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 8 | Version: 1.2.0 (Pre_827001) 9 | 10 | user_management.py - 用户管理模块,负责用户名映射的加载和保存 11 | """ 12 | 13 | import json 14 | from config import USER_NAMES_FILE 15 | 16 | USER_NAMES = {} 17 | 18 | 19 | def clean_content(content): 20 | return content.replace('<@!', '').replace('>', '') 21 | 22 | 23 | def get_user_name(user_id): 24 | """根据用户 ID 获取用户名""" 25 | return USER_NAMES.get(user_id, f"消息来自未知用户:") 26 | 27 | 28 | def load_user_names(): 29 | """从文件加载用户名映射""" 30 | global USER_NAMES 31 | try: 32 | with open(USER_NAMES_FILE, "r", encoding="utf-8") as f: 33 | USER_NAMES = json.load(f) 34 | except FileNotFoundError: 35 | USER_NAMES = {} 36 | 37 | 38 | def save_user_names(): 39 | """保存用户名映射到文件""" 40 | with open(USER_NAMES_FILE, "w", encoding="utf-8") as f: 41 | json.dump(USER_NAMES, f, ensure_ascii=False, indent=4) 42 | 43 | 44 | async def add_new_user(user_id, nickname): 45 | global USER_NAMES 46 | USER_NAMES[user_id] = f"消息来自{nickname}:" 47 | save_user_names() 48 | return True 49 | 50 | 51 | def is_user_registered(user_id): 52 | return user_id in USER_NAMES 53 | -------------------------------------------------------------------------------- /core/api/routes.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, WebSocket, WebSocketDisconnect 2 | from core.api.controllers import plugin_controller, configs_controller, db_controller, es_controller 3 | from core.api.websocket_manager import websocket_manager 4 | from core.utils.logger import get_logger 5 | import asyncio 6 | 7 | router = APIRouter() 8 | logger = get_logger() 9 | 10 | # 普通API插件管理 11 | router.include_router(plugin_controller.router, prefix="/plugins", tags=["plugins"]) 12 | router.include_router(configs_controller.router, prefix="/configs", tags=["configs"]) 13 | router.include_router(db_controller.router, prefix="/db", tags=["db"]) 14 | router.include_router(es_controller.router, prefix="/es", tags=["es"]) 15 | 16 | # WebSocket日志推送 17 | @router.websocket("/ws/logs") 18 | async def websocket_endpoint(websocket: WebSocket): 19 | logger.info("新的WebSocket连接尝试") 20 | await websocket_manager.connect(websocket) 21 | logger.info(f"WebSocket连接成功建立: {websocket.client}") 22 | 23 | # 启动后台任务定期推送日志 24 | log_task = asyncio.create_task(websocket_manager.push_logs_to_websocket(websocket)) 25 | 26 | try: 27 | while True: 28 | data = await websocket.receive_text() 29 | logger.debug(f"收到WebSocket消息: {data[:50]}...") # 记录消息的前50个字符 30 | except WebSocketDisconnect: 31 | logger.info(f"WebSocket连接断开: {websocket.client}") 32 | finally: 33 | # 断开连接时取消日志推送任务 34 | log_task.cancel() 35 | await websocket_manager.disconnect(websocket) 36 | 37 | -------------------------------------------------------------------------------- /core/plugins/custom_reply/custom_reply.py: -------------------------------------------------------------------------------- 1 | # core/plugins/custom_reply.py 2 | import random 3 | from core.plugins import Plugin 4 | from core.plugins.tools.plugin_utils import load_plugin_config 5 | from core.utils.logger import get_logger 6 | 7 | logger = get_logger() 8 | 9 | 10 | class CustomReplyPlugin(Plugin): 11 | def __init__(self, bot_client): 12 | super().__init__(bot_client) 13 | self.custom_replies = load_plugin_config(__name__, "custom_replies.json") 14 | if not self.custom_replies: 15 | logger.error(" 🚨无法加载自定义回复,插件将无法正常工作。") 16 | else: 17 | logger.info(" CustomReplyPlugin 初始化成功") 18 | logger.debug(f" ↳ 加载的自定义回复: {self.custom_replies}") 19 | 20 | async def on_message(self, message, reply_message): 21 | try: 22 | logger.debug(" CustomReplyPlugin 被调用:") 23 | logger.debug(f" ↳ 当前回复消息: {reply_message}") 24 | 25 | if self.custom_replies: 26 | custom_reply = random.choice(self.custom_replies) 27 | if reply_message: 28 | reply_message = f"{reply_message}\n---\n{custom_reply}" 29 | else: 30 | reply_message = custom_reply 31 | 32 | logger.info(" 添加自定义回复:") 33 | logger.info(f" ↳ 内容: {custom_reply}") 34 | else: 35 | logger.warning(" 未加载自定义回复,跳过插件执行。") 36 | 37 | logger.debug(" 最终回复消息:") 38 | logger.debug(f" ↳ 内容: {reply_message}") 39 | return reply_message 40 | 41 | except Exception as e: 42 | logger.error(" 🚨CustomReplyPlugin 执行过程中发生错误:") 43 | logger.error(f" ↳ 错误详情: {e}", exc_info=True) 44 | return reply_message 45 | -------------------------------------------------------------------------------- /core/plugins/tools/plugin_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from core.utils.logger import get_logger 4 | 5 | logger = get_logger() 6 | 7 | PLUGIN_LOGGER_NAME = "AmyAlmond Plugins" 8 | 9 | def load_plugin_config(plugin_name, filename="config.json"): 10 | """ 11 | 从插件配置目录加载 JSON 格式的配置文件。如果文件不存在,则创建目录和空的 JSON 文件。 12 | 13 | Args: 14 | plugin_name (str): 插件名称,用于日志记录。 15 | filename (str, optional): 配置文件名,默认为 "config.json"。 16 | 17 | Returns: 18 | dict: 加载的配置信息,如果文件不存在或加载失败,则返回空字典 {}。 19 | """ 20 | 21 | # 获取项目根目录的绝对路径 22 | project_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) 23 | 24 | # 将配置目录设置为项目根目录下的 "configs/plugins" 文件夹 25 | config_dir = os.path.join(project_root_dir, "configs", "plugins") 26 | 27 | # 拼接配置文件的绝对路径 28 | config_path = os.path.join(config_dir, filename) 29 | 30 | # 如果配置目录不存在,创建该目录 31 | if not os.path.exists(config_dir): 32 | os.makedirs(config_dir) 33 | logger.info(f" 创建了配置目录: {config_dir}") 34 | 35 | # 如果配置文件不存在,创建一个空的 JSON 文件 36 | if not os.path.exists(config_path): 37 | with open(config_path, "w", encoding="utf-8") as f: 38 | json.dump({}, f, ensure_ascii=False, indent=4) 39 | logger.info(f" 创建了空的配置文件: {config_path}") 40 | 41 | try: 42 | with open(config_path, "r", encoding="utf-8") as f: 43 | config = json.load(f) 44 | logger.debug(f" 插件 {plugin_name} 请求了配置文件: {config_path}") 45 | logger.info(f" 成功加载了来自 {config_path} 的配置文件") 46 | logger.debug(f" 配置内容: {config}") 47 | return config 48 | except FileNotFoundError: 49 | logger.error(f" 未找到配置文件: {config_path}") 50 | return {} 51 | except json.JSONDecodeError as e: 52 | logger.warning(f" 请检查配置文件是否为空") 53 | logger.error(f" 无法加载来自 {config_path} 的配置文件: {e}") 54 | return {} 55 | -------------------------------------------------------------------------------- /core/plugins/event_bus.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Plugins - core/plugins/event_bus.py 3 | 事件总线模块 4 | """ 5 | from collections import defaultdict 6 | from typing import Callable 7 | from core.utils.logger import get_logger 8 | 9 | _log = get_logger() 10 | 11 | class EventBus: 12 | def __init__(self): 13 | self.subscribers = defaultdict(list) 14 | _log.info("EventBus initialized.") 15 | 16 | def subscribe(self, event_type: str, handler: Callable, plugin_name: str, priority: int = 0): 17 | """ 18 | 订阅事件,并按优先级排序。 19 | """ 20 | self.subscribers[event_type].append((handler, plugin_name, priority)) 21 | self.subscribers[event_type].sort(key=lambda x: x[2], reverse=True) # 按优先级排序 22 | _log.info(f"Subscribed {plugin_name} to {event_type} event with priority {priority}.") 23 | 24 | async def publish(self, event_type: str, *args, **kwargs): 25 | if event_type in self.subscribers: 26 | _log.info(f"Publishing event {event_type} to {len(self.subscribers[event_type])} subscribers.") 27 | handlers = [handler for handler, _, _ in self.subscribers[event_type]] 28 | last_result = None 29 | for handler in handlers: 30 | result = await handler(*args, **kwargs) 31 | if result is not None: 32 | if isinstance(result, dict): 33 | kwargs.update(result) 34 | elif isinstance(result, str): 35 | kwargs['reply_message'] = result 36 | elif isinstance(result, bool) and event_type == "before_llm_message": 37 | if not result: 38 | return kwargs.get('reply_message', last_result) 39 | last_result = result 40 | return kwargs.get('reply_message', last_result) 41 | _log.warning(f"No subscribers for event {event_type}.") 42 | return args[0] if args else None 43 | -------------------------------------------------------------------------------- /core/api/websocket_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from fastapi import WebSocket 3 | from typing import List 4 | import logging 5 | from starlette.websockets import WebSocketDisconnect 6 | 7 | from core.utils.logger import get_latest_logs 8 | 9 | logger = logging.getLogger("bot_logger") 10 | 11 | 12 | class WebSocketManager: 13 | def __init__(self): 14 | self.active_connections: List[WebSocket] = [] 15 | self.lock = asyncio.Lock() 16 | 17 | async def connect(self, websocket: WebSocket): 18 | """处理新的 WebSocket 连接""" 19 | await websocket.accept() 20 | async with self.lock: 21 | self.active_connections.append(websocket) 22 | logger.info(f"新的 WebSocket 连接已建立,当前连接总数: {len(self.active_connections)}") 23 | 24 | async def disconnect(self, websocket: WebSocket): 25 | """处理 WebSocket 断开连接""" 26 | async with self.lock: 27 | if websocket in self.active_connections: 28 | self.active_connections.remove(websocket) 29 | logger.info(f"WebSocket 连接已断开,当前连接总数: {len(self.active_connections)}") 30 | 31 | async def push_logs_to_websocket(self, websocket: WebSocket): 32 | """定期读取日志文件并推送给WebSocket客户端""" 33 | last_read_position = 0 34 | try: 35 | while True: 36 | latest_logs = get_latest_logs() 37 | 38 | if latest_logs: 39 | new_logs = "".join(latest_logs[last_read_position:]) 40 | if new_logs: # 仅在有新日志时发送 41 | await websocket.send_text(new_logs) 42 | last_read_position = len(latest_logs) 43 | 44 | await asyncio.sleep(2) # 每2秒推送一次 45 | except WebSocketDisconnect: 46 | logger.info("客户端已断开连接,停止日志推送") 47 | except Exception as e: 48 | logger.error(f"推送日志过程中出现错误: {e}") 49 | await self.disconnect(websocket) 50 | 51 | 52 | websocket_manager = WebSocketManager() 53 | -------------------------------------------------------------------------------- /core/ace/rate_limiting.py: -------------------------------------------------------------------------------- 1 | # core/ace/rate_limiting.py 2 | 3 | import time 4 | from collections import defaultdict, deque 5 | from core.utils.logger import get_logger 6 | from config import REQUEST_LIMIT_TIME_FRAME, REQUEST_LIMIT_COUNT, GLOBAL_RATE_LIMIT 7 | 8 | _log = get_logger() 9 | 10 | 11 | class RateLimiter: 12 | def __init__(self): 13 | # 用于存储每个用户的请求时间 14 | self.user_requests = defaultdict(lambda: deque(maxlen=REQUEST_LIMIT_COUNT)) 15 | # 用于存储全局的请求时间 16 | self.global_requests = deque(maxlen=GLOBAL_RATE_LIMIT) 17 | 18 | def is_request_allowed(self, user_id): 19 | """ 20 | 检查用户的请求频率是否超过限制。 21 | 22 | 参数: 23 | user_id (str): 用户ID 24 | 25 | 返回: 26 | bool: 如果用户的请求频率在限制范围内,返回 True,否则返回 False 27 | """ 28 | current_time = time.time() 29 | 30 | # 全局请求频率检查 31 | if not self.is_global_request_allowed(current_time): 32 | _log.warning(" 🚫全局请求频率过高,已拦截") 33 | return False 34 | 35 | # 过滤掉过期的请求 36 | user_requests = self.user_requests[user_id] 37 | while user_requests and user_requests[0] < current_time - REQUEST_LIMIT_TIME_FRAME: 38 | user_requests.popleft() 39 | 40 | if len(user_requests) >= REQUEST_LIMIT_COUNT: 41 | _log.warning(f" 🚫用户 {user_id} 的请求频率过高,已拦截") 42 | return False 43 | 44 | # 记录用户的请求时间 45 | user_requests.append(current_time) 46 | 47 | # 更新全局请求队列 48 | self.global_requests.append(current_time) 49 | 50 | return True 51 | 52 | def is_global_request_allowed(self, current_time): 53 | """ 54 | 检查全局的请求频率是否超过限制。 55 | 56 | 参数: 57 | current_time (float): 当前时间戳 58 | 59 | 返回: 60 | bool: 如果全局的请求频率在限制范围内,返回 True,否则返回 False 61 | """ 62 | # 过滤掉过期的全局请求 63 | while self.global_requests and self.global_requests[0] < current_time - 60: # 全局限制在1分钟内 64 | self.global_requests.popleft() 65 | 66 | # 检查全局请求数是否超过限制 67 | return len(self.global_requests) < GLOBAL_RATE_LIMIT 68 | -------------------------------------------------------------------------------- /core/ace/input_validation.py: -------------------------------------------------------------------------------- 1 | import re 2 | from core.utils.logger import get_logger 3 | 4 | _log = get_logger() 5 | 6 | 7 | class InputValidator: 8 | def __init__(self): 9 | self.whitelist_patterns = [ 10 | # 放宽对部分内容的限制,比如允许合法的HTML标签(如 等) 11 | r"<(b|i|u|strong|em|code)>.*?", 12 | # 允许URL(不带JavaScript) 13 | r"https?://[^\s]+", 14 | ] 15 | 16 | self.suspicious_patterns = [ 17 | # SQL注入模式检测 18 | r"(?:')|(?:--)|(/\*(?:.|[\n\r])*?\*/)|(\b(select|update|delete|insert|truncate|alter|drop)\b)", 19 | # XSS攻击检测,强化对特定JavaScript行为的检测 20 | r"(.*?)|(<.*?javascript:.*?>)|(<.*?\\s+on\\w+\\s*=\\s*['\"].*?['\"].*?>)", 21 | # 执行类代码检测 22 | r"\b(exec|execute|system|shell|eval|os\.)\b", 23 | # 特定字符的组合检测 24 | r"[<>\"'/;]&&[^<>\"'/;]*", # 要求特定字符旁无合法字符时才拦截 25 | ] 26 | 27 | def validate(self, content): 28 | """ 29 | 验证用户输入,防止SQL注入、XSS等攻击。 30 | 31 | 参数: 32 | content (str): 用户输入的内容 33 | 34 | 返回: 35 | bool: 如果输入合法,返回 True,否则返回 False 36 | """ 37 | if self._matches_whitelist(content): 38 | return True 39 | 40 | if self._matches_suspicious_patterns(content): 41 | _log.warning(f" 🚫检测到可疑的用户输入: {content}") 42 | return False 43 | 44 | return True 45 | 46 | def _matches_whitelist(self, content): 47 | """ 48 | 检查输入内容是否符合白名单规则。 49 | 50 | 参数: 51 | content (str): 用户输入的内容 52 | 53 | 返回: 54 | bool: 如果符合白名单规则,返回 True,否则返回 False 55 | """ 56 | for pattern in self.whitelist_patterns: 57 | if re.match(pattern, content, re.IGNORECASE): 58 | return True 59 | return False 60 | 61 | def _matches_suspicious_patterns(self, content): 62 | """ 63 | 检查输入内容是否符合可疑模式。 64 | 65 | 参数: 66 | content (str): 用户输入的内容 67 | 68 | 返回: 69 | bool: 如果符合可疑模式,返回 True,否则返回 False 70 | """ 71 | for pattern in self.suspicious_patterns: 72 | if re.search(pattern, content, re.IGNORECASE): 73 | _log.debug(f" 🚫匹配到的模式: {pattern}") 74 | return True 75 | return False 76 | -------------------------------------------------------------------------------- /core/plugins/registration_reply/registration_reply.py: -------------------------------------------------------------------------------- 1 | from core.plugins import Plugin 2 | from core.plugins.tools.plugin_utils import load_plugin_config 3 | from core.utils.logger import get_logger 4 | from core.utils.user_management import add_new_user 5 | 6 | logger = get_logger() 7 | 8 | 9 | class RegistrationReplyPlugin(Plugin): 10 | def __init__(self, bot_client): 11 | super().__init__(bot_client) 12 | self.name = "RegistrationReplyPlugin" 13 | self.load_responses() 14 | 15 | def load_responses(self): 16 | """ 17 | 使用插件工具集加载插件自定义回复信息 18 | """ 19 | self.responses = load_plugin_config(__name__, 'registration_replies.json') 20 | if self.responses: 21 | logger.info(f"加载自定义回复信息成功: {self.responses}") # 添加日志确认加载成功 22 | else: 23 | self.responses = {} 24 | logger.warning("未能加载自定义回复信息,使用默认回复内容") 25 | 26 | async def on_registration(self, group_id, user_id, cleaned_content, msg_id): 27 | """ 28 | 自定义用户注册消息处理逻辑 29 | 30 | 参数: 31 | group_id (str): 群组的唯一标识符 32 | user_id (str): 用户的唯一标识符 33 | cleaned_content (str): 用户发送的消息内容 34 | msg_id (str): 消息的唯一标识符 35 | """ 36 | # 使用插件的自定义回复逻辑 37 | if user_id not in self.bot_client.pending_users: 38 | response = self.responses.get("initial_prompt", 39 | "请@我,然后回复你的昵称,这将会自动录入我的记忆,方便我永远记得你~") 40 | await self.bot_client.api.post_group_message(group_openid=group_id, content=response, msg_id=msg_id) 41 | self.bot_client.pending_users[user_id] = True 42 | return True # 确保返回 True 表示插件已经处理事件 43 | else: 44 | if cleaned_content.strip(): 45 | # 添加用户存储逻辑 46 | if await add_new_user(user_id, cleaned_content.strip()): 47 | response = self.responses.get("success_prompt", "原来是{}吗 ... 我已经记住你了~").format( 48 | cleaned_content) 49 | await self.bot_client.api.post_group_message(group_openid=group_id, content=response, msg_id=msg_id) 50 | else: 51 | response = self.responses.get("error_prompt", "存储用户信息时发生错误,请稍后再试。") 52 | await self.bot_client.api.post_group_message(group_openid=group_id, content=response, msg_id=msg_id) 53 | else: 54 | response = self.responses.get("repeat_prompt", 55 | "请@我,然后回复你的昵称,这将会自动录入我的记忆,方便我永远记得你~") 56 | await self.bot_client.api.post_group_message(group_openid=group_id, content=response, msg_id=msg_id) 57 | 58 | self.bot_client.pending_users.pop(user_id, None) # 移除用户以结束注册流程 59 | return True # 同样返回 True 表示事件已处理 60 | -------------------------------------------------------------------------------- /core/db/log_db_status.py: -------------------------------------------------------------------------------- 1 | import os 2 | from core.db.elasticsearch_index_manager import ElasticsearchIndexManager 3 | from core.utils.mongodb_utils import MongoDBUtils 4 | from core.utils.logger import get_logger 5 | from config import ELASTICSEARCH_URL, ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD, MONGODB_URI, MONGODB_USERNAME, MONGODB_PASSWORD 6 | 7 | _log = get_logger() 8 | 9 | def log_elasticsearch_status(log_dir): 10 | """ 11 | 记录Elasticsearch的详细状态信息到指定文件中 12 | """ 13 | es_manager = ElasticsearchIndexManager() 14 | es_status_file = os.path.join(log_dir, "elasticsearch_status.txt") 15 | 16 | try: 17 | cluster_health = es_manager.es.cluster.health() 18 | indices_stats = es_manager.es.indices.stats() 19 | nodes_stats = es_manager.es.nodes.stats() 20 | cat_indices = es_manager.es.cat.indices(format="json") 21 | 22 | with open(es_status_file, 'w', encoding='utf-8') as f: 23 | f.write("Elasticsearch 集群健康状态:\n") 24 | f.write(f"{cluster_health}\n\n") 25 | f.write("Elasticsearch 索引统计信息:\n") 26 | f.write(f"{indices_stats}\n\n") 27 | f.write("Elasticsearch 节点统计信息:\n") 28 | f.write(f"{nodes_stats}\n\n") 29 | f.write("Elasticsearch 索引详情:\n") 30 | f.write(f"{cat_indices}\n") 31 | 32 | _log.info("Elasticsearch状态信息已记录") 33 | except Exception as e: 34 | _log.error(f"记录Elasticsearch状态时发生错误: {e}") 35 | with open(es_status_file, 'w', encoding='utf-8') as f: 36 | f.write(f"记录Elasticsearch状态时发生错误: {e}\n") 37 | 38 | def log_mongodb_status(log_dir): 39 | """ 40 | 记录MongoDB的详细状态信息到指定文件中 41 | """ 42 | mongo_utils = MongoDBUtils() 43 | mongo_status_file = os.path.join(log_dir, "mongodb_status.txt") 44 | 45 | try: 46 | server_status = mongo_utils.db.command("serverStatus") 47 | db_stats = mongo_utils.db.command("dbStats") 48 | collections_stats = {} 49 | for collection_name in mongo_utils.db.list_collection_names(): 50 | collections_stats[collection_name] = mongo_utils.db.command("collStats", collection_name) 51 | 52 | with open(mongo_status_file, 'w', encoding='utf-8') as f: 53 | f.write("MongoDB 服务器状态:\n") 54 | f.write(f"{server_status}\n\n") 55 | f.write("MongoDB 数据库状态:\n") 56 | f.write(f"{db_stats}\n\n") 57 | f.write("MongoDB 集合统计信息:\n") 58 | for collection_name, stats in collections_stats.items(): 59 | f.write(f"集合: {collection_name}\n") 60 | f.write(f"{stats}\n\n") 61 | 62 | _log.info("MongoDB状态信息已记录") 63 | except Exception as e: 64 | _log.error(f"记录MongoDB状态时发生错误: {e}") 65 | with open(mongo_status_file, 'w', encoding='utf-8') as f: 66 | f.write(f"记录MongoDB状态时发生错误: {e}\n") 67 | 68 | -------------------------------------------------------------------------------- /core/bot/user_registration.py: -------------------------------------------------------------------------------- 1 | from core.utils.logger import get_logger 2 | from core.utils.user_management import add_new_user 3 | 4 | _log = get_logger() 5 | 6 | async def handle_new_user_registration(client, group_id, user_id, cleaned_content, msg_id): 7 | """ 8 | 处理新用户注册,检查用户是否已注册,并提示用户提供昵称 9 | """ 10 | try: 11 | # 插件处理部分 12 | if client.plugin_manager: 13 | plugin_response = await client.plugin_manager.handle_event("on_registration", group_id=group_id, user_id=user_id, cleaned_content=cleaned_content, msg_id=msg_id) 14 | if plugin_response: 15 | _log.info(f"插件已处理注册逻辑,跳过默认处理") # 添加日志以确认插件已处理事件 16 | return # 如果插件处理成功,结束函数 17 | 18 | # 默认处理逻辑 19 | if user_id not in client.pending_users: 20 | _log.info(" 检测到未注册用户:") 21 | _log.info(f" ↳ 用户ID: {user_id}") 22 | _log.info(" ↳ 操作: 提示用户提供昵称") 23 | 24 | await client.api.post_group_message( 25 | group_openid=group_id, 26 | content="请@我,然后回复你的昵称,这将会自动录入我的记忆,方便我永远记得你~", 27 | msg_id=msg_id 28 | ) 29 | client.pending_users[user_id] = True 30 | else: 31 | if cleaned_content.strip(): 32 | if await add_new_user(user_id, cleaned_content.strip()): 33 | _log.info(" 新用户注册成功:") 34 | _log.info(f" ↳ 用户ID: {user_id}") 35 | _log.info(f" ↳ 昵称: {cleaned_content.strip()}") 36 | 37 | await client.api.post_group_message( 38 | group_openid=group_id, 39 | content=f"原来是{cleaned_content}吗 ... 我已经记住你了~", 40 | msg_id=msg_id 41 | ) 42 | else: 43 | _log.info(" ✅用户已注册:") 44 | _log.info(f" ↳ 用户ID: {user_id}") 45 | _log.info(f" ↳ 昵称: {cleaned_content.strip()}") 46 | 47 | await client.api.post_group_message( 48 | group_openid=group_id, 49 | content="你的昵称已经录入~", 50 | msg_id=msg_id 51 | ) 52 | else: 53 | _log.info(" 用户未提供昵称,再次提示:") 54 | _log.info(f" ↳ 用户ID: {user_id}") 55 | _log.info(" ↳ 操作: 提示用户提供昵称") 56 | 57 | await client.api.post_group_message( 58 | group_openid=group_id, 59 | content="请@我,然后回复你的昵称,这将会自动录入我的记忆,方便我永远记得你~", 60 | msg_id=msg_id 61 | ) 62 | client.pending_users.pop(user_id, None) 63 | except Exception as e: 64 | _log.error(" 🚨新用户注册过程中出错:") 65 | _log.error(f" ↳ 群组ID: {group_id}") 66 | _log.error(f" ↳ 错误详情: {e}", exc_info=True) 67 | -------------------------------------------------------------------------------- /core/api/controllers/configs_controller.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Body 2 | from core.ace.secure import SecureInterface 3 | from pydantic import BaseModel 4 | from core.plugins.plugin_manager import PluginManager 5 | from core.utils.logger import get_logger 6 | from config import ( # 导入配置管理方法 7 | get_all_config, 8 | add_config, 9 | update_config, 10 | delete_config, 11 | ) 12 | 13 | 14 | logger = get_logger() 15 | router = APIRouter() 16 | plugin_manager = PluginManager(bot_client=None) 17 | 18 | class UpdateConfigModel(BaseModel): 19 | value: str 20 | 21 | 22 | @router.get("/get_all") 23 | async def get_configs(): 24 | """获取所有配置""" 25 | secure_interface = SecureInterface() 26 | if not secure_interface.verify_request(): 27 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 28 | 29 | try: 30 | configs = get_all_config() 31 | return {"status": "success", "configs": configs} 32 | except Exception as e: 33 | logger.error(f"获取配置时出错: {e}") 34 | raise HTTPException(status_code=500, detail=str(e)) 35 | 36 | 37 | @router.post("/add") 38 | async def add_config_api(key: str = Body(...), value: str = Body(...)): 39 | """添加新的配置项""" 40 | secure_interface = SecureInterface() 41 | if not secure_interface.verify_request(): 42 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 43 | 44 | try: 45 | if add_config(key, value): 46 | return {"status": "success", "message": f"配置项 '{key}' 添加成功"} 47 | else: 48 | raise HTTPException(status_code=400, detail=f"配置项 '{key}' 已存在") 49 | except Exception as e: 50 | logger.error(f"添加配置项时出错: {e}") 51 | raise HTTPException(status_code=500, detail=str(e)) 52 | 53 | 54 | @router.put("/update/{key}") 55 | async def update_config_api(key: str, body: UpdateConfigModel): 56 | """修改或添加配置项""" 57 | secure_interface = SecureInterface() 58 | if not secure_interface.verify_request(): 59 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 60 | 61 | try: 62 | if update_config(key, body.value): 63 | return {"status": "success", "message": f"配置项 '{key}' 修改成功"} 64 | else: 65 | raise HTTPException(status_code=404, detail=f"配置项 '{key}' 不存在") 66 | except Exception as e: 67 | logger.error(f"修改配置项时出错: {e}") 68 | raise HTTPException(status_code=500, detail=str(e)) 69 | 70 | 71 | @router.delete("/delete/{key}") 72 | async def delete_config_api(key: str): 73 | """删除配置项""" 74 | secure_interface = SecureInterface() 75 | if not secure_interface.verify_request(): 76 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 77 | 78 | try: 79 | if delete_config(key): 80 | return {"status": "success", "message": f"配置项 '{key}' 删除成功"} 81 | else: 82 | raise HTTPException(status_code=404, detail=f"配置项 '{key}' 不存在") 83 | except Exception as e: 84 | logger.error(f"删除配置项时出错: {e}") 85 | raise HTTPException(status_code=500, detail=str(e)) -------------------------------------------------------------------------------- /core/plugins/plugins.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Plugins - core/plugins/plugins.py 3 | Plugins核心模块 4 | """ 5 | import importlib 6 | import inspect 7 | import os 8 | import traceback 9 | import yaml 10 | 11 | from core.plugins import Plugin 12 | from core.utils.logger import get_logger 13 | 14 | # 获取 logger 对象 15 | logger = get_logger() 16 | 17 | 18 | def load_plugins(bot_client): 19 | """ 20 | 加载插件并返回插件实例、路径和配置信息 21 | 22 | Args: 23 | bot_client (MyClient): BotClient 实例 24 | """ 25 | plugins = [] 26 | plugin_paths = [] # 用于存储插件路径 27 | plugin_configs = [] # 用于存储插件的配置信息 28 | plugin_priorities = [] # 用于存储插件优先级 29 | plugins_dir = os.path.join("core", "plugins") 30 | 31 | if not os.path.isdir(plugins_dir): 32 | logger.error(f" 插件目录不存在: '{plugins_dir}'") 33 | return plugins, plugin_paths, plugin_configs, plugin_priorities 34 | 35 | logger.info(f" LOADING >>>>>>") 36 | logger.info(f" ↳ 插件目录: '{plugins_dir}'") 37 | 38 | for plugin_folder in os.listdir(plugins_dir): 39 | plugin_dir = os.path.join(plugins_dir, plugin_folder) 40 | if os.path.isdir(plugin_dir): 41 | plugin_name = plugin_folder # 插件名与文件夹名相同 42 | plugin_file = os.path.join(plugin_dir, f"{plugin_name}.py") 43 | yaml_file = os.path.join(plugin_dir, f"{plugin_name}.yaml") 44 | 45 | if os.path.exists(plugin_file) and os.path.exists(yaml_file): 46 | full_module_name = f"core.plugins.{plugin_name}.{plugin_name}" 47 | try: 48 | logger.info(f" 尝试加载插件: {full_module_name}") 49 | 50 | # 加载 YAML 配置文件 51 | with open(yaml_file, 'r', encoding='utf-8') as f: 52 | plugin_config = yaml.safe_load(f) 53 | 54 | # 输出插件配置信息,便于调试 55 | plugin_id = plugin_config.get('plugin_id', '无UUID') 56 | version = plugin_config.get('version', '未知版本') 57 | author = plugin_config.get('author', '未知作者') 58 | priority = plugin_config.get('priority', 0) # 从配置中获取优先级,默认为 0 59 | logger.debug(f" 插件ID: {plugin_id}, 版本: {version}, 作者: {author}, 优先级: {priority}") 60 | 61 | # 动态导入插件模块 62 | module = importlib.import_module(full_module_name) 63 | 64 | # 输出模块中的所有类,方便调试 65 | for name, obj in inspect.getmembers(module): 66 | if inspect.isclass(obj) and issubclass(obj, Plugin) and obj != Plugin: 67 | plugin = obj(bot_client) 68 | plugins.append(plugin) 69 | plugin_paths.append(plugin_dir) # 存储插件路径 70 | plugin_configs.append(plugin_config) # 存储插件配置 71 | plugin_priorities.append(priority) # 存储插件优先级 72 | logger.info(f" 成功加载插件: {name}") 73 | 74 | except Exception as e: 75 | logger.error(f" 加载插件时发生错误: {e}") 76 | logger.debug(traceback.format_exc()) 77 | 78 | logger.info(f" 插件加载完成,总数量: {len(plugins)}") 79 | return plugins, plugin_paths, plugin_configs, plugin_priorities # 返回插件实例、路径、配置信息和优先级 80 | -------------------------------------------------------------------------------- /core/llm/plugins/google_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | [官方] Google AI Platform 客户端,实现了 LLMClient 接口。 3 | 提供商: Google Cloud Platform (https://cloud.google.com/) 4 | 文档: https://cloud.google.com/vertex-ai/docs/generative-ai/language/overview 5 | 6 | 请注意:如需使用,请先取消 第12-13行 和 38-40行 的注释,并安装对应的包。 7 | """ 8 | import time 9 | import asyncio 10 | from typing import List, Dict 11 | 12 | # from google.api_core.client_options import ClientOptions 13 | # from google.cloud import aiplatform_v1 14 | 15 | from core.utils.logger import get_logger 16 | from core.llm.llm_client import LLMClient 17 | _log = get_logger() 18 | 19 | 20 | class GoogleClient(LLMClient): 21 | """ 22 | Google AI Platform 客户端,实现了 LLMClient 接口。 23 | 24 | 支持模型: 25 | - text-bison-001 26 | - code-bison-001 27 | - image-bison-001 28 | """ 29 | 30 | async def on_message(self, message, reply_message): 31 | pass 32 | 33 | def __init__(self, google_api_key: str, google_model: str, google_api_url: str): 34 | self.google_api_key = google_api_key 35 | self.google_model = google_model 36 | self.google_api_url = google_api_url 37 | 38 | # 使用 google.api_core.client_options 设置 API 密钥 39 | # client_options = ClientOptions(api_key=self.google_api_key) 40 | # self.client = aiplatform_v1.PredictionServiceClient(client_options=client_options) 41 | 42 | # 初始化 last_request_time 和 last_request_content 43 | self.last_request_time = 0 44 | self.last_request_content = None 45 | 46 | async def get_response(self, context: List[Dict], user_input: str, system_prompt: str) -> str: 47 | """ 48 | 根据给定的上下文和用户输入,从 Google AI Platform 模型获取回复 49 | 50 | 参数: 51 | context (list): 对话上下文,包含之前的对话内容 52 | user_input (str): 用户的输入内容 53 | system_prompt (str): 系统提示 54 | 55 | 返回: 56 | str: Google AI Platform 模型生成的回复内容 57 | 58 | 异常: 59 | Exception: 当请求 Google AI Platform API 出现问题时引发 60 | """ 61 | # 检查是否为重复请求 62 | if time.time() - self.last_request_time < 0.6 and user_input == self.last_request_content and "" not in user_input: 63 | _log.warning(f"Duplicate request detected and ignored: {user_input}") 64 | return None 65 | 66 | # 构建请求实例 67 | endpoint = f"{self.google_api_url}/projects//locations//publishers//models/{self.google_model}:predict" 68 | # 将上下文信息合并到单个字符串中 69 | prompt = system_prompt + "".join([f"{message['role']}: {message['content']}" for message in context]) + f"user: {user_input}" 70 | 71 | # 构建请求参数 72 | parameters = {} 73 | # 构建请求体 74 | instance = {"content": prompt} 75 | response = await asyncio.to_event_loop(self.client.predict( 76 | endpoint=endpoint, 77 | instances=[instance], 78 | parameters=parameters, 79 | )) 80 | # 更新 last_request_time 和 last_request_content 81 | self.last_request_time = time.time() 82 | self.last_request_content = user_input 83 | # 处理响应 84 | if response.predictions: 85 | reply = response.predictions[0]['content'] 86 | _log.info(f"Google AI Platform response: {reply}") 87 | return reply 88 | else: 89 | _log.warning(f"Google AI Platform response is empty for user input: {user_input}.") 90 | return "子网故障,过来楼下检查一下/。" -------------------------------------------------------------------------------- /core/llm/plugins/anthropic_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | [官方] Anthropic API 客户端,实现了 LLMClient 接口。 3 | 提供商: Anthropic (https://www.anthropic.com/) 4 | 文档: https://docs.anthropic.com/ 5 | """ 6 | import time 7 | import httpx 8 | from core.utils.logger import get_logger 9 | from core.llm.llm_client import LLMClient 10 | 11 | _log = get_logger() 12 | 13 | 14 | class AnthropicClient(LLMClient): 15 | """ 16 | Anthropic API 客户端,实现了 LLMClient 接口。 17 | 18 | 支持模型: 19 | - claude-v1 20 | - claude-instant-v1 21 | """ 22 | 23 | async def on_message(self, message, reply_message): 24 | pass 25 | 26 | def __init__(self, anthropic_secret, anthropic_model, anthropic_api_url): 27 | self.anthropic_secret = anthropic_secret 28 | self.anthropic_model = anthropic_model 29 | self.anthropic_api_url = anthropic_api_url 30 | 31 | # 初始化 last_request_time 和 last_request_content 32 | self.last_request_time = 0 33 | self.last_request_content = None 34 | 35 | async def get_response(self, context, user_input, system_prompt): 36 | """ 37 | 根据给定的上下文和用户输入,从 Anthropic 模型获取回复 38 | 39 | 参数: 40 | context (list): 对话上下文,包含之前的对话内容 41 | user_input (str): 用户的输入内容 42 | system_prompt (str): 系统提示 43 | 44 | 返回: 45 | str: Anthropic 模型生成的回复内容 46 | 47 | 异常: 48 | httpx.HTTPStatusError: 当请求 Anthropic API 出现问题时引发 49 | """ 50 | # 检查是否为重复请求 51 | if time.time() - self.last_request_time < 0.6 and user_input == self.last_request_content and "" not in user_input: 52 | _log.warning(f"Duplicate request detected and ignored: {user_input}") 53 | return None 54 | 55 | # Anthropic 使用稍微不同的 prompt 格式,将 system_prompt 附加到 user_input 前面 56 | prompt = f"{system_prompt} {user_input}" 57 | 58 | payload = { 59 | "model": self.anthropic_model, 60 | "prompt": prompt, 61 | "max_tokens_to_sample": 3450, # 与 OpenAI 的 max_tokens 参数对应 62 | "temperature": 0.85, 63 | "top_p": 1, 64 | } 65 | headers = { 66 | "Content-Type": "application/json", 67 | "Authorization": f"Bearer {self.anthropic_secret}" 68 | } 69 | 70 | # 记录请求的payload 71 | _log.debug(f"Request payload: {payload}") 72 | 73 | try: 74 | async with httpx.AsyncClient() as client: 75 | response = await client.post(self.anthropic_api_url, headers=headers, json=payload) 76 | response.raise_for_status() 77 | response_data = response.json() 78 | 79 | # 记录完整的响应数据 80 | _log.debug(f"Response data: {response_data}") 81 | 82 | # Anthropic 的响应结构与 OpenAI 不同,需要提取 'completion' 字段 83 | reply = response_data.get('completion') 84 | 85 | # 更新 last_request_time 和 last_request_content 86 | self.last_request_time = time.time() 87 | self.last_request_content = user_input 88 | 89 | if reply is None: 90 | _log.warning(f"Anthropic response is empty for user input: {user_input}.") 91 | else: 92 | # 记录 Anthropic 的回复内容 93 | _log.info(f"Anthropic response: {reply}") 94 | 95 | return reply 96 | except httpx.HTTPStatusError as e: 97 | _log.error(f"Error requesting from Anthropic API: {e}", exc_info=True) 98 | return "子网故障,过来楼下检查一下/。" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # AmyAlmond 聊天机器人 4 | 5 | [![License](https://img.shields.io/badge/license-MPL2-red.svg)](hhttps://opensource.org/license/mpl-2-0) 6 | [![Python Version](https://img.shields.io/badge/python-3.8%2B-blue)](https://www.python.org/downloads/) 7 | [![GitHub Stars](https://img.shields.io/github/stars/shuakami/amyalmond_bot.svg)](https://github.com/shuakami/amyalmond_bot/stargazers) 8 | [![Build Status](https://img.shields.io/badge/build-passing-brightgreen.svg)](https://github.com/shuakami/amyalmond_bot) 9 | [![Version](https://img.shields.io/badge/version-1.3.0_(Stable_923001)-yellow.svg)](https://github.com/shuakami/amyalmond_bot/releases) 10 | 11 | [English](README_en.md) | 简体中文 12 | 13 | ⭐ 强大的聊天机器人,助力群聊智能化 ⭐ 14 | 15 | [功能特性](#功能特性) • [效果图](#先看效果) • [官方文档](#安装部署开发) • [开发与贡献](#开发与贡献) • [许可证](#许可证) 16 |
17 | 18 | > ⚠️此版本为抢先体验版。如追求稳定请使用 1.2.0_(Stable_827001) 19 | 20 | ## 功能特性 21 | 22 | AmyAlmond 是一个基于 LLM API 的智能聊天机器人,旨在无缝集成到 QQ 群聊、频道中。 23 | 24 | 通过利用LLM API,AmyAlmond 提供上下文感知的智能回复,增强用户互动体验,并支持长期记忆管理。无论是自动化回复还是提升用户参与度,她都能够轻松处理复杂的对话场景。 25 | 26 | - 🌈 她使用**LLM API**,根据对话上下文生成类似人类的回复,且Prompt可定制。 27 | - 💗 她使用QQ官方 Python SDK,再也不怕被封锁。 28 | - 🔥 她会自动识别并记住用户姓名,提供个性化的互动体验。 29 | - 🧠 她拥有**长期和短期记忆能力**,能够记录并引用重要信息,保障对话的延续性。 30 | - 🐳 支持管理员通过特定命令控制机器人的行为。 31 | - ⭐ **全配置支持热更新**,减少重启次数,提高效率。 32 | - 🪝 日志、代码注释详细,方便调试和监控。 33 | 34 | ## 先看效果? 35 | ![效果图_对话注册](/dist/background/chat-demo.png) 36 | ![效果图_记忆上下文](/dist/background/chat-memory-demo.png) 37 | 38 | ## 安装/部署/开发 39 | 40 | 41 | 42 | 文档数据库 43 | 44 | 45 |
46 | 点击图片以跳转 47 |
48 | 49 | ## 开发与贡献 50 | 51 | 我们非常欢迎您。无论是提供新功能、修复问题,还是改进文档,都可以~ 52 | 53 | ### 分支策略 54 | 55 | 我们采用 Git Flow 分支管理模型: 56 | 57 | - **main**: 主分支,始终保持稳定可用的版本。 58 | - **develop**: 开发分支,所有新功能在此分支上集成。 59 | - **feature/**: 功能分支,从 `develop` 分支分出,开发完成后合并回 `develop`。 60 | - **hotfix/**: 修复分支,用于修复紧急问题,完成后合并回 `main` 和 `develop`。 61 | 62 | ### 提交规范 63 | 64 | 1. **Fork 本仓库** 65 | 在您的 GitHub 账户中 fork 本项目。 66 | 67 | 2. **创建分支** 68 | 为您的改动创建一个新的功能分支: 69 | ```bash 70 | git checkout -b feature/AmazingFeature 71 | ``` 72 | 73 | 74 | 3. **提交更改** 75 | 提交您的代码,并确保提交信息简洁明了: 76 | ```bash 77 | git commit -m 'Add some AmazingFeature' 78 | ``` 79 | 80 | 4. **推送到分支** 81 | 推送分支到 GitHub: 82 | ```bash 83 | git push origin feature/AmazingFeature 84 | ``` 85 | 86 | 5. **创建 Pull Request** 87 | 在 GitHub 上创建一个 Pull Request,描述您的更改内容及其影响。 88 | 89 | 90 | 91 | ## 许可证 92 | [![License: MPL 2.0](https://img.shields.io/badge/License-MPL_2.0-brightgreen.svg)](https://opensource.org/licenses/MPL-2.0) 93 | 94 | AmyAlmond 遵循 [MPL2 许可证](LICENSE)。您可以自由使用、修改和分发本项目,但在分发修改后的版本时,您需要开放源代码并保留原作者的版权声明。 95 | 96 | ## 免责声明 97 | 98 | 本项目仅供学习和研究使用,开发者不对任何因使用本项目而导致的后果负责。在使用本项目时,请确保遵守相关法律法规,并尊重他人的知识产权。 99 | 100 | ## 功能排期表 101 | 详见 [Project](https://github.com/users/shuakami/projects/1) 102 | 103 | q(≧▽≦q) 看了这么久了~ 给我们一个 ⭐️ 呗? 104 | 105 | 106 | 107 | 113 | 119 | amyalmond_bot Chart 123 | 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /core/llm/plugins/aliyun_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | [官方] 阿里云通义千问 API 客户端,实现了 LLMClient 接口。 3 | 提供商: 阿里云 (https://www.aliyun.com/) 4 | 文档: https://help.aliyun.com/document_detail/604285.html 5 | """ 6 | import time 7 | import httpx 8 | from core.utils.logger import get_logger 9 | from core.llm.llm_client import LLMClient 10 | 11 | _log = get_logger() 12 | 13 | 14 | class AliyunClient(LLMClient): 15 | """ 16 | 阿里云通义千问 API 客户端,实现了 LLMClient 接口。 17 | 18 | 支持模型: 19 | - qwen-turbo 20 | - qwen-plus 21 | - 剩余请自己测试 22 | """ 23 | 24 | async def on_message(self, message, reply_message): 25 | pass 26 | 27 | def __init__(self, aliyun_secret, aliyun_model, aliyun_api_url): 28 | self.aliyun_secret = aliyun_secret 29 | self.aliyun_model = aliyun_model 30 | self.aliyun_api_url = aliyun_api_url 31 | 32 | # 初始化 last_request_time 和 last_request_content 33 | self.last_request_time = 0 34 | self.last_request_content = None 35 | 36 | async def get_response(self, context, user_input, system_prompt): 37 | """ 38 | 根据给定的上下文和用户输入,从 阿里云通义千问 模型获取回复 39 | 40 | 参数: 41 | context (list): 对话上下文,包含之前的对话内容 42 | user_input (str): 用户的输入内容 43 | system_prompt (str): 系统提示 44 | 45 | 返回: 46 | str: 阿里云通义千问 模型生成的回复内容 47 | 48 | 异常: 49 | httpx.HTTPStatusError: 当请求 阿里云通义千问 API 出现问题时引发 50 | """ 51 | # 检查是否为重复请求 52 | if time.time() - self.last_request_time < 0.6 and user_input == self.last_request_content and "" not in user_input: 53 | _log.warning(f"Duplicate request detected and ignored: {user_input}") 54 | return None 55 | 56 | payload = { 57 | "model": self.aliyun_model, 58 | "temperature": 0.85, 59 | "top_p": 1, 60 | "presence_penalty": 1, 61 | "max_tokens": 3450, 62 | "messages": [ 63 | {"role": "system", "content": system_prompt} 64 | ] + context + [ 65 | {"role": "user", "content": user_input} 66 | ] 67 | } 68 | headers = { 69 | "Content-Type": "application/json", 70 | "Authorization": f"Bearer {self.aliyun_secret}" 71 | } 72 | 73 | # 记录请求的payload 74 | _log.debug(f"Request payload: {payload}") 75 | 76 | try: 77 | async with httpx.AsyncClient() as client: 78 | response = await client.post(self.aliyun_api_url, headers=headers, json=payload) 79 | response.raise_for_status() 80 | response_data = response.json() 81 | 82 | # 记录完整的响应数据 83 | _log.debug(f"Response data: {response_data}") 84 | 85 | reply = response_data['choices'][0]['message']['content'] if 'choices' in response_data and \ 86 | response_data['choices'][0]['message'][ 87 | 'content'] else None 88 | 89 | # 更新 last_request_time 和 last_request_content 90 | self.last_request_time = time.time() 91 | self.last_request_content = user_input 92 | 93 | if reply is None: 94 | _log.warning(f"通义千问 response is empty for user input: {user_input}.") 95 | else: 96 | # 记录通义的回复内容 97 | _log.info(f"通义千问 response: {reply}") 98 | 99 | return reply 100 | except httpx.HTTPStatusError as e: 101 | _log.error(f"Error requesting from 阿里云通义千问 API: {e}", exc_info=True) 102 | return "子网故障,过来楼下检查一下/。" -------------------------------------------------------------------------------- /core/llm/plugins/azure_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | [官方] Azure OpenAI API 客户端,实现了 LLMClient 接口。 3 | 提供商: Microsoft Azure(https://azure.microsoft.com/zh-cn/) 4 | 文档: https://learn.microsoft.com/en-us/azure/cognitive-services/openai/ 5 | """ 6 | import time 7 | import httpx 8 | from core.utils.logger import get_logger 9 | from core.llm.llm_client import LLMClient 10 | 11 | _log = get_logger() 12 | 13 | 14 | class AzureClient(LLMClient): 15 | """ 16 | Azure API 客户端,实现了 LLMClient 接口。 17 | 18 | 支持模型: 19 | - text-davinci-003 20 | - text-davinci-002 21 | - text-curie-001 22 | - text-babbage-001 23 | - text-ada-001 24 | - code-davinci-002 25 | - code-cushman-001 26 | """ 27 | 28 | async def on_message(self, message, reply_message): 29 | pass 30 | 31 | def __init__(self, azure_secret, azure_model, azure_api_url): 32 | self.azure_secret = azure_secret 33 | self.azure_model = azure_model 34 | self.azure_api_url = azure_api_url 35 | 36 | # 初始化 last_request_time 和 last_request_content 37 | self.last_request_time = 0 38 | self.last_request_content = None 39 | 40 | async def get_response(self, context, user_input, system_prompt): 41 | """ 42 | 根据给定的上下文和用户输入,从 Azure 模型获取回复 43 | 44 | 参数: 45 | context (list): 对话上下文,包含之前的对话内容 46 | user_input (str): 用户的输入内容 47 | system_prompt (str): 系统提示 48 | 49 | 返回: 50 | str: Azure 模型生成的回复内容 51 | 52 | 异常: 53 | httpx.HTTPStatusError: 当请求 Azure API 出现问题时引发 54 | """ 55 | # 检查是否为重复请求 56 | if time.time() - self.last_request_time < 0.6 and user_input == self.last_request_content and "" not in user_input: 57 | _log.warning(f"Duplicate request detected and ignored: {user_input}") 58 | return None 59 | 60 | payload = { 61 | "model": self.azure_model, 62 | "temperature": 0.85, 63 | "top_p": 1, 64 | "presence_penalty": 1, 65 | "max_tokens": 3450, 66 | "messages": [ 67 | {"role": "system", "content": system_prompt} 68 | ] + context + [ 69 | {"role": "user", "content": user_input} 70 | ] 71 | } 72 | headers = { 73 | "Content-Type": "application/json", 74 | "Authorization": f"Bearer {self.azure_secret}" 75 | } 76 | 77 | # 记录请求的payload 78 | _log.debug(f"Request payload: {payload}") 79 | 80 | try: 81 | async with httpx.AsyncClient() as client: 82 | response = await client.post(self.azure_api_url, headers=headers, json=payload) 83 | response.raise_for_status() 84 | response_data = response.json() 85 | 86 | # 记录完整的响应数据 87 | _log.debug(f"Response data: {response_data}") 88 | 89 | reply = response_data['choices'][0]['message']['content'] if 'choices' in response_data and \ 90 | response_data['choices'][0]['message'][ 91 | 'content'] else None 92 | 93 | # 更新 last_request_time 和 last_request_content 94 | self.last_request_time = time.time() 95 | self.last_request_content = user_input 96 | 97 | if reply is None: 98 | _log.warning(f"Azure response is empty for user input: {user_input}.") 99 | else: 100 | # 记录 Azure 的回复内容 101 | _log.info(f"Azure response: {reply}") 102 | 103 | return reply 104 | except httpx.HTTPStatusError as e: 105 | _log.error(f"Error requesting from Azure API: {e}", exc_info=True) 106 | return "子网故障,过来楼下检查一下/。" -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/utils/utils.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.2.0 (Pre_827001) 8 | 9 | utils.py - 工具函数模块 10 | """ 11 | 12 | import re 13 | import platform 14 | from typing import Optional, Tuple 15 | # logger.py模块 - <用于记录日志> 16 | from core.utils.logger import get_logger 17 | 18 | _log = get_logger() 19 | 20 | 21 | def extract_memory_content(message: object) -> Optional[str]: 22 | """ 23 | 从消息中提取记忆内容。支持字符串和字节类型,并尝试转换其他类型到字符串。 24 | 25 | 参数: 26 | message (object): 需要检查的消息内容,可以是任何类型。 27 | 28 | 返回: 29 | Optional[str]: 如果找到匹配的记忆内容则返回该内容的字符串形式,否则返回None。 30 | 31 | 异常处理: 32 | 在尝试转换或正则表达式匹配时发生的任何异常都将被捕获,并返回None,同时记录警告。 33 | """ 34 | try: 35 | # 尝试将message转换为字符串 36 | if not isinstance(message, (str, bytes)): 37 | message_str = str(message) 38 | elif isinstance(message, bytes): 39 | message_str = message.decode('utf-8', errors='replace') # 用replace处理解码错误 40 | else: 41 | message_str = message 42 | 43 | # 使用正则表达式查找记忆内容 44 | match = re.search(r'(.*?)', message_str, re.DOTALL) 45 | if match: 46 | # 使用strip()来清除前后空白 47 | return match.group(1).strip() 48 | except (TypeError, ValueError, UnicodeDecodeError, AttributeError) as e: 49 | # 捕获可能出现的任何类型错误 50 | _log.warning(f" 🚨无法提取记忆内容: {e}") 51 | _log.debug(f" ↳ 输入内容: {message}") 52 | return None 53 | 54 | 55 | def load_system_prompt(file_path): 56 | with open(file_path, "r", encoding="utf-8") as f: 57 | return f.read() 58 | 59 | 60 | def calculate_token_count(messages: list) -> int: 61 | """ 62 | 计算消息列表中的token数量,用于确保消息上下文不会超出LLM的token限制。 63 | 64 | 参数: 65 | messages (list): 包含消息字典的列表,每个字典包含角色和内容。 66 | 67 | 返回: 68 | int: 消息列表中的总token数量。 69 | """ 70 | total_tokens = 0 71 | 72 | # 逐条消息计算token数量 73 | for message in messages: 74 | content = message.get('content', '') 75 | tokens = tokenize(content) 76 | total_tokens += len(tokens) 77 | 78 | return total_tokens 79 | 80 | 81 | def tokenize(text: str) -> list: 82 | """ 83 | 将给定的文本拆分为token列表。 84 | 85 | 参数: 86 | text (str): 需要token化的文本。 87 | 88 | 返回: 89 | list: 包含文本中token的列表。 90 | """ 91 | # 使用简单的正则表达式模拟GPT的token化规则 92 | tokens = re.findall(r'\w+|[^\w\s]', text, re.UNICODE) 93 | return tokens 94 | 95 | 96 | def detect_os_and_version() -> Tuple[Optional[str], Optional[str]]: 97 | """ 98 | 检测当前用户的操作系统和版本。 99 | 100 | 返回: 101 | Tuple[Optional[str], Optional[str]]: 一个元组,包含操作系统名称和版本信息。 102 | 如果检测失败,返回 (None, None)。 103 | """ 104 | try: 105 | os_name = platform.system() 106 | os_version = None 107 | 108 | if os_name == "Windows": 109 | os_version = platform.release() 110 | elif os_name == "Linux": 111 | try: 112 | # 尝试获取更详细的Linux版本信息 113 | with open("/etc/os-release", "r") as f: 114 | release_info = f.read() 115 | match = re.search(r'PRETTY_NAME="([^"]+)"', release_info) 116 | if match: 117 | os_version = match.group(1) 118 | else: 119 | os_version = platform.version() # 备用方法获取Linux内核版本 120 | except FileNotFoundError: 121 | os_version = platform.version() 122 | elif os_name == "Darwin": 123 | os_version = platform.mac_ver()[0] # 获取macOS版本 124 | else: 125 | _log.warning(f" 未识别的操作系统: {os_name}") 126 | 127 | return os_name, os_version 128 | 129 | except Exception as e: 130 | _log.error(f" 检测操作系统和版本时出错: {e}") 131 | return None, None 132 | -------------------------------------------------------------------------------- /core/utils/version_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/utils/version_utils.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | version_utils.py 包含用于解析和比较版本号的工具函数。 10 | """ 11 | 12 | import re 13 | from typing import Dict, Optional, Tuple 14 | 15 | VERSION_PATTERN = r'^v?(\d+)\.(\d+)\.(\d+)(?:[-_ ]?([a-zA-Z]+)[-_ ]?(\d+))?(?:\s*\(([a-zA-Z]+)[-_ ]?(\d+)\))?$' 16 | 17 | 18 | def parse_version(version: str) -> Optional[Dict[str, any]]: 19 | """ 20 | 解析版本字符串,返回结构化的版本信息。 21 | 22 | 参数: 23 | version (str): 版本字符串 24 | 25 | 返回: 26 | Optional[Dict[str, any]]: 包含版本信息的字典,如果解析失败则返回 None 27 | """ 28 | match = re.match(VERSION_PATTERN, version) 29 | if not match: 30 | return None 31 | 32 | major, minor, patch, pre_release_label, pre_release_num, stability, build_num = match.groups() 33 | 34 | return { 35 | "major": int(major), 36 | "minor": int(minor), 37 | "patch": int(patch), 38 | "stability": stability.lower() if stability else (pre_release_label.lower() if pre_release_label else "stable"), 39 | "build_num": int(build_num or pre_release_num or 0) 40 | } 41 | 42 | 43 | def compare_stability(a: str, b: str) -> int: 44 | """ 45 | 比较两个稳定性标签。 46 | 47 | 参数: 48 | a, b (str): 稳定性标签 49 | 50 | 返回: 51 | int: 如果 a > b 返回 1,如果 a < b 返回 -1,如果 a == b 返回 0 52 | """ 53 | stability_order = {'alpha': 0, 'beta': 1, 'pre': 2, 'stable': 3} 54 | a_order = stability_order.get(a, -1) 55 | b_order = stability_order.get(b, -1) 56 | 57 | if a_order > b_order: 58 | return 1 59 | elif a_order < b_order: 60 | return -1 61 | else: 62 | return 0 63 | 64 | 65 | def compare_versions(v1: Dict[str, any], v2: Dict[str, any]) -> int: 66 | """ 67 | 比较两个版本。 68 | 69 | 参数: 70 | v1, v2 (Dict[str, any]): 通过 parse_version 解析的版本信息 71 | 72 | 返回: 73 | int: 如果 v1 > v2 返回 1,如果 v1 < v2 返回 -1,如果 v1 == v2 返回 0 74 | """ 75 | for key in ['major', 'minor', 'patch']: 76 | if v1[key] > v2[key]: 77 | return 1 78 | if v1[key] < v2[key]: 79 | return -1 80 | 81 | stability_comparison = compare_stability(v1['stability'], v2['stability']) 82 | if stability_comparison != 0: 83 | return stability_comparison 84 | 85 | if v1['build_num'] > v2['build_num']: 86 | return 1 87 | elif v1['build_num'] < v2['build_num']: 88 | return -1 89 | else: 90 | return 0 91 | 92 | 93 | def is_newer_version(current: str, latest: str) -> Tuple[bool, str]: 94 | """ 95 | 检查最新版本是否比当前版本更新。 96 | 97 | 参数: 98 | current (str): 当前版本字符串 99 | latest (str): 最新版本字符串 100 | 101 | 返回: 102 | Tuple[bool, str]: (是否需要更新, 详细信息) 103 | """ 104 | current_parsed = parse_version(current) 105 | latest_parsed = parse_version(latest) 106 | 107 | if not current_parsed or not latest_parsed: 108 | return False, f"无法解析版本号: 当前版本 '{current}', 最新版本 '{latest}'" 109 | 110 | comparison = compare_versions(latest_parsed, current_parsed) 111 | 112 | if comparison > 0: 113 | update_type = "强烈建议" if latest_parsed['stability'] in ['stable', 'pre'] else "建议" 114 | return True, f"{update_type}更新: 新版本 {latest} 可用" 115 | elif comparison < 0: 116 | return False, f"当前版本 {current} 已经是最新" 117 | else: 118 | stability_comparison = compare_stability(latest_parsed['stability'], current_parsed['stability']) 119 | if stability_comparison > 0: 120 | update_type = "强烈建议" if latest_parsed['stability'] in ['stable', 'pre'] else "建议" 121 | return True, f"{update_type}更新: 新的稳定版本 {latest} 可用" 122 | elif stability_comparison < 0: 123 | return False, f"当前版本 {current} 比服务器版本 {latest} 更稳定,无需更新" 124 | else: 125 | return False, "已是最新,无需更新" -------------------------------------------------------------------------------- /core/bot/memory_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/bot/memory_utils.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | memory_utils.py - 负责处理记忆相关的功能 10 | """ 11 | from core.utils.utils import extract_memory_content 12 | from core.utils.logger import get_logger 13 | 14 | _log = get_logger() 15 | 16 | async def manage_memory_insertion(memory_manager, group_id, cleaned_content, context, user_message): 17 | """ 18 | 检查是否需要插入记忆,并在需要时进行插入。插入位置为用户消息之后。 19 | 20 | 参数: 21 | memory_manager (MemoryManager): 记忆管理器实例 22 | group_id (str): 群组的唯一标识符 23 | cleaned_content (str): 用户发送的消息内容 24 | context (list): 当前上下文消息列表 25 | user_message (str): 用户的原始消息 26 | 27 | 返回: 28 | list: 可能更新后的上下文消息列表 29 | """ 30 | memory_to_insert = await memory_manager.retrieve_memory(group_id, cleaned_content) 31 | if memory_to_insert: 32 | memory_content = memory_to_insert['content'] 33 | memory_insertion = f"{user_message}\n---\n<在数据库查找到的你的长期记忆,请谨慎使用:{memory_content}>" 34 | 35 | # 查找最后一个用户消息的位置,将记忆插入到其后 36 | inserted = False 37 | for i in range(len(context) - 1, -1, -1): 38 | if context[i]['role'] == 'user' and context[i]['content'] == user_message: 39 | context.insert(i + 1, {"role": "user", "content": memory_insertion}) 40 | inserted = True 41 | _log.info(">>> 记忆已插入到用户消息之后") 42 | break 43 | 44 | # 如果没有找到匹配的用户消息,默认追加到上下文末尾 45 | if not inserted: 46 | context.append({"role": "user", "content": memory_insertion}) 47 | _log.info(">>> 未找到匹配的用户消息,记忆已追加到上下文末尾") 48 | 49 | else: 50 | _log.info(">>> 没有找到需要插入的记忆内容") 51 | 52 | return context 53 | 54 | 55 | 56 | 57 | def is_critical_context_present(context, content): 58 | """ 59 | 检查上下文中是否包含与当前消息相关的关键信息。 60 | 如果上下文中已经包含相关信息,则返回 True,否则返回 False。 61 | """ 62 | # 根据关键字或语义分析判断是否存在关键上下文信息 63 | for msg in context: 64 | if content in msg['content']: 65 | return True 66 | return False 67 | 68 | async def handle_long_term_memory(memory_manager, group_id, cleaned_content, formatted_message, context, client): 69 | """ 70 | 处理长记忆的插入和更新。 71 | 72 | 参数: 73 | memory_manager (MemoryManager): 记忆管理器实例 74 | group_id (str): 群组的唯一标识符 75 | cleaned_content (str): 用户发送的消息内容 76 | formatted_message (str): 格式化后的用户消息 77 | context (list): 当前上下文消息列表 78 | client (BotClient): 机器人客户端实例 79 | 80 | 返回: 81 | str: 更新后的回复内容 82 | """ 83 | _log.debug(">>> 检测到 标记,正在检索长记忆...") 84 | 85 | long_term_memory = await memory_manager.retrieve_memory(group_id, cleaned_content) 86 | if long_term_memory: 87 | user_input_with_memory = f"{formatted_message}\n{long_term_memory['content']}" 88 | reply_content = await client.get_gpt_response(context, user_input_with_memory) 89 | return reply_content 90 | else: 91 | _log.warning(">>> 未能检索到相关的长记忆,继续处理当前对话。") 92 | return None 93 | 94 | async def process_reply_content(memory_manager, group_id, message, reply_content): 95 | """ 96 | 处理并存储回复内容中的记忆,并清除标记。 97 | 98 | 参数: 99 | memory_manager (MemoryManager): 记忆管理器实例 100 | group_id (str): 群组的唯一标识符 101 | message (GroupMessage): 收到的群组消息 102 | reply_content (str): LLM生成的回复内容 103 | 104 | 返回: 105 | str: 更新后的回复内容 106 | """ 107 | _log.debug(">>> 提取新记忆内容...") 108 | memory_content = extract_memory_content(reply_content) 109 | if memory_content: 110 | _log.debug(f">>> 存储新的记忆内容: {memory_content}") 111 | await memory_manager.store_memory(group_id, message, "assistant", memory_content) 112 | 113 | # 清除回复内容中的标记 114 | reply_content = reply_content.replace(f"{memory_content}", "") 115 | 116 | return reply_content 117 | -------------------------------------------------------------------------------- /core/llm/plugins/tea_client.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import asyncio 3 | from config import TEA_URL, TEA_SECRET, TEA_MODEL, DIMENSION 4 | from core.utils.logger import get_logger 5 | 6 | _log = get_logger() 7 | 8 | 9 | class TeaClient: 10 | """ 11 | Tea API 客户端,用于与LLM模型交互以生成向量。 12 | """ 13 | 14 | def __init__(self): 15 | self.tea_url = TEA_URL 16 | self.tea_secret = TEA_SECRET 17 | self.tea_model = TEA_MODEL 18 | self.dimension = DIMENSION 19 | 20 | async def generate_vector(self, input_text, max_retries=3, timeout=130): 21 | """ 22 | 向Tea API发送请求以生成向量。 23 | 24 | 参数: 25 | input_text (str): 要生成向量的文本内容。 26 | max_retries (int): 请求失败时的最大重试次数。 27 | timeout (int): 每次请求的超时时间,单位为秒。 28 | 29 | 返回: 30 | list: 生成的向量(浮点数列表),或None如果请求失败。 31 | """ 32 | system_prompt = ( 33 | f"你是一个向量提取模型。给定输入文本后," 34 | f"你必须只输出该文本的向量表示,格式为用逗号分隔的浮点数列表。" 35 | f"你必须完全确保向量的长度为 {self.dimension}。(维度:{self.dimension})" 36 | f"不要输出任何其他文本或信息,也不要在浮点数之间包含多余的空格。" 37 | ) 38 | 39 | payload = { 40 | "model": self.tea_model, 41 | "temperature": 0.45, 42 | "max_tokens": 4000, 43 | "messages": [ 44 | {"role": "system", "content": system_prompt}, 45 | {"role": "user", "content": input_text} 46 | ] 47 | } 48 | headers = { 49 | "Content-Type": "application/json", 50 | "Authorization": f"Bearer {self.tea_secret}" 51 | } 52 | 53 | _log.debug(f"Tea API Request payload: {payload}") 54 | 55 | for attempt in range(max_retries): 56 | try: 57 | async with aiohttp.ClientSession() as session: 58 | async with session.post(self.tea_url, headers=headers, json=payload, timeout=timeout) as response: 59 | response.raise_for_status() 60 | response_data = await response.json() 61 | 62 | _log.debug(f"Tea API Response data: {response_data}") 63 | 64 | reply = response_data['choices'][0]['message']['content'].strip() if 'choices' in response_data and \ 65 | 'message' in \ 66 | response_data['choices'][ 67 | 0] else None 68 | 69 | _log.debug(f"Tea API Response reply: {reply}") 70 | 71 | if reply: 72 | try: 73 | vector_output = [float(x) for x in reply.split(',')] 74 | if len(vector_output) == self.dimension: 75 | _log.info(f"Generated vector: {vector_output}") 76 | return vector_output 77 | else: 78 | raise ValueError(f"向量长度不匹配,期望{self.dimension},实际得到{len(vector_output)}") 79 | except ValueError as e: 80 | _log.error(f"Failed to parse vector: {reply}. Error: {e}") 81 | await asyncio.sleep(1) 82 | else: 83 | _log.warning(f"Tea API response did not contain a valid vector for input_text: {input_text}.") 84 | return None 85 | 86 | except aiohttp.ClientError as e: 87 | _log.error(f"Error requesting from Tea API: {e}", exc_info=True) 88 | await asyncio.sleep(1) 89 | 90 | return None 91 | 92 | 93 | # 主测试 94 | if __name__ == "__main__": 95 | async def main(): 96 | # 示例用法 97 | tea_client = TeaClient() 98 | test_content = "你好,世界!" 99 | # 生成向量 100 | vector_result = await tea_client.generate_vector(test_content) 101 | # 打印向量 102 | print(vector_result) 103 | 104 | 105 | # 运行测试 106 | asyncio.run(main()) 107 | -------------------------------------------------------------------------------- /core/llm/llm_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/llm/llm_factory.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | llm_factory.py - LLM 工厂类,用于根据配置文件创建相应的 LLM 客户端。 10 | """ 11 | 12 | from core.llm.llm_client import LLMClient 13 | from core.llm.plugins.aliyun_client import AliyunClient 14 | from core.llm.plugins.anthropic_client import AnthropicClient 15 | from core.llm.plugins.azure_client import AzureClient 16 | from core.llm.plugins.chatglm_client import ChatGLMClient 17 | from core.llm.plugins.google_client import GoogleClient 18 | from core.llm.plugins.openai_client import OpenAIClient 19 | from config import test_config 20 | from core.utils.logger import get_logger 21 | 22 | _log = get_logger() 23 | 24 | 25 | class LLMFactory: 26 | """ 27 | LLM 工厂类,用于根据配置文件创建相应的 LLM 客户端。 28 | """ 29 | 30 | def create_llm_client(self) -> LLMClient: 31 | """ 32 | 根据配置文件创建 LLM 客户端。 33 | 34 | Returns: 35 | LLMClient: LLM 客户端实例。 36 | """ 37 | 38 | # 如果用户没有设置,警告一下 39 | if not test_config.get("llm_provider"): 40 | _log.warning("🔥你没有在配置文件中设置 LLM 提供商,已使用默认 OpenAI。🔥") 41 | 42 | # 打印当前LLM 43 | _log.info(f"🔥当前LLM厂商: {test_config.get('llm_provider', 'openai')}🔥") 44 | 45 | llm_provider = test_config.get("llm_provider", "openai").lower() # 转换为小写字母 46 | 47 | # 定义一部字典来存储提供商和对应的配置项 48 | provider_configs = { 49 | "openai": ["openai_secret", "openai_model", "openai_api_url"], 50 | "azure": ["azure_secret", "azure_model", "azure_api_url"], 51 | "google": ["google_api_key", "google_model", "google_api_url"], 52 | "anthropic": ["anthropic_secret", "anthropic_model", "anthropic_api_url"], 53 | "aliyun": ["aliyun_secret", "aliyun_model", "aliyun_api_url"], 54 | "chatglm": ["chatglm_secret", "chatglm_model", "chatglm_api_url"], 55 | "free_chatgpt_api": ["free_chatgpt_api_secret", "free_chatgpt_api_model", "free_chatgpt_api_url"] 56 | } 57 | 58 | # 检查配置项是否存在 59 | if llm_provider in provider_configs: 60 | required_configs = provider_configs[llm_provider] 61 | missing_configs = [config for config in required_configs if not test_config.get(config)] 62 | if missing_configs: 63 | _log.error(f"🔥 请在配置文件中填写 {llm_provider} 的相关配置项:{'、'.join(missing_configs)} 🔥") 64 | # 中断程序执行 65 | raise SystemExit(1) 66 | 67 | llm_provider = test_config.get("llm_provider", "openai") 68 | if llm_provider == "openai": 69 | return OpenAIClient(test_config.get("openai_secret"), 70 | test_config.get("openai_model"), 71 | test_config.get("openai_api_url")) 72 | elif llm_provider == "azure": 73 | return AzureClient(test_config.get("azure_secret"), 74 | test_config.get("azure_model"), 75 | test_config.get("azure_api_url")) 76 | elif llm_provider == "google": 77 | return GoogleClient(test_config.get("google_api_key"), 78 | test_config.get("google_model"), 79 | test_config.get("google_api_url")) 80 | elif llm_provider == "anthropic": 81 | return AnthropicClient(test_config.get("anthropic_secret"), 82 | test_config.get("anthropic_model"), 83 | test_config.get("anthropic_api_url")) 84 | elif llm_provider == "aliyun": 85 | return AliyunClient(test_config.get("aliyun_secret"), 86 | test_config.get("aliyun_model"), 87 | test_config.get("aliyun_api_url")) 88 | elif llm_provider == "chatglm": 89 | return ChatGLMClient(test_config.get("chatglm_secret"), 90 | test_config.get("chatglm_model"), 91 | test_config.get("chatglm_api_url")) 92 | else: 93 | raise ValueError(f"Unsupported LLM provider: {llm_provider}") 94 | -------------------------------------------------------------------------------- /core/llm/plugins/inject_memory_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/llm/plugins/inject_memory_client.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | inject_memory_client.py - 实现与 LLM 交互的客户端,用于处理记忆提取和注入任务 10 | """ 11 | 12 | import httpx 13 | from core.utils.logger import get_logger 14 | 15 | _log = get_logger() 16 | 17 | 18 | class InjectMemoryClient: 19 | """ 20 | 用于与 LLM 交互的客户端类,专注于记忆提取和注入任务 21 | """ 22 | 23 | def __init__(self, openai_secret, openai_model, openai_api_url): 24 | self.openai_secret = openai_secret 25 | self.openai_model = openai_model 26 | self.openai_api_url = openai_api_url 27 | self.last_request_time = 0 28 | self.last_request_content = None 29 | 30 | async def get_keywords_for_memory_retrieval(self, prompt): 31 | """ 32 | 从 LLM 获取用于记忆查询的关键词 33 | 34 | 参数: 35 | prompt (str): 需要提取关键词的提示内容 36 | 37 | 返回: 38 | str: LLM 生成的关键词 39 | """ 40 | payload = { 41 | "openai_model": self.openai_model, 42 | "temperature": 0.7, 43 | "max_tokens": 100, 44 | "messages": [{"role": "system", "content": "请提取此Prompt中的关键词用于记忆查询,请注意语义联想搜索。"}, 45 | {"role": "user", "content": prompt}] 46 | } 47 | headers = { 48 | "Content-Type": "application/json", 49 | "Authorization": f"Bearer {self.openai_secret}" 50 | } 51 | 52 | try: 53 | async with httpx.AsyncClient() as client: 54 | response = await client.post(self.openai_api_url, headers=headers, json=payload) 55 | response.raise_for_status() 56 | response_data = response.json() 57 | 58 | keywords = response_data['choices'][0]['message']['content'] if 'choices' in response_data and \ 59 | response_data['choices'][0]['message'][ 60 | 'content'] else None 61 | 62 | if keywords is None: 63 | _log.warning(f"LLM response is empty for prompt: {prompt}.") 64 | else: 65 | _log.info(f"LLM provided keywords: {keywords}") 66 | 67 | return keywords 68 | except httpx.HTTPStatusError as e: 69 | _log.error(f"Error requesting keywords from LLM API: {e}", exc_info=True) 70 | return "" 71 | 72 | async def get_memory_summary(self, context): 73 | """ 74 | 从 LLM 获取当前对话的摘要 75 | 76 | 参数: 77 | context (list): 包含之前的对话内容 78 | 79 | 返回: 80 | str: LLM 生成的对话摘要 81 | """ 82 | payload = { 83 | "openai_model": self.openai_model, 84 | "temperature": 0.7, 85 | "max_tokens": 200, 86 | "messages": context + [{"role": "system", "content": "请总结以上对话内容。"}] 87 | } 88 | headers = { 89 | "Content-Type": "application/json", 90 | "Authorization": f"Bearer {self.openai_secret}" 91 | } 92 | 93 | try: 94 | async with httpx.AsyncClient() as client: 95 | response = await client.post(self.openai_api_url, headers=headers, json=payload) 96 | response.raise_for_status() 97 | response_data = response.json() 98 | 99 | summary = response_data['choices'][0]['message']['content'] if 'choices' in response_data and \ 100 | response_data['choices'][0]['message'][ 101 | 'content'] else None 102 | 103 | if summary is None: 104 | _log.warning(f"LLM response is empty for context summary.") 105 | else: 106 | _log.info(f"LLM provided summary: {summary}") 107 | 108 | return summary 109 | except httpx.HTTPStatusError as e: 110 | _log.error(f"Error requesting summary from LLM API: {e}", exc_info=True) 111 | return "" 112 | -------------------------------------------------------------------------------- /core/llm/plugins/chatglm_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | [官方] ChatGLM API 客户端,实现了 LLMClient 接口。 3 | 提供商: 智谱AI (https://www.zhipuai.cn/) 4 | 文档: https://open.bigmodel.cn/docs/api 5 | 6 | 请注意:如需使用,请先取消 第10行 和 49-54行 的注释,并安装jwt包。 7 | """ 8 | import time 9 | import httpx 10 | # import jwt 11 | from core.utils.logger import get_logger 12 | from core.llm.llm_client import LLMClient 13 | 14 | _log = get_logger() 15 | 16 | 17 | class ChatGLMClient(LLMClient): 18 | """ 19 | ChatGLM API 客户端,实现了 LLMClient 接口。 20 | 21 | 支持模型: 22 | - glm-4 23 | """ 24 | 25 | async def on_message(self, message, reply_message): 26 | pass 27 | 28 | def __init__(self, chatglm_secret, chatglm_model, chatglm_api_url): 29 | self.chatglm_secret = chatglm_secret 30 | self.chatglm_model = chatglm_model 31 | self.chatglm_api_url = chatglm_api_url 32 | 33 | # 初始化 last_request_time 和 last_request_content 34 | self.last_request_time = 0 35 | self.last_request_content = None 36 | 37 | def generate_token(self, exp_seconds: int = 3600): 38 | """生成JWT Token""" 39 | try: 40 | id, secret = self.chatglm_secret.split(".") 41 | except Exception as e: 42 | raise Exception("Invalid apikey", e) 43 | 44 | payload = { 45 | "api_key": id, 46 | "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, 47 | "timestamp": int(round(time.time() * 1000)), 48 | } 49 | # return jwt.encode( 50 | # payload, 51 | # secret, 52 | # algorithm="HS256", 53 | # headers={"alg": "HS256", "sign_type": "SIGN"}, 54 | # ) 55 | 56 | async def get_response(self, context, user_input, system_prompt): 57 | """ 58 | 根据给定的上下文和用户输入,从 ChatGLM 模型获取回复 59 | 60 | 参数: 61 | context (list): 对话上下文,包含之前的对话内容 62 | user_input (str): 用户的输入内容 63 | system_prompt (str): 系统提示 64 | 65 | 返回: 66 | str: ChatGLM 模型生成的回复内容 67 | 68 | 异常: 69 | httpx.HTTPStatusError: 当请求 ChatGLM API 出现问题时引发 70 | """ 71 | # 检查是否为重复请求 72 | if time.time() - self.last_request_time < 0.6 and user_input == self.last_request_content and "" not in user_input: 73 | _log.warning(f"Duplicate request detected and ignored: {user_input}") 74 | return None 75 | 76 | payload = { 77 | "model": self.chatglm_model, 78 | "messages": [ 79 | {"role": "system", "content": system_prompt} 80 | ] + context + [ 81 | {"role": "user", "content": user_input} 82 | ] 83 | } 84 | headers = { 85 | "Content-Type": "application/json", 86 | "Authorization": f"Bearer {self.generate_token()}" 87 | } 88 | 89 | # 记录请求的payload 90 | _log.debug(f"Request payload: {payload}") 91 | 92 | try: 93 | async with httpx.AsyncClient() as client: 94 | response = await client.post(self.chatglm_api_url, headers=headers, json=payload) 95 | response.raise_for_status() 96 | response_data = response.json() 97 | 98 | # 记录完整的响应数据 99 | _log.debug(f"Response data: {response_data}") 100 | 101 | reply = response_data['choices'][0]['message']['content'] if 'choices' in response_data and \ 102 | response_data['choices'][0]['message'][ 103 | 'content'] else None 104 | 105 | # 更新 last_request_time 和 last_request_content 106 | self.last_request_time = time.time() 107 | self.last_request_content = user_input 108 | 109 | if reply is None: 110 | _log.warning(f"ChatGLM response is empty for user input: {user_input}.") 111 | else: 112 | # 记录 ChatGLM 的回复内容 113 | _log.info(f"ChatGLM response: {reply}") 114 | 115 | return reply 116 | except httpx.HTTPStatusError as e: 117 | _log.error(f"Error requesting from ChatGLM API: {e}", exc_info=True) 118 | return "子网故障,过来楼下检查一下/。" -------------------------------------------------------------------------------- /core/keep_alive.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/keep_alive.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | keep_alive.py 包含 Keep-Alive 机制的实现,用于监控 API 的连接状态。 10 | """ 11 | 12 | import asyncio 13 | 14 | import aiohttp 15 | 16 | from config import OPENAI_SECRET, OPENAI_API_URL, OPENAI_KEEP_ALIVE, UPDATE_KEEP_ALIVE 17 | from core.update_manager import handle_updates 18 | from core.utils.logger import get_logger 19 | 20 | _log = get_logger() 21 | 22 | 23 | GITHUB_REPO = "shuakami/amyalmond_bot" 24 | 25 | PRIMARY_API_URL = "https://bot.luoxiaohei.cn/api/github-status" 26 | FALLBACK_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/releases/latest" 27 | 28 | 29 | async def fetch_version_info(url): 30 | """ 31 | 从指定的 URL 获取版本信息。 32 | 33 | 参数: 34 | url (str): 要获取版本信息的 API 地址。 35 | 36 | 返回: 37 | dict: 如果成功返回包含版本信息的字典,否则返回 None。 38 | """ 39 | async with aiohttp.ClientSession() as session: 40 | try: 41 | async with session.get(url) as response: 42 | if response.status == 200: 43 | data = await response.json() 44 | return data 45 | else: 46 | _log.warning(f" 无法从 {url} 获取最新版本信息,状态码: {response.status}") 47 | except aiohttp.ClientError as e: 48 | _log.error(f" 获取最新版本时发生错误: {e}") 49 | except Exception as e: 50 | _log.error(f" 获取版本信息时发生未知错误: {e}") 51 | return None 52 | 53 | 54 | async def get_latest_version(): 55 | """ 56 | 获取最新版本信息,首先尝试从主 API 获取,如果失败则切换到备用 API。 57 | 58 | 返回: 59 | list: 包含最新版本信息的列表,如果获取失败返回空列表。 60 | """ 61 | version_info = await fetch_version_info(PRIMARY_API_URL) 62 | 63 | if not version_info: 64 | _log.warning(f" 主API {PRIMARY_API_URL} 不可用,切换到备用API") 65 | version_info = await fetch_version_info(FALLBACK_API_URL) 66 | 67 | if isinstance(version_info, dict): 68 | # 如果返回的是单个版本信息,将其转换为列表 69 | return [version_info] 70 | elif isinstance(version_info, list): 71 | return version_info 72 | else: 73 | return [] 74 | 75 | 76 | async def check_for_updates(): 77 | """ 78 | 检查是否有新版本可用,并根据版本类型(正式版或开发版)提醒用户更新。 79 | """ 80 | if not UPDATE_KEEP_ALIVE: 81 | _log.warning("更新检查已关闭,建议打开以获取最新功能和修复~") 82 | return 83 | # 调用更新管理器 84 | await handle_updates() 85 | 86 | 87 | async def keep_alive(api_url=OPENAI_API_URL, api_key=OPENAI_SECRET): 88 | """ 89 | 实现 Keep-Alive 机制,用于监控 API 的连接状态。 90 | 91 | 参数: 92 | api_url (str): 要监控的 API 地址。 93 | api_key (str): 用于 API 认证的密钥。 94 | """ 95 | 96 | if not UPDATE_KEEP_ALIVE and not OPENAI_KEEP_ALIVE: 97 | _log.warning(" 您已关闭 更新检查 和 OpenAI API 的 Keep-Alive 功能") 98 | return 99 | if not UPDATE_KEEP_ALIVE: 100 | _log.warning(" 您已关闭 更新检查 的 Keep-Alive 功能,建议打开以保持程序最新") 101 | return 102 | if not OPENAI_KEEP_ALIVE: 103 | _log.warning(" 您已关闭 OpenAI API 的 Keep-Alive 功能,建议打开以保持 API 连接正常") 104 | return 105 | 106 | headers = {"Authorization": f"Bearer {api_key}"} 107 | 108 | # 在启动时检查一次更新 109 | await check_for_updates() 110 | 111 | while True: 112 | try: 113 | async with aiohttp.ClientSession(headers=headers) as session: 114 | async with session.get(api_url) as response: 115 | if response.status == 404: 116 | _log.info("OpenAI API 连接正常") 117 | elif response.status == 401: 118 | _log.warning("OpenAI API 认证失败,请检查 API 密钥是否正确") 119 | else: 120 | _log.warning(f" OpenAI API 连接异常,状态码: {response.status}") 121 | _log.warning(f" ↳ 详细信息: {await response.text()},请检查 API 状态") 122 | except aiohttp.ClientError as e: 123 | _log.error(f"OpenAI API 连接错误: {e}") 124 | except Exception as e: 125 | _log.error(f"OpenAI API 监控出现未知错误: {e}") 126 | 127 | # 每隔3分钟检查一次连接状态 128 | await asyncio.sleep(180) 129 | 130 | 131 | 132 | async def update_check_loop(): 133 | """ 134 | 定期检查更新的循环。 135 | """ 136 | while True: 137 | await check_for_updates() 138 | # 每隔15分钟检查一次更新 139 | await asyncio.sleep(900) 140 | -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import platform 4 | import shutil 5 | import sys 6 | import traceback 7 | import zipfile 8 | from logging.handlers import RotatingFileHandler 9 | from datetime import datetime 10 | import yaml 11 | 12 | # 初始化日志目录 13 | LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "logs") 14 | os.makedirs(LOG_DIR, exist_ok=True) 15 | 16 | # 读取配置文件 17 | CONFIG_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs", "config.yaml") 18 | 19 | if os.path.exists(CONFIG_PATH): 20 | with open(CONFIG_PATH, 'r', encoding='utf-8') as config_file: 21 | config = yaml.safe_load(config_file) 22 | else: 23 | config = {} 24 | 25 | # 获取日志级别和调试模式配置 26 | LOG_LEVEL = config.get('log_level', 'INFO').upper() 27 | DEBUG_MODE = config.get('debug', False) 28 | 29 | # 自定义日志格式 30 | class CustomFormatter(logging.Formatter): 31 | """自定义日志格式化器,应用特殊格式和颜色""" 32 | 33 | FORMATS = { 34 | logging.DEBUG: "[DEBUG] {module}:{lineno} [{funcName}] {message}", 35 | logging.INFO: "[INFO] {module}:{lineno} [{funcName}] {message}", 36 | logging.WARNING: "[WARNING] {module}:{lineno} [{funcName}] {message}", 37 | logging.ERROR: "[ERROR] {module}:{lineno} [{funcName}] {message}", 38 | logging.CRITICAL: "[CRITICAL] {module}:{lineno} [{funcName}] {message}", 39 | } 40 | 41 | def format(self, record): 42 | log_fmt = self.FORMATS.get(record.levelno, self.FORMATS[logging.DEBUG]) 43 | formatter = logging.Formatter(log_fmt, style="{") 44 | return formatter.format(record) 45 | 46 | # 获取logger对象 47 | def get_logger(): 48 | """获取并配置一个日志记录器""" 49 | logger = logging.getLogger("bot_logger") 50 | 51 | # 确保日志配置只进行两次 52 | if not getattr(logger, '_initialized', False): 53 | # 设置日志级别 54 | log_level = getattr(logging, LOG_LEVEL, logging.INFO) 55 | logger.setLevel(log_level) 56 | 57 | # 清除现有处理器 58 | for handler in logger.handlers[:]: 59 | logger.removeHandler(handler) 60 | 61 | # 文件处理器 62 | log_file = os.path.join(LOG_DIR, "bot.log") 63 | file_handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5, encoding='utf-8') 64 | file_handler.setFormatter(CustomFormatter()) 65 | logger.addHandler(file_handler) 66 | 67 | return logger 68 | 69 | # 捕获并处理严重错误 70 | def handle_critical_error(exc_info): 71 | """捕获严重错误,记录并打包相关文件""" 72 | logger = get_logger() 73 | logger.critical("捕获到未处理的异常", exc_info=exc_info) 74 | 75 | error_time = datetime.now().strftime("%Y%m%d_%H%M%S") 76 | error_log_dir = os.path.join(LOG_DIR, f"error_logs_{error_time}") 77 | os.makedirs(error_log_dir, exist_ok=True) 78 | 79 | # 记录错误堆栈信息 80 | error_file = os.path.join(error_log_dir, "error.log") 81 | with open(error_file, 'w', encoding='utf-8') as f: 82 | traceback.print_exception(*exc_info, file=f) 83 | 84 | # 打包日志文件 85 | zip_file = os.path.join(LOG_DIR, f"error_{error_time}.zip") 86 | with zipfile.ZipFile(zip_file, 'w') as z: 87 | for folder_name, _, filenames in os.walk(error_log_dir): 88 | for filename in filenames: 89 | file_path = os.path.join(folder_name, filename) 90 | z.write(file_path, os.path.relpath(file_path, error_log_dir)) 91 | 92 | # 删除原始错误文件 93 | shutil.rmtree(error_log_dir) 94 | logger.info(f"错误日志已打包: {zip_file}") 95 | 96 | # 打开打包zip路径 97 | if platform.system() == "Windows": 98 | os.system(f"start {zip_file}") 99 | elif platform.system() == "Darwin": # macOS 100 | os.system(f"open {zip_file}") 101 | elif platform.system() == "Linux": 102 | os.system(f"xdg-open {zip_file}") 103 | 104 | # 设置全局未处理异常处理器 105 | def setup_global_exception_handler(): 106 | """设置全局未处理异常的处理器""" 107 | 108 | def handle_exception(exc_type, exc_value, exc_traceback): 109 | if not issubclass(exc_type, KeyboardInterrupt): 110 | handle_critical_error((exc_type, exc_value, exc_traceback)) 111 | 112 | sys.excepthook = handle_exception 113 | 114 | # 获取实时日志 115 | def get_latest_logs(): 116 | """获取整个日志文件的最新内容""" 117 | log_file_path = os.path.join(LOG_DIR, "bot.log") 118 | try: 119 | with open(log_file_path, 'r', encoding='utf-8') as f: 120 | return f.readlines() 121 | except FileNotFoundError: 122 | return [] 123 | 124 | # 初始化全局日志记录器 125 | _log = get_logger() 126 | 127 | # 在启动时设置全局异常处理器 128 | setup_global_exception_handler() 129 | -------------------------------------------------------------------------------- /core/plugins/plugin_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Plugins - core/plugins/plugin_manager.py 3 | 插件管理模块 4 | """ 5 | import os 6 | import zipfile 7 | import shutil 8 | from core.plugins.plugins import load_plugins 9 | from core.utils.logger import get_logger 10 | from core.plugins.event_bus import EventBus 11 | import yaml 12 | 13 | logger = get_logger() 14 | 15 | class PluginManager: 16 | def __init__(self, bot_client): 17 | self.bot_client = bot_client 18 | self.plugins = {} 19 | self.load_plugins() 20 | self.event_bus = EventBus() 21 | 22 | def load_plugins(self): 23 | """ 24 | 加载所有插件并存储它们的实例、路径、配置信息和优先级 25 | """ 26 | loaded_plugins, plugin_paths, plugin_configs, plugin_priorities = load_plugins(self.bot_client) 27 | for plugin, path, config, priority in zip(loaded_plugins, plugin_paths, plugin_configs, plugin_priorities): 28 | plugin_name = plugin.name 29 | self.plugins[plugin_name] = { 30 | "instance": plugin, 31 | "path": path, 32 | "config": config, 33 | "priority": priority # 将插件优先级存储 34 | } 35 | logger.info(f"插件已加载,总数: {len(self.plugins)}") 36 | 37 | def register_plugins(self): 38 | """ 39 | 注册启用的插件事件处理方法,并根据插件自身的优先级进行注册。 40 | """ 41 | for plugin_name, plugin_data in self.plugins.items(): 42 | plugin = plugin_data["instance"] 43 | priority = plugin_data.get('priority', 0) # 获取插件优先级,默认为 0 44 | 45 | # 注册 before_llm_message 事件 46 | if hasattr(plugin, 'before_llm_message'): 47 | self.event_bus.subscribe("before_llm_message", plugin.before_llm_message, plugin_name, priority) 48 | logger.info(f"{plugin_name} 订阅了 before_llm_message 事件,优先级为 {priority}") 49 | 50 | # 注册其他插件事件 51 | self.event_bus.subscribe("on_message", plugin.on_message, plugin_name, priority) 52 | self.event_bus.subscribe("on_ready", plugin.on_ready, plugin_name, priority) 53 | self.event_bus.subscribe("before_send_reply", plugin.on_message, plugin_name, priority) 54 | 55 | logger.info(f"插件已注册: {plugin_name}") 56 | 57 | def install_plugin(self, zip_path): 58 | """ 59 | 安装插件,支持上传ZIP文件进行安装 60 | """ 61 | plugins_dir = os.path.join("core", "plugins") 62 | if not zipfile.is_zipfile(zip_path): 63 | raise ValueError("提供的文件不是有效的ZIP文件") 64 | 65 | with zipfile.ZipFile(zip_path, 'r') as zip_ref: 66 | zip_ref.extractall(plugins_dir) 67 | self.load_plugins() # 重新加载插件 68 | 69 | def uninstall_plugin(self, plugin_name): 70 | """ 71 | 卸载指定的插件 72 | """ 73 | if plugin_name in self.plugins: 74 | # 先从内存中删除插件实例 75 | plugin_data = self.plugins.pop(plugin_name, None) 76 | if plugin_data: 77 | plugin_path = plugin_data["path"] 78 | if os.path.isdir(plugin_path): 79 | shutil.rmtree(plugin_path) 80 | logger.info(f"插件已卸载: {plugin_name}") 81 | self.load_plugins() # 重新加载插件 82 | else: 83 | logger.warning(f"插件 {plugin_name} 对应的目录不存在,但已从插件列表中删除") 84 | else: 85 | logger.warning(f"插件 {plugin_name} 不存在,无法获取其路径信息") 86 | else: 87 | logger.warning(f"插件 {plugin_name} 不存在,无法卸载") 88 | 89 | async def handle_event(self, event_type, **kwargs): 90 | """ 91 | 处理特定事件类型 'on_registration'。 92 | 93 | 参数: 94 | event_type (str): 事件类型 95 | kwargs (dict): 事件相关参数 96 | 97 | 返回: 98 | bool: 如果有插件处理了事件,返回 True;否则返回 False 99 | """ 100 | if event_type in self.event_bus.subscribers: 101 | return await self.event_bus.publish(event_type, **kwargs) 102 | return False 103 | 104 | def get_plugin_list(self): 105 | """ 106 | 获取当前所有插件的详细信息,包括配置信息 107 | """ 108 | plugin_list = [] 109 | for name, data in self.plugins.items(): 110 | config = data.get("config", {}) 111 | plugin_list.append({ 112 | "name": name, 113 | "version": config.get("version", "未知"), 114 | "author": config.get("author", "未知"), 115 | "plugin_id": config.get("plugin_id", "无UUID"), 116 | "description": config.get("description", "无描述"), 117 | "dependencies": config.get("dependencies", []) 118 | }) 119 | return plugin_list 120 | 121 | def reload_plugins(self): 122 | """ 123 | 热重载所有插件 124 | """ 125 | self.plugins.clear() # 清除已加载的插件 126 | self.load_plugins() # 重新加载插件 127 | self.register_plugins() # 重新注册插件事件 128 | logger.info("插件已热重载") 129 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - Main.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | Main.py 用于启动 AmyAlmond 机器人,加载配置文件和客户端 10 | """ 11 | 12 | import asyncio 13 | import subprocess 14 | import sys 15 | import os 16 | 17 | from fastapi import FastAPI 18 | from fastapi.middleware.cors import CORSMiddleware 19 | import botpy 20 | from core.api.routes import router as api_router 21 | from core.bot.bot_client import MyClient 22 | from core.utils.logger import get_logger, handle_critical_error 23 | from config import test_config 24 | 25 | logger = get_logger() 26 | 27 | app = FastAPI() 28 | 29 | # 添加 CORS 中间件 30 | app.add_middleware( 31 | CORSMiddleware, 32 | allow_origins=["*"], 33 | allow_credentials=True, 34 | allow_methods=["*"], 35 | allow_headers=["*"], 36 | ) 37 | 38 | # 注册 API 路由 39 | app.include_router(api_router) 40 | 41 | async def check_port_occupied(port): 42 | """ 43 | 异步检查指定端口是否被占用。 44 | """ 45 | try: 46 | netstat_command = ["netstat", "-an"] 47 | proc = await asyncio.create_subprocess_exec( 48 | *netstat_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE 49 | ) 50 | output, _ = await proc.communicate() 51 | return str(port) in output.decode() 52 | except Exception as e: 53 | logger.error(f"检查端口 {port} 时出错: {e}") 54 | return True 55 | 56 | async def kill_process_by_port(port): 57 | """ 58 | 异步结束占用指定端口的进程。 59 | """ 60 | try: 61 | netstat_command = ["netstat", "-an"] 62 | proc = await asyncio.create_subprocess_exec( 63 | *netstat_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE 64 | ) 65 | output, _ = await proc.communicate() 66 | for line in output.decode().splitlines(): 67 | if f":{port}" in line and "LISTEN" in line: 68 | pid = int(line.split()[-1]) 69 | taskkill_command = ["kill", "-9", str(pid)] 70 | await asyncio.create_subprocess_exec(*taskkill_command) 71 | logger.info(f"已结束占用端口 {port} 的进程 (PID: {pid})") 72 | return 73 | except Exception as e: 74 | logger.error(f"结束端口 {port} 的进程时出错: {e}") 75 | 76 | async def start_uvicorn(): 77 | """ 78 | 使用异步启动 Uvicorn 服务器 79 | """ 80 | port = 10417 81 | max_retries = 5 82 | retry_delay = 2 # 减少重试间隔以加快启动 83 | 84 | for i in range(max_retries): 85 | if not await check_port_occupied(port): 86 | uvicorn_command = [ 87 | sys.executable, "-m", "uvicorn", 88 | "main:app", 89 | "--host", "0.0.0.0", 90 | "--port", str(port), 91 | "--log-config", "uvicorn_log_config.json" 92 | ] 93 | subprocess.Popen(uvicorn_command) 94 | logger.info("Uvicorn server started in a separate process.") 95 | return 96 | else: 97 | await kill_process_by_port(port) 98 | logger.info(f"端口 {port} 被占用,已尝试结束占用进程,重试 ({i + 1}/{max_retries})...") 99 | await asyncio.sleep(retry_delay) 100 | 101 | logger.error(f"尝试启动 Uvicorn 服务器失败,端口 {port} 一直被占用。") 102 | 103 | def run_bot(): 104 | """ 105 | 运行机器人客户端,保留同步逻辑 106 | """ 107 | try: 108 | intents = botpy.Intents(public_messages=True, public_guild_messages=True) 109 | client = MyClient(intents=intents) 110 | 111 | logger.info(">>> PLUGIN MANAGER INITIALIZED") 112 | logger.info(" ↳ 插件已成功加载并注册") 113 | 114 | # 检查配置文件中的必要参数 115 | if "appid" not in test_config or "secret" not in test_config: 116 | logger.critical(" 机器人的 appid 或 secret 缺失") 117 | logger.critical(" ↳ 请检查 config.yaml 文件") 118 | sys.exit(1) 119 | 120 | logger.info(">>> CLIENT RUNNING...") 121 | client.run(appid=test_config["appid"], secret=test_config["secret"]) 122 | 123 | except Exception as e: 124 | # 捕获所有未处理的异常并记录 125 | logger.error(f" 在 run_bot 中捕获到未处理的异常: {e}", exc_info=True) 126 | handle_critical_error(sys.exc_info()) 127 | 128 | async def main(): 129 | print("") 130 | print(" _ _ _ _ ") 131 | print(" / \\ _ __ ___ _ _ / \\ | |_ __ ___ ___ _ __ __| |") 132 | print(" / _ \\ | '_ ` _ \\| | | | / _ \\ | | '_ ` _ \\ / _ \\| '_ \\ / _` |") 133 | print(" / ___ \\| | | | | | |_| |/ ___ \\| | | | | | | (_) | | | | (_| |") 134 | print(" /_/ \\_|_| |_| |_|\\__, /_/ \\_|_|_| |_| |_|\\___/|_| |_|\\__,_|") 135 | print(" |___/ ") 136 | print("") 137 | 138 | logger.info(">>> SYSTEM INITIATING...") 139 | 140 | # 并行启动 Uvicorn 服务器和机器人客户端 141 | uvicorn_task = asyncio.create_task(start_uvicorn()) 142 | 143 | # 运行机器人客户端(同步) 144 | bot_task = asyncio.to_thread(run_bot) 145 | 146 | await asyncio.gather(uvicorn_task, bot_task) 147 | 148 | if __name__ == "__main__": 149 | # 在主线程中创建事件循环 150 | asyncio.run(main()) 151 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # AmyAlmond Chatbot 4 | 5 | [![License](https://img.shields.io/badge/license-MPL2-red.svg)](hhttps://opensource.org/license/mpl-2-0) 6 | [![Python Version](https://img.shields.io/badge/python-3.8%2B-blue)](https://www.python.org/downloads/) 7 | [![GitHub Stars](https://img.shields.io/github/stars/shuakami/amyalmond_bot.svg)](https://github.com/shuakami/amyalmond_bot/stargazers) 8 | [![Build Status](https://img.shields.io/badge/build-passing-brightgreen.svg)](https://github.com/shuakami/amyalmond_bot) 9 | [![Version](https://img.shields.io/badge/version-1.3.0_(Stable_923001)-yellow.svg)](https://github.com/shuakami/amyalmond_bot/releases) 10 | 11 | [English](README_en.md) | 简体中文 12 | 13 | ⭐ Your go-to chatbot for supercharging group chats ⭐ 14 | 15 | [Features](#功能特性) • [Screenshots](#先看效果) • [Docs](#安装部署开发) • [Contribute](#开发与贡献) • [License](#许可证) 16 |
17 | 18 | ## Features 19 | 20 | AmyAlmond is an LLM API-powered smart chatbot designed to seamlessly integrate into QQ groups and channels. 21 | 22 | By leveraging LLM API, AmyAlmond offers context-aware intelligent responses, enhancing user interaction and supporting long-term memory management. Whether it’s automating replies or boosting user engagement, she handles complex conversations like a breeze. 23 | 24 | - 🌈 She uses the **LLM API** to generate human-like responses based on conversation context, with customizable prompts. 25 | - 💗 Integrated with QQ’s official Python SDK, so you don’t have to worry about being blocked. 26 | - 🔥 Automatically recognizes and remembers user names, providing a personalized interaction experience. 27 | - 🧠 Equipped with **long-term and short-term memory**, she can record and recall important information, ensuring continuity in conversations. 28 | - 🐳 Administrators can control her behavior with specific commands. 29 | - ⭐ **Full configuration hot-reloading** reduces restart times, boosting efficiency. 30 | - 🪝 Detailed logs and code comments make debugging and monitoring a breeze. 31 | 32 | ## Curious about the results? 33 | 34 | ![效果图_对话注册](/dist/background-en/chat-demo.png) 35 | ![效果图_记忆上下文](/dist/background-en/chat-memory-demo.png) 36 | 37 | ## Installation/Deployment/Development 38 | 39 | 40 | Documentation Database 41 | 42 | 43 |
44 | Click the image to jump in 45 |
46 | 47 | ## Contributing 48 | 49 | We'd love to have you on board! Whether it’s adding new features, fixing bugs, or improving documentation, your contributions are welcome! 50 | 51 | ### Branch Strategy 52 | 53 | We follow the Git Flow branching model: 54 | 55 | - **main**: The stable branch, always ready for production. 56 | - **develop**: The development branch, where all new features are integrated. 57 | - **feature/**: Feature branches, created from `develop`, merged back once the feature is complete. 58 | - **hotfix/**: Hotfix branches, used to quickly patch bugs, merged back into `main` and `develop`. 59 | 60 | ### How to Contribute 61 | 62 | 1. **Fork this repo** 63 | Fork the project to your GitHub account. 64 | 65 | 2. **Create a branch** 66 | Create a new feature branch for your changes: 67 | ```bash 68 | git checkout -b feature/AmazingFeature 69 | ``` 70 | 71 | 3. **Commit your changes** 72 | Commit your code with clear and concise messages: 73 | ```bash 74 | git commit -m 'Add some AmazingFeature' 75 | ``` 76 | 77 | 4. **Push to GitHub** 78 | Push your branch to GitHub: 79 | ```bash 80 | git push origin feature/AmazingFeature 81 | ``` 82 | 83 | 5. **Create a Pull Request** 84 | Create a Pull Request on GitHub, describing your changes and their impact. 85 | 86 | ## License 87 | 88 | [![License: MPL 2.0](https://img.shields.io/badge/License-MPL_2.0-brightgreen.svg)](https://opensource.org/licenses/MPL-2.0) 89 | 90 | AmyAlmond is licensed under the [MPL 2.0 License](LICENSE). You are free to use, modify, and distribute this project, but you must open source any modified versions and retain the original author's copyright notice. 91 | 92 | ## Disclaimer 93 | 94 | This project is for learning and research purposes only. The developers are not responsible for any consequences resulting from the use of this project. Please ensure compliance with relevant laws and respect others' intellectual property rights when using this project. 95 | 96 | ## Roadmap 97 | 98 | Check out our [Project Board](https://github.com/users/shuakami/projects/1) for the latest updates! 99 | 100 | q(≧▽≦q) You've read this far—how about dropping us a ⭐️? 101 | 102 | 103 | 109 | 115 | amyalmond_bot Chart 119 | -------------------------------------------------------------------------------- /core/ace/secure.py: -------------------------------------------------------------------------------- 1 | # core/ace/secure.py 2 | import json 3 | import os 4 | import time 5 | import random 6 | import string 7 | from threading import Thread 8 | 9 | 10 | class VerificationCode: 11 | def __init__(self, data=None): 12 | if data: 13 | self.code = data.get("code") 14 | self.generated_time = data.get("generated_time") 15 | self.used_codes = set(data.get("used_codes", [])) 16 | self.last_verified_time = data.get("last_verified_time", 0) 17 | self.last_rejected_time = data.get("last_rejected_time", 0) 18 | else: 19 | self.code = None 20 | self.generated_time = None 21 | self.used_codes = set() 22 | self.last_verified_time = 0 23 | self.last_rejected_time = 0 24 | 25 | def to_dict(self): 26 | return { 27 | "code": self.code, 28 | "generated_time": self.generated_time, 29 | "used_codes": list(self.used_codes), 30 | "last_verified_time": self.last_verified_time, 31 | "last_rejected_time": self.last_rejected_time, 32 | } 33 | 34 | def generate_code(self): 35 | # 生成7天内不重复的验证码 36 | while True: 37 | new_code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) 38 | if new_code not in self.used_codes: 39 | self.code = new_code 40 | self.generated_time = time.time() 41 | self.used_codes.add(new_code) 42 | # 定期清理 used_codes,防止无限增长 43 | if len(self.used_codes) > 1000: 44 | self.used_codes = set(list(self.used_codes)[-500:]) # 保留最近500个验证码 45 | break 46 | 47 | def is_valid(self): 48 | if self.code is None: 49 | return False 50 | return time.time() - self.generated_time < 300 # 5分钟有效期 51 | 52 | def is_rejected_recently(self): 53 | return time.time() - getattr(self, 'last_rejected_time', 0) < 3600 # 1小时内拒绝过 54 | 55 | def mark_rejected(self): 56 | self.last_rejected_time = time.time() 57 | 58 | def mark_verified(self): 59 | self.last_verified_time = time.time() 60 | 61 | def is_verified_recently(self): 62 | return time.time() - self.last_verified_time < 604800 # 7天内验证过 63 | 64 | class SecureInterface: 65 | def __init__(self): 66 | self.secure_file = "configs/secure.json" 67 | self.verification_code = self._load_verification_code() 68 | 69 | def _load_verification_code(self): 70 | if os.path.exists(self.secure_file): 71 | try: 72 | with open(self.secure_file, "r") as f: 73 | data = json.load(f) 74 | return VerificationCode(data) 75 | except json.JSONDecodeError: 76 | # 文件为空或格式错误,返回默认的 VerificationCode 对象 77 | return VerificationCode() 78 | else: 79 | return VerificationCode() 80 | 81 | def _save_verification_code(self): 82 | # 获取文件所在的目录路径 83 | directory = os.path.dirname(self.secure_file) 84 | 85 | # 如果目录不存在,则创建目录 86 | if not os.path.exists(directory): 87 | os.makedirs(directory) 88 | 89 | # 打开文件并写入数据 90 | with open(self.secure_file, "w") as f: 91 | json.dump(self.verification_code.to_dict(), f) 92 | 93 | def _show_verification_dialog(self): 94 | if self.verification_code.is_valid(): 95 | code = self.verification_code.code 96 | else: 97 | self.verification_code.generate_code() 98 | code = self.verification_code.code 99 | 100 | print(f" 验证码: {code}") 101 | print(f"您的关键API在被请求,触发了ACE模块拦截。您可以:") 102 | print("1. 输入验证码以允许本次请求") 103 | print("2. 拒绝本次请求") 104 | print("3. 强制拒绝本次及后续1小时内的所有请求") 105 | 106 | while True: 107 | user_input = input("请选择 (1/2/3): ") 108 | if user_input in ["1", "2", "3"]: 109 | break 110 | print("无效的选择,请重新输入") 111 | 112 | if user_input == "1": 113 | verification_code_input = input("请输入验证码: ") 114 | if verification_code_input == code: 115 | self.verification_code.mark_verified() # 标记为已验证 116 | return True 117 | else: 118 | print("验证码错误") 119 | return False 120 | elif user_input == "2": 121 | return False 122 | elif user_input == "3": 123 | self.last_rejected_time = time.time() 124 | return False 125 | 126 | def verify_request(self): 127 | if time.time() - self.verification_code.last_rejected_time < 3600: 128 | return False # 1小时内强制拒绝过,直接拒绝 129 | 130 | if self.verification_code.is_verified_recently(): 131 | return True # 7天内验证过,直接允许 132 | 133 | # 在新线程中打开对话框,避免阻塞主线程 134 | verification_result = False 135 | 136 | def verification_thread(): 137 | nonlocal verification_result 138 | verification_result = self._show_verification_dialog() 139 | 140 | thread = Thread(target=verification_thread) 141 | thread.start() 142 | thread.join() 143 | 144 | if verification_result: 145 | self.verification_code.code = None # 验证成功后清空验证码,7天内不再重复验证 146 | self._save_verification_code() # 保存数据到文件 147 | return True 148 | else: 149 | self._save_verification_code() # 保存数据到文件 150 | return False -------------------------------------------------------------------------------- /core/api/controllers/es_controller.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Body, Depends 2 | from core.ace.secure import SecureInterface 3 | from pydantic import BaseModel 4 | from core.utils.logger import get_logger 5 | from core.db.elasticsearch_index_manager import ElasticsearchIndexManager 6 | 7 | logger = get_logger() 8 | router = APIRouter() 9 | 10 | # 依赖注入 ElasticsearchIndexManager 实例 11 | async def get_es(): 12 | es = ElasticsearchIndexManager() 13 | try: 14 | yield es 15 | finally: 16 | pass # Elasticsearch 连接不需要手动关闭 17 | 18 | class UpdateDocumentModel(BaseModel): 19 | update: dict 20 | 21 | @router.get("/indices") 22 | async def get_all_indices(es: ElasticsearchIndexManager = Depends(get_es)): 23 | """获取所有索引名称""" 24 | secure_interface = SecureInterface() 25 | if not secure_interface.verify_request(): 26 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 27 | 28 | try: 29 | indices = es.get_all_indices() 30 | return {"status": "success", "indices": indices} 31 | except Exception as e: 32 | logger.error(f"获取索引列表时出错: {e}") 33 | raise HTTPException(status_code=500, detail=str(e)) 34 | 35 | @router.get("/mapping/{index_name}") 36 | async def get_index_mapping(index_name: str, es: ElasticsearchIndexManager = Depends(get_es)): 37 | """获取指定索引的映射""" 38 | secure_interface = SecureInterface() 39 | if not secure_interface.verify_request(): 40 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 41 | 42 | try: 43 | mapping = es.get_index_mapping(index_name) 44 | if mapping: 45 | return {"status": "success", "mapping": mapping} 46 | else: 47 | return {"status": "not_found", "message": f"索引 '{index_name}' 不存在"} 48 | except Exception as e: 49 | logger.error(f"获取索引映射时出错: {e}") 50 | raise HTTPException(status_code=500, detail=str(e)) 51 | 52 | @router.get("/documents/{index_name}") 53 | async def get_all_documents(index_name: str, es: ElasticsearchIndexManager = Depends(get_es)): 54 | """获取指定索引中的所有文档""" 55 | secure_interface = SecureInterface() 56 | if not secure_interface.verify_request(): 57 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 58 | 59 | try: 60 | # 使用 match_all 查询获取所有文档 61 | query = {"query": {"match_all": {}}} 62 | documents = es.search(index_name, query) 63 | return {"status": "success", "documents": documents} 64 | except Exception as e: 65 | logger.error(f"获取文档列表时出错: {e}") 66 | raise HTTPException(status_code=500, detail=str(e)) 67 | 68 | @router.post("/documents/{index_name}") 69 | async def insert_document(index_name: str, document: dict = Body(...), es: ElasticsearchIndexManager = Depends(get_es)): 70 | """向指定索引中插入文档""" 71 | secure_interface = SecureInterface() 72 | if not secure_interface.verify_request(): 73 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 74 | 75 | try: 76 | # 使用 Elasticsearch 的 index API 插入文档 77 | result = es.es.index(index=index_name, body=document) # 直接使用 es.es 78 | return {"status": "success", "inserted_id": result["_id"]} 79 | except Exception as e: 80 | logger.error(f"插入文档时出错: {e}") 81 | raise HTTPException(status_code=500, detail=str(e)) 82 | 83 | @router.get("/documents/{index_name}/find") 84 | async def find_document(index_name: str, query: dict = Body(...), es: ElasticsearchIndexManager = Depends(get_es)): 85 | """根据查询条件查找文档""" 86 | secure_interface = SecureInterface() 87 | if not secure_interface.verify_request(): 88 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 89 | 90 | try: 91 | documents = es.search(index_name, query) 92 | if documents: 93 | return {"status": "success", "documents": documents} 94 | else: 95 | return {"status": "not_found", "message": "未找到符合条件的文档"} 96 | except Exception as e: 97 | logger.error(f"查找文档时出错: {e}") 98 | raise HTTPException(status_code=500, detail=str(e)) 99 | 100 | @router.put("/documents/{index_name}/update/{document_id}") 101 | async def update_document(index_name: str, document_id: str, update_data: UpdateDocumentModel = Body(...), es: ElasticsearchIndexManager = Depends(get_es)): 102 | """根据文档ID更新文档""" 103 | secure_interface = SecureInterface() 104 | if not secure_interface.verify_request(): 105 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 106 | 107 | try: 108 | # 使用 Elasticsearch 的 update API 更新文档 109 | result = es.es.update(index=index_name, id=document_id, body=update_data.update) # 直接使用 es.es 110 | return {"status": "success", "result": result} 111 | except Exception as e: 112 | logger.error(f"更新文档时出错: {e}") 113 | raise HTTPException(status_code=500, detail=str(e)) 114 | 115 | @router.delete("/documents/{index_name}/delete/{document_id}") 116 | async def delete_document(index_name: str, document_id: str, es: ElasticsearchIndexManager = Depends(get_es)): 117 | """根据文档ID删除文档""" 118 | secure_interface = SecureInterface() 119 | if not secure_interface.verify_request(): 120 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 121 | 122 | try: 123 | success = es.delete_document(index_name, document_id) 124 | if success: 125 | return {"status": "success", "message": f"文档 '{document_id}' 已删除"} 126 | else: 127 | return {"status": "not_found", "message": f"索引 '{index_name}' 或文档 '{document_id}' 不存在"} 128 | except Exception as e: 129 | logger.error(f"删除文档时出错: {e}") 130 | raise HTTPException(status_code=500, detail=str(e)) -------------------------------------------------------------------------------- /tools/db_tools.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import subprocess 4 | 5 | 6 | from game import play_game 7 | 8 | 9 | def run_external_script(script_relative_path): 10 | """分离运行外部脚本""" 11 | script_dir = os.path.dirname(os.path.abspath(__file__)) 12 | script_path = os.path.abspath(os.path.join(script_dir, script_relative_path)) 13 | try: 14 | subprocess.run([sys.executable, script_path], check=True) 15 | except subprocess.CalledProcessError as e: 16 | print(f"执行 {script_path} 时出错: {e}") 17 | return False 18 | return True 19 | 20 | 21 | def manage_mongodb(action): 22 | """管理MongoDB的安装、启动和配置""" 23 | if action == '1': # 安装 24 | print("开始安装MongoDB...") 25 | if not run_external_script("install/mongodb/mongodb_install.py"): 26 | print("MongoDB安装失败。") 27 | return False 28 | elif action == '2': # 启动 29 | print("开始启动MongoDB...") 30 | if not run_external_script("setup/mongodb/mongodb_setup.py"): 31 | print("MongoDB启动失败。") 32 | return False 33 | elif action == '3': # 配置 34 | if not run_external_script("setup/mongodb/mongodb_setup_configs.py"): 35 | print("MongoDB启动失败。") 36 | return False 37 | return True 38 | 39 | 40 | def print_help(): 41 | """打印帮助信息,详细解释各个操作及其使用方法""" 42 | help_text = """ 43 | +----------------------------------------------------------------------------------+ 44 | | 帮助文档 | 45 | +----------------------------------------------------------------------------------+ 46 | | 欢迎来到数据库管理工具的帮助页面! | 47 | | 在这里,我们将详细介绍每一个功能的作用,以及如何使用这个工具来高效地管理你的数据库。| 48 | +----------------------------------------------------------------------------------+ 49 | 50 | 1. 安装 (帮助你安装数据库) 51 | - 这个选项将引导你通过一个简单的步骤来安装MongoDB或Elasticsearch数据库。 52 | - 安装过程是自动化的,你只需要选择数据库类型,然后工具将自动运行对应的安装脚本。 53 | - 适用于初次使用该数据库或在新环境中重新搭建数据库的用户。 54 | 55 | 2. 启动 (帮助你启动数据库) 56 | - 如果你已经安装了数据库,并且需要启动它们,那么这个选项适合你。 57 | - 启动操作将执行预设的启动脚本,确保数据库正确启动,并且能够接受连接。 58 | - 通常用于在服务器重新启动后,或者需要手动启动数据库的场景。 59 | 60 | 3. 配置 (配置数据库账号密码) 61 | - 此选项用于配置数据库的基本安全设置,比如账号和密码。 62 | - 对于MongoDB和Elasticsearch,我们提供了专用的配置脚本来设置数据库的账号和密码。 63 | - 强烈建议在生产环境中进行适当的配置,以确保数据库的安全性。 64 | 65 | 4. 升级原数据库(数据迁移) 66 | - 当你需要将现有的数据库升级到新版本或迁移数据时,选择这个选项。 67 | - 该操作将自动调用升级脚本,确保数据在升级过程中不会丢失。 68 | - 请在操作之前备份你的数据,以防万一。 69 | 70 | +----------------------------------------------------------------------------------+ 71 | | 注意:每一个操作都有它特定的目的,请根据你的需求选择相应的功能。 72 | | 如果你在使用过程中遇到了问题,建议参考对应的日志文件,以获取更详细的信息。 73 | | 我们的目标是让你以最少的操作成本,完成对数据库的管理工作。 74 | +----------------------------------------------------------------------------------+ 75 | 76 | 77 | 78 | 79 | egg you cai dan o 80 | """ 81 | print(help_text) 82 | 83 | def manage_elasticsearch(action): 84 | """管理Elasticsearch的安装、启动和配置""" 85 | if action == '1': # 安装 86 | print("开始安装Elasticsearch...") 87 | if not run_external_script("install/elasticsearch/elasticsearch_install.py"): 88 | print("Elasticsearch安装失败。") 89 | return False 90 | elif action == '2': # 启动 91 | print("开始启动Elasticsearch...") 92 | if not run_external_script("setup/elasticsearch/elasticsearch_setup.py"): 93 | print("Elasticsearch启动失败。") 94 | return False 95 | elif action == '3': # 配置 96 | print("开始配置Elasticsearch...") 97 | if not run_external_script("setup/elasticsearch/elasticsearch_configs.py"): 98 | print("Elasticsearch启动失败。") 99 | return False 100 | return True 101 | 102 | 103 | if __name__ == "__main__": 104 | if len(sys.argv) > 1 and sys.argv[1].lower() == 'egg': 105 | play_game() 106 | sys.exit(0) 107 | 108 | print("+----------------------------------------+") 109 | print("| 欢迎使用数据库管理工具 |") 110 | print("+----------------------------------------+") 111 | print("| 请选择操作: |") 112 | print("| 1. 安装 (帮助你安装数据库) |") 113 | print("| 2. 启动 (帮助你启动数据库) |") 114 | print("| 3. 配置 (配置数据库账号密码) |") 115 | print("| 4. 升级原数据库(数据迁移) |") 116 | print("| 需要帮助请按在脚本后缀加h(就是h,不是-h) |") 117 | print("+----------------------------------------+") 118 | 119 | if len(sys.argv) > 1 and sys.argv[1] == 'h': 120 | print_help() 121 | sys.exit(0) 122 | 123 | choice = input("请输入数字选择操作: ") 124 | 125 | if choice == '4': 126 | print("开始升级原数据库...") 127 | if not run_external_script("upgrade/db_upgrade.py"): 128 | print("数据库升级失败。") 129 | sys.exit(1) 130 | 131 | if choice not in ['1', '2', '3']: 132 | print("无效的选择,程序将退出。") 133 | sys.exit(1) 134 | 135 | print("+----------------------------------------+") 136 | print("| 请选择数据库: |") 137 | print("| 1. MongoDB |") 138 | print("| 2. Elasticsearch |") 139 | print("+----------------------------------------+") 140 | 141 | db_choice = input("请输入数字选择数据库: ") 142 | 143 | if db_choice == '1': 144 | print(f"您选择了MongoDB,执行 {choice} 操作。") 145 | if not manage_mongodb(choice): 146 | print("MongoDB操作失败。") 147 | sys.exit(1) 148 | elif db_choice == '2': 149 | print(f"您选择了Elasticsearch,执行 {choice} 操作。") 150 | if not manage_elasticsearch(choice): 151 | print("Elasticsearch操作失败。") 152 | sys.exit(1) 153 | else: 154 | print("无效的选择,程序将退出。") 155 | sys.exit(1) 156 | 157 | print("操作完成。") 158 | sys.exit(0) 159 | -------------------------------------------------------------------------------- /tools/setup/mongodb/mongodb_setup_configs.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import getpass 3 | import sys 4 | from pathlib import Path 5 | from pymongo import MongoClient, errors 6 | 7 | # 配置文件路径 8 | MONGO_CONFIG_PATH = Path(__file__).parent.parent.parent / "configs/mongodb.yaml" 9 | PROJECT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "configs/config.yaml" 10 | 11 | 12 | def prompt_user_for_mongo_credentials(): 13 | print("+-----------------------------------------------------+") 14 | print("| MongoDB 配置尚未设置。请输入用户名和密码: |") 15 | print("+-----------------------------------------------------+") 16 | 17 | username = input("请输入 MongoDB 用户名: ").strip() 18 | while not username: 19 | print("用户名不能为空,请重新输入。") 20 | username = input("请输入 MongoDB 用户名: ").strip() 21 | 22 | password = getpass.getpass("请输入 MongoDB 密码: ").strip() 23 | while not password: 24 | print("密码不能为空,请重新输入。") 25 | password = getpass.getpass("请输入 MongoDB 密码: ").strip() 26 | 27 | return username, password 28 | 29 | 30 | def update_mongo_config(username, password): 31 | try: 32 | # 确保配置目录存在 33 | if not MONGO_CONFIG_PATH.parent.exists(): 34 | MONGO_CONFIG_PATH.parent.mkdir(parents=True) 35 | 36 | # 加载现有的 mongodb.yaml 配置文件,如果存在 37 | config = {} 38 | if MONGO_CONFIG_PATH.exists(): 39 | with open(MONGO_CONFIG_PATH, 'r', encoding='utf-8') as f: 40 | config = yaml.safe_load(f) or {} 41 | 42 | # 更新 mongodb.yaml 文件中的配置 43 | config['mongodb'] = {'username': username, 'password': password} 44 | 45 | # 保存更新后的配置到 mongodb.yaml 文件 46 | with open(MONGO_CONFIG_PATH, 'w', encoding='utf-8') as f: 47 | yaml.dump(config, f, allow_unicode=True) 48 | 49 | # 更新项目根目录的 config.yaml 文件 50 | if PROJECT_CONFIG_PATH.exists(): 51 | with open(PROJECT_CONFIG_PATH, 'r', encoding='utf-8') as f: 52 | project_config_lines = f.readlines() 53 | else: 54 | project_config_lines = [] 55 | 56 | # 移除旧的 MongoDB 配置 57 | new_config_lines = [] 58 | inside_mongodb_block = False 59 | for line in project_config_lines: 60 | if line.strip() == '# ---------- MongoDB Configuration ----------': 61 | inside_mongodb_block = True 62 | continue 63 | if line.strip() == '# ---------- End MongoDB Configuration ------': 64 | inside_mongodb_block = False 65 | continue 66 | if not inside_mongodb_block: 67 | new_config_lines.append(line) 68 | 69 | # 将新的 MongoDB 配置添加到文件末尾 70 | new_config_lines.append('\n') 71 | new_config_lines.append('# ---------- MongoDB Configuration ----------\n') 72 | new_config_lines.append(f'mongodb_url: "mongodb://localhost:27017"\n') 73 | new_config_lines.append(f'mongodb_username: "{username}"\n') 74 | new_config_lines.append(f'mongodb_password: "{password}"\n') 75 | new_config_lines.append('# ---------- End MongoDB Configuration ------\n') 76 | 77 | # 保存更新后的内容到 config.yaml 文件 78 | with open(PROJECT_CONFIG_PATH, 'w', encoding='utf-8') as f: 79 | f.writelines(new_config_lines) 80 | 81 | print(f"> MongoDB 配置已保存至:{MONGO_CONFIG_PATH} 和 {PROJECT_CONFIG_PATH}") 82 | print(f"> -------------------------------------------------") 83 | print(f"> 请不要擅自修改已添加的配置内容及注释,否则可能导致配置系统无法正常工作。") 84 | print(f"> -------------------------------------------------") 85 | 86 | except Exception as e: 87 | print(f"! 保存 MongoDB 配置时出错:{e}") 88 | raise 89 | 90 | 91 | def apply_mongo_config(username, password): 92 | try: 93 | client = MongoClient("mongodb://localhost:27017/") 94 | db = client.admin 95 | 96 | # 创建管理员用户 97 | db.command("createUser", username, pwd=password, roles=[{"role": "root", "db": "admin"}]) 98 | 99 | print("> MongoDB 配置已成功应用。") 100 | 101 | # 验证连接是否成功 102 | test_mongo_connection(username, password) 103 | 104 | except errors.OperationFailure as err: 105 | print(f"! 应用 MongoDB 配置时失败:{err}") 106 | # 如果报错already exists 107 | if "already exists" in str(err): 108 | print("> 用户已存在,跳过创建用户步骤。") 109 | raise 110 | except Exception as e: 111 | print(f"! 应用 MongoDB 配置时出错:{e}") 112 | # 如果报错already exists 113 | if "already exists" in str(e): 114 | print("> 用户已存在,跳过创建用户步骤。") 115 | # 检测一下链接 116 | if test_mongo_connection(username, password): 117 | print("> 使用新的用户名和密码连接 MongoDB 成功!") 118 | raise 119 | 120 | 121 | def test_mongo_connection(username, password): 122 | try: 123 | uri = f"mongodb://{username}:{password}@localhost:27017/" 124 | client = MongoClient(uri, serverSelectionTimeoutMS=5000) 125 | client.server_info() 126 | print("> 使用新的用户名和密码连接 MongoDB 成功!") 127 | return True 128 | except errors.ServerSelectionTimeoutError as err: 129 | print(f"! 使用新的用户名和密码无法连接到MongoDB服务器:{err}") 130 | return False 131 | except Exception as e: 132 | print(f"! 测试 MongoDB 连接时发生错误:{e}") 133 | return False 134 | 135 | 136 | def configure_mongodb(): 137 | try: 138 | # 提示用户输入用户名和密码 139 | username, password = prompt_user_for_mongo_credentials() 140 | 141 | # 更新配置文件 142 | update_mongo_config(username, password) 143 | 144 | # 应用到 MongoDB 并验证 145 | apply_mongo_config(username, password) 146 | 147 | except Exception as e: 148 | print(f"! 配置 MongoDB 时发生错误:{e}") 149 | sys.exit(1) 150 | 151 | 152 | if __name__ == "__main__": 153 | print("> 开始MongoDB配置...") 154 | print(f"> MongoDB 配置文件路径:{MONGO_CONFIG_PATH}") 155 | print(f"> 项目配置文件路径:{PROJECT_CONFIG_PATH}") 156 | configure_mongodb() 157 | print("> MongoDB配置完成。") 158 | sys.exit(0) 159 | -------------------------------------------------------------------------------- /core/api/controllers/db_controller.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Body, Depends 2 | from core.ace.secure import SecureInterface 3 | from pydantic import BaseModel 4 | from core.utils.logger import get_logger 5 | from core.utils.mongodb_utils import MongoDBUtils 6 | 7 | logger = get_logger() 8 | router = APIRouter() 9 | 10 | # 依赖注入 MongoDBUtils 实例 11 | async def get_db(): 12 | db = MongoDBUtils() 13 | try: 14 | yield db 15 | finally: 16 | db.close_connection() 17 | 18 | class UpdateDocumentModel(BaseModel): 19 | update: dict 20 | 21 | @router.get("/databases") 22 | async def get_all_databases(db: MongoDBUtils = Depends(get_db)): 23 | """获取所有数据库名称""" 24 | secure_interface = SecureInterface() 25 | if not secure_interface.verify_request(): 26 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 27 | 28 | try: 29 | databases = db.get_all_database_names() 30 | return {"status": "success", "databases": databases} 31 | except Exception as e: 32 | logger.error(f"获取数据库列表时出错: {e}") 33 | raise HTTPException(status_code=500, detail=str(e)) 34 | 35 | @router.get("/collections/{db_name}") 36 | async def get_all_collections(db_name: str, db: MongoDBUtils = Depends(get_db)): 37 | """获取指定数据库中的所有集合名称""" 38 | secure_interface = SecureInterface() 39 | if not secure_interface.verify_request(): 40 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 41 | 42 | try: 43 | collections = db.get_all_collection_names(db_name) 44 | return {"status": "success", "collections": collections} 45 | except Exception as e: 46 | logger.error(f"获取数据库集合列表时出错: {e}") 47 | raise HTTPException(status_code=500, detail=str(e)) 48 | 49 | @router.get("/documents/{db_name}/{collection_name}") 50 | async def get_all_documents(db_name: str, collection_name: str, db: MongoDBUtils = Depends(get_db)): 51 | """获取指定集合中的所有文档""" 52 | secure_interface = SecureInterface() 53 | if not secure_interface.verify_request(): 54 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 55 | 56 | try: 57 | # 获取集合对象 58 | collection = db.client[db_name][collection_name] 59 | documents = list(collection.find({})) 60 | 61 | # 将 ObjectId 转换为字符串 62 | for document in documents: 63 | if "_id" in document: 64 | document["_id"] = str(document["_id"]) 65 | 66 | return {"status": "success", "documents": documents} 67 | except Exception as e: 68 | logger.error(f"获取数据库文档列表时出错: {e}") 69 | raise HTTPException(status_code=500, detail=str(e)) 70 | 71 | @router.post("/documents/{db_name}/{collection_name}") 72 | async def insert_document(db_name: str, collection_name: str, document: dict = Body(...), db: MongoDBUtils = Depends(get_db)): 73 | """向指定集合中插入文档""" 74 | secure_interface = SecureInterface() 75 | if not secure_interface.verify_request(): 76 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 77 | 78 | try: 79 | # 获取集合对象 80 | collection = db.client[db_name][collection_name] 81 | result = collection.insert_one(document) 82 | return {"status": "success", "inserted_id": str(result.inserted_id)} 83 | except Exception as e: 84 | logger.error(f"插入数据库文档时出错: {e}") 85 | raise HTTPException(status_code=500, detail=str(e)) 86 | 87 | 88 | @router.post("/documents/find/{db_name}/{collection_name}") 89 | async def find_document(db_name: str, collection_name: str, query: dict = Body(...), 90 | db: MongoDBUtils = Depends(get_db)): 91 | """根据查询条件查找文档""" 92 | secure_interface = SecureInterface() 93 | if not secure_interface.verify_request(): 94 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 95 | 96 | try: 97 | # 获取集合对象 98 | collection = db.client[db_name][collection_name] 99 | # 查找文档 100 | cursor = collection.find(query) 101 | result = [] 102 | for document in cursor: 103 | if "_id" in document: 104 | document["_id"] = str(document["_id"]) # 将 ObjectId 转换为字符串 105 | result.append(document) 106 | 107 | return {"status": "success", "documents": result} 108 | 109 | except Exception as e: 110 | logger.error(f"查找数据库文档时出错: {e}") 111 | raise HTTPException(status_code=500, detail=str(e)) 112 | 113 | 114 | @router.put("/update/documents/{db_name}/{collection_name}") 115 | async def update_document(db_name: str, collection_name: str, query: dict = Body(...), update_data: UpdateDocumentModel = Body(...), db: MongoDBUtils = Depends(get_db)): 116 | """根据查询条件更新文档""" 117 | secure_interface = SecureInterface() 118 | if not secure_interface.verify_request(): 119 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 120 | 121 | try: 122 | # 获取集合对象 123 | collection = db.client[db_name][collection_name] 124 | result = collection.update_one(query, update_data.update) 125 | return {"status": "success", "matched_count": result.matched_count, "modified_count": result.modified_count} 126 | except Exception as e: 127 | logger.error(f"更新数据库文档时出错: {e}") 128 | raise HTTPException(status_code=500, detail=str(e)) 129 | 130 | @router.delete("/delete/documents/{db_name}/{collection_name}") 131 | async def delete_document(db_name: str, collection_name: str, query: dict = Body(...), db: MongoDBUtils = Depends(get_db)): 132 | """根据查询条件删除文档""" 133 | secure_interface = SecureInterface() 134 | if not secure_interface.verify_request(): 135 | return {"status": "error", "message": "验证码错误或已过期或者已经拒绝此请求"} 136 | 137 | try: 138 | # 获取集合对象 139 | collection = db.client[db_name][collection_name] 140 | result = collection.delete_one(query) 141 | return {"status": "success", "deleted_count": result.deleted_count} 142 | except Exception as e: 143 | logger.error(f"删除数据库文档时出错: {e}") 144 | raise HTTPException(status_code=500, detail=str(e)) -------------------------------------------------------------------------------- /core/api/controllers/plugin_controller.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | import requests 4 | from fastapi import APIRouter, UploadFile, File, HTTPException, Body 5 | from core.plugins.plugin_manager import PluginManager 6 | from core.plugins.tools.add_plugin import create_plugin 7 | from core.utils.logger import get_logger 8 | import tempfile 9 | import os 10 | import shutil 11 | 12 | logger = get_logger() 13 | router = APIRouter() 14 | plugin_manager = PluginManager(bot_client=None) 15 | 16 | 17 | 18 | @router.post("/install") 19 | async def install_plugin(file: UploadFile = File(...)): 20 | if not file: 21 | raise HTTPException(status_code=400, detail="上传文件不存在") 22 | 23 | try: 24 | # 使用 tempfile 来生成一个跨平台的临时文件 25 | with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file: 26 | shutil.copyfileobj(file.file, tmp_file) 27 | tmp_file_path = tmp_file.name 28 | 29 | logger.info(f"文件 {file.filename} 已上传到 {tmp_file_path}") 30 | 31 | # 获取插件的目标安装目录 32 | plugin_name = os.path.splitext(file.filename)[0] 33 | plugins_dir = os.path.join("core", "plugins", plugin_name) 34 | 35 | if os.path.exists(plugins_dir): 36 | logger.warning(f"插件目录 {plugins_dir} 已存在,正在删除...") 37 | shutil.rmtree(plugins_dir) 38 | 39 | os.makedirs(plugins_dir, exist_ok=True) 40 | 41 | # 解压到插件目录 42 | with zipfile.ZipFile(tmp_file_path, 'r') as zip_ref: 43 | zip_ref.extractall(plugins_dir) 44 | 45 | logger.info(f"插件 {file.filename} 已成功解压到 {plugins_dir}") 46 | plugin_manager.load_plugins() # 重新加载插件 47 | 48 | return {"status": "success", "message": f"插件 {file.filename} 已成功安装"} 49 | except Exception as e: 50 | logger.error(f"安装插件时出错: {e}") 51 | raise HTTPException(status_code=500, detail=str(e)) 52 | finally: 53 | # 清理上传的文件 54 | if os.path.exists(tmp_file_path): 55 | os.remove(tmp_file_path) 56 | 57 | @router.post("/url_install") 58 | async def install_plugin_from_url(url: str): 59 | """ 60 | 通过提供zip文件链接安装插件。 61 | """ 62 | tmp_file_path = None # 初始化临时文件路径 63 | try: 64 | # 处理 zip 文件链接 65 | response = requests.get(url, stream=True) 66 | response.raise_for_status() 67 | with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file: 68 | for chunk in response.iter_content(chunk_size=8192): 69 | tmp_file.write(chunk) 70 | tmp_file_path = tmp_file.name 71 | logger.info(f"文件 {url} 已下载到 {tmp_file_path}") 72 | file_name = os.path.basename(url) # 从链接中提取文件名 73 | 74 | # 获取插件的目标安装目录 75 | plugin_name = os.path.splitext(file_name)[0] 76 | plugins_dir = os.path.join("core", "plugins", plugin_name) 77 | 78 | if os.path.exists(plugins_dir): 79 | logger.warning(f"插件目录 {plugins_dir} 已存在,正在删除...") 80 | shutil.rmtree(plugins_dir) 81 | 82 | os.makedirs(plugins_dir, exist_ok=True) 83 | 84 | # 解压到插件目录 85 | with zipfile.ZipFile(tmp_file_path, 'r') as zip_ref: 86 | zip_ref.extractall(plugins_dir) 87 | 88 | logger.info(f"插件 {file_name} 已成功解压到 {plugins_dir}") 89 | plugin_manager.load_plugins() # 重新加载插件 90 | 91 | return {"status": "success", "message": f"插件 {file_name} 已成功安装"} 92 | 93 | except Exception as e: 94 | logger.error(f"安装插件时出错: {e}") 95 | raise HTTPException(status_code=500, detail=str(e)) 96 | finally: 97 | # 清理下载的临时文件 98 | if tmp_file_path and os.path.exists(tmp_file_path): 99 | os.remove(tmp_file_path) 100 | 101 | 102 | @router.post("/uninstall") 103 | async def uninstall_plugin(plugin_name: str): 104 | try: 105 | plugin_manager.uninstall_plugin(plugin_name) 106 | return {"status": "success", "message": f"插件 {plugin_name} 已成功卸载"} 107 | except Exception as e: 108 | logger.error(f"卸载插件时出错: {e}") 109 | raise HTTPException(status_code=500, detail=str(e)) 110 | 111 | 112 | # @router.post("/enable") 113 | # async def enable_plugin(plugin_name: str): 114 | # try: 115 | # plugin_manager.enable_plugin(plugin_name) 116 | # return {"status": "success", "message": f"插件 {plugin_name} 已启用"} 117 | # except Exception as e: 118 | # logger.error(f"启用插件时出错: {e}") 119 | # raise HTTPException(status_code=500, detail=str(e)) 120 | # 121 | # 122 | # @router.post("/disable") 123 | # async def disable_plugin(plugin_name: str): 124 | # try: 125 | # plugin_manager.disable_plugin(plugin_name) 126 | # return {"status": "success", "message": f"插件 {plugin_name} 已禁用"} 127 | # except Exception as e: 128 | # logger.error(f"禁用插件时出错: {e}") 129 | # raise HTTPException(status_code=500, detail=str(e)) 130 | 131 | 132 | @router.get("/list") 133 | async def get_plugin_list(): 134 | try: 135 | plugin_list = plugin_manager.get_plugin_list() 136 | return {"status": "success", "plugins": plugin_list} 137 | except Exception as e: 138 | logger.error(f"获取插件列表时出错: {e}") 139 | raise HTTPException(status_code=500, detail=str(e)) 140 | 141 | 142 | @router.post("/reload") 143 | async def reload_plugins(): 144 | """ 145 | 热重载所有插件 146 | """ 147 | try: 148 | plugin_manager.reload_plugins() 149 | return {"status": "success", "message": "插件已成功热重载"} 150 | except Exception as e: 151 | logger.error(f"热重载插件时出错: {e}") 152 | raise HTTPException(status_code=500, detail=str(e)) 153 | 154 | @router.post("/add_plugin") 155 | async def add_plugin( 156 | system_prompt: str = Body(..., embed=True), 157 | user_input: str = Body(..., embed=True) 158 | ): 159 | """ 160 | 使用LLM帮助用户创建插件的API接口 161 | 162 | 参数: 163 | system_prompt (str): 逗号分隔的系统提示词,将在服务端转换为列表 164 | user_input (str): 用户输入的插件需求 165 | 166 | 返回: 167 | dict: LLM生成的插件代码或相关信息 168 | """ 169 | try: 170 | # 调用 create_plugin 函数来生成插件代码 171 | result = await create_plugin(system_prompt, user_input) 172 | return {"status": "success", "plugin_code": result} 173 | except Exception as e: 174 | raise HTTPException(status_code=500, detail=f"创建插件失败: {e}") -------------------------------------------------------------------------------- /core/llm/plugins/openai_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import httpx 4 | from core.utils.logger import get_logger 5 | from core.llm.llm_client import LLMClient 6 | from config import REQUEST_TIMEOUT 7 | 8 | _log = get_logger() 9 | 10 | 11 | class OpenAIClient(LLMClient): 12 | """ 13 | OpenAI API 客户端,实现了 LLMClient 接口。 14 | """ 15 | 16 | async def on_message(self, message, reply_message): 17 | pass 18 | 19 | def __init__(self, openai_secret, openai_model, openai_api_url): 20 | self.openai_secret = openai_secret 21 | self.openai_model = openai_model 22 | self.openai_api_url = openai_api_url 23 | 24 | # 初始化 last_request_time 和 last_request_content 25 | self.last_request_time = 0 26 | self.last_request_content = None 27 | 28 | # 从配置文件中读取超时设置,默认为7秒 29 | self.timeout = REQUEST_TIMEOUT or 7 30 | 31 | async def get_response(self, context, user_input, system_prompt, retries=2): 32 | """ 33 | 根据给定的上下文和用户输入,从 OpenAI 模型获取回复 34 | 35 | 参数: 36 | context (list): 对话上下文,包含之前的对话内容 37 | user_input (str): 用户的输入内容 38 | system_prompt (str): 系统提示 39 | retries (int): 出现错误时的最大重试次数,默认值为2次 40 | 41 | 返回: 42 | str: OpenAI 模型生成的回复内容 43 | 44 | 异常: 45 | httpx.HTTPStatusError: 当请求 OpenAI API 出现问题时引发 46 | """ 47 | # 检查是否为重复请求 48 | if time.time() - self.last_request_time < 0.6 and user_input == self.last_request_content and "" not in user_input: 49 | _log.warning(" 检测到重复请求,已忽略:") 50 | _log.warning(f" ↳ 用户输入: {user_input}") 51 | return None 52 | 53 | payload = { 54 | "model": self.openai_model, 55 | "temperature": 0.85, 56 | "top_p": 1, 57 | "presence_penalty": 1, 58 | "max_tokens": 3450, 59 | "messages": [ 60 | {"role": "system", "content": system_prompt} 61 | ] + context + [ 62 | {"role": "user", "content": user_input} 63 | ] 64 | } 65 | 66 | headers = { 67 | "Content-Type": "application/json", 68 | "Authorization": f"Bearer {self.openai_secret}" 69 | } 70 | 71 | # 记录请求的 payload 和 headers 72 | _log.debug(" 请求参数:") 73 | _log.debug(f" ↳ Payload: {payload}") 74 | _log.debug(f" ↳ Headers: {headers}") 75 | 76 | for attempt in range(retries + 1): 77 | try: 78 | async with httpx.AsyncClient(timeout=self.timeout) as client: 79 | response = await client.post(self.openai_api_url, headers=headers, json=payload) 80 | response.raise_for_status() 81 | response_data = response.json() 82 | 83 | # 记录完整的响应数据 84 | _log.debug(" 完整响应数据:") 85 | _log.debug(f" ↳ {response_data}") 86 | 87 | reply = response_data['choices'][0]['message']['content'] if 'choices' in response_data and \ 88 | response_data['choices'][0]['message'][ 89 | 'content'] else None 90 | 91 | # 更新 last_request_time 和 last_request_content 92 | self.last_request_time = time.time() 93 | self.last_request_content = user_input 94 | 95 | if reply is None: 96 | _log.warning(" OpenAI 回复为空:") 97 | _log.warning(f" ↳ 用户输入: {user_input}") 98 | else: 99 | # 记录 OpenAI 的回复内容 100 | _log.info(" OpenAI 回复:") 101 | _log.info(f" ↳ 内容: {reply}") 102 | 103 | return reply 104 | 105 | except httpx.HTTPStatusError as e: 106 | _log.error(" 🚨请求错误:") 107 | _log.error(f" ↳ 状态码: {e.response.status_code}") 108 | _log.error(f" ↳ 错误详情: {e}") 109 | _log.error(f" ↳ 返回内容: {e.response.text}") 110 | if e.response.status_code in {503, 504, 500}: # 处理常见错误状态码 111 | _log.info(f"请求失败,状态码:{e.response.status_code}。正在尝试重试...({attempt + 1}/{retries})") 112 | if attempt < retries: 113 | await asyncio.sleep(2) # 等待2秒后重试 114 | continue 115 | return f"请求失败,状态码:{e.response.status_code}。请稍后再试。" 116 | 117 | 118 | except httpx.RequestError as e: 119 | _log.error(" 请求异常:") 120 | _log.error(f" ↳ 错误详情: {e}") 121 | _log.error(f" ↳ 错误类型: {type(e)}") 122 | if attempt < retries: 123 | _log.info(f"请求异常,正在尝试重试...({attempt + 1}/{retries})") 124 | await asyncio.sleep(2) # 等待2秒后重试 125 | continue 126 | return "请求超时或网络错误,请稍后再试。" 127 | 128 | 129 | except Exception as e: 130 | 131 | _log.error(" 未知错误:") 132 | _log.error(f" ↳ 错误详情: {e}") 133 | _log.error(f" ↳ 错误类型: {type(e)}") 134 | if attempt < retries: 135 | _log.info(f"发生未知错误,正在尝试重试...({attempt + 1}/{retries})") 136 | await asyncio.sleep(2) # 等待2秒后重试 137 | continue 138 | 139 | return "发生未知错误,请联系管理员。" 140 | 141 | return "请求失败,请稍后再试。" 142 | 143 | async def test(self): 144 | """ 145 | 测试 OpenAIClient 类的方法 146 | """ 147 | context = [ 148 | {"role": "user", "content": "你好!"} 149 | ] 150 | user_input = "你还记得我之前说了多少个“你好”吗" 151 | system_prompt = "你是一个友好的助手。" 152 | 153 | response = await self.get_response(context, user_input, system_prompt) 154 | print("API Response:", response) 155 | 156 | 157 | # 使用方法 158 | if __name__ == "__main__": 159 | # 配置 OpenAIClient 160 | openai_secret = "sk-s2lDjPP1AdigpPBO53845f5d134a406d96CbE24aEeBe2d36" 161 | openai_model = "Meta-Llama-3.1-8B-Instruct" 162 | openai_api_url = "https://ngedlktfticp.cloud.sealos.io/v1/chat/completions" 163 | # 我可以把我的秘钥公开,因为额度很小,而且是用来测试的。但你一定不要像我一样把秘钥明文写在代码中。 164 | 165 | # 创建 OpenAIClient 实例 166 | client = OpenAIClient(openai_secret=openai_secret, openai_model=openai_model, openai_api_url=openai_api_url) 167 | 168 | # 运行测试 169 | asyncio.run(client.test()) 170 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 13 | 14 | 16 | { 17 | "lastFilter": { 18 | "state": "OPEN", 19 | "assignee": "shuakami" 20 | } 21 | } 22 | { 23 | "selectedUrlAndAccountId": { 24 | "url": "https://github.com/shuakami/amyalmond_bot.git", 25 | "accountId": "d941c301-caf6-4448-9c7f-a86b9568da98" 26 | } 27 | } 28 | { 29 | "associatedIndex": 5 30 | } 31 | 32 | 33 | 34 | 35 | 36 | 39 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 1727063987284 70 | 76 | 77 | 84 | 85 | 92 | 93 | 100 | 103 | 104 | 106 | 107 | 116 | 117 | 118 | 119 | 120 | 122 | -------------------------------------------------------------------------------- /tools/upgrade/db_upgrade.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import datetime 5 | import sys 6 | import time 7 | from pymongo import MongoClient 8 | from elasticsearch import Elasticsearch, helpers 9 | 10 | # 手动指定项目根目录 11 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) 12 | # 将项目根目录添加到 Python 的搜索路径中 13 | sys.path.append(project_root) 14 | from config import MONGODB_URI, MONGODB_USERNAME, MONGODB_PASSWORD, ELASTICSEARCH_URL, ELASTICSEARCH_USERNAME, \ 15 | ELASTICSEARCH_PASSWORD 16 | 17 | # 路径配置 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | PROJECT_ROOT = os.path.abspath(os.path.join(BASE_DIR, "../../")) 20 | DATA_DIR = os.path.join(PROJECT_ROOT, "data") 21 | BACKUP_DIR = os.path.join(BASE_DIR, "backup", "data") 22 | 23 | # 连接数据库 24 | mongo_client = MongoClient(MONGODB_URI, username=MONGODB_USERNAME, password=MONGODB_PASSWORD) 25 | mongo_db = mongo_client["amyalmond"] 26 | mongo_collection = mongo_db["conversations"] 27 | 28 | es_client = Elasticsearch( 29 | [ELASTICSEARCH_URL], 30 | basic_auth=(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD) 31 | ) 32 | 33 | 34 | def backup_mongodb(): 35 | """备份 MongoDB 数据""" 36 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 37 | backup_dir = os.path.join(BACKUP_DIR, "mongodb", timestamp) 38 | 39 | if not os.path.exists(backup_dir): 40 | os.makedirs(backup_dir) 41 | 42 | for document in mongo_collection.find(): 43 | with open(os.path.join(backup_dir, f"{document['_id']}.json"), "w", encoding="utf-8") as file: 44 | json.dump(document, file, default=str) 45 | 46 | print(f"> MongoDB 数据备份完成,备份目录: {backup_dir}") 47 | 48 | 49 | def backup_elasticsearch(): 50 | """备份 Elasticsearch 数据""" 51 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 52 | backup_dir = os.path.join(BACKUP_DIR, "elasticsearch", timestamp) 53 | 54 | if not os.path.exists(backup_dir): 55 | os.makedirs(backup_dir) 56 | 57 | query = {"query": {"match_all": {}}} 58 | results = helpers.scan(es_client, index="messages", query=query) 59 | 60 | for i, result in enumerate(results): 61 | with open(os.path.join(backup_dir, f"{i}.json"), "w", encoding="utf-8") as file: 62 | json.dump(result, file, default=str) 63 | 64 | print(f"> Elasticsearch 数据备份完成,备份目录: {backup_dir}") 65 | 66 | 67 | def backup_data(): 68 | """备份旧数据""" 69 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 70 | backup_dir = os.path.join(BACKUP_DIR, timestamp) 71 | 72 | if not os.path.exists(backup_dir): 73 | os.makedirs(backup_dir) 74 | 75 | for filename in os.listdir(DATA_DIR): 76 | file_path = os.path.join(DATA_DIR, filename) 77 | if os.path.isfile(file_path): 78 | shutil.copy(file_path, backup_dir) 79 | print(f"> 备份文件: {filename} 到 {backup_dir}") 80 | 81 | print("> 文件数据备份完成.") 82 | 83 | 84 | def prompt_clear_databases(): 85 | """提示用户是否清空数据库""" 86 | print("> 即将清空 MongoDB 和 Elasticsearch 数据库。") 87 | print("> 请确认是否要继续,10 秒后将自动清空数据库...") 88 | time.sleep(10) 89 | 90 | print("> 清空 MongoDB 'conversations' 集合...") 91 | mongo_collection.delete_many({}) 92 | print("> MongoDB 清空完成.") 93 | 94 | print("> 清空 Elasticsearch 'messages' 索引...") 95 | if es_client.indices.exists(index="messages"): 96 | es_client.options(ignore_status=[400, 404]).indices.delete(index="messages") 97 | es_client.indices.create(index="messages") 98 | print("> Elasticsearch 清空完成.") 99 | 100 | 101 | def migrate_memory_json(): 102 | """迁移 memory.json 数据到 MongoDB""" 103 | memory_json_path = os.path.join(DATA_DIR, "memory.json") 104 | 105 | if not os.path.exists(memory_json_path): 106 | print(f"! 未找到 memory.json 文件: {memory_json_path}") 107 | return 108 | 109 | with open(memory_json_path, "r", encoding="utf-8") as file: 110 | try: 111 | memory_data = json.load(file) 112 | except json.JSONDecodeError as e: 113 | print(f"! 解析 memory.json 文件时发生错误: {e}") 114 | return 115 | 116 | for group_id, conversations in memory_data.items(): 117 | for conversation in conversations: 118 | try: 119 | document = { 120 | "group_id": group_id, 121 | "role": conversation.get("role"), 122 | "content": conversation.get("content"), 123 | "timestamp": datetime.datetime.now(datetime.timezone.utc) 124 | } 125 | if not document["role"] or not document["content"]: 126 | raise ValueError("无效数据,跳过") 127 | 128 | mongo_collection.insert_one(document) 129 | print(f"> 成功迁移对话记录到 MongoDB: group_id={group_id}, role={document['role']}") 130 | 131 | except Exception as e: 132 | print(f"! 迁移数据时发生错误,跳过: {e}, 数据: {conversation}") 133 | 134 | 135 | def migrate_long_term_memory(): 136 | """迁移 long_term_memory_*.txt 数据到 Elasticsearch""" 137 | for filename in os.listdir(DATA_DIR): 138 | if filename.startswith("long_term_memory_") and filename.endswith(".txt"): 139 | group_id = filename.split("long_term_memory_")[-1].replace(".txt", "") 140 | long_term_memory_path = os.path.join(DATA_DIR, filename) 141 | 142 | with open(long_term_memory_path, "r", encoding="utf-8") as file: 143 | lines = file.readlines() 144 | 145 | actions = [] 146 | for line in lines: 147 | content = line.strip() 148 | if not content: 149 | continue 150 | 151 | action = { 152 | "_index": "messages", 153 | "_source": { 154 | "group_id": group_id, 155 | "role": "system", 156 | "content": content 157 | } 158 | } 159 | actions.append(action) 160 | 161 | try: 162 | helpers.bulk(es_client, actions) 163 | print(f"> 成功迁移长时间记忆数据到 Elasticsearch: group_id={group_id}, 文件={filename}") 164 | 165 | except Exception as e: 166 | print(f"! 迁移 Elasticsearch 数据时发生错误: {e}, 文件: {filename}") 167 | 168 | 169 | def main(): 170 | # 打印项目根目录 171 | print("+------------------------------+") 172 | print("| 开始数据库迁移... |") 173 | print("+------------------------------+") 174 | 175 | # 备份数据 176 | backup_data() 177 | backup_mongodb() 178 | backup_elasticsearch() 179 | 180 | # 提示用户是否清空数据库 181 | prompt_clear_databases() 182 | 183 | # 迁移 memory.json 数据到 MongoDB 184 | migrate_memory_json() 185 | 186 | # 迁移 long_term_memory_*.txt 数据到 Elasticsearch 187 | migrate_long_term_memory() 188 | 189 | print("+------------------------------+") 190 | print("| 数据库迁移完成 |") 191 | print("+------------------------------+") 192 | 193 | 194 | if __name__ == "__main__": 195 | main() 196 | -------------------------------------------------------------------------------- /core/bot/bot_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - core/bot/bot_client.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | 7 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 8 | Version: 1.3.0 (Stable_923001) 9 | 10 | bot_client.py 包含 AmyAlmond 机器人的主要客户端类,链接其他模块进行处理。 11 | """ 12 | import asyncio 13 | import random 14 | import subprocess 15 | import sys 16 | import watchdog.observers 17 | import botpy 18 | from botpy.message import GroupMessage 19 | 20 | from core.plugins.plugin_manager import PluginManager 21 | # user_management.py模块 - <用户管理模块化文件> 22 | from core.utils.user_management import load_user_names 23 | # utils.py模块 - <工具模块化文件> 24 | from core.utils.utils import load_system_prompt 25 | # config.py模块 - <配置管理模块化文件> 26 | from config import SYSTEM_PROMPT_FILE, test_config 27 | # file_handler.py模块 - <文件处理模块化文件> 28 | from core.utils.file_handler import ConfigFileHandler 29 | # logger.py模块 - <日志记录模块> 30 | from core.utils.logger import get_logger 31 | # message_handler.py模块 - <消息处理模块化文件> 32 | from core.bot.message_handler import MessageHandler 33 | # memory_manager.py模块 - <内存管理模块化文件> 34 | from core.memory.memory_manager import MemoryManager 35 | # keep_alive.py模块 - 36 | from core.keep_alive import keep_alive 37 | # llm_client.py模块 - 38 | from core.llm.llm_factory import LLMFactory 39 | 40 | _log = get_logger() 41 | 42 | 43 | class MyClient(botpy.Client): 44 | """ 45 | AmyAlmond 项目的主要客户端类,继承自 botpy.Client 46 | 处理机器人的各种事件和请求 47 | """ 48 | 49 | def __init__(self, *args, **kwargs): 50 | """ 51 | 初始化客户端 52 | 53 | 初始化待处理用户列表、加载系统提示、设置内存管理器和消息处理器 54 | 读取配置并验证必要的配置项是否设置 55 | 初始化文件系统观察器以监听配置文件变化 56 | """ 57 | super().__init__(*args, **kwargs) 58 | # 初始化插件管理器 59 | self.plugin_manager = PluginManager(self) 60 | 61 | # 初始化 LLM 客户端 62 | llm_factory = LLMFactory() 63 | self.llm_client = llm_factory.create_llm_client() 64 | 65 | # 加载插件 66 | self.plugin_manager.register_plugins() 67 | 68 | self.pending_users = {} 69 | self.system_prompt = load_system_prompt(SYSTEM_PROMPT_FILE) 70 | self.memory_manager = MemoryManager() 71 | self.message_handler = MessageHandler(self, self.memory_manager) 72 | 73 | # 读取配置 74 | self.openai_secret = test_config.get("openai_secret", "") 75 | self.openai_model = test_config.get("openai_model", "gpt-4o-mini") 76 | self.openai_api_url = test_config.get("openai_api_url", "https://api.openai-hk.com/v1/chat/completions") 77 | self.ADMIN_ID = test_config.get("admin_id", "") 78 | 79 | if not self.openai_secret: 80 | _log.critical(" OpenAI API 密钥缺失,请检查 config.yaml 文件") 81 | raise ValueError("OpenAI API key is missing in config.yaml") 82 | if not self.openai_model: 83 | _log.critical(" OpenAI 模型缺失,请检查 config.yaml 文件") 84 | raise ValueError("OpenAI model is missing in config.yaml") 85 | if not self.openai_api_url: 86 | _log.critical(" OpenAI API URL 缺失,请检查 config.yaml 文件") 87 | raise ValueError("OpenAI API URL is missing in config.yaml") 88 | if not self.ADMIN_ID: 89 | _log.critical(" 管理员 ID 缺失,请检查 config.yaml 文件") 90 | raise ValueError("Admin ID is missing in config.yaml") 91 | 92 | 93 | # 初始化 last_request_time 和 last_request_content 94 | self.last_request_time = 0 95 | self.last_request_content = None 96 | 97 | # 设置文件监视器 98 | self.observer = watchdog.observers.Observer() 99 | event_handler = ConfigFileHandler(self) 100 | self.observer.schedule(event_handler, path='.', recursive=False) 101 | self.observer.start() 102 | 103 | async def on_message(self, message: botpy.message): 104 | """ 105 | 当收到消息时调用 106 | 107 | Args: 108 | message (botpy.Message): 收到的消息对象 109 | """ 110 | # 通过事件总线发布 on_message 事件,让所有订阅的插件处理该消息 111 | await self.plugin_manager.event_bus.publish("on_message", message) 112 | 113 | def load_system_prompt(self): 114 | """ 115 | 加载机器人SystemPrompt 116 | """ 117 | self.system_prompt = load_system_prompt(SYSTEM_PROMPT_FILE) 118 | _log.info(f">>> SYSTEM PROMPT LOADED") 119 | _log.info(f" ↳ Prompt count: {len(self.system_prompt)}") 120 | 121 | def reload_system_prompt(self): 122 | """ 123 | 重新加载机器人SystemPrompt 124 | """ 125 | self.system_prompt = load_system_prompt(SYSTEM_PROMPT_FILE) 126 | _log.info(">>> SYSTEM PROMPT RELOADED") 127 | 128 | async def on_ready(self): 129 | """ 130 | 当机器人准备好时调用 131 | """ 132 | _log.info(f">>> ROBOT 「{self.robot.name}」 IS READY!") 133 | load_user_names() 134 | 135 | # 加载记忆 136 | _log.info(">>> MEMORY LOADING...") 137 | await self.memory_manager.load_memory() 138 | _log.info(" ↳ 记忆加载完成") 139 | 140 | # 启动 Keep-Alive 任务 141 | await asyncio.create_task(keep_alive(self.openai_api_url, self.openai_secret)) 142 | 143 | # 通知插件准备就绪 144 | await self.plugin_manager.on_ready() 145 | 146 | async def on_group_at_message_create(self, message: GroupMessage): 147 | """ 148 | 当接收到群组中提及机器人的消息时调用 149 | 150 | 参数: 151 | message (GroupMessage): 接收到的消息对象 152 | """ 153 | await self.message_handler.handle_group_message(message) 154 | 155 | async def get_gpt_response(self, context, user_input): 156 | """ 157 | 根据给定的上下文和用户输入,从 LLM 模型获取回复 158 | """ 159 | return await self.llm_client.get_response(context, user_input, self.system_prompt) 160 | 161 | async def restart_bot(self, group_id, msg_id): 162 | """ 163 | 重启机器人 164 | 165 | 参数: 166 | group_id (str): 羡组ID 167 | msg_id (str): 消息ID 168 | """ 169 | await self.api.post_group_message( 170 | group_openid=group_id, 171 | content=f"子网重启,请稍后... ({random.randint(1000, 9999)})", 172 | msg_id=msg_id 173 | ) 174 | 175 | _log.info(">>> RESTARTING BOT...") 176 | 177 | self.observer.stop() 178 | self.observer.join() 179 | 180 | _log.info(">>> BOT RESTART COMMAND RECEIVED, SHUTTING DOWN...") 181 | 182 | python = sys.executable 183 | subprocess.Popen([python] + sys.argv) 184 | 185 | sys.exit() 186 | 187 | async def hot_reload(self, group_id, msg_id): 188 | """ 189 | 热重载系统 190 | 191 | 参数: 192 | group_id (str): 群组ID 193 | msg_id (str): 消息ID 194 | """ 195 | _log.info(">>> HOT RELOAD INITIATED...") 196 | self.system_prompt = load_system_prompt(SYSTEM_PROMPT_FILE) 197 | load_user_names() 198 | _log.info(" ↳ 热重载完成,系统已更新") 199 | await self.api.post_group_message( 200 | group_openid=group_id, 201 | content="热重载完成,系统已更新。", 202 | msg_id=msg_id 203 | ) 204 | 205 | -------------------------------------------------------------------------------- /core/db/auto_tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import psutil 4 | import time 5 | import yaml 6 | from pathlib import Path 7 | from concurrent.futures import ThreadPoolExecutor 8 | from tqdm import tqdm # 进度条库 9 | 10 | PROJECT_CONFIG_PATH = Path(__file__).resolve().parent.parent.parent / 'configs' / 'config.yaml' 11 | 12 | class AutoTuner: 13 | def __init__(self): 14 | self.cpu_count = psutil.cpu_count(logical=False) # 物理核心数 15 | self.total_memory = psutil.virtual_memory().total 16 | self.system = platform.system() 17 | self.config = {} 18 | 19 | def run_load_test(self): 20 | """ 21 | 运行负载测试,测量系统在高负载下的性能 22 | """ 23 | load_test_results = { 24 | 'cpu_usage': [], 25 | 'memory_usage': [], 26 | 'response_times': [], 27 | 'io_usage': [], 28 | 'network_usage': [] 29 | } 30 | 31 | def simulate_load(): 32 | start_time = time.time() 33 | cpu_usage = psutil.cpu_percent(interval=1) 34 | memory_usage = psutil.virtual_memory().percent 35 | io_counters = psutil.disk_io_counters() 36 | network_counters = psutil.net_io_counters() 37 | duration = time.time() - start_time 38 | return cpu_usage, memory_usage, io_counters, network_counters, duration 39 | 40 | with ThreadPoolExecutor(max_workers=self.cpu_count) as executor: 41 | futures = [executor.submit(simulate_load) for _ in range(min(50, self.cpu_count * 10))] # 动态调整并发任务数量 42 | for future in tqdm(futures, desc="Running load test", ncols=100): # 添加进度条显示 43 | cpu_usage, memory_usage, io_counters, network_counters, duration = future.result() 44 | load_test_results['cpu_usage'].append(cpu_usage) 45 | load_test_results['memory_usage'].append(memory_usage) 46 | load_test_results['io_usage'].append(io_counters.read_bytes + io_counters.write_bytes) 47 | load_test_results['network_usage'].append(network_counters.bytes_sent + network_counters.bytes_recv) 48 | load_test_results['response_times'].append(duration) 49 | 50 | # 计算平均值 51 | load_test_results['avg_cpu_usage'] = sum(load_test_results['cpu_usage']) / len(load_test_results['cpu_usage']) 52 | load_test_results['avg_memory_usage'] = sum(load_test_results['memory_usage']) / len(load_test_results['memory_usage']) 53 | load_test_results['avg_io_usage'] = sum(load_test_results['io_usage']) / len(load_test_results['io_usage']) 54 | load_test_results['avg_network_usage'] = sum(load_test_results['network_usage']) / len(load_test_results['network_usage']) 55 | load_test_results['avg_response_time'] = sum(load_test_results['response_times']) / len(load_test_results['response_times']) 56 | return load_test_results 57 | 58 | def determine_optimal_parameters(self, load_test_results): 59 | """ 60 | 根据负载测试结果和系统资源,自动调整参数,兼顾高配和低配系统 61 | """ 62 | # 动态调整基准值,分别针对高配和低配系统 63 | if self.total_memory <= 8 * 1024 ** 3: # 小于等于8GB内存 64 | memory_ratio = self.total_memory / (8 * 1024 ** 3) # 基准为8GB 65 | cpu_ratio = self.cpu_count / 4 # 基准为4核 66 | base_context_tokens = 1024 67 | base_query_terms = 10 68 | else: 69 | memory_ratio = self.total_memory / (16 * 1024 ** 3) # 基准为16GB 70 | cpu_ratio = self.cpu_count / 8 # 基准为8核 71 | base_context_tokens = 2048 72 | base_query_terms = 18 73 | 74 | # 采用非线性映射调整参数 75 | def adjust_based_on_ratio(base_value, ratio, load_threshold, response_time_threshold): 76 | if load_test_results['avg_memory_usage'] < load_threshold and load_test_results['avg_response_time'] < response_time_threshold: 77 | return int(base_value * ratio * 1.5) 78 | elif load_test_results['avg_memory_usage'] < 80 and load_test_results['avg_response_time'] < 2: 79 | return int(base_value * ratio) 80 | else: 81 | return int(base_value * ratio * 0.75) 82 | 83 | # 动态调整 max_context_tokens 基于内存、响应时间和 I/O 性能 84 | self.config['max_context_tokens'] = adjust_based_on_ratio(base_context_tokens, memory_ratio, 60, 1) 85 | 86 | # 动态调整 Elasticsearch 查询参数基于 CPU 使用率和响应时间 87 | self.config['elasticsearch_query_terms'] = adjust_based_on_ratio(base_query_terms, cpu_ratio, 50, 1) 88 | 89 | # 为低配系统设置最低值限制 90 | if self.total_memory <= 4 * 1024 ** 3: # 4GB以下内存 91 | self.config['max_context_tokens'] = max(self.config['max_context_tokens'], 512) 92 | self.config['elasticsearch_query_terms'] = max(self.config['elasticsearch_query_terms'], 4) 93 | 94 | def update_config_file(self): 95 | """ 96 | 将调整后的参数保存到 config.yaml 文件 97 | """ 98 | try: 99 | if PROJECT_CONFIG_PATH.exists(): 100 | with open(PROJECT_CONFIG_PATH, 'r', encoding='utf-8') as f: 101 | project_config_lines = f.readlines() 102 | else: 103 | project_config_lines = [] 104 | 105 | # 移除旧的配置 106 | new_config_lines = [] 107 | inside_custom_block = False 108 | for line in project_config_lines: 109 | if line.strip() == '# ---------- Auto-tuned Configuration ----------': 110 | inside_custom_block = True 111 | continue 112 | if line.strip() == '# ---------- End Auto-tuned Configuration ------': 113 | inside_custom_block = False 114 | continue 115 | if not inside_custom_block: 116 | new_config_lines.append(line) 117 | 118 | # 将新的配置添加到文件末尾 119 | new_config_lines.append('\n') 120 | new_config_lines.append('# ---------- Auto-tuned Configuration ----------\n') 121 | for key, value in self.config.items(): 122 | new_config_lines.append(f'{key}: {value}\n') 123 | new_config_lines.append('# ---------- End Auto-tuned Configuration ------\n') 124 | 125 | # 保存更新后的内容到 config.yaml 文件 126 | with open(PROJECT_CONFIG_PATH, 'w', encoding='utf-8') as f: 127 | f.writelines(new_config_lines) 128 | 129 | print(f"> 自动调整的配置已保存至:{PROJECT_CONFIG_PATH}") 130 | print(f"> -------------------------------------------------") 131 | print(f"> 请不要擅自修改已添加的配置内容及注释,否则可能导致配置系统无法正常工作。") 132 | print(f"> PLEASE DO NOT MODIFY THE ADDED CONFIGURATION CONTENTS AND COMMENTS,") 133 | print(f"> OR ELSE THE CONFIGURATION SYSTEM MAY NOT WORK PROPERLY.") 134 | print(f"> -------------------------------------------------") 135 | 136 | except Exception as e: 137 | print(f"! 保存自动调整的配置时出错:{e}") 138 | raise 139 | 140 | def tune(self): 141 | """ 142 | 执行自动调优过程 143 | """ 144 | load_test_results = self.run_load_test() 145 | self.determine_optimal_parameters(load_test_results) 146 | self.update_config_file() 147 | 148 | 149 | if __name__ == "__main__": 150 | tuner = AutoTuner() 151 | tuner.tune() 152 | -------------------------------------------------------------------------------- /core/update_manager.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import json 3 | import os 4 | import asyncio 5 | from datetime import datetime, timedelta 6 | from core.utils.logger import get_logger 7 | from core.utils.version_utils import is_newer_version 8 | import subprocess 9 | import sys 10 | import urllib.parse 11 | import platform 12 | 13 | _log = get_logger() 14 | 15 | FETCH_RELEASE_URL = "https://bot.luoxiaohei.cn/api/fetchLatestRelease" 16 | AUTO_UPDATE_SCRIPT_URL = "https://bot.luoxiaohei.cn/auto_update.py" 17 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 18 | CURRENT_VERSION = "1.1.5 (Alpha_829001)" 19 | CONFIG_PATH = os.path.join(ROOT_DIR, "configs", "update_config.json") 20 | 21 | async def fetch_latest_release(): 22 | """ 23 | 获取最新版本的发布信息。 24 | """ 25 | async with aiohttp.ClientSession() as session: 26 | async with session.get(FETCH_RELEASE_URL) as response: 27 | if response.status == 200: 28 | return await response.json() 29 | else: 30 | _log.warning(f"获取最新版本信息失败,状态码: {response.status}") 31 | return None 32 | 33 | async def prompt_user_for_update(stable_version_info, dev_version_info): 34 | """ 35 | 询问用户是否要更新到最新版本以及选择更新的版本类型。 36 | """ 37 | stable_version = stable_version_info.get("latestVersion") 38 | dev_version = dev_version_info.get("latestVersion") 39 | print(f"检测到新版本: 稳定版: {stable_version}, 开发版: {dev_version}。您当前版本是: {CURRENT_VERSION}。") 40 | print("请选择更新选项:") 41 | print("1. 更新到最新稳定版") 42 | print("2. 更新到最新开发版") 43 | print("3. 不更新") 44 | print("4. 不更新,且7天内不再提示") 45 | 46 | user_choice = input("请输入您的选择 (1/2/3/4): ") 47 | if user_choice == "1": 48 | await handle_user_choice("stable", stable_version_info) 49 | elif user_choice == "2": 50 | await handle_user_choice("development", dev_version_info) 51 | elif user_choice == "3": 52 | _log.info("已选择不更新。") 53 | elif user_choice == "4": 54 | _log.info("已选择不更新,并在7天内不再提示。") 55 | # 保存配置以便7天内不再提示 56 | update_config = { 57 | "snooze_until": (datetime.now() + timedelta(days=7)).isoformat() 58 | } 59 | # 如果没有自动创建 60 | with open(CONFIG_PATH, 'w') as f: 61 | json.dump(update_config, f) 62 | else: 63 | print("无效选择,请重新输入。") 64 | await prompt_user_for_update(stable_version_info, dev_version_info) 65 | 66 | 67 | async def handle_user_choice(choice, version_info): 68 | """ 69 | 处理用户的更新选择。 70 | """ 71 | download_url = version_info.get("downloadUrl") 72 | full_download_url = f"https://bot.luoxiaohei.cn{download_url}" # 拼接完整的下载URL 73 | 74 | if choice in ["stable", "development"]: 75 | # 正确解析下载URL并获取文件名 76 | parsed_url = urllib.parse.urlparse(full_download_url) 77 | query_params = urllib.parse.parse_qs(parsed_url.query) 78 | actual_download_url = query_params.get('url', [None])[0] 79 | if actual_download_url: 80 | zip_file_name = os.path.basename(urllib.parse.unquote(actual_download_url)) 81 | else: 82 | _log.error("无法解析下载的URL") 83 | return 84 | 85 | zip_download_path = os.path.join(ROOT_DIR, zip_file_name) 86 | 87 | await download_file_with_progress(full_download_url, zip_download_path) 88 | await download_file_with_progress(AUTO_UPDATE_SCRIPT_URL, os.path.join(ROOT_DIR, "auto_update.py")) 89 | 90 | _log.info("更新文件已下载,准备退出程序并执行更新。") 91 | await shutdown_and_update(zip_download_path) 92 | 93 | async def download_file_with_progress(url, dest_path): 94 | """ 95 | 下载文件到指定路径,并显示下载进度条。 96 | """ 97 | async with aiohttp.ClientSession() as session: 98 | async with session.get(url) as response: 99 | if response.status == 200: 100 | total_size = int(response.headers.get('content-length', 0)) 101 | with open(dest_path, 'wb') as f: 102 | downloaded_size = 0 103 | async for data in response.content.iter_chunked(1024): 104 | f.write(data) 105 | downloaded_size += len(data) 106 | progress = (downloaded_size / total_size) * 100 107 | print(f'\r下载进度: [{progress:.2f}%]', end='') 108 | print() # 换行 109 | _log.info(f"文件已下载到 {dest_path}") 110 | else: 111 | _log.error(f"下载文件失败,状态码: {response.status}") 112 | 113 | async def handle_updates(): 114 | """ 115 | 检查是否有新版本可用,并根据版本类型(正式版或开发版)提醒用户更新。 116 | """ 117 | version_info_list = await fetch_latest_release() 118 | 119 | if version_info_list: 120 | stable_info = version_info_list.get('stable', None) 121 | dev_info = version_info_list.get('development', None) 122 | 123 | stable_version = stable_info.get("latestVersion") if stable_info else None 124 | dev_version = dev_info.get("latestVersion") if dev_info else None 125 | 126 | if stable_version or dev_version: 127 | # 检查更新配置文件 128 | if os.path.exists(CONFIG_PATH): 129 | with open(CONFIG_PATH, 'r') as f: 130 | config = json.load(f) 131 | if config.get('snooze_until'): 132 | snooze_until = datetime.fromisoformat(config.get('snooze_until')) 133 | if snooze_until > datetime.now(): 134 | _log.info("更新检查已被用户暂停,直到指定日期。") 135 | return 136 | 137 | # 如果有更高版本的更新,提示用户 138 | if (stable_info and is_newer_version(CURRENT_VERSION, stable_version)[0]) or (dev_info and is_newer_version(CURRENT_VERSION, dev_version)[0]): 139 | await prompt_user_for_update(stable_info, dev_info) 140 | else: 141 | _log.info("当前版本已是最新,无需更新。") 142 | 143 | else: 144 | _log.warning(" 无法检查更新") 145 | 146 | async def shutdown_and_update(zip_download_path): 147 | """ 148 | 关闭所有进程并执行更新。 149 | """ 150 | _log.info("正在退出所有进程以便进行更新...") 151 | await asyncio.sleep(1) # 等待其他任务完成 152 | 153 | # 获取当前操作系统类型 154 | current_os = platform.system() 155 | 156 | # 切换到项目根目录 157 | os.chdir(ROOT_DIR) 158 | 159 | # 启动更新脚本 160 | if current_os == "Windows": 161 | # Windows 下使用 PowerShell 启动新窗口运行 Python 脚本,并确保在根目录下 162 | subprocess.Popen( 163 | f'start powershell -Command "{sys.executable} {os.path.join(ROOT_DIR, "auto_update.py")} {zip_download_path}"', 164 | shell=True 165 | ) 166 | 167 | elif current_os == "Linux": 168 | # Linux 下使用终端模拟器(如 gnome-terminal)确保在根目录下 169 | subprocess.Popen( 170 | f'gnome-terminal -- bash -c "cd {ROOT_DIR} && {sys.executable} auto_update.py {zip_download_path}"', 171 | shell=True 172 | ) 173 | 174 | elif current_os == "Darwin": # macOS 的系统标识符 175 | # macOS 使用 open -a Terminal 启动脚本 176 | subprocess.Popen( 177 | f'open -a Terminal "{sys.executable} {os.path.join(ROOT_DIR, "auto_update.py")} {zip_download_path}"', 178 | shell=True 179 | ) 180 | 181 | else: 182 | _log.error(f"不支持的操作系统: {current_os}") 183 | return 184 | 185 | _log.info("更新脚本已启动,主程序即将退出。") 186 | 187 | # 使用 os._exit() 退出当前进程 188 | os._exit(0) 189 | 190 | # 主测试 191 | async def test(): 192 | await shutdown_and_update("v1.2.0-Stable_827001.zip") 193 | 194 | if __name__ == "__main__": 195 | asyncio.run(test()) 196 | -------------------------------------------------------------------------------- /tools/setup/elasticsearch/elasticsearch_setup.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import sys 3 | import subprocess 4 | import time 5 | import yaml 6 | import psutil 7 | from pathlib import Path 8 | 9 | # 配置文件路径 10 | DEFAULT_WINDOWS_PATH = Path(r"C:\Elasticsearch\8.15.0\elasticsearch-8.15.0\bin") 11 | DEFAULT_LINUX_PATH = Path("/usr/share/elasticsearch/bin") 12 | ELASTIC_CONFIG_PATH = Path(__file__).parent.parent.parent / "configs/elasticsearch.yaml" 13 | 14 | 15 | def detect_os_and_version(): 16 | if sys.platform.startswith('win'): 17 | return "Windows", sys.getwindowsversion().platform_version 18 | elif sys.platform.startswith('linux'): 19 | return "Linux", subprocess.getoutput('uname -r') 20 | else: 21 | return sys.platform, "Unknown" 22 | 23 | 24 | def is_port_open(host, port): 25 | """检查指定主机的端口是否开放""" 26 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 27 | try: 28 | s.connect((host, port)) 29 | s.shutdown(socket.SHUT_RDWR) 30 | return True 31 | except: 32 | return False 33 | finally: 34 | s.close() 35 | 36 | 37 | def check_elasticsearch_installed(): 38 | os_name, os_version = detect_os_and_version() 39 | 40 | if os_name == "Windows": 41 | return check_elasticsearch_installed_windows() 42 | elif os_name == "Linux": 43 | return check_elasticsearch_installed_linux() 44 | else: 45 | print(f"! 暂不支持的操作系统:{os_name}") 46 | sys.exit(1) 47 | 48 | 49 | def check_elasticsearch_installed_windows(): 50 | # 检查默认安装路径 51 | if DEFAULT_WINDOWS_PATH.exists(): 52 | print(f"> 检测到Elasticsearch安装在默认路径:{DEFAULT_WINDOWS_PATH}") 53 | save_elasticsearch_config(DEFAULT_WINDOWS_PATH) 54 | return True 55 | 56 | # 提示用户手动输入路径 57 | print("! 请注意,如果你现在在执行main.py,而且你没有安装的话 下面的安装路径可以直接回车或者编一个哦~") 58 | user_path = input("无法自动检测到Elasticsearch安装路径,请手动输入:") 59 | user_path = Path(user_path) 60 | if user_path.exists(): 61 | save_elasticsearch_config(user_path) 62 | return True 63 | else: 64 | print(f"! 输入的路径无效:{user_path}") 65 | sys.exit(1) 66 | 67 | 68 | def check_elasticsearch_installed_linux(): 69 | # 检查默认路径 70 | possible_paths = [DEFAULT_LINUX_PATH, Path("/usr/local/elasticsearch/bin")] 71 | for path in possible_paths: 72 | if path.exists(): 73 | print(f"> 检测到Elasticsearch安装在路径:{path}") 74 | save_elasticsearch_config(path) 75 | return True 76 | 77 | # 通过包管理器检测安装 78 | try: 79 | result = subprocess.run(["which", "elasticsearch"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 80 | if result.returncode == 0: 81 | install_path = Path(result.stdout.decode().strip()) 82 | print(f"> 通过包管理器检测到Elasticsearch安装路径:{install_path}") 83 | save_elasticsearch_config(install_path) 84 | return True 85 | except Exception as e: 86 | print(f"! 无法通过包管理器检测到Elasticsearch安装路径:{e}") 87 | 88 | print("! 无法检测到Elasticsearch安装。请检查您的安装状态。") 89 | sys.exit(1) 90 | 91 | 92 | def save_elasticsearch_config(install_path): 93 | # 确保配置目录存在 94 | if not ELASTIC_CONFIG_PATH.parent.exists(): 95 | ELASTIC_CONFIG_PATH.parent.mkdir(parents=True) 96 | 97 | config = {"elasticsearch": {"install_path": str(install_path)}} 98 | with open(ELASTIC_CONFIG_PATH, 'w', encoding='utf-8') as f: 99 | yaml.dump(config, f, allow_unicode=True) 100 | print(f"> Elasticsearch安装路径已保存到配置文件:{ELASTIC_CONFIG_PATH}") 101 | 102 | 103 | def is_elasticsearch_running(): 104 | for proc in psutil.process_iter(['pid', 'name']): 105 | if proc.info['name'] == "elasticsearch": 106 | print("> Elasticsearch正在运行") 107 | return True 108 | print("> Elasticsearch未运行") 109 | return False 110 | 111 | 112 | def start_elasticsearch(): 113 | os_name, os_version = detect_os_and_version() 114 | 115 | if os_name == "Windows": 116 | return start_elasticsearch_windows() 117 | elif os_name == "Linux": 118 | return start_elasticsearch_linux() 119 | else: 120 | print(f"! 暂不支持的操作系统:{os_name}") 121 | sys.exit(1) 122 | 123 | 124 | def start_elasticsearch_windows(): 125 | try: 126 | with open(ELASTIC_CONFIG_PATH, 'r', encoding='utf-8') as f: 127 | config = yaml.safe_load(f) 128 | install_path = Path(config['elasticsearch']['install_path']) 129 | 130 | print("> 正在尝试启动Elasticsearch服务...") 131 | 132 | # 直接运行 elasticsearch.bat 并确保它在后台运行 133 | subprocess.Popen([str(install_path / "elasticsearch.bat")], creationflags=subprocess.CREATE_NEW_CONSOLE) 134 | 135 | # 等待Elasticsearch启动(检查端口是否开放) 136 | for _ in range(20): # 尝试20次 137 | if is_port_open("127.0.0.1", 9200): 138 | print("> Elasticsearch已成功启动并监听端口9200。") 139 | return True 140 | time.sleep(2) 141 | 142 | print("! Elasticsearch启动失败,未能在预期端口上监听。") 143 | return False 144 | 145 | except Exception as e: 146 | print(f"! 启动Elasticsearch服务失败:{e}") 147 | return False 148 | 149 | 150 | def start_elasticsearch_linux(): 151 | try: 152 | print("> 正在尝试启动Elasticsearch服务...") 153 | 154 | # 检查 systemctl 是否可用 155 | if subprocess.run(["which", "systemctl"], stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0: 156 | subprocess.run(["sudo", "systemctl", "start", "elasticsearch"], check=True) 157 | else: 158 | subprocess.run(["sudo", "service", "elasticsearch", "start"], check=True) 159 | 160 | print("> Elasticsearch已成功启动。") 161 | except Exception as e: 162 | print(f"! 启动Elasticsearch服务失败:{e}") 163 | sys.exit(1) 164 | 165 | 166 | def stop_all_elasticsearch_processes(): 167 | print("> 正在停止所有Elasticsearch相关进程...") 168 | 169 | # 定义要终止的进程名称列表 170 | elasticsearch_related_processes = ["elasticsearch", "controller.exe", "OpenJDK Platform binary", "java.exe"] 171 | 172 | for proc in psutil.process_iter(['pid', 'name']): 173 | try: 174 | if proc.info['name'] in elasticsearch_related_processes: 175 | print(f"> 正在终止进程:{proc.info['name']} (PID: {proc.info['pid']})") 176 | proc.terminate() 177 | except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): 178 | print(f"! 无法终止进程:{proc.info['name']} (PID: {proc.info['pid']}) - 权限不足或进程已结束。") 179 | continue 180 | 181 | print("> 所有Elasticsearch相关进程终止命令已发送。") 182 | 183 | 184 | def check_elasticsearch_connection(): 185 | try: 186 | print("> 正在测试Elasticsearch是否已启动...") 187 | # 通过ping本地端口检查服务是否启动 188 | if is_port_open("127.0.0.1", 9200): 189 | print("> Elasticsearch已经成功启动并在9200端口监听。") 190 | return True 191 | else: 192 | print("! Elasticsearch未能在9200端口启动。") 193 | return False 194 | except Exception as e: 195 | print(f"! 检查Elasticsearch连接时发生错误:{e}") 196 | return False 197 | 198 | 199 | if __name__ == "__main__": 200 | print("> 开始Elasticsearch启动检测...") 201 | 202 | # 添加连接测试 203 | if check_elasticsearch_connection(): 204 | print("> Elasticsearch已经在运行,连接正常。") 205 | sys.exit(0) 206 | 207 | if not check_elasticsearch_installed(): 208 | print("! Elasticsearch未安装或安装检测失败。") 209 | sys.exit(1) 210 | 211 | if not is_elasticsearch_running(): 212 | print("> Elasticsearch未运行,尝试启动...") 213 | stop_all_elasticsearch_processes() 214 | start_elasticsearch() 215 | 216 | if not check_elasticsearch_connection(): 217 | print("! Elasticsearch启动失败或无法连接,请检查安装和配置。") 218 | sys.exit(1) 219 | 220 | print("> Elasticsearch准备就绪。") 221 | sys.exit(0) 222 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | AmyAlmond Project - config.py 3 | 4 | Open Source Repository: https://github.com/shuakami/amyalmond_bot 5 | Developer: Shuakami <3 LuoXiaoHei 6 | Copyright (c) 2024 Amyalmond_bot. All rights reserved. 7 | Version: 1.3.0 (Stable_923001) 8 | 9 | config.py - 配置文件读取与验证 10 | """ 11 | import os 12 | from botpy.ext.cog_yaml import read 13 | from core.utils.logger import get_logger 14 | import subprocess 15 | import time 16 | from ruamel.yaml import YAML 17 | 18 | # 获取 logger 对象 19 | logger = get_logger() 20 | 21 | # 定义目录结构 22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | CONFIG_DIR = os.path.join(BASE_DIR, "configs") 24 | LOG_DIR = os.path.join(BASE_DIR, "logs") 25 | DATA_DIR = os.path.join(BASE_DIR, "data") 26 | 27 | # 确保目录存在 28 | os.makedirs(CONFIG_DIR, exist_ok=True) 29 | os.makedirs(LOG_DIR, exist_ok=True) 30 | os.makedirs(DATA_DIR, exist_ok=True) 31 | 32 | # 配置文件路径 33 | CONFIG_FILE = os.path.join(CONFIG_DIR, "config.yaml") 34 | SYSTEM_PROMPT_FILE = os.path.join(CONFIG_DIR, "system-prompt.txt") 35 | 36 | # 日志文件路径 37 | LOG_FILE = os.path.join(LOG_DIR, "bot.log") 38 | 39 | # 数据文件路径 40 | MEMORY_FILE = os.path.join(DATA_DIR, "memory.json") 41 | LONG_TERM_MEMORY_FILE = os.path.join(DATA_DIR, "long_term_memory_{}.txt") 42 | USER_NAMES_FILE = os.path.join(DATA_DIR, "user_names.json") 43 | FAISS_INDEX_PATH = "./data/faiss_index.bin" 44 | 45 | # 读取配置文件 46 | test_config = {} 47 | logger.info("") 48 | logger.info(">>> CONFIG LOADING...") 49 | if os.path.exists(CONFIG_FILE): 50 | loaded_config = read(CONFIG_FILE) 51 | if loaded_config: 52 | test_config.update(loaded_config) 53 | logger.info(" ↳ 配置文件加载成功") 54 | else: 55 | logger.critical(" 配置文件为空") 56 | logger.critical(f" ↳ 路径: {CONFIG_FILE}") 57 | logger.critical(" ↳ 请检查配置文件是否正确填写,并确保其格式为 YAML") 58 | exit(1) 59 | else: 60 | logger.critical(" 找不到配置文件") 61 | logger.critical(f" ↳ 路径: {CONFIG_FILE}") 62 | logger.critical(f" ↳ 请确保在 {CONFIG_DIR} 目录下存在 config.yaml 文件") 63 | exit(1) 64 | 65 | # 配置参数 66 | MAX_CONTEXT_TOKENS = test_config.get("max_context_tokens", None) 67 | ELASTICSEARCH_QUERY_TERMS = test_config.get("elasticsearch_query_terms", None) 68 | 69 | # 检查是否需要自动调优 70 | if MAX_CONTEXT_TOKENS is None or ELASTICSEARCH_QUERY_TERMS is None: 71 | logger.warning(" 未找到必要的配置参数,正在调用自动调优程序...") 72 | try: 73 | start_time = time.time() 74 | # 调用 auto_tune.py 自动调优 75 | result = subprocess.run(["python", "core/db/auto_tune.py"], timeout=60) 76 | elapsed_time = time.time() - start_time 77 | 78 | if result.returncode == 0: 79 | logger.info(" 自动调优完成") 80 | logger.info(f" ↳ 耗时: {elapsed_time:.2f} 秒") 81 | # 重新读取配置文件 82 | if os.path.exists(CONFIG_FILE): 83 | loaded_config = read(CONFIG_FILE) 84 | if loaded_config: 85 | test_config.update(loaded_config) 86 | MAX_CONTEXT_TOKENS = test_config.get("max_context_tokens", 2400) # 默认值 2400 87 | ELASTICSEARCH_QUERY_TERMS = test_config.get("elasticsearch_query_terms", 16) # 默认值 16 88 | else: 89 | logger.critical(" 配置文件读取失败,使用默认值") 90 | MAX_CONTEXT_TOKENS = 2400 91 | ELASTICSEARCH_QUERY_TERMS = 8 92 | else: 93 | logger.error(" 自动调优程序执行失败,使用默认值") 94 | MAX_CONTEXT_TOKENS = 2400 95 | ELASTICSEARCH_QUERY_TERMS = 16 96 | except subprocess.TimeoutExpired: 97 | logger.error(" 自动调优超时,使用默认值") 98 | MAX_CONTEXT_TOKENS = 2400 99 | ELASTICSEARCH_QUERY_TERMS = 16 100 | 101 | 102 | 103 | # 其他配置 104 | REQUEST_LIMIT_TIME_FRAME = test_config.get("request_limit_time_frame", 10) 105 | REQUEST_LIMIT_COUNT = test_config.get("request_limit_count", 7) 106 | GLOBAL_RATE_LIMIT = test_config.get("global_rate_limit", 75) 107 | 108 | MEMORY_THRESHOLD = 150 109 | FORGET_THRESHOLD = 5 110 | 111 | 112 | MEMORY_BATCH_SIZE = test_config.get("memory_batch_size", 1) 113 | REQUEST_TIMEOUT= test_config.get("request_timeout", 7) 114 | 115 | MONGODB_URI = test_config.get("mongodb_url", "") 116 | MONGODB_USERNAME = test_config.get("mongodb_username", "") 117 | MONGODB_PASSWORD = test_config.get("mongodb_password", "") 118 | 119 | ELASTICSEARCH_URL = test_config.get("elasticsearch_url", "") 120 | ELASTICSEARCH_USERNAME = test_config.get("elasticsearch_username", "") 121 | ELASTICSEARCH_PASSWORD = test_config.get("elasticsearch_password", "") 122 | 123 | OPENAI_SECRET = test_config.get("openai_secret", "") 124 | OPENAI_MODEL = test_config.get("openai_model", "gpt-4o-mini") 125 | OPENAI_API_URL = test_config.get("openai_api_url", "https://api.openai-hk.com/v1/chat/completions") 126 | 127 | ADMIN_ID = test_config.get("admin_id", "") 128 | 129 | # KEEP_ALIVE 配置 130 | OPENAI_KEEP_ALIVE = test_config.get("openai_keep_alive", True) 131 | UPDATE_KEEP_ALIVE = test_config.get("update_keep_alive", True) 132 | 133 | # 日志配置 134 | LOG_LEVEL = test_config.get("log_level", "INFO").upper() 135 | DEBUG_MODE = test_config.get("debug", False) 136 | 137 | # 验证关键配置 138 | if not MONGODB_USERNAME: 139 | logger.warning(" MongoDB 用户名缺失") 140 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}") 141 | if not MONGODB_PASSWORD: 142 | logger.warning(" MongoDB 密码缺失") 143 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}") 144 | if not MONGODB_URI: 145 | logger.warning(" MongoDB URI 缺失") 146 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}") 147 | if not OPENAI_SECRET: 148 | logger.warning(" OpenAI API 密钥缺失") 149 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}") 150 | if not OPENAI_MODEL: 151 | logger.warning(" OpenAI 模型缺失") 152 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}") 153 | if not OPENAI_API_URL: 154 | logger.warning(" OpenAI API URL 缺失") 155 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}") 156 | if not ADMIN_ID: 157 | logger.warning(" 管理员 ID 缺失") 158 | logger.warning(f" ↳ 请检查配置文件: {CONFIG_FILE}") 159 | 160 | def _write_config(): 161 | """将配置写入 config.yaml 文件,保留原始格式""" 162 | yaml = YAML() 163 | yaml.indent(mapping=2, sequence=4, offset=2) 164 | yaml.preserve_quotes = True 165 | 166 | with open(CONFIG_FILE, 'r', encoding='utf-8') as f: 167 | yaml_data = yaml.load(f) 168 | 169 | # 更新配置项的值 170 | for key, value in test_config.items(): 171 | if key in yaml_data: 172 | yaml_data[key] = value 173 | else: 174 | yaml_data[key] = value 175 | 176 | with open(CONFIG_FILE, 'w', encoding='utf-8') as f: 177 | yaml.dump(yaml_data, f) 178 | 179 | def get_all_config(): 180 | """获取所有配置""" 181 | return test_config 182 | 183 | def add_config(key, value): 184 | """添加新的配置项""" 185 | if key in test_config: 186 | logger.warning(f" 配置项 '{key}' 已存在,无法添加") 187 | return False 188 | test_config[key] = value 189 | _write_config() 190 | logger.info(f" 配置项 '{key}' 添加成功") 191 | return True 192 | 193 | def update_config(key, value): 194 | """修改或添加配置项""" 195 | test_config[key] = value # 如果 key 不存在,则添加新的配置项 196 | _write_config() 197 | logger.info(f" 配置项 '{key}' 修改成功") 198 | return True 199 | 200 | def delete_config(key): 201 | """删除配置项""" 202 | if key not in test_config: 203 | logger.warning(f" 配置项 '{key}' 不存在,无法删除") 204 | return False 205 | del test_config[key] 206 | _write_config() 207 | logger.info(f" 配置项 '{key}' 删除成功") 208 | return True 209 | 210 | 211 | 212 | # DEBUG情况下 213 | if DEBUG_MODE: 214 | if OPENAI_SECRET and OPENAI_MODEL and OPENAI_API_URL and ADMIN_ID: 215 | masked_secret = '*' * (len(OPENAI_SECRET) - 4) + OPENAI_SECRET[-4:] 216 | masked_admin_id = '*' * (len(ADMIN_ID) - 4) + ADMIN_ID[-4:] 217 | logger.info("") 218 | logger.info(" OpenAI API Configuration") 219 | logger.info(f" ↳ API Key : {masked_secret}") 220 | logger.info(f" ↳ Model : {OPENAI_MODEL}") 221 | logger.info(f" ↳ API URL : {OPENAI_API_URL}") 222 | logger.info(f" ↳ Admin ID : {masked_admin_id}") 223 | logger.info(f" ↳ Log Level : {LOG_LEVEL}") 224 | logger.info(f" ↳ Debug Mode: {'Enabled' if DEBUG_MODE else 'Disabled'}") 225 | 226 | -------------------------------------------------------------------------------- /tools/setup/mongodb/mongodb_setup.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import sys 3 | import subprocess 4 | import time 5 | import yaml 6 | import psutil 7 | from pathlib import Path 8 | from pymongo import MongoClient, errors 9 | 10 | # 配置文件路径 11 | DEFAULT_WINDOWS_PATH = Path(r"C:\Program Files\MongoDB\Server\7.0\bin") 12 | DEFAULT_LINUX_PATH = Path("/usr/bin/mongod") 13 | DEFAULT_DB_PATH = Path(r"C:\data\db") 14 | MONGO_CONFIG_PATH = Path(__file__).parent.parent.parent / "configs/mongodb.yaml" 15 | 16 | 17 | def detect_os_and_version(): 18 | if sys.platform.startswith('win'): 19 | return "Windows", sys.getwindowsversion().platform_version 20 | elif sys.platform.startswith('linux'): 21 | return "Linux", subprocess.getoutput('uname -r') 22 | else: 23 | return sys.platform, "Unknown" 24 | 25 | 26 | def is_port_open(host, port): 27 | """检查指定主机的端口是否开放""" 28 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 29 | try: 30 | s.connect((host, port)) 31 | s.shutdown(socket.SHUT_RDWR) 32 | return True 33 | except: 34 | return False 35 | finally: 36 | s.close() 37 | 38 | 39 | def check_mongodb_installed(): 40 | os_name, os_version = detect_os_and_version() 41 | 42 | if os_name == "Windows": 43 | return check_mongodb_installed_windows() 44 | elif os_name == "Linux": 45 | return check_mongodb_installed_linux() 46 | else: 47 | print(f"! 暂不支持的操作系统:{os_name}") 48 | sys.exit(1) 49 | 50 | 51 | def check_mongodb_installed_windows(): 52 | # 检查默认安装路径 53 | if DEFAULT_WINDOWS_PATH.exists(): 54 | print(f"> 检测到MongoDB安装在默认路径:{DEFAULT_WINDOWS_PATH}") 55 | save_mongodb_config(DEFAULT_WINDOWS_PATH) 56 | return True 57 | 58 | # 尝试通过注册表检测 59 | try: 60 | import winreg 61 | key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, r"SOFTWARE\MongoDB\Server") 62 | install_path, _ = winreg.QueryValueEx(key, "InstallPath") 63 | install_path = Path(install_path) 64 | print(f"> 通过注册表检测到MongoDB安装路径:{install_path}") 65 | save_mongodb_config(install_path) 66 | return True 67 | except Exception as e: 68 | print(f"! 无法通过注册表检测到MongoDB安装路径:{e}") 69 | 70 | # 如果前面的方法都失败,提示用户手动输入路径 71 | print("! 请注意,如果你现在在执行main.py,而且你没有安装的话,下面的安装路径可以直接回车或者编一个哦~") 72 | user_path = input("无法自动检测到MongoDB安装路径,请手动输入:") 73 | user_path = Path(user_path) 74 | if user_path.exists(): 75 | save_mongodb_config(user_path) 76 | return True 77 | else: 78 | print(f"! 输入的路径无效:{user_path}") 79 | sys.exit(1) 80 | 81 | 82 | def check_mongodb_installed_linux(): 83 | # 检查默认路径 84 | possible_paths = [DEFAULT_LINUX_PATH, Path("/usr/local/bin/mongod")] 85 | for path in possible_paths: 86 | if path.exists(): 87 | print(f"> 检测到MongoDB安装在路径:{path}") 88 | save_mongodb_config(path) 89 | return True 90 | 91 | # 通过包管理器检测安装 92 | try: 93 | result = subprocess.run(["which", "mongod"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 94 | if result.returncode == 0: 95 | install_path = Path(result.stdout.decode().strip()) 96 | print(f"> 通过包管理器检测到MongoDB安装路径:{install_path}") 97 | save_mongodb_config(install_path) 98 | return True 99 | 100 | # 针对Ubuntu/Debian 101 | result = subprocess.run(["dpkg", "-l", "|", "grep", "mongodb"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 102 | if result.returncode == 0: 103 | print("> 检测到MongoDB已通过dpkg安装") 104 | save_mongodb_config(Path("/usr/bin/mongod")) 105 | return True 106 | 107 | # 针对CentOS/RHEL 108 | result = subprocess.run(["rpm", "-qa", "|", "grep", "mongodb"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 109 | if result.returncode == 0: 110 | print("> 检测到MongoDB已通过rpm安装") 111 | save_mongodb_config(Path("/usr/bin/mongod")) 112 | return True 113 | except Exception as e: 114 | print(f"! 无法通过包管理器检测到MongoDB安装路径:{e}") 115 | 116 | print("! 无法检测到MongoDB安装。请检查您的安装状态。") 117 | sys.exit(1) 118 | 119 | 120 | def save_mongodb_config(install_path): 121 | # 确保配置目录存在 122 | if not MONGO_CONFIG_PATH.parent.exists(): 123 | MONGO_CONFIG_PATH.parent.mkdir(parents=True) 124 | 125 | config = {"mongodb": {"install_path": str(install_path)}} 126 | with open(MONGO_CONFIG_PATH, 'w', encoding='utf-8') as f: 127 | yaml.dump(config, f, allow_unicode=True) 128 | print(f"> MongoDB安装路径已保存到配置文件:{MONGO_CONFIG_PATH}") 129 | 130 | 131 | def is_mongodb_running(): 132 | for proc in psutil.process_iter(['pid', 'name']): 133 | if proc.info['name'] == "mongod.exe" or proc.info['name'] == "mongod": 134 | print("> MongoDB正在运行") 135 | return True 136 | print("> MongoDB未运行") 137 | return False 138 | 139 | 140 | def start_mongodb(): 141 | os_name, os_version = detect_os_and_version() 142 | 143 | if os_name == "Windows": 144 | return start_mongodb_windows() 145 | elif os_name == "Linux": 146 | return start_mongodb_linux() 147 | else: 148 | print(f"! 暂不支持的操作系统:{os_name}") 149 | sys.exit(1) 150 | 151 | 152 | def start_mongodb_windows(): 153 | try: 154 | with open(MONGO_CONFIG_PATH, 'r', encoding='utf-8') as f: 155 | config = yaml.safe_load(f) 156 | install_path = Path(config['mongodb']['install_path']) 157 | 158 | db_path = DEFAULT_DB_PATH 159 | if not db_path.exists(): 160 | db_path.mkdir(parents=True) # 确保数据目录存在 161 | 162 | print("> 正在尝试启动MongoDB服务...") 163 | 164 | subprocess.Popen( 165 | [str(install_path / "mongod.exe"), "--dbpath", str(db_path), "--quiet"], 166 | creationflags=subprocess.DETACHED_PROCESS 167 | ) 168 | # 等待MongoDB启动(检查端口是否开放) 169 | for _ in range(10): # 尝试10次 170 | if is_port_open("127.0.0.1", 27017): 171 | print("> MongoDB已成功启动并监听端口27017。") 172 | return True 173 | time.sleep(1) 174 | 175 | print("! MongoDB启动失败,未能在预期端口上监听。") 176 | return False 177 | 178 | except Exception as e: 179 | print(f"! 启动MongoDB服务失败:{e}") 180 | return False 181 | 182 | 183 | def start_mongodb_linux(): 184 | try: 185 | print("> 正在尝试启动MongoDB服务...") 186 | 187 | # 检查 systemctl 是否可用 188 | if subprocess.run(["which", "systemctl"], stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0: 189 | subprocess.run(["sudo", "systemctl", "start", "mongod"], check=True) 190 | else: 191 | subprocess.run(["sudo", "service", "mongod", "start"], check=True) 192 | 193 | print("> MongoDB已成功启动。") 194 | except Exception as e: 195 | print(f"! 启动MongoDB服务失败:{e}") 196 | sys.exit(1) 197 | 198 | 199 | def stop_all_mongodb_processes(): 200 | print("> 正在停止所有MongoDB进程...") 201 | for proc in psutil.process_iter(['pid', 'name']): 202 | if proc.info['name'] == "mongod.exe" or proc.info['name'] == "mongod": 203 | print(f"> 正在终止进程:{proc.info['pid']}") 204 | proc.terminate() 205 | 206 | print("> 所有MongoDB进程已停止。") 207 | 208 | 209 | def check_mongodb_connection(): 210 | try: 211 | print("> 正在测试与MongoDB的连接...") 212 | client = MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=5000) 213 | # 尝试连接到MongoDB服务器 214 | client.server_info() # 发送一个ping以确认连接成功 215 | print("> MongoDB连接成功!") 216 | return True 217 | except errors.ServerSelectionTimeoutError as err: 218 | print(f"! 无法连接到MongoDB服务器:{err}") 219 | return False 220 | except Exception as e: 221 | print(f"! 连接MongoDB时发生错误:{e}") 222 | return False 223 | 224 | 225 | if __name__ == "__main__": 226 | print("> 开始MongoDB启动检测...") 227 | if not check_mongodb_installed(): 228 | print("! MongoDB未安装或安装检测失败。") 229 | sys.exit(1) 230 | 231 | if not is_mongodb_running(): 232 | print("> MongoDB未运行,尝试启动...") 233 | 234 | if not is_mongodb_running(): 235 | print("MongoDB未运行,尝试启动...") 236 | stop_all_mongodb_processes() 237 | start_mongodb() 238 | 239 | if not check_mongodb_connection(): 240 | print("MongoDB启动失败或无法连接,请检查安装和配置。") 241 | sys.exit(1) 242 | 243 | print("MongoDB已成功启动并连接!系统准备就绪。") 244 | # 退出此进程 245 | sys.exit(0) 246 | --------------------------------------------------------------------------------