├── .dbs ├── interactions.db └── marine_species.db ├── .env.example ├── .gitignore ├── LICENSE ├── app.py ├── broadcast.py ├── docs ├── demo130.json └── demo_18.json ├── dspy_evaluation.py ├── dspy_inference.py ├── dspy_program ├── program_v1.0.1_20250313195723.pkl └── program_v1.0.3_20250315154834.pkl ├── dspy_query_db.py ├── graph_data_new ├── entity_vectors.json ├── graph_entity_relation_detailed.graphml └── relation_vectors.json ├── images ├── function-diagram.png ├── startup-success.jpg ├── 二维码.jpg ├── 优化样本.jpg ├── 关系信息查询.jpg ├── 实体信息查询.jpg ├── 属性信息查询.jpg ├── 版本选择.jpg ├── 统计信息查询.jpg ├── 训练所有样本.jpg ├── 非实体信息截图.jpg └── 项目技术路线.jpg ├── nanovector_db.py ├── react_tools.py ├── readme.md ├── readme_en.md ├── requirements.txt └── tools ├── entity_extraction.py └── entity_extraction_db.py /.dbs/interactions.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/.dbs/interactions.db -------------------------------------------------------------------------------- /.dbs/marine_species.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/.dbs/marine_species.db -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | 本项目使用DSPy进行意图识别,需要配置两个独立的模型: 2 | Dspy官方中文文档信息:https://www.aidoczh.com/dspy/ 3 | 1. 问答/推理模型:用于处理用户查询和推理任务 4 | 2. 后训练模型:用于基于用户的业务问答得到的反馈数据,进行后训练任务 5 | 两个模型可以使用相同或不同的配置,支持OpenAI-SDK格式的模型: 6 | - OpenAI API系列:GPT-3.5/4 7 | - DeepSeek系列:deepseek-chat/coder 8 | - 阿里云系列:Qwen/通义千问 9 | - 百度文心系列:ERNIE-Bot 10 | - Ollama本地部署 11 | - HuggingFace部署 12 | - VLLM高性能部署 13 | 14 | 15 | # 问答/推理模型配置(用于处理用户查询和推理) 16 | LLM_TYPE ="deepseek" # 模型类型 17 | API_KEY ="sk-you_api_key" # API密钥 18 | BASE_URL ="https://api.deepseek.com/v1" # API基础地址 19 | LLM_MODEL ="deepseek-chat" # 具体的模型名称 20 | 21 | # Ollama配置(本地部署方案,适合离线环境) 22 | # LLM_TYPE="ollama_chat" # 设置为使用Ollama本地模型 23 | # API_KEY="" # Ollama本地部署不需要API密钥 24 | # BASE_URL="http://localhost:11434" # Ollama服务的本地地址 25 | # LLM_MODEL="gemma3:12b" # 使用的具体模型,这里使用Gemma 12B版本 26 | 27 | # 后训练模型配置(用于模型训练和知识提取,可以和问答/推理模型配置相同,也可以用更适合的模型) 28 | Train_LLM_TYPE ="deepseek" # 后训练模型类型 29 | Train_LLM_MODEL ="deepseek-chat" # 后训练使用的具体模型 30 | Train_OPENAI_API_KEY ="sk-you_api_key" # 后训练模型的API密钥 31 | Train_OPENAI_BASE_URL ="https://api.deepseek.com/v1" # 后训练模型的API地址 32 | 33 | # 系统环境配置(核心路径和基础设置) 34 | RAG_DIR = "graph_data_new" # 知识图谱数据的存储目录 35 | LOG_LEVEL = "DEBUG" # 日志级别,可选:DEBUG, INFO, WARNING, ERROR 36 | DATABASE_URL ="sqlite:///.dbs/interactions.db" # 用户交互记录数据库的路径 37 | SPECIES_DB_URL = "./.dbs/marine_species.db" # 物种信息数据库的路径 38 | 39 | # 向量检索配置(影响检索质量和效率的关键参数) 40 | VECTOR_SEARCH_TOP_K = 3 # 向量检索返回的最大结果数量 41 | BETTER_THAN_THRESHOLD = 0.7 # 相似度筛选阈值,范围0-1,越大要求越严格 42 | GRAPHML_DIR = "graph_entity_relation_detailed.graphml" # 知识图谱的存储文件路径 43 | 44 | # Embedding配置(文本向量化相关参数) 45 | MAX_BATCH_SIZE = 100 # 批处理大小,影响处理速度和内存使用 46 | EMBEDDING_MAX_TOKEN_SIZE = 8192 # 单次能处理的最大token数量 47 | EMBEDDING_DIM = 1024 # 生成的向量维度 48 | EMBEDDING_MODEL = "text-embedding-v3" # 使用的embedding模型版本 49 | EMBEDDING_MODEL_BASE_URL ="https://dashscope.Trainyuncs.com/compatible-mode/v1" # embedding服务的API地址 50 | EMBEDDING_MODEL_API_KEY ="sk-you_api_key" # embedding服务的API密钥 51 | 52 | # 检索相关配置 53 | MAX_ITERS = 10 # 最大检索迭代次数,控制循环次数,官方默认是 5 次 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .DS_Store 6 | # C extensions 7 | *.so 8 | .venv/ 9 | .env 10 | # Ignore proprietary data 11 | graph_data_new/ 12 | dspy_program/ 13 | docs/ 14 | .dbs/ 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # UV 105 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | #uv.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | #pdm.lock 120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 121 | # in version control. 122 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 123 | .pdm.toml 124 | .pdm-python 125 | .pdm-build/ 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | .dmypy.json 159 | dmypy.json 160 | 161 | # Pyre type checker 162 | .pyre/ 163 | 164 | # pytype static type analyzer 165 | .pytype/ 166 | 167 | # Cython debug symbols 168 | cython_debug/ 169 | 170 | # PyCharm 171 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 172 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 173 | # and can be added to the global gitignore or merged into this file. For a more nuclear 174 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 175 | #.idea/ 176 | 177 | # Ruff stuff: 178 | .ruff_cache/ 179 | 180 | # PyPI configuration file 181 | .pypirc 182 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 loukie7 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os,io,sys 2 | from dotenv import load_dotenv 3 | load_dotenv(override=True) 4 | from loguru import logger 5 | 6 | # 设置 logger 日志级别 7 | log_level = os.getenv("LOG_LEVEL", "INFO") 8 | logger.remove() # 移除默认处理器 9 | logger.add(sys.stderr, level=log_level) # 添加新的处理器并设置日志级别 10 | logger.info(f"日志级别设置为: {log_level}") 11 | 12 | from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect,Query, Body,File, UploadFile, HTTPException, BackgroundTasks 13 | from fastapi.responses import StreamingResponse 14 | from fastapi.middleware.cors import CORSMiddleware 15 | import dspy 16 | from pydantic import BaseModel, Field 17 | from typing import List,Dict, Any 18 | import tempfile 19 | import json 20 | 21 | from sqlalchemy import create_engine, Column, String, JSON, DateTime,Integer 22 | from sqlalchemy.ext.declarative import declarative_base 23 | from sqlalchemy.orm import sessionmaker 24 | from datetime import datetime 25 | import uuid 26 | import asyncio 27 | from broadcast import ConnectionManager 28 | from dspy_inference import DspyInferenceProcessor 29 | from dspy_evaluation import DspyEvaluationProcessor 30 | 31 | 32 | 33 | app = FastAPI() 34 | app.add_middleware( 35 | CORSMiddleware, 36 | allow_origins=["*"], 37 | allow_credentials=True, 38 | allow_methods=["*"], 39 | allow_headers=["*"], 40 | ) 41 | 42 | 43 | manager = ConnectionManager() 44 | # 初始化 DspyProcessor 45 | dspy_processor = DspyInferenceProcessor() 46 | # 初始化流式模型 47 | streaming_react = dspy_processor.stream_predict 48 | 49 | eval_processor = DspyEvaluationProcessor() 50 | 51 | 52 | predictor_version = "1.0.0" 53 | 54 | # 定义数据库模型 55 | Base = declarative_base() 56 | 57 | # 创建数据库引擎 58 | engine = create_engine(os.getenv("DATABASE_URL", "sqlite:///interactions.db"), echo=False) 59 | 60 | Base.metadata.create_all(engine) 61 | 62 | # 创建会话 63 | SessionLocal = sessionmaker(bind=engine) 64 | 65 | # 定义封装的响应模型 66 | class ResponseWrapper(BaseModel): 67 | status_code: int 68 | detail: str 69 | data: Any 70 | 71 | class Interaction(Base): 72 | __tablename__ = 'interactions' 73 | 74 | id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) 75 | timestamp = Column(DateTime, default=datetime.now) 76 | question = Column(String) 77 | model = Column(String) 78 | version = Column(String) 79 | messages = Column(JSON) 80 | retrievmethod = Column(JSON) 81 | prompt = Column(String) 82 | modelResponse = Column(String) 83 | reasoning = Column(String) 84 | processingTime = Column(Integer) 85 | tokens = Column(JSON) 86 | 87 | # 新增版本管理模型 88 | class Version(Base): 89 | __tablename__ = 'versions' 90 | 91 | version = Column(String, primary_key=True) 92 | file_path = Column(String) 93 | description = Column(String) 94 | created_at = Column(DateTime, default=datetime.now) 95 | 96 | # 新增请求体模型 97 | class TrainingRequest(BaseModel): 98 | ids: List[str] 99 | version: str 100 | 101 | 102 | @app.websocket("/ws") 103 | async def websocket_endpoint(websocket: WebSocket): 104 | await manager.connect(websocket) 105 | try: 106 | while True: 107 | # 保持连接(这里简单接收消息,可用于心跳检 108 | await websocket.receive_text() 109 | except WebSocketDisconnect: 110 | manager.disconnect(websocket) 111 | 112 | # 异步生成器:流式返回 ReAct 模块的回答,并在结束后通过 websocket 推送 prompt 历史 113 | async def stream_react_response(prompt: str): 114 | global streaming_react 115 | try: 116 | # 跟踪上一次的内容,用于增量更新 117 | last_reasoning = "" 118 | last_answer = "" 119 | 120 | # 修改这里:直接调用 streaming_react 函数 121 | async for chunk in streaming_react(question=prompt): 122 | # 假设每个 chunk 为 Prediction 对象,包含 answer 与 reasoning 字段 123 | if chunk: 124 | # 获取当前的 reasoning 和 answer 125 | current_reasoning = getattr(chunk, "reasoning", "") or "" 126 | current_answer = getattr(chunk, "answer", "") or "" 127 | 128 | # 计算增量内容 129 | reasoning_delta = current_reasoning[len(last_reasoning):] if current_reasoning else "" 130 | answer_delta = current_answer[len(last_answer):] if current_answer else "" 131 | 132 | # 只有当有新内容时才发送 133 | if reasoning_delta or answer_delta: 134 | data = { 135 | "reasoning_delta": reasoning_delta, 136 | "answer_delta": answer_delta, 137 | "reasoning": current_reasoning, # 也可以选择只发送增量 138 | "answer": current_answer # 也可以选择只发送增量 139 | } 140 | logger.info(f"增量数据: {json.dumps(data)}") 141 | yield f"data: {json.dumps(data)}\n\n" 142 | 143 | # 更新上一次的内容 144 | last_reasoning = current_reasoning 145 | last_answer = current_answer 146 | 147 | # 流式结束后的处理... 148 | last_message = dspy_processor.get_last_message() 149 | 150 | # 检查 last_message 是否为 None 或不包含必要字段 151 | if not last_message: 152 | error_data = {"error": "无法获取消息历史", "message": "处理请求时发生错误"} 153 | logger.error(f"last_message 为空或无效") 154 | yield f"data: {json.dumps(error_data)}\n\n" 155 | yield "data: [DONE]\n\n" 156 | return 157 | 158 | # 构造一个只包含所需字段的新字典 159 | data_to_send = { 160 | "question": prompt, 161 | "prompt": last_message.get("prompt"), 162 | "messages": last_message.get("messages"), 163 | "timestamp": last_message.get("timestamp"), 164 | "uuid": last_message.get("uuid"), 165 | "model": last_message.get("model"), 166 | "version": predictor_version 167 | } 168 | 169 | # 从 response 中提取 choices 第一个元素的 message 的 content 字段 170 | try: 171 | # 检查 response 是否存在且包含必要字段 172 | if "response" in last_message and last_message["response"] and "choices" in last_message["response"]: 173 | data_to_send["content"] = last_message["response"].choices[0].message.content 174 | tokens = {} 175 | if "usage" in last_message: 176 | tokens["completion_tokens"] = last_message["usage"].get("completion_tokens", 0) 177 | tokens["prompt_tokens"] = last_message["usage"].get("prompt_tokens", 0) 178 | tokens["total_tokens"] = last_message["usage"].get("total_tokens", 0) 179 | data_to_send["tokens"] = tokens 180 | else: 181 | data_to_send["content"] = None 182 | data_to_send["tokens"] = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0} 183 | logger.warning("response 字段不存在或格式不正确") 184 | except (KeyError, IndexError, AttributeError) as e: 185 | # 如果不存在该字段则设为 None 或者按需处理 186 | data_to_send["content"] = None 187 | data_to_send["tokens"] = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0} 188 | logger.error(f"提取 content 时出错:{e}") 189 | 190 | # 将数据转换为 JSON 字符串 191 | json_message = json.dumps(data_to_send, ensure_ascii=False, indent=2) 192 | logger.info(json_message) 193 | 194 | # 修改:不再通过 websocket 广播,而是通过流式返回完整消息 195 | yield f"data: {json.dumps({'prompt_history': json_message})}\n\n" 196 | yield "data: [DONE]\n\n" 197 | 198 | except Exception as e: 199 | # 捕获所有异常,返回错误信息 200 | error_message = str(e) 201 | logger.error(f"stream_react_response 发生错误: {error_message}") 202 | error_data = {"error": "处理请求失败", "message": error_message} 203 | yield f"data: {json.dumps(error_data)}\n\n" 204 | yield "data: [DONE]\n\n" 205 | 206 | 207 | @app.post("/chat") 208 | async def chat(request: Request, prompt: str = Body(..., embed=True), stream: int = Body(None, embed=True), version: str = Body(None, embed=True)): 209 | 210 | global predictor_version 211 | global streaming_react # 添加全局声明 212 | try: 213 | # 创建会话 214 | session = SessionLocal() 215 | pred = dspy_processor.model 216 | 217 | predictor_version =dspy_processor.get_version() 218 | # 记录一个当前的版本号,如果版本号没有发生变化,则不需要进行操作 219 | if version and version != predictor_version: 220 | # 查询版本信息 221 | version_info = session.query(Version).filter(Version.version == version).first() 222 | if not version_info: 223 | return ResponseWrapper(status_code=404, detail="error", data={"message": f"Version {version} not found"}) 224 | 225 | # 加载指定版本的模型文件todo 226 | logger.info(f"开始切换版本:{version}/{version_info.file_path}") 227 | file_path = version_info.file_path 228 | dspy_processor.load_model(file_path) 229 | # 更新 predictor_version 230 | predictor_version = version 231 | dspy_processor.set_version(version) 232 | logger.info(f"切换版本成功:{version},清除缓存") 233 | # 重新初始化 streaming_react 234 | streaming_react = dspy_processor.stream_predict # 修改这里:直接赋值函数引用,不要调用 235 | 236 | if stream == 1: 237 | # 流式返回:包装生成器,media_type 为 "text/event-stream" 238 | return StreamingResponse(stream_react_response(prompt), media_type="text/event-stream") 239 | else: 240 | # 非流式返回:直接调用 ReAct 模块,获取最终答案 241 | # 为pred设置独立的llm配置 242 | with dspy.llm_config(model=dspy_processor.lm): 243 | pred = dspy_processor.model 244 | dspyres = pred(question=prompt) 245 | content = dspyres.answer 246 | reasoning = dspyres.reasoning 247 | return ResponseWrapper(status_code=200, detail="success", data={"content": content, "reasoning": reasoning}) 248 | except Exception as e: 249 | return ResponseWrapper(status_code=500, detail="error", data={"message": str(e)}) 250 | finally: 251 | session.close() 252 | # 新增的 API 方法:接收数据并保存到 JSON 文件 253 | @app.post("/save_data") 254 | async def save_data(data: Dict): 255 | try: 256 | # 定义保存数据的文件路径 257 | file_path = "saved_data.json" 258 | 259 | # 检查文件是否存在,如果存在则读取现有数据 260 | if os.path.exists(file_path): 261 | with open(file_path, "r", encoding="utf-8") as file: 262 | existing_data = json.load(file) 263 | else: 264 | existing_data = [] 265 | 266 | # 将新数据添加到现有数据中 267 | existing_data.append(data) 268 | 269 | # 将更新后的数据写回文件 270 | with open(file_path, "w", encoding="utf-8") as file: 271 | json.dump(existing_data, file, ensure_ascii=False, indent=2) 272 | 273 | return ResponseWrapper(status_code=200, detail="success", data={"message": "Data saved successfully"}) 274 | except Exception as e: 275 | return ResponseWrapper(status_code=500, detail="error", data={"message": str(e)}) 276 | 277 | 278 | 279 | # 新增的 API 方法:接收数据并保存到 SQLite 数据库 280 | @app.post("/save_to_db") 281 | async def save_to_db(data: Dict): 282 | try: 283 | # 创建会话 284 | session = SessionLocal() 285 | 286 | # 检查是否已存在相同ID 287 | if data.get("id"): 288 | existing = session.query(Interaction).get(data["id"]) 289 | if existing: 290 | return ResponseWrapper( 291 | status_code=400, 292 | detail="error", 293 | data={"message": f"相同记录 {data['id']} 已存在"} 294 | ) 295 | 296 | 297 | # 格式化 messages 和 retrievmethod 字段 298 | formatted_messages = json.dumps(data.get("messages"), ensure_ascii=False, indent=2) 299 | formatted_retrievmethod = json.dumps(data.get("retrievmethod"), ensure_ascii=False, indent=2) 300 | 301 | 302 | # 创建 Interaction 实例 303 | interaction = Interaction( 304 | id=data.get("id"), 305 | timestamp=datetime.fromisoformat(data.get("timestamp")), 306 | question=data.get("question"), 307 | model=data.get("model"), 308 | version=data.get("version"), 309 | messages=json.loads(formatted_messages), 310 | retrievmethod=json.loads(formatted_retrievmethod), 311 | prompt=data.get("prompt"), 312 | modelResponse=data.get("modelResponse"), 313 | reasoning=data.get("reasoning"), 314 | processingTime=data.get("processingTime"), 315 | tokens=data.get("tokens") 316 | ) 317 | 318 | # 添加到会话并提交 319 | session.add(interaction) 320 | session.commit() 321 | 322 | return ResponseWrapper(status_code=200, detail="success", data={"message": "Data saved successfully to database"}) 323 | except Exception as e: 324 | session.rollback() 325 | return ResponseWrapper(status_code=500, detail="error", data={"message": str(e)}) 326 | finally: 327 | session.close() 328 | 329 | @app.delete("/interactions/{interaction_id}", response_model=ResponseWrapper) 330 | async def delete_interaction(interaction_id: str): 331 | try: 332 | session = SessionLocal() 333 | 334 | # 查询要删除的记录 335 | interaction = session.query(Interaction).filter(Interaction.id == interaction_id).first() 336 | 337 | if not interaction: 338 | return ResponseWrapper( 339 | status_code=404, 340 | detail="error", 341 | data={"message": f"ID为 {interaction_id} 的记录不存在"} 342 | ) 343 | 344 | # 执行删除 345 | session.delete(interaction) 346 | session.commit() 347 | 348 | return ResponseWrapper( 349 | status_code=200, 350 | detail="success", 351 | data={"message": "记录删除成功", "deleted_id": interaction_id} 352 | ) 353 | except Exception as e: 354 | session.rollback() 355 | return ResponseWrapper( 356 | status_code=500, 357 | detail="error", 358 | data={"message": f"删除失败: {str(e)}"} 359 | ) 360 | finally: 361 | session.close() 362 | 363 | # 新增的 API 方法:接收数据并更新 SQLite 数据库中的记录 364 | @app.post("/editdata") 365 | async def edit_data(data: Dict): 366 | try: 367 | # 创建会话 368 | session = SessionLocal() 369 | 370 | # 获取 messageId 和更新字段 371 | message_id = data.get("messageId") 372 | update_fields = data.get("updateFields", {}) 373 | 374 | # 根据 messageId 查找记录 375 | interaction = session.query(Interaction).filter(Interaction.id == message_id).first() 376 | 377 | if not interaction: 378 | return ResponseWrapper(status_code=404, detail="error", data={"message": "Record not found"}) 379 | 380 | # 更新指定的字段 381 | for field, value in update_fields.items(): 382 | if hasattr(interaction, field): 383 | if field in ["messages", "retrievmethod"]: 384 | # 格式化 JSON 字段 385 | setattr(interaction, field, json.loads(json.dumps(value, ensure_ascii=False, indent=2))) 386 | else: 387 | setattr(interaction, field, value) 388 | else: 389 | return ResponseWrapper(status_code=400, detail="error", data={"message": f"Field '{field}' does not exist"}) 390 | 391 | # 提交更改 392 | session.commit() 393 | 394 | return ResponseWrapper(status_code=200, detail="success", data={"message": "Data updated successfully"}) 395 | except Exception as e: 396 | session.rollback() 397 | return ResponseWrapper(status_code=500, detail="error", data={"message": str(e)}) 398 | finally: 399 | session.close() 400 | 401 | @app.get("/interactions/{interaction_id}", response_model=ResponseWrapper) 402 | async def get_interaction_by_id(interaction_id: str): 403 | try: 404 | session = SessionLocal() 405 | interaction = session.query(Interaction).filter(Interaction.id == interaction_id).first() 406 | 407 | if not interaction: 408 | return ResponseWrapper( 409 | status_code=404, 410 | detail="error", 411 | data={"message": f"ID为 {interaction_id} 的记录不存在"} 412 | ) 413 | 414 | interaction_data = { 415 | "id": interaction.id, 416 | "timestamp": interaction.timestamp.isoformat(), 417 | "question": interaction.question, 418 | "model": interaction.model, 419 | "version": interaction.version, 420 | "messages": interaction.messages, 421 | "retrievmethod": interaction.retrievmethod, 422 | "prompt": interaction.prompt, 423 | "modelResponse": interaction.modelResponse, 424 | "reasoning": interaction.reasoning, 425 | "processingTime": interaction.processingTime, 426 | "tokens": interaction.tokens 427 | } 428 | 429 | return ResponseWrapper( 430 | status_code=200, 431 | detail="success", 432 | data=interaction_data 433 | ) 434 | except Exception as e: 435 | return ResponseWrapper( 436 | status_code=500, 437 | detail="error", 438 | data={"message": f"查询失败: {str(e)}"} 439 | ) 440 | finally: 441 | session.close() 442 | 443 | @app.get("/interactions", response_model=ResponseWrapper) 444 | async def get_interactions_by_version( 445 | version: str = Query(None), 446 | page: int = Query(1, ge=1, description="页码,从1开始"), 447 | page_size: int = Query(10, ge=1, le=100, description="每页数量") 448 | ): 449 | try: 450 | session = SessionLocal() 451 | 452 | # 获取最新版本(如果未指定) 453 | # latest_version = session.query(Version.version)\ 454 | # .order_by(Version.created_at.desc())\ 455 | # .first() 456 | # if not latest_version: 457 | # return ResponseWrapper(status_code=404, detail="error", data={"message": "无可用版本"}) 458 | # version = latest_version[0] 459 | # 修改后的代码片段 460 | if not version: 461 | # 移除获取最新版本逻辑,直接构建无版本过滤的查询 462 | base_query = session.query( 463 | Interaction.id, 464 | Interaction.question, 465 | Interaction.version, 466 | Interaction.model, 467 | Interaction.processingTime, 468 | Interaction.timestamp 469 | ).order_by( 470 | Interaction.timestamp.desc() 471 | ) 472 | else: 473 | # 当指定版本时保持原有过滤逻辑 474 | base_query = session.query( 475 | Interaction.id, 476 | Interaction.question, 477 | Interaction.version, 478 | Interaction.model, 479 | Interaction.processingTime, 480 | Interaction.timestamp 481 | ).filter( 482 | Interaction.version == version 483 | ).order_by( 484 | Interaction.timestamp.desc() 485 | ) 486 | 487 | # 分页处理 488 | total_count = base_query.count() 489 | total_pages = (total_count + page_size - 1) // page_size 490 | 491 | interactions = base_query.offset( 492 | (page - 1) * page_size 493 | ).limit( 494 | page_size 495 | ).all() 496 | 497 | # 构建响应数据 498 | interaction_list = [ 499 | { 500 | "id": row.id, 501 | "question": row.question, 502 | "version": row.version, 503 | "model": row.model, 504 | "processingTime": row.processingTime, 505 | "timestamp": row.timestamp.isoformat() 506 | } 507 | for row in interactions 508 | ] 509 | 510 | return ResponseWrapper( 511 | status_code=200, 512 | detail="success", 513 | data={ 514 | "version": version, 515 | "pagination": { 516 | "total": total_count, 517 | "total_pages": total_pages, 518 | "current_page": page, 519 | "page_size": page_size 520 | }, 521 | "interactions": interaction_list 522 | } 523 | ) 524 | except Exception as e: 525 | return ResponseWrapper(status_code=500, detail="error", data={"message": str(e)}) 526 | finally: 527 | session.close() 528 | 529 | # 全局优化任务跟踪 530 | optimization_tasks = {} 531 | 532 | # 异步优化任务 533 | async def run_dspy_optimization(training_data: List[Dict], version: str, ids: List[str]): 534 | task_id = f"optimization_task_{version}_{datetime.now().strftime('%Y%m%d%H%M%S')}" 535 | try: 536 | from dspy.teleprompt import BootstrapFewShot 537 | from dspy.evaluate import Evaluate 538 | from dspy.evaluate.metrics import answer_exact_match 539 | 540 | # 更新状态并发送开始消息 541 | logger.info(f"开始优化任务 {task_id},数据量: {len(training_data)},版本: {version}") 542 | optimization_tasks[task_id] = "loading_data" 543 | await manager.broadcast(json.dumps({ 544 | "type": "optimization_status", 545 | "data": { 546 | "task_id": task_id, 547 | "status": "loading_data", 548 | "progress": 5, 549 | "message": "正在准备训练数据..." 550 | } 551 | })) 552 | 553 | # 创建训练集 554 | trainset = [dspy.Example(question=x["question"],reasoning=x["reasoning"], answer=x["modelResponse"]).with_inputs("question") for x in training_data] 555 | logger.info(f"任务 {task_id}: 已创建训练集,共 {len(trainset)} 条数据") 556 | 557 | # 更新状态 558 | optimization_tasks[task_id] = "preparing_model" 559 | await manager.broadcast(json.dumps({ 560 | "type": "optimization_status", 561 | "data": { 562 | "task_id": task_id, 563 | "status": "preparing_model", 564 | "progress": 10, 565 | "message": "正在准备模型..." 566 | } 567 | })) 568 | 569 | # 从最新版本加载预测模型 570 | session = SessionLocal() 571 | 572 | # 修改这里:使用 dspy_processor 的 model 而不是 eval_processor 的 model 573 | # 因为 DspyEvaluationProcessor 没有 model 属性 574 | predict = dspy_processor.model 575 | logger.info(f"任务 {task_id}: 已加载模型") 576 | 577 | # 设置优化器 578 | teleprompter = BootstrapFewShot(metric=eval_processor.llm_biological_metric, max_labeled_demos=15) 579 | 580 | # 更新状态 581 | optimization_tasks[task_id] = "optimizing" 582 | await manager.broadcast(json.dumps({ 583 | "type": "optimization_status", 584 | "data": { 585 | "task_id": task_id, 586 | "status": "optimizing", 587 | "progress": 15, 588 | "message": "正在进行模型优化..." 589 | } 590 | })) 591 | 592 | # 编译优化 593 | logger.info(f"任务 {task_id}: 开始编译优化") 594 | compiled_predictor = teleprompter.compile(predict, trainset=trainset) 595 | logger.info(f"任务 {task_id}: 编译优化完成") 596 | 597 | # 更新状态 598 | optimization_tasks[task_id] = "saving_model" 599 | await manager.broadcast(json.dumps({ 600 | "type": "optimization_status", 601 | "data": { 602 | "task_id": task_id, 603 | "status": "saving_model", 604 | "progress": 50, 605 | "message": "正在保存优化后的模型..." 606 | } 607 | })) 608 | 609 | # 确保目录存在 610 | os.makedirs("dspy_program", exist_ok=True) 611 | last_version = session.query(Version.version).order_by(Version.created_at.desc()).first().version 612 | 613 | 614 | # 保存优化后的模型 615 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 616 | output_path = f"dspy_program/program_v{last_version}_{timestamp}.pkl" 617 | compiled_predictor.save(output_path, save_program=False) 618 | logger.info(f"任务 {task_id}: 已保存模型到 {output_path}") 619 | 620 | # 解析当前版本号,生成新版本号 621 | # 从数据库获取最新版本号,原生新增 622 | major, minor, patch = map(int, last_version.split('.')) 623 | new_version = f"{major}.{minor}.{patch + 1}" 624 | 625 | # 描述信息 626 | description = f"基于 {version} 版本,使用 {len(ids)} 条数据优化生成的新版本" 627 | 628 | # 创建新版本 629 | new_version_instance = Version( 630 | version=new_version, 631 | file_path=output_path, 632 | description=description 633 | ) 634 | 635 | session.add(new_version_instance) 636 | session.commit() 637 | logger.info(f"任务 {task_id}: 已创建新版本 {new_version}") 638 | 639 | # 更新状态为完成 640 | optimization_tasks[task_id] = "completed" 641 | 642 | # 通过 WebSocket 广播版本更新消息 643 | await manager.broadcast(json.dumps({ 644 | "type": "version_update", 645 | "data": { 646 | "old_version": version, 647 | "new_version": new_version, 648 | "description": description, 649 | "model_path": output_path, 650 | "training_ids": ids, 651 | "progress": 100, 652 | "message": f"优化完成,已创建新版本{new_version}" 653 | } 654 | })) 655 | logger.info(f"任务 {task_id}: 优化任务完成") 656 | 657 | except Exception as e: 658 | # 记录错误并通过 WebSocket 发送失败消息 659 | error_message = str(e) 660 | logger.error(f"任务 {task_id} 失败: {error_message}") 661 | optimization_tasks[task_id] = f"failed: {error_message}" 662 | 663 | await manager.broadcast(json.dumps({ 664 | "type": "optimization_failed", 665 | "data": { 666 | "version": version, 667 | "error": error_message, 668 | "task_id": task_id, 669 | "progress": 0, 670 | "message": f"优化失败: {error_message}" 671 | } 672 | })) 673 | finally: 674 | if 'session' in locals(): 675 | session.close() 676 | 677 | @app.post("/addtraining", response_model=ResponseWrapper) 678 | async def add_training(request: TrainingRequest, background_tasks: BackgroundTasks): # 使用新模型 679 | session = None 680 | try: 681 | # 获取ID列表 682 | ids = request.ids 683 | version = request.version 684 | 685 | # 参数校验 686 | if not ids: 687 | return ResponseWrapper( 688 | status_code=400, 689 | detail="error", 690 | data={"message": "未提供有效ID列表"} 691 | ) 692 | if not version: 693 | return ResponseWrapper( 694 | status_code=400, 695 | detail="error", 696 | data={"message": "必须提供版本号参数"} 697 | ) 698 | 699 | session = SessionLocal() 700 | 701 | # 查询数据库并收集数据 702 | training_data = [] 703 | for interaction_id in ids: 704 | interaction = session.query(Interaction).get(interaction_id) 705 | if interaction: 706 | training_data.append({ 707 | "id": interaction.id, 708 | "question": interaction.question, 709 | "reasoning": interaction.reasoning, 710 | "modelResponse": interaction.modelResponse, 711 | "timestamp": interaction.timestamp.isoformat() 712 | }) 713 | 714 | if not training_data: 715 | return ResponseWrapper( 716 | status_code=404, 717 | detail="error", 718 | data={"message": "未找到匹配的记录"} 719 | ) 720 | 721 | # 生成任务ID 722 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 723 | task_id = f"optimization_task_{version}_{timestamp}" 724 | 725 | # 在后台启动优化任务前先设置状态 726 | optimization_tasks[task_id] = "pending" 727 | 728 | # 将训练数据和任务信息保存为全局变量,以便后台任务使用 729 | # 这样可以避免在后台任务中重新查询数据库 730 | task_info = { 731 | "training_data": training_data, 732 | "version": version, 733 | "ids": [item["id"] for item in training_data], 734 | "task_id": task_id 735 | } 736 | 737 | # 添加后台任务 - 使用普通函数而不是异步函数 738 | background_tasks.add_task( 739 | start_optimization_task, 740 | task_info 741 | ) 742 | 743 | # 立即返回响应,不等待优化任务完成 744 | logger.info(f"已创建优化任务 {task_id},将在后台处理 {len(training_data)} 条数据") 745 | return ResponseWrapper( 746 | status_code=200, 747 | detail="success", 748 | data={ 749 | "message": f"成功收集 {len(training_data)} 条训练数据,已创建后台优化任务", 750 | "task_id": task_id, 751 | "exported_ids": [item["id"] for item in training_data], 752 | "version": version, # 返回版本号用于验证 753 | "optimization_status": "pending" # 返回初始优化状态 754 | } 755 | ) 756 | 757 | except Exception as e: 758 | logger.error(f"创建优化任务失败: {str(e)}") 759 | return ResponseWrapper( 760 | status_code=500, 761 | detail="error", 762 | data={"message": f"处理失败: {str(e)}"} 763 | ) 764 | finally: 765 | if session: 766 | session.close() 767 | 768 | # 新增函数:启动优化任务的普通函数 769 | def start_optimization_task(task_info): 770 | """启动优化任务的普通函数,用于后台任务""" 771 | # 创建一个新的事件循环 772 | loop = asyncio.new_event_loop() 773 | asyncio.set_event_loop(loop) 774 | 775 | # 在新的事件循环中运行异步任务 776 | try: 777 | # 发送初始通知 778 | loop.run_until_complete(manager.broadcast(json.dumps({ 779 | "type": "optimization_created", 780 | "data": { 781 | "task_id": task_info["task_id"], 782 | "status": "pending", 783 | "progress": 0, 784 | "message": f"已创建优化任务,准备处理 {len(task_info['training_data'])} 条数据", 785 | "version": task_info["version"], 786 | "ids": task_info["ids"] 787 | } 788 | }))) 789 | 790 | # 设置状态为 running 791 | optimization_tasks[task_info["task_id"]] = "running" 792 | 793 | # 执行实际的优化任务 794 | loop.run_until_complete(run_dspy_optimization( 795 | task_info["training_data"], 796 | task_info["version"], 797 | task_info["ids"] 798 | )) 799 | except Exception as e: 800 | logger.error(f"优化任务执行失败: {str(e)}") 801 | # 设置任务状态为失败 802 | optimization_tasks[task_info["task_id"]] = f"failed: {str(e)}" 803 | # 发送失败通知 804 | loop.run_until_complete(manager.broadcast(json.dumps({ 805 | "type": "optimization_failed", 806 | "data": { 807 | "version": task_info["version"], 808 | "error": str(e), 809 | "task_id": task_info["task_id"], 810 | "progress": 0, 811 | "message": f"优化失败: {str(e)}" 812 | } 813 | }))) 814 | finally: 815 | # 关闭事件循环 816 | loop.close() 817 | 818 | # 新增的 API 方法:创建新版本 819 | @app.post("/create_version") 820 | async def create_version(file_path: str = Body(..., embed=True), old_version: str = Body(..., embed=True), description: str = Body(..., embed=True)): 821 | try: 822 | # 创建会话 823 | session = SessionLocal() 824 | 825 | # 解析旧版本号 826 | major, minor, patch = map(int, old_version.split('.')) 827 | 828 | # 递增版本号 829 | new_version = f"{major}.{minor}.{patch + 1}" 830 | 831 | # 检查新版本号是否已存在 832 | existing_version = session.query(Version).filter(Version.version == new_version).first() 833 | if existing_version: 834 | return ResponseWrapper(status_code=400, detail="error", data={"message": f"Version {new_version} already exists"}) 835 | 836 | # 创建新版本实例 837 | new_version_instance = Version( 838 | version=new_version, 839 | file_path=file_path, 840 | description=description 841 | ) 842 | 843 | # 添加到会话并提交 844 | session.add(new_version_instance) 845 | session.commit() 846 | 847 | return ResponseWrapper(status_code=200, detail="success", data={"message": "Version created successfully", "new_version": new_version}) 848 | except Exception as e: 849 | session.rollback() 850 | return ResponseWrapper(status_code=500, detail="error", data={"message": str(e)}) 851 | finally: 852 | session.close() 853 | 854 | @app.get("/versions", response_model=ResponseWrapper) 855 | async def get_versions(): 856 | try: 857 | # 创建会话 858 | session = SessionLocal() 859 | 860 | # 查询所有版本并按创建时间排序 861 | versions = session.query(Version).order_by(Version.created_at.asc()).all() 862 | 863 | # 提取版本号 864 | version_list = [{"version": version.version, "file_path": version.file_path, "description": version.description, "created_at": version.created_at} for version in versions] 865 | 866 | return ResponseWrapper(status_code=200, detail="success", data={"versions": version_list}) 867 | except Exception as e: 868 | return ResponseWrapper(status_code=500, detail="error", data={"message": str(e)}) 869 | finally: 870 | session.close() 871 | 872 | @app.get("/health",response_model=ResponseWrapper) 873 | async def health_check(): 874 | return ResponseWrapper(status_code=200, detail="success", data={"status": "healthy"}) 875 | 876 | # 添加一个 API 端点查询优化任务状态 877 | @app.get("/optimization_status/{task_id:path}", response_model=ResponseWrapper) 878 | async def get_optimization_status(task_id: str): 879 | try: 880 | if task_id in optimization_tasks: 881 | status = optimization_tasks[task_id] 882 | return ResponseWrapper( 883 | status_code=200, 884 | detail="success", 885 | data={ 886 | "task_id": task_id, 887 | "status": status 888 | } 889 | ) 890 | else: 891 | return ResponseWrapper( 892 | status_code=404, 893 | detail="error", 894 | data={"message": f"未找到对应的优化任务: {task_id}"} 895 | ) 896 | except Exception as e: 897 | return ResponseWrapper( 898 | status_code=500, 899 | detail="error", 900 | data={"message": f"查询失败: {str(e)}"} 901 | ) 902 | 903 | if __name__ == "__main__": 904 | import uvicorn 905 | uvicorn.run(app, host="0.0.0.0", port=8080) 906 | -------------------------------------------------------------------------------- /broadcast.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect 2 | from loguru import logger 3 | # 定义一个 WebSocket 连接管理器 4 | class ConnectionManager: 5 | def __init__(self): 6 | self.active_connections = [] 7 | 8 | async def connect(self, websocket: WebSocket): 9 | await websocket.accept() 10 | self.active_connections.append(websocket) 11 | 12 | def disconnect(self, websocket: WebSocket): 13 | self.active_connections.remove(websocket) 14 | 15 | async def broadcast(self, message: str): 16 | dead_connections = [] 17 | for connection in self.active_connections: 18 | try: 19 | await connection.send_text(message) 20 | logger.info(f"已向客户端推送消息: {message}") 21 | except Exception as e: 22 | dead_connections.append(connection) 23 | logger.error(f"广播消息时出错: {str(e)}") 24 | continue 25 | 26 | # 清理已断开的连接 27 | for dead_connection in dead_connections: 28 | try: 29 | self.active_connections.remove(dead_connection) 30 | except ValueError: 31 | pass -------------------------------------------------------------------------------- /docs/demo_18.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "拉丁学名": "Eptatretus burgeri", 4 | "命名年份": "1855", 5 | "作者": "Girard", 6 | "中文学名": "蒲氏黏盲鳗", 7 | "界": "动物界", 8 | "门": "脊索动物门", 9 | "纲": "圆口纲", 10 | "目": "盲鳗目", 11 | "科": "盲鳗科", 12 | "属": "黏盲鳗属", 13 | "种": "蒲氏黏盲鳗", 14 | "自然分布地": "我国台湾东北海域、黄海南部、东海,以及日本南部海域", 15 | "生活习性": "蒲氏黏盲鳗栖息于水深达200米的海域,肉食性,营寄生生活。繁殖期在秋、冬季,亲鱼游向深水产卵。卵粒大,长椭球形,附着于海藻、礁石上发育。", 16 | "生物特征": "蒲氏黏盲鳗体鳗形,口圆形,端位,具短须。眼退化,埋于皮下。分泌黏液多。体黄褐色,以鱼体背中线处具1条白色纵带为特征。体长为40-60厘米。鱼体含有一种芳胺类物质,即黏盲鳗素,对心脏有刺激起博等作用。" 17 | }, 18 | { 19 | "拉丁学名": "Eptatretus okinoseanus", 20 | "命名年份": "1904", 21 | "中文学名": "紫黏盲鳗", 22 | "作者": "Dean", 23 | "界": "动物界", 24 | "门": "脊索动物门", 25 | "纲": "圆口纲", 26 | "目": "盲鳗目", 27 | "科": "盲鳗科", 28 | "属": "黏盲鳗属", 29 | "种": "紫黏盲鳗", 30 | "自然分布地": "我国南海北部、东海,以及日本南部海域", 31 | "生活习性": "紫黏盲鳗栖息于水深200至600米的海域。其生态习性与蒲氏黏盲鳗相似,肉食性,营寄生生活。繁殖期在秋、冬季,亲鱼游向深水产卵。卵粒大,长椭球形,附着于海藻、礁石上发育。", 32 | "生物特征": "紫黏盲鳗体鳗形,体长为60至80厘米。体紫黑色,外鳃孔8对,眼表皮、外鳃孔和腹侧正中皮褶白色。" 33 | }, 34 | { 35 | "拉丁学名": "Eptatretus rubicundus", 36 | "命名年份": "2010", 37 | "作者": "Lee et Mok", 38 | "中文学名": "红盲鳗", 39 | "界": "动物界", 40 | "门": "脊索动物门", 41 | "纲": "圆口纲", 42 | "目": "盲鳗目", 43 | "科": "盲鳗科", 44 | "属": "红盲鳗属", 45 | "种": "红盲鳗", 46 | "自然分布地": "中国台湾海域", 47 | "生活习性": "红身黏盲鳗生活在深海区域,通常在海底泥沙中栖息。它们通过分泌大量黏液来防御捕食者。繁殖季节和繁殖特点尚不明确。", 48 | "生物特征": "红身黏盲鳗体形较粗壮,呈鳗形。鼻管末端有1对小瓣膜,无眼点。体全身粉红色,背中线处无白色纵带。具5对外鳃孔,与鳃区黏液孔均呈一直线排列。黏液孔多达100-102个。口漏斗状,外缘具3枚多尖齿和7枚单尖齿,内缘具2枚多尖齿及7枚单尖齿。" 49 | }, 50 | { 51 | "拉丁学名": "Eptatretus nelsoni", 52 | "命名年份": "1994", 53 | "作者": "Huang et Mok", 54 | "中文学名": "纽氏黏盲鳗", 55 | "界": "动物界", 56 | "门": "脊索动物门", 57 | "纲": "圆口纲", 58 | "目": "盲鳗目", 59 | "科": "盲鳗科", 60 | "属": "黏盲鳗属", 61 | "种": "纽氏黏盲鳗", 62 | "自然分布地": "中国台湾海域", 63 | "生活习性": "纽氏黏盲鳗为深水半寄生种,通常生活在深海区域。它们通过分泌黏液来防御捕食者。繁殖季节和繁殖特点尚不明确。", 64 | "生物特征": "纽氏黏盲鳗体细长,眼退化,埋于皮下。口漏斗状,有口须。鳃孔每侧4个,呈堆状。每个鳃孔周缘均呈白色,左侧最后鳃孔扩大。鳃孔前有黏液孔19个,鳃孔上无黏液孔,躯干部有黏液孔35个,尾部有黏液孔8个。齿式8+3/2+7。体长约20 cm。" 65 | }, 66 | { 67 | "拉丁学名": "Eptatretus yangi", 68 | "命名年份": "1958", 69 | "作者": "Yangi Teng", 70 | "中文学名": "杨氏黏盲鳗", 71 | "界": "动物界", 72 | "门": "脊索动物门", 73 | "纲": "圆口纲", 74 | "目": "盲鳗目", 75 | "科": "盲鳗科", 76 | "属": "黏盲鳗属", 77 | "种": "杨氏黏盲鳗", 78 | "自然分布地": "中国台湾海域", 79 | "生活习性": "杨氏黏盲鳗栖息在较浅的海域,水深20-50米。它们是肉食性,营寄生生活。通常通过拖网捕获。", 80 | "生物特征": "杨氏黏盲鳗体鳗形。外鳃孔每侧5个,相互接近,不规则地排成一堆。鳃孔前有黏液孔16-23个,鳃孔上无黏液孔,躯干部有黏液孔42-47个,尾部有黏液孔8-11个。齿式(6-8)+3/2+(6-8)。体灰褐色,腹部灰色。体长约30 cm。" 81 | }, 82 | { 83 | "拉丁学名": "Eptatretus cheni", 84 | "命名年份": "1975", 85 | "作者": "Shen et Tao", 86 | "中文学名": "陈氏黏盲鳗", 87 | "界": "动物界", 88 | "门": "脊索动物门", 89 | "纲": "圆口纲", 90 | "目": "盲鳗目", 91 | "科": "盲鳗科", 92 | "属": "黏盲鳗属", 93 | "种": "陈氏黏盲鳗", 94 | "自然分布地": "中国台湾海域", 95 | "生活习性": "陈氏黏盲鳗为深水半寄生种,通常生活在深海区域。它们通过分泌黏液来防御捕食者。繁殖季节和繁殖特点尚不明确。", 96 | "生物特征": "陈氏黏盲鳗体鳗形,眼已退化,埋于皮下。口漏斗状,口缘有须。鳃孔每侧5个,呈直线排列。各鳃孔间距离短。鳃孔前有26个黏液孔,鳃孔上无黏液孔,躯干部有45-47个黏液孔,尾部有7-8个黏液孔。齿式(9-11)+3/2+11。体呈暗灰色。体长约16 cm。" 97 | }, 98 | { 99 | "拉丁学名": "Eptatretus sheni", 100 | "命名年份": "1994", 101 | "作者": "Huang et Mok", 102 | "中文学名": "沈氏黏盲鳗", 103 | "界": "动物界", 104 | "门": "脊索动物门", 105 | "纲": "圆口纲", 106 | "目": "盲鳗目", 107 | "科": "盲鳗科", 108 | "属": "黏盲鳗属", 109 | "种": "沈氏黏盲鳗", 110 | "自然分布地": "中国台湾海域", 111 | "生活习性": "沈氏黏盲鳗为深水半寄生种,通常生活在深海区域。它们通过分泌黏液来防御捕食者。繁殖季节和繁殖特点尚不明确。", 112 | "生物特征": "沈氏黏盲鳗体细长似鳗,眼退化,埋于皮下,眼点较明显。口漏斗状,周缘具须。鳃孔每侧6个,呈直线状紧密排列。每个鳃孔周边均有白环。鳃孔前有13-18个黏液孔,鳃孔上有0-2个黏液孔,躯干部有39-46个黏液孔,尾部有8-12个黏液孔。齿式11+3/3+10。体略带褐色。体长约45 cm。" 113 | }, 114 | { 115 | "拉丁学名": " Eptatretus taiwanae", 116 | "命名年份": "1975", 117 | "作者": "Shen et Tao", 118 | "中文学名": "台湾黏盲鳗", 119 | "界": "动物界", 120 | "门": "脊索动物门", 121 | "纲": "圆口纲", 122 | "目": "盲鳗目", 123 | "科": "盲鳗科", 124 | "属": "黏盲鳗属", 125 | "种": "台湾黏盲鳗", 126 | "自然分布地": "中国台湾海域", 127 | "生活习性": "台湾黏盲鳗为深水半寄生种,通常生活在深海区域。它们通过分泌黏液来防御捕食者。繁殖季节和繁殖特点尚不明确。", 128 | "生物特征": "台湾黏盲鳗体细长似鳗,眼退化,埋于皮下。口漏斗状,周缘有须。鳃孔每侧6个,排列成2列,有的不甚规则。鳃孔前有16-19个黏液孔,鳃孔上无黏液孔,躯干部有36-42个黏液孔,尾部有6-9个黏液孔。齿式(6-8)+3/2+(6-7)。体淡红褐色,背缘褐色。" 129 | }, 130 | { 131 | "拉丁学名": "Eptatretus atami", 132 | "命名年份": "1904", 133 | "作者": "Dean", 134 | "中文学名": "阿塔氏黏盲鳗", 135 | "界": "动物界", 136 | "门": "脊索动物门", 137 | "纲": "圆口纲", 138 | "目": "盲鳗目", 139 | "科": "盲鳗科", 140 | "属": "黏盲鳗属", 141 | "种": "阿塔氏黏盲鳗", 142 | "自然分布地": "分布于黄海南部海域,以及日本青森以南海域。", 143 | "生活习性": "阿塔氏黏盲鳗产卵期为4-8月份;怀卵量少,为15-30粒。卵径为25-26 mm;卵两端密生附着丝。属肉食性,以其他鱼类和底栖动物为食。", 144 | "生物特征": "阿塔氏黏盲鳗体鳗形。和蒲氏黏盲鳗相似,鳃孔亦为6对,但外鳃孔相互靠近,并呈两列不规则排列。齿式(10-13)+3/2+(10-13)。从头部到尾部腹面并排有两列黏液孔,能分泌大量黏液。体茶褐色,外鳃孔周缘白色。体长50 cm左右。肉可食,皮可制革。" 145 | }, 146 | { 147 | "拉丁学名": "Myxine formosana", 148 | "命名年份": "2001", 149 | "作者": "Mok et Kuo", 150 | "中文学名": "台湾盲鳗", 151 | "界": "动物界", 152 | "门": "脊索动物门", 153 | "纲": "圆口纲", 154 | "目": "盲鳗目", 155 | "科": "盲鳗科", 156 | "属": "盲鳗属", 157 | "种": "台湾盲鳗", 158 | "自然分布地": "分布于我国东海、台湾海域", 159 | "生活习性": "台湾盲鳗为半寄生深水种,生活在深水区域。体长70 cm左右。", 160 | "生物特征": "台湾盲鳗体鳗形,眼退化。外鼻孔1个,开口于吻端。口漏斗状,口缘有短须。每侧1个鳃孔,常具白缘。鳃囊5个。鳃孔前有26-32个黏液孔,鳃孔上无黏液孔,躯干部有54-58个黏液孔,尾部有14个黏液孔。齿较多,齿式(8-12)+3/2+(8-12)。体暗灰色。" 161 | }, 162 | { 163 | "拉丁学名": "Lethenteron camtschaticum ", 164 | "命名年份": "1811", 165 | "作者": "Tilesius", 166 | "中文学名": "东亚叉牙七鳃鳗", 167 | "界": "动物界", 168 | "门": "脊索动物门", 169 | "纲": "圆口纲", 170 | "目": "七鳃鳗目", 171 | "科": "七鳃鳗科", 172 | "属": "叉牙七鳃鳗属", 173 | "种": "东亚叉牙七鳃鳗", 174 | "自然分布地": "分布于中国的黑龙江、图们江流域,偶见于鸭绿江口及江苏近岸水域;以及日本海。", 175 | "生活习性": "东亚叉牙七鳃鳗为洄游性鱼形动物。其幼体至成体栖息于海中,营半寄生生活。性成熟个体于冬季上溯至河口,翌年5-6月在河中筑巢产卵,每次产8万-10万粒,卵黏附于沙砾上发育。亲体产卵后死亡。", 176 | "生物特征": "东亚叉牙七鳃鳗体鳗形。口腹位,呈吸盘状,周缘具穗状突起。眼发达,位于头前部。鳃孔7对,位于眼后,故又有“八目鳗”之称。两背鳍略分离,第2背鳍较高而长,末端附近呈黑色。尾鳍矛状,褐色或黑色。体青绿色,腹部浅黄色或灰白色。体长为50-60 cm。" 177 | }, 178 | { 179 | "拉丁学名": "Lampetra reissneri", 180 | "命名年份": "1869", 181 | "作者": "Dybowski", 182 | "中文学名": "雷氏七鳃鳗", 183 | "界": "动物界", 184 | "门": "脊索动物门", 185 | "纲": "圆口纲", 186 | "目": "七鳃鳗目", 187 | "科": "七鳃鳗科", 188 | "属": "七鳃鳗属", 189 | "种": "雷氏七鳃鳗", 190 | "自然分布地": "中国的黄海北部、黑龙江、松花江、图们江流域,以及日本海等", 191 | "生活习性": "雷氏七鳃鳗其生态习性仍不甚明了。孟庆闻等(1995)、尼科尔斯基(1960)、朱元鼎等(2001)都认为其是不进行洄游的淡水种类,并有在兴凯湖等繁殖的具体报告。而中坊徹次(1993)、刘瑞玉(2008)明确其为江海洄游种类。益田一(1984)记述其幼体从夏到冬完成变态,随后降海潜泥底生活;翌春产卵。", 192 | "生物特征": "雷氏七鳃鳗形态特征与日本七鳃鳗相似,只是吻较宽短,上口齿板齿钝尖,两背鳍连续,尾鳍色较淡。体长可达20 cm,但以小型个体较多见。" 193 | }, 194 | { 195 | "拉丁学名": "Chimaera phantasma", 196 | "命名年份": "1900", 197 | "作者": "Jordan et Snyder", 198 | "中文学名": "黑线银鲛", 199 | "界": "动物界", 200 | "门": "脊索动物门", 201 | "纲": "软骨鱼纲", 202 | "目": "银鲛目", 203 | "科": "银鲛科", 204 | "属": "银鲛属", 205 | "种": "黑线银鲛", 206 | "自然分布地": "分布于我国东海、黄海、台湾海域,以及日本北海道以南海域、朝鲜半岛西南部海域。", 207 | "生活习性": "黑线银鲛属冷温性较深水分布种,栖息水深90-500米。冬季向近海洄游。卵生,卵大且呈纺锤形。主食软体动物。", 208 | "生物特征": "黑线银鲛头高而侧扁。体侧扁,延长,向后细小。尾呈鞭状。雄性的眼前上方具一柄状额鳍脚,腹鳍内侧具一三叉形鳍脚。吻短。口横裂。上颌前齿板喙状;侧齿板宽大,呈三角形。背鳍2个,以低鳍膜相连,第1背鳍具一扁长硬棘。臀鳍低平,后端尖突,与尾鳍下叶分隔处有一缺刻。侧线小波曲状。体银白色,头上部、第1至第2背鳍上部、背侧上部褐色。侧线下方,胸鳍、腹鳍间有一黑色纵带。全长可达1米。" 209 | }, 210 | { 211 | "拉丁学名": "Hydrolagus mitsukurii", 212 | "命名年份": "1904", 213 | "作者": "Jordan et Snyder", 214 | "中文学名": "箕作兔银鲛", 215 | "界": "动物界", 216 | "门": "脊索动物门", 217 | "纲": "软骨鱼纲", 218 | "目": "银鲛目", 219 | "科": "银鲛科", 220 | "属": "兔银鲛属", 221 | "种": "箕作兔银鲛", 222 | "自然分布地": "分布于我国东海、南海,以及日本南部海域、冲绳海漕。", 223 | "生活习性": "箕作兔银鲛栖息水深600-900米。", 224 | "生物特征": "箕作兔银鲛为兔银鲛属鱼类,以臀鳍缺失与银鲛属鱼类相区别。上颌具6枚齿板。第1背鳍鳍棘后缘呈锯齿状,其长度几乎与头长相等。头前部高耸。尾鳍丝状部显著比头长。雄鱼交尾器分2支。体褐色,腹部色略浅,各鳍呈褐色。体侧具若干与侧线平行的浅色纵带。全长58-85厘米。" 225 | }, 226 | { 227 | "拉丁学名": "Hydrolagus purpurescens", 228 | "命名年份": "1905", 229 | "作者": "Gilbert", 230 | "中文学名": "紫银鲛", 231 | "界": "动物界", 232 | "门": "脊索动物门", 233 | "纲": "软骨鱼纲", 234 | "目": "银鲛目", 235 | "科": "银鲛科", 236 | "属": "兔银鲛属", 237 | "种": "紫银鲛", 238 | "自然分布地": "分布于我国南海,以及日本岩手海域、美国夏威夷海域等。", 239 | "生活习性": "紫银鲛栖息水深1120-1920米。", 240 | "生物特征": "紫银鲛无臀鳍。尾鳍下叶前无缺刻。头前部缓尖。第1背鳍鳍棘后缘光滑。尾鳍丝状部较短。侧线有小波纹状弯曲,但侧线上部无短横带。雄鱼交尾器分3支。体褐色,略带紫色,无斑点。全长约80厘米。" 241 | }, 242 | { 243 | "拉丁学名": "Hydrolagus ogilbyi", 244 | "命名年份": "1898", 245 | "作者": "Waite", 246 | "中文学名": "奥氏兔银鲛", 247 | "界": "动物界", 248 | "门": "脊索动物门", 249 | "纲": "软骨鱼纲", 250 | "目": "银鲛目", 251 | "科": "兔银鲛科", 252 | "属": "兔银鲛属", 253 | "种": "奥氏兔银鲛", 254 | "自然分布地": "中国黄海、东海、南海、台湾海域,日本南部海域,澳大利亚海域,印度-西太平洋", 255 | "生活习性": "奥氏兔银鲛海洋底栖鱼类,栖息水深120-350米。游泳能力弱,以小型底栖动物为食。偶见于底拖网渔获。", 256 | "生物特征": "奥氏兔银鲛有低平臀鳍,并与尾鳍下叶相连,无缺刻。侧线波纹状,侧线上方具许多短横纹。体淡褐色,腹部色浅。背部有一宽纵带。侧线灰色,各鳍略呈褐色。全长75-95厘米。该鱼背鳍硬棘中空,具毒腺。" 257 | }, 258 | { 259 | "拉丁学名": "Rhinochimaera pacifica", 260 | "命名年份": "1895", 261 | "作者": "Mitsukuri", 262 | "中文学名": "太平洋长吻银鲛", 263 | "界": "动物界", 264 | "门": "脊索动物门", 265 | "纲": "软骨鱼纲", 266 | "目": "银鲛目", 267 | "科": "长吻银鲛科", 268 | "属": "长吻银鲛属", 269 | "种": "太平洋长吻银鲛", 270 | "自然分布地": "中国东海、南海,日本北海道以南海域,新西兰海域,秘鲁海域", 271 | "生活习性": "太平洋长吻银鲛深海底层鱼类,通常栖息水深750-1100米。卵生。", 272 | "生物特征": "太平洋长吻银鲛吻尖长,呈剑状,基部近侧扁。口小,横裂,有齿板3对。躯体侧扁,延长。无臀鳍。雄鱼交尾器棒状。尾鳍上叶为肉质,其边缘有1列30-50个齿状突起。侧线几乎平直。体与各鳍均为黑褐色。全长可达1.3米。" 273 | }, 274 | { 275 | "拉丁学名": "Rhinochimaera africana", 276 | "命名年份": "1990", 277 | "作者": "Stehmann et Ebert", 278 | "中文学名": "非洲长吻银鲛", 279 | "界": "动物界", 280 | "门": "脊索动物门", 281 | "纲": "软骨鱼纲", 282 | "目": "银鲛目", 283 | "科": "长吻银鲛科", 284 | "属": "长吻银鲛属", 285 | "种": "非洲长吻银鲛", 286 | "自然分布地": "中国台湾海域,印度-太平洋深水域", 287 | "生活习性": "非洲长吻银鲛为深海鱼类。", 288 | "生物特征": "非洲长吻银鲛体延长,侧扁。吻长,平扁,吻长大于或等于头长的2倍。眼上侧线管与眼下侧线管于吻腹面交会,交会点到吻端较到鼻管为近。第1背鳍鳍棘较软条为长。胸鳍大。尾鳍上叶无小棘。体黑色。" 289 | } 290 | ] -------------------------------------------------------------------------------- /dspy_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dspy 3 | from dotenv import load_dotenv 4 | from loguru import logger 5 | 6 | # 确保环境变量已加载 7 | load_dotenv(override=True) 8 | 9 | class DspyEvaluationProcessor: 10 | def __init__(self): 11 | # 初始化评估用的语言模型 12 | self.eval_lm = dspy.LM( 13 | f'{os.getenv("Train_LLM_TYPE")}/{os.getenv("Train_LLM_MODEL")}', 14 | base_url=os.getenv("Train_OPENAI_BASE_URL"), 15 | api_key=os.getenv("Train_OPENAI_API_KEY"), 16 | stream=True # 直接在创建模型时启用流式模式 17 | ) 18 | # 移除全局配置,避免影响其他模块 19 | # dspy.configure(lm=self.eval_lm) 20 | 21 | # 评估相关功能 22 | class BiologicalRetrievalEvaluation(dspy.Signature): 23 | """评估生物检索任务的推理步骤质量""" 24 | question = dspy.InputField(desc="用户的查询问题") 25 | standard_reasoning = dspy.InputField(desc="标准的推理步骤") 26 | predicted_reasoning = dspy.InputField(desc="模型产生的推理步骤") 27 | evaluation_score = dspy.OutputField(desc="评分(0-100)") 28 | evaluation_feedback = dspy.OutputField(desc="详细的评分解释,包括各个方面的得分") 29 | 30 | class LLMBiologicalEvaluator(dspy.Module): 31 | def __init__(self, eval_lm): 32 | super().__init__() 33 | # 使用传入的评估模型 34 | self.eval_lm = eval_lm 35 | 36 | # 使用思维链方式增强评估能力,直接提供指令,并使用专用评估模型 37 | self.eval_chain = dspy.ChainOfThought( 38 | DspyEvaluationProcessor.BiologicalRetrievalEvaluation, 39 | instructions=""" 40 | 您是一位专业的生物检索质量评估专家。您的任务是评估模型产生的生物检索推理步骤质量。 41 | 42 | 请根据以下标准进行评分(总分100分): 43 | 44 | 1. 检索条件识别准确性 (20分) 45 | - 是否正确识别了所有检索条件 46 | - 是否正确区分了精确条件和模糊条件 47 | 48 | 2. 需求识别准确性 (10分) 49 | - 是否正确识别了查询中的所有需求 50 | 51 | 3. 检索策略合理性 (40分) 52 | - 是否先执行精确检索,后执行模糊检索 (10分) 53 | - 后续检索步骤是否基于前面步骤的结果 (10分) 54 | - 筛选顺序是否从限制性强的条件开始 (10分) 55 | - 前面步骤检索到的内容是否把全部信息传递给后面检索所使用的工具 (10分) 56 | 57 | 4. 结果整合正确性和完整性 (30分) 58 | - 答案准确性,与标准答案相比,核心事实是否一致,即使表达方式不同,只要核心信息正确也应得高分 (25分) 59 | - 提取所有需要汇总的信息 (5分) 60 | 61 | 评估须知: 62 | 1. 在评估答案准确性时,请比对预测结果与标准答案的内容,理解语义等价性而非只做字符匹配 63 | 2. 即使表达方式不同,只要内容实质相同,也应给予高分 64 | 3. 同一事实的不同表述方式应被视为正确,如"共有3种"和"总共有三种"表达的是相同含义 65 | 4. 对每个评分维度提供详细分析和具体理由 66 | """ 67 | ) 68 | # 显式设置评估链使用评估模型 69 | self.eval_chain.lm = self.eval_lm 70 | 71 | def forward(self, example, prediction): 72 | """评估预测与标准答案的匹配程度 73 | 74 | Args: 75 | example: 包含标准答案的示例 76 | prediction: 模型的预测结果 77 | 78 | Returns: 79 | float: 0-1之间的分数 80 | """ 81 | # 如果没有推理步骤,使用简单的答案匹配 82 | if not hasattr(example, 'reasoning') or not hasattr(prediction, 'reasoning'): 83 | return 1.0 if dspy.evaluate.answer_exact_match(example, prediction) else 0.0 84 | 85 | # 准备标准推理步骤 86 | standard_reasoning = "\n".join(example.reasoning) if isinstance(example.reasoning, list) else example.reasoning 87 | 88 | # 获取预测的推理步骤 89 | predicted_reasoning = prediction.reasoning if hasattr(prediction, 'reasoning') else "" 90 | 91 | try: 92 | # 直接使用评估链,不再使用 context 管理器 93 | # 因为我们已经在创建模型时启用了流式模式,并显式设置了评估链使用评估模型 94 | evaluation = self.eval_chain( 95 | question=example.question, 96 | standard_reasoning=standard_reasoning, 97 | predicted_reasoning=predicted_reasoning 98 | ) 99 | 100 | # 将分数从0-100转换为0-1 101 | try: 102 | score = float(evaluation.evaluation_score) / 100.0 103 | # 边界处理 104 | score = max(0.0, min(1.0, score)) 105 | logger.info(f"评估结果: {score} (问题: {example.question[:30]}...)") 106 | return score 107 | except: 108 | # 如果分数转换失败,默认返回0.5 109 | logger.warning(f"评估分数转换失败,使用默认分数0.5") 110 | return 0.5 111 | except Exception as e: 112 | logger.error(f"评估过程出错: {str(e)}") 113 | # 出错时返回默认分数 114 | return 0.5 115 | 116 | def llm_biological_metric(self, example, pred, trace=None, frac=1.0): 117 | """使用大模型评估函数""" 118 | try: 119 | # 创建评估器实例,传入评估模型 120 | evaluator = self.LLMBiologicalEvaluator(self.eval_lm) 121 | 122 | # 确保在 litellm 客户端级别启用流式模式 123 | if hasattr(evaluator.eval_lm, 'client'): 124 | evaluator.eval_lm.client.stream = True 125 | logger.info("已在评估模型客户端级别启用流式模式") 126 | 127 | # 执行评估 128 | result = evaluator(example, pred) 129 | return result 130 | except Exception as e: 131 | logger.error(f"评估指标计算出错: {str(e)}") 132 | # 出错时返回默认分数 133 | return 0.5 -------------------------------------------------------------------------------- /dspy_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dspy 3 | from react_tools import ReActTools, GraphVectorizer 4 | from dotenv import load_dotenv 5 | from dspy_query_db import MarineSpeciesQuery 6 | import json 7 | from loguru import logger 8 | 9 | 10 | # 确保环境变量已加载 11 | load_dotenv(override=True) 12 | 13 | MAX_ITERS = int(os.getenv("MAX_ITERS","10")) 14 | 15 | class DspyInferenceProcessor: 16 | def __init__(self): 17 | self.ragtool = ReActTools() 18 | self.graphvectorizer = GraphVectorizer() 19 | self.query_processor = MarineSpeciesQuery(os.getenv("SPECIES_DB_URL","./.dbs/marine_species.db")) 20 | # 初始化语言模型 21 | self.lm = dspy.LM( 22 | f'{os.getenv("LLM_TYPE")}/{os.getenv("LLM_MODEL")}', 23 | base_url=os.getenv("BASE_URL"), 24 | api_key=os.getenv("API_KEY") 25 | ) 26 | 27 | # 配置 dspy 使用该语言模型 28 | dspy.configure(lm=self.lm) 29 | 30 | # 初始化版本号 31 | self.predictor_version = "1.0.0" 32 | # 初始化 RactModel 33 | self.model = self.RactModel(self) 34 | # 使用 streamify 包装,获得支持流式返回的模块 35 | self.streaming_model = dspy.streamify(self.model) 36 | 37 | def find_nodes_by_node_type(self, start_node, trget_node_type): 38 | ''' 39 | 此方法会根据传入的节点名称,在图数据中以该节点为起点查找包含指定节点类型的节点列表,并返回节点数量与节点列表。 40 | start_node 为开始查找的树节点名称,只允许单个节点、 41 | trget_node_type 目标节点类型,只允许单个类型名称 42 | 返回值为从该节点开始,包含指定属性名的节点数量与节点列表 43 | 已知图数据中存在一系列的海洋生物相关信息: 44 | 1. ⽣物分类学图数据:包括"拉丁学名", "命名年份", "作者", "中文学名", 45 | 2. ⽣物科属于数据:"界", "门", "纲", "目", "科", "属", "种"(种即是中文学名),它们的从属关系是: 界 -> 门 -> 纲 -> 目 -> 科 ->属 ->种 。 46 | 3. ⽣物特征图数据:包括"自然分布地", "生物特征","生活习性"等。 47 | 本方法可以根据给定的节点名称,在图数据中以此节点为起点查找包含指定该属性的节点或节点列表,例如1:"盲鳗科" "种" 则会返回 盲鳗科所有的种,例如2:"盲鳗科" "界" 则会返回 盲鳗科对应的界, 。 48 | 4. 因为本方法需要的参数是精准的节点属性名称(或节点类型名),建议查询的节点类型属于"自然分布地", "生物特征", "生活习性"等时,或查询返回为空时、查询失败时,先通过get_unique_vector_query_results方法获取准确的节点名称,再通过本方法获取对应的节点信息。 49 | 50 | Args: 51 | start_node: 开始查找的节点名称 52 | trget_node_type: 目标节点类型 53 | Returns: 54 | count: 节点数量 55 | nodes: 节点列表 56 | ''' 57 | nodes = self.ragtool.find_nodes_by_node_type(start_node, trget_node_type) 58 | # 如果nodes为空,则返回0,不为为空时,则返回节点数量与节点列表 59 | if not nodes: 60 | return 0,[] 61 | count = len(nodes) 62 | return count,nodes 63 | 64 | def batch_find_nodes_by_node_type(self, start_nodes, trget_node_type): 65 | ''' 66 | 此方法会根据传入包含多个开始节点的列表,批量查询指定目标节点类型的节点列表,返回多条查询的结果集。 67 | Args: 68 | start_nodes: 开始查找的节点名称列表 69 | trget_node_type: 目标节点类型 70 | Returns: 71 | traget_nodes_list: 多条查询结果的列表 72 | ''' 73 | # 字典格式为,key为节点名称,value为包含指定属性名的节点数量与节点列表 74 | traget_nodes_list = {} 75 | for node in start_nodes: 76 | count, nodes = self.find_nodes_by_node_type(start_nodes, trget_node_type) 77 | traget_nodes_list[node] = {"count": count, "nodes": nodes} 78 | return traget_nodes_list 79 | 80 | def get_unique_vector_query_results(self, query, node_type=None, search_type="all", top_k=1, better_than_threshold=0.65): 81 | """通过向量搜索,获取与查询最相关的实体或关系 82 | Args: 83 | query: 搜索查询文本 84 | node_type: 实体类型筛选条件,如果为None则不筛选。可选值包括: 85 | - species (种、中文名) 86 | - 界 87 | - 门 88 | - 纲 89 | - 目 90 | - 科 91 | - 属 92 | - 位置 93 | - 繁殖特征 94 | - 行为特征 95 | - 体型 96 | - 体色 97 | - 体长 98 | - 特殊特征 99 | search_type: 搜索类型,'all'/'entity'/'relation' 100 | top_k: 返回结果的数量 101 | better_than_threshold: 相似度阈值,只返回相似度高于此值的结果 102 | Returns: 103 | list: 搜索结果,精准的实体名列表 104 | """ 105 | try: 106 | # 添加超时控制 107 | import asyncio 108 | from concurrent.futures import ThreadPoolExecutor 109 | 110 | # 使用线程池执行可能耗时的操作 111 | with ThreadPoolExecutor() as executor: 112 | # 设置超时时间(例如10秒) 113 | future = executor.submit(self.graphvectorizer.search, query, node_type, search_type, top_k, better_than_threshold) 114 | try: 115 | result = future.result(timeout=10) # 10秒超时 116 | return result 117 | except TimeoutError: 118 | logger.error(f"向量搜索超时: query={query}, node_type={node_type}") 119 | return [] # 超时返回空列表 120 | except Exception as e: 121 | # 捕获所有异常,确保不会导致整个流程崩溃 122 | logger.error(f"向量搜索出错: {str(e)}, query={query}, node_type={node_type}") 123 | return [] # 出错返回空列表 124 | 125 | def get_node_attribute(self,node_id): 126 | ''' 127 | 根据节点id获取所有属性,包括中文学名、拉丁学名、命名年份、作者、node_type 128 | Args: 129 | node_id: 节点id 130 | Returns: 131 | list: 属性列表 132 | ''' 133 | return self.ragtool.get_node_attribute(node_id) 134 | def get_adjacent_node_descriptions(self, nodenames): 135 | ''' 136 | 此方法会根据传入的节点列表,获取每个节点相邻所有节点描述,合并到一个列表中返回,非精准检索,谨慎使用 137 | Args: 138 | nodenames: 节点名称列表 139 | Returns: 140 | list: 相邻节点描述列表 141 | ''' 142 | return self.ragtool.get_adjacent_node_descriptions(nodenames) 143 | 144 | def nodes_count(self, nodes): 145 | ''' 146 | 此方法会根据传入的节点列表,统计数量,返回数量 147 | Args: 148 | nodes: 节点列表 149 | Returns: 150 | int: 节点数量 151 | ''' 152 | if not nodes: 153 | return 0 154 | return len(nodes) 155 | 156 | def MarineSpeciesQuery(self,query): 157 | """根据自然语言查询数据库 158 | Args: 159 | natural_language_query: 用户的自然语言查询 160 | 161 | Returns: 162 | 查询结果和解释 163 | """ 164 | result = self.query_processor.query_database(query) 165 | return self.query_processor.format_query_results(result) 166 | 167 | # 定义签名类 168 | class MarineBiologyKnowledgeQueryAnswer(dspy.Signature): 169 | """ 170 | 针对复杂检索问题的增强签名。 171 | 此签名能够: 172 | 1. 分析用户问题,提取精确检索条件和模糊检索条件 173 | 2. 确定检索顺序和优先级策略 174 | 3. 对多实体结果进行遍历查询 175 | 4. 按照检索需求有序组织答案 176 | """ 177 | # 输入字段 178 | question = dspy.InputField(desc="用户的原始问题") 179 | # 输出字段 180 | answer = dspy.OutputField(desc="根据检索结果综合形成的完整答案,确保涵盖所有检索需求,使用中文回复") 181 | 182 | # 建议添加的问题分类签名 183 | class QuestionClassifier(dspy.Signature): 184 | """对用户问题进行分类""" 185 | question = dspy.InputField(desc="用户的原始问题") 186 | question_type = dspy.OutputField(desc="问题类型,可能的值包括:实体查询/关系查询/属性查询/统计查询等") 187 | search_strategy = dspy.OutputField(desc="建议的检索策略:向量检索/图检索/混合检索") 188 | key_entities = dspy.OutputField(desc="问题中的关键实体列表") 189 | 190 | # 定义 RactModel 类 191 | class RactModel(dspy.Module): 192 | def __init__(self, processor): 193 | super().__init__() 194 | # 保存外部类的引用 195 | self.processor = processor 196 | # 利用 ReAct 将工具函数集成进来 197 | self.react = dspy.ReAct( 198 | DspyInferenceProcessor.MarineBiologyKnowledgeQueryAnswer, 199 | max_iters = MAX_ITERS, 200 | tools=[ 201 | processor.find_nodes_by_node_type, 202 | processor.batch_find_nodes_by_node_type, 203 | processor.get_unique_vector_query_results, 204 | processor.get_node_attribute, 205 | processor.get_adjacent_node_descriptions, 206 | processor.nodes_count 207 | ] 208 | ) 209 | 210 | def forward(self, question): 211 | return self.react(question=question) 212 | 213 | def get_last_message(self): 214 | """获取最后一条消息历史""" 215 | return self.lm.history[-1] if self.lm.history else None 216 | 217 | def load_model(self, file_path): 218 | """加载指定版本的模型""" 219 | result = self.model.load(file_path) 220 | # 加载模型后清除缓存 221 | dspy.settings.configure(cache=None) 222 | return result 223 | 224 | def set_version(self, version): 225 | """设置当前预测器版本""" 226 | self.predictor_version = version 227 | 228 | def get_version(self): 229 | """获取当前预测器版本""" 230 | return self.predictor_version 231 | 232 | def predict(self, question): 233 | """非流式预测""" 234 | return self.model(question=question) 235 | 236 | def stream_predict(self, question): 237 | """流式预测,实现真正的增量输出""" 238 | try: 239 | # 创建一个跟踪状态的对象 240 | class StreamState: 241 | def __init__(self): 242 | self.last_answer = "" 243 | self.last_reasoning = "" 244 | self.is_first_chunk = True 245 | 246 | state = StreamState() 247 | 248 | # 使用 dspy 的流式模型获取结果 249 | async def real_stream(): 250 | # 首先发送一个空的初始状态 251 | if state.is_first_chunk: 252 | initial_prediction = type('Prediction', (), { 253 | 'answer': '', 254 | 'reasoning': '思考中...' 255 | }) 256 | state.is_first_chunk = False 257 | yield initial_prediction 258 | 259 | # 启动非流式预测(在后台运行) 260 | import asyncio 261 | from concurrent.futures import ThreadPoolExecutor 262 | 263 | # 创建一个执行器来运行阻塞的预测 264 | with ThreadPoolExecutor() as executor: 265 | # 提交预测任务到线程池 266 | future = executor.submit(self.predict, question) 267 | 268 | # 每隔一小段时间检查一次结果,模拟流式输出 269 | while not future.done(): 270 | await asyncio.sleep(0.2) # 等待200毫秒 271 | # 发送思考中的状态 272 | thinking_prediction = type('Prediction', (), { 273 | 'answer': state.last_answer, 274 | 'reasoning': state.last_reasoning + "." # 添加一个点表示思考 275 | }) 276 | state.last_reasoning += "." 277 | yield thinking_prediction 278 | 279 | # 获取最终结果 280 | try: 281 | final_result = future.result() 282 | # 如果最终结果可用,分段返回 283 | if hasattr(final_result, 'answer') and hasattr(final_result, 'reasoning'): 284 | # 将答案和推理过程分成多个部分 285 | answer_parts = self._split_text(final_result.answer, 10) # 分成约10个部分 286 | reasoning_parts = self._split_text(final_result.reasoning, 5) # 分成约5个部分 287 | 288 | # 先返回完整的推理过程 289 | for i, reasoning_part in enumerate(reasoning_parts): 290 | current_reasoning = "".join(reasoning_parts[:i+1]) 291 | prediction = type('Prediction', (), { 292 | 'answer': state.last_answer, 293 | 'reasoning': current_reasoning 294 | }) 295 | state.last_reasoning = current_reasoning 296 | yield prediction 297 | await asyncio.sleep(0.1) # 短暂停顿 298 | 299 | # 然后逐步返回答案 300 | for i, answer_part in enumerate(answer_parts): 301 | current_answer = "".join(answer_parts[:i+1]) 302 | prediction = type('Prediction', (), { 303 | 'answer': current_answer, 304 | 'reasoning': final_result.reasoning 305 | }) 306 | state.last_answer = current_answer 307 | yield prediction 308 | await asyncio.sleep(0.1) # 短暂停顿 309 | except Exception as e: 310 | logger.error(f"获取预测结果时出错: {str(e)}") 311 | error_prediction = type('Prediction', (), { 312 | 'answer': '处理您的请求时出现错误', 313 | 'reasoning': f'发生错误: {str(e)}' 314 | }) 315 | yield error_prediction 316 | 317 | return real_stream() 318 | except Exception as e: 319 | logger.error(f"流式预测出错: {str(e)}") 320 | # 如果流式预测失败,尝试使用非流式预测 321 | try: 322 | logger.info("尝试使用非流式预测作为备选方案") 323 | result = self.predict(question) 324 | # 将非流式结果转换为可迭代对象以模拟流式返回 325 | async def mock_stream(): 326 | yield result 327 | return mock_stream() 328 | except Exception as e2: 329 | logger.error(f"备选预测也失败: {str(e2)}") 330 | raise e # 重新抛出原始异常 331 | 332 | def _split_text(self, text, num_parts): 333 | """将文本分成大约 num_parts 个部分""" 334 | if not text: 335 | return [""] 336 | 337 | # 计算每部分的大致长度 338 | part_length = max(1, len(text) // num_parts) 339 | parts = [] 340 | 341 | for i in range(0, len(text), part_length): 342 | parts.append(text[i:i + part_length]) 343 | 344 | return parts -------------------------------------------------------------------------------- /dspy_program/program_v1.0.1_20250313195723.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/dspy_program/program_v1.0.1_20250313195723.pkl -------------------------------------------------------------------------------- /dspy_program/program_v1.0.3_20250315154834.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/dspy_program/program_v1.0.3_20250315154834.pkl -------------------------------------------------------------------------------- /dspy_query_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | import dspy 5 | import sqlite3 6 | from dotenv import load_dotenv 7 | # 设置DSPy的语言模型 8 | def setup_dspy(): 9 | load_dotenv(override=True) 10 | 11 | if os.getenv("Train_LLM_MODEL"): 12 | Train = dspy.LM( 13 | f'deepseek/{os.getenv("Train_LLM_MODEL")}', 14 | base_url=os.getenv("Train_OPENAI_BASE_URL"), 15 | api_key=os.getenv("Train_OPENAI_API_KEY") 16 | ) 17 | dspy.settings.configure(lm=Train) 18 | else: 19 | # 默认使用OpenAI 20 | dspy.settings.configure(lm="openai") 21 | 22 | 23 | # 在已有的签名定义之后添加 24 | class NaturalLanguageToSQL(dspy.Signature): 25 | """将自然语言查询转换为SQL语句。注意:返回纯SQL文本,不要包含```sql或```等代码块标记。 26 | 重要:保持原始查询中的中文词汇不变,不要自动转换为拉丁文或英文。 27 | 当查询涉及到地理位置(distributions表中的location字段)时,必须使用LIKE语句而不是精确匹配, 28 | 例如:WHERE location LIKE '%东海%' 而不是 WHERE location = '东海'""" 29 | query = dspy.InputField(description="用户的自然语言查询") 30 | db_schema = dspy.InputField(description="数据库的表结构信息") 31 | sql = dspy.OutputField(description="生成的SQL查询语句,必须是纯SQL文本,对地理位置使用LIKE操作符") 32 | explanation = dspy.OutputField(description="SQL查询的解释") 33 | 34 | # 在已有的提取器类之后添加 35 | class SQLGenerator(dspy.Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.generator = dspy.ChainOfThought(NaturalLanguageToSQL) 39 | 40 | def forward(self, query, db_schema): 41 | return self.generator(query=query, db_schema=db_schema) 42 | 43 | # 查询相关类 44 | class MarineSpeciesQuery: 45 | def __init__(self, db_path): 46 | """初始化查询器 47 | 48 | Args: 49 | db_path: SQLite数据库文件路径 50 | """ 51 | self.db_path = db_path 52 | setup_dspy() 53 | 54 | def query_database(self, natural_language_query): 55 | """根据自然语言查询数据库 56 | 57 | Args: 58 | natural_language_query: 用户的自然语言查询 59 | 60 | Returns: 61 | 查询结果和解释 62 | """ 63 | # 先获取表中实际的值 64 | with sqlite3.connect(self.db_path) as conn: 65 | cursor = conn.cursor() 66 | cursor.execute("SELECT DISTINCT family FROM species") 67 | families = [row[0] for row in cursor.fetchall()] 68 | 69 | # 获取数据库表结构 70 | with sqlite3.connect(self.db_path) as conn: 71 | cursor = conn.cursor() 72 | 73 | # 获取所有表名 74 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 75 | tables = cursor.fetchall() 76 | 77 | db_schema = [] 78 | for table in tables: 79 | table_name = table[0] 80 | cursor.execute(f"PRAGMA table_info({table_name})") 81 | columns = cursor.fetchall() 82 | 83 | column_info = [] 84 | for col in columns: 85 | column_info.append({ 86 | "name": col[1], 87 | "type": col[2] 88 | }) 89 | 90 | db_schema.append({ 91 | "table": table_name, 92 | "columns": column_info 93 | }) 94 | 95 | db_schema_str = json.dumps(db_schema, ensure_ascii=False, indent=2) 96 | 97 | # 当拼接db_schema_enriched时,添加关于location的使用说明 98 | db_schema_enriched = json.dumps(db_schema, ensure_ascii=False, indent=2) 99 | 100 | # 添加额外使用提示 101 | location_usage_hint = """ 102 | 重要提示:当查询涉及地理位置时,请使用LIKE操作符而不是等号(=)。 103 | 例如: 104 | 正确: WHERE d.location LIKE '%东海%' 105 | 错误: WHERE d.location = '东海' 106 | 107 | 这是因为地理位置通常需要模糊匹配,一个物种可能分布在多个地区, 108 | 或者地理位置描述可能包含其他词汇。 109 | """ 110 | 111 | # 初始化SQL生成器 112 | sql_generator = SQLGenerator() 113 | 114 | # 生成SQL 115 | result = sql_generator(natural_language_query, db_schema_enriched + "\n" + location_usage_hint) 116 | 117 | # 清理SQL,移除Markdown代码块标记 118 | sql = result.sql 119 | sql = re.sub(r'```sql\s*', '', sql) # 移除开始的```sql 120 | sql = re.sub(r'\s*```\s*$', '', sql) # 移除结束的``` 121 | 122 | try: 123 | # 执行SQL查询 124 | print(f"执行SQL查询: {sql}") 125 | cursor.execute(sql) 126 | 127 | # 获取列名 128 | column_names = [description[0] for description in cursor.description] 129 | 130 | # 获取查询结果 131 | rows = cursor.fetchall() 132 | 133 | # 转换为字典列表 134 | results = [] 135 | for row in rows: 136 | result_dict = {} 137 | for i, col_name in enumerate(column_names): 138 | result_dict[col_name] = row[i] 139 | results.append(result_dict) 140 | 141 | return { 142 | "success": True, 143 | "query": natural_language_query, 144 | "sql": sql, # 使用清理后的SQL 145 | "explanation": result.explanation, 146 | "results": results, 147 | "column_names": column_names, 148 | "row_count": len(rows) 149 | } 150 | except Exception as e: 151 | print(f"SQL执行错误: {e}") 152 | return { 153 | "success": False, 154 | "query": natural_language_query, 155 | "sql": sql, # 使用清理后的SQL 156 | "explanation": result.explanation, 157 | "error": str(e) 158 | } 159 | 160 | def format_query_results(self, query_result): 161 | """格式化查询结果 162 | 163 | Args: 164 | query_result: 查询结果字典 165 | 166 | Returns: 167 | 格式化的结果字符串 168 | """ 169 | if not query_result["success"]: 170 | return f"查询失败: {query_result['error']}\n原始SQL: {query_result['sql']}" 171 | 172 | output = [] 173 | output.append(f"查询: {query_result['query']}") 174 | output.append(f"SQL: {query_result['sql']}") 175 | output.append(f"解释: {query_result['explanation']}") 176 | output.append(f"找到 {query_result['row_count']} 条结果:") 177 | 178 | if query_result['row_count'] > 0: 179 | # 计算每列的最大宽度 180 | widths = {} 181 | for col in query_result['column_names']: 182 | widths[col] = len(col) 183 | 184 | for row in query_result['results']: 185 | for col in query_result['column_names']: 186 | val = str(row[col]) if row[col] is not None else 'NULL' 187 | widths[col] = max(widths[col], len(val)) 188 | 189 | # 生成表头 190 | header = " | ".join(col.ljust(widths[col]) for col in query_result['column_names']) 191 | separator = "-+-".join("-" * widths[col] for col in query_result['column_names']) 192 | 193 | output.append(header) 194 | output.append(separator) 195 | 196 | # 生成数据行 197 | for row in query_result['results']: 198 | row_str = " | ".join( 199 | str(row[col]).ljust(widths[col]) if row[col] is not None else 'NULL'.ljust(widths[col]) 200 | for col in query_result['column_names'] 201 | ) 202 | output.append(row_str) 203 | 204 | return "\n".join(output) 205 | 206 | 207 | if __name__ == "__main__": 208 | # 直接使用查询处理器示例 209 | query_processor = MarineSpeciesQuery("marine_species.db") 210 | result = query_processor.query_database("分布在东海的盲鳗科哪些生物?有多少?") 211 | formatted_result = query_processor.format_query_results(result) 212 | print(formatted_result) 213 | -------------------------------------------------------------------------------- /images/function-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/function-diagram.png -------------------------------------------------------------------------------- /images/startup-success.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/startup-success.jpg -------------------------------------------------------------------------------- /images/二维码.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/二维码.jpg -------------------------------------------------------------------------------- /images/优化样本.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/优化样本.jpg -------------------------------------------------------------------------------- /images/关系信息查询.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/关系信息查询.jpg -------------------------------------------------------------------------------- /images/实体信息查询.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/实体信息查询.jpg -------------------------------------------------------------------------------- /images/属性信息查询.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/属性信息查询.jpg -------------------------------------------------------------------------------- /images/版本选择.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/版本选择.jpg -------------------------------------------------------------------------------- /images/统计信息查询.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/统计信息查询.jpg -------------------------------------------------------------------------------- /images/训练所有样本.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/训练所有样本.jpg -------------------------------------------------------------------------------- /images/非实体信息截图.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/非实体信息截图.jpg -------------------------------------------------------------------------------- /images/项目技术路线.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loukie7/Datacapsule/08a6ff167a89234868a3970bc42f93bef41058a3/images/项目技术路线.jpg -------------------------------------------------------------------------------- /nanovector_db.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dotenv import load_dotenv 3 | from openai import OpenAI 4 | import os 5 | import json 6 | from pathlib import Path 7 | import networkx as nx 8 | from loguru import logger 9 | 10 | class NanoVectorDB: 11 | def __init__(self, db_path: str): 12 | """初始化向量数据库 13 | 14 | Args: 15 | db_path: 数据库文件存储路径 16 | """ 17 | self.db_path = Path(db_path) 18 | self.db_path.mkdir(parents=True, exist_ok=True) 19 | 20 | self.entity_vectors_file = self.db_path / 'entity_vectors.json' 21 | self.relation_vectors_file = self.db_path / 'relation_vectors.json' 22 | logger.info(f"初始化向量数据库: {self.db_path}/entity_vectors.json, relation_vectors.json" ) 23 | # 初始化存储文件 24 | if not self.entity_vectors_file.exists(): 25 | logger.info(f"文件不存在,开始创建向量数据库: {self.db_path}/entity_vectors.json, relation_vectors.json" ) 26 | self._save_vectors(self.entity_vectors_file, []) 27 | if not self.relation_vectors_file.exists(): 28 | logger.info(f"文件不存在,开始创建向量数据库: {self.db_path}/relation_vectors.json" ) 29 | self._save_vectors(self.relation_vectors_file, []) 30 | logger.info(f"开始缓存向量数据: {self.db_path}/entity_vectors.json, relation_vectors.json" ) 31 | # 缓存向量数据 32 | self.entity_vectors_cache = self._load_vectors(self.entity_vectors_file) 33 | self.relation_vectors_cache = self._load_vectors(self.relation_vectors_file) 34 | logger.info(f"已缓存实体向量 {len(self.entity_vectors_cache)} 条,关系向量 {len(self.relation_vectors_cache)} 条") 35 | 36 | def _save_vectors(self, file_path: Path, vectors: list): 37 | """保存向量数据到文件""" 38 | with open(file_path, 'w', encoding='utf-8') as f: 39 | json.dump(vectors, f, ensure_ascii=False) 40 | 41 | def _load_vectors(self, file_path: Path) -> list: 42 | """从文件加载向量数据""" 43 | logger.info(f"开始加载向量数据库: {file_path}") 44 | with open(file_path, 'r', encoding='utf-8') as f: 45 | data = json.load(f) 46 | logger.info(f"成功加载向量数据库: {file_path}") 47 | return data 48 | 49 | def add_entity(self, entity_id: str, entity_type: str, entity_name: str, embedding: list): 50 | """添加实体向量""" 51 | self.entity_vectors_cache.append({ 52 | 'entity_id': entity_id, 53 | 'entity_type': entity_type, 54 | 'entity_name': entity_name, 55 | 'embedding': embedding 56 | }) 57 | self._save_vectors(self.entity_vectors_file, self.entity_vectors_cache) 58 | 59 | def add_relation(self, source_id: str, target_id: str, relation_type: str, embedding: list): 60 | """添加关系向量""" 61 | self.relation_vectors_cache.append({ 62 | 'source_id': source_id, 63 | 'target_id': target_id, 64 | 'relation_type': relation_type, 65 | 'embedding': embedding 66 | }) 67 | self._save_vectors(self.relation_vectors_file, self.relation_vectors_cache) 68 | 69 | def search_entities(self, query_embedding: list, k: int = 5) -> list: 70 | """搜索最相似的实体""" 71 | results = [] 72 | 73 | for entity in self.entity_vectors_cache: 74 | similarity = 1 - self._cosine_distance(query_embedding, entity['embedding']) 75 | results.append({ 76 | 'type': 'entity', 77 | 'id': entity['entity_id'], 78 | 'entity_type': entity['entity_type'], 79 | 'name': entity['entity_name'], 80 | 'similarity': similarity 81 | }) 82 | 83 | results.sort(key=lambda x: x['similarity'], reverse=True) 84 | return results[:k] 85 | 86 | def search_relations(self, query_embedding: list, k: int = 5) -> list: 87 | """搜索最相似的关系""" 88 | results = [] 89 | 90 | for relation in self.relation_vectors_cache: 91 | similarity = 1 - self._cosine_distance(query_embedding, relation['embedding']) 92 | results.append({ 93 | 'type': 'relation', 94 | 'source': relation['source_id'], 95 | 'target': relation['target_id'], 96 | 'relation_type': relation['relation_type'], 97 | 'similarity': similarity 98 | }) 99 | 100 | results.sort(key=lambda x: x['similarity'], reverse=True) 101 | return results[:k] 102 | 103 | def _cosine_distance(self, v1: list, v2: list) -> float: 104 | """计算余弦距离""" 105 | v1_array = np.array(v1) 106 | v2_array = np.array(v2) 107 | dot_product = np.dot(v1_array, v2_array) 108 | norm_v1 = np.linalg.norm(v1_array) 109 | norm_v2 = np.linalg.norm(v2_array) 110 | return 1 - dot_product / (norm_v1 * norm_v2) -------------------------------------------------------------------------------- /react_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from openai import OpenAI 4 | from loguru import logger 5 | import networkx as nx 6 | import dspy 7 | from typing import List 8 | from nanovector_db import NanoVectorDB 9 | 10 | 11 | MAX_BATCH_SIZE = os.getenv("MAX_BATCH_SIZE") 12 | VECTOR_SEARCH_TOP_K = int(os.getenv("VECTOR_SEARCH_TOP_K","3")) 13 | BETTER_THAN_THRESHOLD = float(os.getenv("BETTER_THAN_THRESHOLD","0.7")) 14 | WORKING_DIR =os.getenv("RAG_DIR","graph_data") 15 | 16 | client = OpenAI(base_url=os.getenv("EMBEDDING_MODEL_BASE_URL"),api_key=os.getenv("EMBEDDING_MODEL_API_KEY"),) 17 | 18 | # 定义节点类型的层级顺序 19 | NODE_HIERARCHY = { 20 | "界": 1, 21 | "门": 2, 22 | "纲": 3, 23 | "目": 4, 24 | "科": 5, 25 | "属": 6, 26 | "种": 7, 27 | "中文学名": 7, 28 | "自然分布地": 8, 29 | "生活习性": 8, 30 | "生物特征": 8, 31 | "经济性": 8, 32 | "保护信息": 8, 33 | "食性":8, 34 | "繁殖特征":8, 35 | "行为特征":8, 36 | "体型":8, 37 | "体色":8, 38 | "体长":8, 39 | "特殊特征":8 40 | } 41 | 42 | class ReActTools: 43 | def __init__(self): 44 | logger.info("ReActTools initialized") 45 | GRAPHML_DIR = os.getenv("GRAPHML_DIR","graph_chunk_entity_relation_clean.graphml") 46 | logger.info("init-ReActTools") 47 | logger.info(f"{WORKING_DIR}/{GRAPHML_DIR}") 48 | if os.path.exists(f"{WORKING_DIR}/{GRAPHML_DIR}"): 49 | self.nx = nx.read_graphml(f"{WORKING_DIR}/{GRAPHML_DIR}") 50 | 51 | # 判断是否正确加载到网络图 52 | if self.nx and self.nx.number_of_nodes() >0: 53 | logger.info(f"NetworkX graph loaded successfully! have nodes: {self.nx.number_of_nodes()}") 54 | self.nx_nodes=self.nx.nodes(data=True) 55 | self.entity_type_map = {} 56 | for node in self.nx_nodes: 57 | item = node[1] 58 | id = node[0] 59 | entity_type = item.get('node_type') 60 | if entity_type: # 只处理包含entity_type的节点 61 | if entity_type not in self.entity_type_map: 62 | self.entity_type_map[entity_type] = {} 63 | self.entity_type_map[entity_type][id] = item 64 | else: 65 | logger.warning(f"Warning: Node {id} missing node_type attribute") 66 | else: 67 | logger.error("NetworkX graph is empty!") 68 | 69 | self.dim = int(os.getenv("EMBEDDING_DIM",1536)) 70 | self.vectorizer = GraphVectorizer(WORKING_DIR) 71 | 72 | def openai_embedding_function(self,texts: List[str]): 73 | 74 | response = client.embeddings.create( 75 | input=texts, 76 | model=os.getenv("EMBEDDING_MODEL") 77 | ) 78 | return [x.embedding for x in response.data] 79 | 80 | def find_nodes_by_node_type(self,start_node,attr_name): 81 | ''' 82 | 根据开始节点名查找具有指定属性节点,返回节点信息,节点不存时返回None 83 | ''' 84 | logger.info(f"开始查找 - 起始节点: '{start_node}', 目标属性: '{attr_name}'") 85 | checked_nodes = [] 86 | nodes = set() 87 | self.find_neighbors_recursive(start_node, attr_name, nodes, checked_nodes, depth=0) 88 | logger.info(f"查找完成 - 找到 {len(nodes)} 个节点: {nodes}") 89 | return nodes 90 | 91 | 92 | def find_neighbors_recursive(self,node, target, nodes, checked_nodes, depth=0): 93 | """ 94 | 递归查询某一节点的邻居,并根据目标进行逐层判断,确保递进朝一个方向。 95 | :param node: 当前节点 96 | :param target: 目标节点的类型 97 | :param nodes: 已找到的目标节点列表 98 | :param checked_nodes: 已检查的节点列表 99 | :param depth: 当前递归深度(用于日志缩进) 100 | """ 101 | indent = " " * depth 102 | logger.debug(f"{indent}检查节点: '{node}' (递归深度: {depth}, 已检查节点数: {len(checked_nodes)})") 103 | checked_nodes.append(node) # 标记当前节点已检查 104 | 105 | # 添加异常处理,检查节点是否存在 106 | try: 107 | if node not in self.nx.nodes: 108 | logger.warning(f"{indent}节点 '{node}' 不存在于图中") 109 | return 110 | 111 | source_node_type = self.nx.nodes[node].get("node_type") 112 | if not source_node_type: 113 | logger.warning(f"{indent}节点 '{node}' 没有node_type属性") 114 | return 115 | 116 | logger.debug(f"{indent}当前节点类型: '{source_node_type}'") 117 | except Exception as e: 118 | logger.error(f"{indent}处理节点 '{node}' 时出错: {str(e)}") 119 | return 120 | 121 | # 获取当前节点和目标节点的层级 122 | source_level = NODE_HIERARCHY.get(source_node_type, float('inf')) 123 | target_level = NODE_HIERARCHY.get(target, float('inf')) 124 | logger.debug(f"{indent}层级比较 - 当前节点: {source_level}, 目标节点: {target_level}") 125 | 126 | if source_level == target_level: 127 | logger.info(f"{indent}找到目标节点! '{node}' (类型: {source_node_type})") 128 | nodes.add(node) 129 | return 130 | 131 | # 获取邻居节点 132 | try: 133 | # 获取所有相邻节点(包括入边和出边) 134 | neighbors = list(self.nx.neighbors(node)) # 获取出边邻居 135 | predecessors = list(self.nx.predecessors(node)) # 获取入边邻居 136 | all_neighbors = list(set(neighbors + predecessors)) # 合并并去重 137 | logger.debug(f"{indent}找到 {len(all_neighbors)} 个邻居节点(包括入边和出边)") 138 | except Exception as e: 139 | logger.error(f"{indent}获取节点 '{node}' 的邻居时出错: {str(e)}") 140 | return 141 | 142 | for neighbor in all_neighbors: 143 | # 跳过已检查的节点 144 | if neighbor in checked_nodes: 145 | logger.debug(f"{indent}跳过已检查的节点: '{neighbor}'") 146 | continue 147 | 148 | try: 149 | neighbor_type = self.nx.nodes[neighbor].get("node_type") 150 | if not neighbor_type: 151 | logger.debug(f"{indent}邻居节点 '{neighbor}' 没有node_type属性,跳过") 152 | continue 153 | 154 | neighbor_level = NODE_HIERARCHY.get(neighbor_type, float('inf')) 155 | logger.debug(f"{indent}检查邻居: '{neighbor}' (类型: {neighbor_type}, 层级: {neighbor_level})") 156 | 157 | # 如果是目标节点,则添加到结果列表 158 | if neighbor_type == target or (neighbor_level == 7 and neighbor_level == target_level): 159 | logger.info(f"{indent}找到目标节点! '{neighbor}' (类型: {neighbor_type})") 160 | nodes.add(neighbor) 161 | # 如果目标比当前节点层级高,停止递归并返回目标节点 162 | if target_level <= source_level: 163 | logger.debug(f"{indent}目标层级({target_level})小于等于当前层级({source_level}),停止递归") 164 | return 165 | else: 166 | if NODE_HIERARCHY.get(neighbor_type, float('inf')) <= 7: 167 | if target_level < source_level and neighbor_level < source_level: 168 | logger.debug(f"{indent}向上递归: '{neighbor}' (当前层级: {source_level}, 邻居层级: {neighbor_level}, 目标层级: {target_level})") 169 | self.find_neighbors_recursive(neighbor, target, nodes, checked_nodes, depth+1) 170 | elif target_level > source_level and neighbor_level > source_level: 171 | logger.debug(f"{indent}向下递归: '{neighbor}' (当前层级: {source_level}, 邻居层级: {neighbor_level}, 目标层级: {target_level})") 172 | self.find_neighbors_recursive(neighbor, target, nodes, checked_nodes, depth+1) 173 | else: 174 | logger.debug(f"{indent}不符合递归条件,跳过邻居: '{neighbor}'") 175 | else: 176 | logger.debug(f"{indent}邻居层级 > 7,跳过: '{neighbor}' (层级: {neighbor_level})") 177 | except Exception as e: 178 | logger.warning(f"{indent}处理邻居节点 '{neighbor}' 时出错: {str(e)}") 179 | continue 180 | 181 | logger.debug(f"{indent}完成节点 '{node}' 的所有邻居检查") 182 | 183 | # 查询指定节点所有属性 184 | def get_node_attribute(self,node_id): 185 | ''' 186 | 根据节点id获取所有属性,包括中文学名、拉丁学名、命名年份、作者、node_type 187 | :param node_id: 节点id 188 | :return: 属性值 189 | ''' 190 | return self.nx.nodes[node_id] 191 | 192 | def get_adjacent_node_descriptions(self,nodenames): 193 | ''' 194 | 根据列表中节点名获取所有相邻节点的description 195 | :param node_id: 节点id 196 | :return: 所有相依节点信息集合 197 | ''' 198 | result = set() 199 | for nodename in nodenames: 200 | # 获取出边邻居 201 | for neighbor in self.nx.neighbors(nodename): 202 | description = self.nx.nodes[neighbor].get("description") 203 | if description: 204 | result.add(description) 205 | # 获取入边邻居 206 | for predecessor in self.nx.predecessors(nodename): 207 | description = self.nx.nodes[predecessor].get("description") 208 | if description: 209 | result.add(description) 210 | return list(result) 211 | 212 | class GraphVectorizer: 213 | def __init__(self, db_path: str=None, openai_api_key: str = None): 214 | """初始化向量化器 215 | 216 | Args: 217 | db_path: 向量数据库存储路径 218 | openai_api_key: OpenAI API密钥,如果不提供则从环境变量获取 219 | """ 220 | if db_path is None: 221 | db_path = WORKING_DIR 222 | self.db = NanoVectorDB(db_path) 223 | 224 | 225 | def _get_embedding(self, text: str) -> list[float]: 226 | """获取文本的向量表示""" 227 | response = client.embeddings.create( 228 | model=os.getenv("EMBEDDING_MODEL"), 229 | input=text, 230 | encoding_format="float" 231 | ) 232 | return response.data[0].embedding 233 | 234 | def vectorize_graph(self, graph_file: str): 235 | """将知识图谱中的实体和关系向量化并存储 236 | 237 | Args: 238 | graph_file: GraphML文件路径 239 | """ 240 | # 读取图谱 241 | G = nx.read_graphml(graph_file) 242 | 243 | # 向量化并存储实体 244 | for node, attrs in G.nodes(data=True): 245 | # 构建实体描述文本 246 | entity_desc = f"实体ID: {node}" 247 | if 'node_type' in attrs: 248 | entity_desc += f", 类型: {attrs['node_type']}" 249 | if 'name' in attrs: 250 | entity_desc += f", 名称: {attrs['name']}" 251 | 252 | # 获取实体向量 253 | embedding = self._get_embedding(entity_desc) 254 | 255 | # 存储实体向量 256 | self.db.add_entity( 257 | entity_id=node, 258 | entity_type=attrs.get('node_type'), 259 | entity_name=attrs.get('name'), 260 | embedding=embedding 261 | ) 262 | 263 | # 向量化并存储关系 264 | for source, target, attrs in G.edges(data=True): 265 | # 构建关系描述文本 266 | relation_desc = f"关系: 从 {source} 到 {target}" 267 | if 'relation' in attrs: 268 | relation_desc += f", 类型: {attrs['relation']}" 269 | 270 | # 获取关系向量 271 | embedding = self._get_embedding(relation_desc) 272 | 273 | # 存储关系向量 274 | self.db.add_relation( 275 | source_id=source, 276 | target_id=target, 277 | relation_type=attrs.get('relation'), 278 | embedding=embedding 279 | ) 280 | 281 | def search(self, query: str, node_type: str = None, search_type: str = 'all', top_k: int = 5, better_than_threshold: float = BETTER_THAN_THRESHOLD): 282 | """搜索与查询最相关的实体或关系 283 | Args: 284 | query: 搜索查询文本 285 | node_type: 实体类型筛选条件,如果为None则不筛选。可选值包括: 286 | - species (种、中文名) 287 | - 界 288 | - 门 289 | - 纲 290 | - 目 291 | - 科 292 | - 属 293 | - 自然分布地 294 | - 食性 295 | - 繁殖特征 296 | - 生活习性 297 | - 体型 298 | - 体色 299 | - 体长 300 | - 特殊特征 301 | k: 返回的结果数量 302 | search_type: 搜索类型,'all'/'entity'/'relation' 303 | better_than_threshold: 相似度阈值,只返回相似度高于此值的结果 304 | 305 | Returns: 306 | list: 搜索结果,精准的实体名列表 307 | """ 308 | # 获取查询向量 309 | query_embedding = self._get_embedding(query) 310 | results = [] 311 | 312 | if search_type in ['all', 'entity']: 313 | entities = self.db.search_entities(query_embedding, k=100) # 获取更多结果用于筛选 314 | # 按node_type筛选 315 | if node_type: 316 | entities = [e for e in entities if e['entity_type'] == node_type] 317 | results.extend(entities) 318 | 319 | if search_type in ['all', 'relation']: 320 | results.extend(self.db.search_relations(query_embedding, k=100)) # 获取更多结果用于筛选 321 | 322 | # 按相似度阈值筛选 323 | results = [r for r in results if r['similarity'] >= better_than_threshold] 324 | 325 | # 按相似度排序 326 | results.sort(key=lambda x: x['similarity'], reverse=True) 327 | return results[:top_k] 328 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |