├── img ├── 1.png ├── 2.png └── 3.png ├── requirements.txt ├── src ├── tools │ ├── __init__.py │ └── search.py ├── state │ ├── __init__.py │ └── state.py ├── llms │ ├── __init__.py │ ├── base.py │ ├── openai_llm.py │ └── deepseek.py ├── __init__.py ├── nodes │ ├── __init__.py │ ├── base_node.py │ ├── report_structure_node.py │ ├── formatting_node.py │ ├── search_node.py │ └── summary_node.py ├── utils │ ├── __init__.py │ ├── text_processing.py │ └── config.py ├── prompts │ ├── __init__.py │ └── prompts.py └── agent.py ├── config.py ├── .gitignore ├── LICENSE ├── examples ├── basic_usage.py ├── advanced_usage.py └── streamlit_app.py └── README.md /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/666ghj/DeepSearchAgent-Demo/HEAD/img/1.png -------------------------------------------------------------------------------- /img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/666ghj/DeepSearchAgent-Demo/HEAD/img/2.png -------------------------------------------------------------------------------- /img/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/666ghj/DeepSearchAgent-Demo/HEAD/img/3.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai>=1.0.0 2 | requests>=2.25.0 3 | tavily-python>=0.3.0 4 | streamlit>=1.28.0 5 | pydantic>=2.0.0 6 | rich>=13.0.0 7 | -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 工具调用模块 3 | 提供外部工具接口,如网络搜索等 4 | """ 5 | 6 | from .search import tavily_search, SearchResult 7 | 8 | __all__ = ["tavily_search", "SearchResult"] 9 | -------------------------------------------------------------------------------- /src/state/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 状态管理模块 3 | 定义Deep Search Agent的状态数据结构 4 | """ 5 | 6 | from .state import State, Paragraph, Research, Search 7 | 8 | __all__ = ["State", "Paragraph", "Research", "Search"] 9 | -------------------------------------------------------------------------------- /src/llms/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | LLM调用模块 3 | 支持多种大语言模型的统一接口 4 | """ 5 | 6 | from .base import BaseLLM 7 | from .deepseek import DeepSeekLLM 8 | from .openai_llm import OpenAILLM 9 | 10 | __all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"] 11 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep Search Agent 3 | 一个无框架的深度搜索AI代理实现 4 | """ 5 | 6 | from .agent import DeepSearchAgent, create_agent 7 | from .utils.config import Config, load_config 8 | 9 | __version__ = "1.0.0" 10 | __author__ = "Deep Search Agent Team" 11 | 12 | __all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"] 13 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Deep Search Agent 配置文件 2 | # 请在这里填入您的API密钥 3 | 4 | # DeepSeek API Key 5 | DEEPSEEK_API_KEY = "your_deepseek_api_key_here" 6 | 7 | # OpenAI API Key (可选) 8 | OPENAI_API_KEY = "your_openai_api_key_here" 9 | 10 | # Tavily搜索API Key 11 | TAVILY_API_KEY = "your_tavily_api_key_here" 12 | 13 | # 配置参数 14 | DEFAULT_LLM_PROVIDER = "deepseek" 15 | DEEPSEEK_MODEL = "deepseek-chat" 16 | OPENAI_MODEL = "gpt-4o-mini" 17 | 18 | MAX_REFLECTIONS = 2 19 | SEARCH_RESULTS_PER_QUERY = 3 20 | SEARCH_CONTENT_MAX_LENGTH = 20000 21 | OUTPUT_DIR = "reports" 22 | SAVE_INTERMEDIATE_STATES = True 23 | -------------------------------------------------------------------------------- /src/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 节点处理模块 3 | 实现Deep Search Agent的各个处理步骤 4 | """ 5 | 6 | from .base_node import BaseNode 7 | from .report_structure_node import ReportStructureNode 8 | from .search_node import FirstSearchNode, ReflectionNode 9 | from .summary_node import FirstSummaryNode, ReflectionSummaryNode 10 | from .formatting_node import ReportFormattingNode 11 | 12 | __all__ = [ 13 | "BaseNode", 14 | "ReportStructureNode", 15 | "FirstSearchNode", 16 | "ReflectionNode", 17 | "FirstSummaryNode", 18 | "ReflectionSummaryNode", 19 | "ReportFormattingNode" 20 | ] 21 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 工具函数模块 3 | 提供文本处理、JSON解析等辅助功能 4 | """ 5 | 6 | from .text_processing import ( 7 | clean_json_tags, 8 | clean_markdown_tags, 9 | remove_reasoning_from_output, 10 | extract_clean_response, 11 | update_state_with_search_results, 12 | format_search_results_for_prompt 13 | ) 14 | 15 | from .config import Config, load_config 16 | 17 | __all__ = [ 18 | "clean_json_tags", 19 | "clean_markdown_tags", 20 | "remove_reasoning_from_output", 21 | "extract_clean_response", 22 | "update_state_with_search_results", 23 | "format_search_results_for_prompt", 24 | "Config", 25 | "load_config" 26 | ] 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | pip-wheel-metadata/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # 虚拟环境 27 | venv/ 28 | env/ 29 | ENV/ 30 | env.bak/ 31 | venv.bak/ 32 | 33 | # 环境变量文件 34 | .env 35 | *.env 36 | 37 | # IDE 38 | .vscode/ 39 | .idea/ 40 | *.swp 41 | *.swo 42 | *~ 43 | 44 | # 日志文件 45 | *.log 46 | logs/ 47 | 48 | # 临时文件 49 | *.tmp 50 | *.temp 51 | temp/ 52 | tmp/ 53 | 54 | # 系统文件 55 | .DS_Store 56 | Thumbs.db 57 | 58 | # 项目特定 59 | reports/ 60 | custom_reports/ 61 | streamlit_reports/ 62 | *.json 63 | *.md 64 | !README.md 65 | !LICENSE 66 | 67 | # Streamlit 68 | .streamlit/ 69 | 70 | # 测试覆盖率 71 | .coverage 72 | htmlcov/ 73 | .pytest_cache/ 74 | -------------------------------------------------------------------------------- /src/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prompt模块 3 | 定义Deep Search Agent各个阶段使用的系统提示词 4 | """ 5 | 6 | from .prompts import ( 7 | SYSTEM_PROMPT_REPORT_STRUCTURE, 8 | SYSTEM_PROMPT_FIRST_SEARCH, 9 | SYSTEM_PROMPT_FIRST_SUMMARY, 10 | SYSTEM_PROMPT_REFLECTION, 11 | SYSTEM_PROMPT_REFLECTION_SUMMARY, 12 | SYSTEM_PROMPT_REPORT_FORMATTING, 13 | output_schema_report_structure, 14 | output_schema_first_search, 15 | output_schema_first_summary, 16 | output_schema_reflection, 17 | output_schema_reflection_summary, 18 | input_schema_report_formatting 19 | ) 20 | 21 | __all__ = [ 22 | "SYSTEM_PROMPT_REPORT_STRUCTURE", 23 | "SYSTEM_PROMPT_FIRST_SEARCH", 24 | "SYSTEM_PROMPT_FIRST_SUMMARY", 25 | "SYSTEM_PROMPT_REFLECTION", 26 | "SYSTEM_PROMPT_REFLECTION_SUMMARY", 27 | "SYSTEM_PROMPT_REPORT_FORMATTING", 28 | "output_schema_report_structure", 29 | "output_schema_first_search", 30 | "output_schema_first_summary", 31 | "output_schema_reflection", 32 | "output_schema_reflection_summary", 33 | "input_schema_report_formatting" 34 | ] 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Deep Search Agent Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/llms/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | LLM基础抽象类 3 | 定义所有LLM实现需要遵循的接口标准 4 | """ 5 | 6 | from abc import ABC, abstractmethod 7 | from typing import Optional, Dict, Any 8 | 9 | 10 | class BaseLLM(ABC): 11 | """LLM基础抽象类""" 12 | 13 | def __init__(self, api_key: str, model_name: Optional[str] = None): 14 | """ 15 | 初始化LLM客户端 16 | 17 | Args: 18 | api_key: API密钥 19 | model_name: 模型名称,如果不指定则使用默认模型 20 | """ 21 | self.api_key = api_key 22 | self.model_name = model_name 23 | 24 | @abstractmethod 25 | def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: 26 | """ 27 | 调用LLM生成回复 28 | 29 | Args: 30 | system_prompt: 系统提示词 31 | user_prompt: 用户输入 32 | **kwargs: 其他参数,如temperature、max_tokens等 33 | 34 | Returns: 35 | LLM生成的回复文本 36 | """ 37 | pass 38 | 39 | @abstractmethod 40 | def get_default_model(self) -> str: 41 | """ 42 | 获取默认模型名称 43 | 44 | Returns: 45 | 默认模型名称 46 | """ 47 | pass 48 | 49 | def validate_response(self, response: str) -> str: 50 | """ 51 | 验证和清理响应内容 52 | 53 | Args: 54 | response: LLM原始响应 55 | 56 | Returns: 57 | 清理后的响应内容 58 | """ 59 | if response is None: 60 | return "" 61 | return response.strip() 62 | -------------------------------------------------------------------------------- /examples/basic_usage.py: -------------------------------------------------------------------------------- 1 | """ 2 | 基本使用示例 3 | 演示如何使用Deep Search Agent进行基本的深度搜索 4 | """ 5 | 6 | import os 7 | import sys 8 | 9 | # 添加项目根目录到Python路径 10 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) 11 | 12 | from src import DeepSearchAgent, load_config 13 | from src.utils.config import print_config 14 | 15 | 16 | def basic_example(): 17 | """基本使用示例""" 18 | print("=" * 60) 19 | print("Deep Search Agent - 基本使用示例") 20 | print("=" * 60) 21 | 22 | try: 23 | # 加载配置 24 | print("正在加载配置...") 25 | config = load_config() 26 | print_config(config) 27 | 28 | # 创建Agent 29 | print("正在初始化Agent...") 30 | agent = DeepSearchAgent(config) 31 | 32 | # 执行研究 33 | query = "2025年人工智能发展趋势" 34 | print(f"开始研究: {query}") 35 | 36 | final_report = agent.research(query, save_report=True) 37 | 38 | # 显示结果 39 | print("\n" + "=" * 60) 40 | print("研究完成!最终报告预览:") 41 | print("=" * 60) 42 | print(final_report[:500] + "..." if len(final_report) > 500 else final_report) 43 | 44 | # 显示进度信息 45 | progress = agent.get_progress_summary() 46 | print(f"\n进度信息:") 47 | print(f"- 总段落数: {progress['total_paragraphs']}") 48 | print(f"- 已完成段落: {progress['completed_paragraphs']}") 49 | print(f"- 完成进度: {progress['progress_percentage']:.1f}%") 50 | print(f"- 是否完成: {progress['is_completed']}") 51 | 52 | except Exception as e: 53 | print(f"示例运行失败: {str(e)}") 54 | print("请检查:") 55 | print("1. 是否安装了所有依赖:pip install -r requirements.txt") 56 | print("2. 是否设置了必要的API密钥") 57 | print("3. 网络连接是否正常") 58 | print("4. 配置文件是否正确") 59 | 60 | 61 | if __name__ == "__main__": 62 | basic_example() 63 | -------------------------------------------------------------------------------- /src/nodes/base_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | 节点基类 3 | 定义所有处理节点的基础接口 4 | """ 5 | 6 | from abc import ABC, abstractmethod 7 | from typing import Any, Dict, Optional 8 | from ..llms.base import BaseLLM 9 | from ..state.state import State 10 | 11 | 12 | class BaseNode(ABC): 13 | """节点基类""" 14 | 15 | def __init__(self, llm_client: BaseLLM, node_name: str = ""): 16 | """ 17 | 初始化节点 18 | 19 | Args: 20 | llm_client: LLM客户端 21 | node_name: 节点名称 22 | """ 23 | self.llm_client = llm_client 24 | self.node_name = node_name or self.__class__.__name__ 25 | 26 | @abstractmethod 27 | def run(self, input_data: Any, **kwargs) -> Any: 28 | """ 29 | 执行节点处理逻辑 30 | 31 | Args: 32 | input_data: 输入数据 33 | **kwargs: 额外参数 34 | 35 | Returns: 36 | 处理结果 37 | """ 38 | pass 39 | 40 | def validate_input(self, input_data: Any) -> bool: 41 | """ 42 | 验证输入数据 43 | 44 | Args: 45 | input_data: 输入数据 46 | 47 | Returns: 48 | 验证是否通过 49 | """ 50 | return True 51 | 52 | def process_output(self, output: Any) -> Any: 53 | """ 54 | 处理输出数据 55 | 56 | Args: 57 | output: 原始输出 58 | 59 | Returns: 60 | 处理后的输出 61 | """ 62 | return output 63 | 64 | def log_info(self, message: str): 65 | """记录信息日志""" 66 | print(f"[{self.node_name}] {message}") 67 | 68 | def log_error(self, message: str): 69 | """记录错误日志""" 70 | print(f"[{self.node_name}] 错误: {message}") 71 | 72 | 73 | class StateMutationNode(BaseNode): 74 | """带状态修改功能的节点基类""" 75 | 76 | @abstractmethod 77 | def mutate_state(self, input_data: Any, state: State, **kwargs) -> State: 78 | """ 79 | 修改状态 80 | 81 | Args: 82 | input_data: 输入数据 83 | state: 当前状态 84 | **kwargs: 额外参数 85 | 86 | Returns: 87 | 修改后的状态 88 | """ 89 | pass 90 | -------------------------------------------------------------------------------- /src/llms/openai_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | OpenAI LLM实现 3 | 使用OpenAI API进行文本生成 4 | """ 5 | 6 | import os 7 | from typing import Optional, Dict, Any 8 | from openai import OpenAI 9 | from .base import BaseLLM 10 | 11 | 12 | class OpenAILLM(BaseLLM): 13 | """OpenAI LLM实现类""" 14 | 15 | def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None): 16 | """ 17 | 初始化OpenAI客户端 18 | 19 | Args: 20 | api_key: OpenAI API密钥,如果不提供则从环境变量读取 21 | model_name: 模型名称,默认使用gpt-4o-mini 22 | """ 23 | if api_key is None: 24 | api_key = os.getenv("OPENAI_API_KEY") 25 | if not api_key: 26 | raise ValueError("OpenAI API Key未找到!请设置OPENAI_API_KEY环境变量或在初始化时提供") 27 | 28 | super().__init__(api_key, model_name) 29 | 30 | # 初始化OpenAI客户端 31 | self.client = OpenAI(api_key=self.api_key) 32 | self.default_model = model_name or self.get_default_model() 33 | 34 | def get_default_model(self) -> str: 35 | """获取默认模型名称""" 36 | return "gpt-4o-mini" 37 | 38 | def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: 39 | """ 40 | 调用OpenAI API生成回复 41 | 42 | Args: 43 | system_prompt: 系统提示词 44 | user_prompt: 用户输入 45 | **kwargs: 其他参数,如temperature、max_tokens等 46 | 47 | Returns: 48 | OpenAI生成的回复文本 49 | """ 50 | try: 51 | # 构建消息 52 | messages = [ 53 | {"role": "system", "content": system_prompt}, 54 | {"role": "user", "content": user_prompt} 55 | ] 56 | 57 | # 设置默认参数 58 | params = { 59 | "model": self.default_model, 60 | "messages": messages, 61 | "temperature": kwargs.get("temperature", 0.7), 62 | "max_tokens": kwargs.get("max_tokens", 4000) 63 | } 64 | 65 | # 调用API 66 | response = self.client.chat.completions.create(**params) 67 | 68 | # 提取回复内容 69 | if response.choices and response.choices[0].message: 70 | content = response.choices[0].message.content 71 | return self.validate_response(content) 72 | else: 73 | return "" 74 | 75 | except Exception as e: 76 | print(f"OpenAI API调用错误: {str(e)}") 77 | raise e 78 | 79 | def get_model_info(self) -> Dict[str, Any]: 80 | """ 81 | 获取当前模型信息 82 | 83 | Returns: 84 | 模型信息字典 85 | """ 86 | return { 87 | "provider": "OpenAI", 88 | "model": self.default_model, 89 | "api_base": "https://api.openai.com" 90 | } 91 | -------------------------------------------------------------------------------- /src/llms/deepseek.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepSeek LLM实现 3 | 使用DeepSeek API进行文本生成 4 | """ 5 | 6 | import os 7 | from typing import Optional, Dict, Any 8 | from openai import OpenAI 9 | from .base import BaseLLM 10 | 11 | 12 | class DeepSeekLLM(BaseLLM): 13 | """DeepSeek LLM实现类""" 14 | 15 | def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None): 16 | """ 17 | 初始化DeepSeek客户端 18 | 19 | Args: 20 | api_key: DeepSeek API密钥,如果不提供则从环境变量读取 21 | model_name: 模型名称,默认使用deepseek-chat 22 | """ 23 | if api_key is None: 24 | api_key = os.getenv("DEEPSEEK_API_KEY") 25 | if not api_key: 26 | raise ValueError("DeepSeek API Key未找到!请设置DEEPSEEK_API_KEY环境变量或在初始化时提供") 27 | 28 | super().__init__(api_key, model_name) 29 | 30 | # 初始化OpenAI客户端,使用DeepSeek的endpoint 31 | self.client = OpenAI( 32 | api_key=self.api_key, 33 | base_url="https://api.deepseek.com" 34 | ) 35 | 36 | self.default_model = model_name or self.get_default_model() 37 | 38 | def get_default_model(self) -> str: 39 | """获取默认模型名称""" 40 | return "deepseek-chat" 41 | 42 | def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: 43 | """ 44 | 调用DeepSeek API生成回复 45 | 46 | Args: 47 | system_prompt: 系统提示词 48 | user_prompt: 用户输入 49 | **kwargs: 其他参数,如temperature、max_tokens等 50 | 51 | Returns: 52 | DeepSeek生成的回复文本 53 | """ 54 | try: 55 | # 构建消息 56 | messages = [ 57 | {"role": "system", "content": system_prompt}, 58 | {"role": "user", "content": user_prompt} 59 | ] 60 | 61 | # 设置默认参数 62 | params = { 63 | "model": self.default_model, 64 | "messages": messages, 65 | "temperature": kwargs.get("temperature", 0.7), 66 | "max_tokens": kwargs.get("max_tokens", 4000), 67 | "stream": False 68 | } 69 | 70 | # 调用API 71 | response = self.client.chat.completions.create(**params) 72 | 73 | # 提取回复内容 74 | if response.choices and response.choices[0].message: 75 | content = response.choices[0].message.content 76 | return self.validate_response(content) 77 | else: 78 | return "" 79 | 80 | except Exception as e: 81 | print(f"DeepSeek API调用错误: {str(e)}") 82 | raise e 83 | 84 | def get_model_info(self) -> Dict[str, Any]: 85 | """ 86 | 获取当前模型信息 87 | 88 | Returns: 89 | 模型信息字典 90 | """ 91 | return { 92 | "provider": "DeepSeek", 93 | "model": self.default_model, 94 | "api_base": "https://api.deepseek.com" 95 | } 96 | -------------------------------------------------------------------------------- /examples/advanced_usage.py: -------------------------------------------------------------------------------- 1 | """ 2 | 高级使用示例 3 | 演示Deep Search Agent的高级功能 4 | """ 5 | 6 | import os 7 | import sys 8 | 9 | # 添加项目根目录到Python路径 10 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) 11 | 12 | from src import DeepSearchAgent, Config 13 | from src.utils.config import print_config 14 | 15 | 16 | def advanced_example(): 17 | """高级使用示例""" 18 | print("=" * 60) 19 | print("Deep Search Agent - 高级使用示例") 20 | print("=" * 60) 21 | 22 | try: 23 | # 自定义配置 24 | print("正在创建自定义配置...") 25 | config = Config( 26 | # 使用OpenAI而不是DeepSeek 27 | default_llm_provider="openai", 28 | openai_model="gpt-4o-mini", 29 | # 自定义搜索参数 30 | max_search_results=5, # 更多搜索结果 31 | max_reflections=3, # 更多反思次数 32 | max_content_length=15000, 33 | # 自定义输出 34 | output_dir="custom_reports", 35 | save_intermediate_states=True 36 | ) 37 | 38 | # 从环境变量设置API密钥 39 | config.openai_api_key = os.getenv("OPENAI_API_KEY") 40 | config.tavily_api_key = os.getenv("TAVILY_API_KEY") 41 | 42 | if not config.validate(): 43 | print("配置验证失败,请检查API密钥设置") 44 | return 45 | 46 | print_config(config) 47 | 48 | # 创建Agent 49 | print("正在初始化Agent...") 50 | agent = DeepSearchAgent(config) 51 | 52 | # 执行多个研究任务 53 | queries = [ 54 | "深度学习在医疗领域的应用", 55 | "区块链技术的最新发展", 56 | "可持续能源技术趋势" 57 | ] 58 | 59 | for i, query in enumerate(queries, 1): 60 | print(f"\n{'='*60}") 61 | print(f"执行研究任务 {i}/{len(queries)}: {query}") 62 | print(f"{'='*60}") 63 | 64 | try: 65 | # 执行研究 66 | final_report = agent.research(query, save_report=True) 67 | 68 | # 保存状态(示例) 69 | state_file = f"custom_reports/state_task_{i}.json" 70 | agent.save_state(state_file) 71 | 72 | print(f"任务 {i} 完成") 73 | print(f"报告长度: {len(final_report)} 字符") 74 | 75 | # 显示进度 76 | progress = agent.get_progress_summary() 77 | print(f"完成进度: {progress['progress_percentage']:.1f}%") 78 | 79 | except Exception as e: 80 | print(f"任务 {i} 失败: {str(e)}") 81 | continue 82 | 83 | print(f"\n{'='*60}") 84 | print("所有研究任务完成!") 85 | print(f"{'='*60}") 86 | 87 | except Exception as e: 88 | print(f"高级示例运行失败: {str(e)}") 89 | 90 | 91 | def state_management_example(): 92 | """状态管理示例""" 93 | print("\n" + "=" * 60) 94 | print("状态管理示例") 95 | print("=" * 60) 96 | 97 | try: 98 | # 创建配置 99 | config = Config.from_env() 100 | if not config.validate(): 101 | print("配置验证失败") 102 | return 103 | 104 | # 创建Agent 105 | agent = DeepSearchAgent(config) 106 | 107 | query = "量子计算的发展现状" 108 | print(f"开始研究: {query}") 109 | 110 | # 执行研究 111 | final_report = agent.research(query) 112 | 113 | # 保存状态 114 | state_file = "custom_reports/quantum_computing_state.json" 115 | agent.save_state(state_file) 116 | print(f"状态已保存到: {state_file}") 117 | 118 | # 创建新的Agent并加载状态 119 | print("\n创建新Agent并加载状态...") 120 | new_agent = DeepSearchAgent(config) 121 | new_agent.load_state(state_file) 122 | 123 | # 检查加载的状态 124 | progress = new_agent.get_progress_summary() 125 | print("加载的状态信息:") 126 | print(f"- 查询: {new_agent.state.query}") 127 | print(f"- 报告标题: {new_agent.state.report_title}") 128 | print(f"- 段落数: {progress['total_paragraphs']}") 129 | print(f"- 完成状态: {progress['is_completed']}") 130 | 131 | except Exception as e: 132 | print(f"状态管理示例失败: {str(e)}") 133 | 134 | 135 | if __name__ == "__main__": 136 | advanced_example() 137 | state_management_example() 138 | -------------------------------------------------------------------------------- /src/tools/search.py: -------------------------------------------------------------------------------- 1 | """ 2 | 搜索工具实现 3 | 支持多种搜索引擎,主要使用Tavily搜索 4 | """ 5 | 6 | import os 7 | from typing import List, Dict, Any, Optional 8 | from dataclasses import dataclass 9 | from tavily import TavilyClient 10 | 11 | 12 | @dataclass 13 | class SearchResult: 14 | """搜索结果数据类""" 15 | title: str 16 | url: str 17 | content: str 18 | score: Optional[float] = None 19 | 20 | def to_dict(self) -> Dict[str, Any]: 21 | """转换为字典格式""" 22 | return { 23 | "title": self.title, 24 | "url": self.url, 25 | "content": self.content, 26 | "score": self.score 27 | } 28 | 29 | 30 | class TavilySearch: 31 | """Tavily搜索客户端封装""" 32 | 33 | def __init__(self, api_key: Optional[str] = None): 34 | """ 35 | 初始化Tavily搜索客户端 36 | 37 | Args: 38 | api_key: Tavily API密钥,如果不提供则从环境变量读取 39 | """ 40 | if api_key is None: 41 | api_key = os.getenv("TAVILY_API_KEY") 42 | if not api_key: 43 | raise ValueError("Tavily API Key未找到!请设置TAVILY_API_KEY环境变量或在初始化时提供") 44 | 45 | self.client = TavilyClient(api_key=api_key) 46 | 47 | def search(self, query: str, max_results: int = 5, include_raw_content: bool = True, 48 | timeout: int = 240) -> List[SearchResult]: 49 | """ 50 | 执行搜索 51 | 52 | Args: 53 | query: 搜索查询 54 | max_results: 最大结果数量 55 | include_raw_content: 是否包含原始内容 56 | timeout: 超时时间(秒) 57 | 58 | Returns: 59 | 搜索结果列表 60 | """ 61 | try: 62 | # 调用Tavily API 63 | response = self.client.search( 64 | query=query, 65 | max_results=max_results, 66 | include_raw_content=include_raw_content, 67 | timeout=timeout 68 | ) 69 | 70 | # 解析结果 71 | results = [] 72 | if 'results' in response: 73 | for item in response['results']: 74 | result = SearchResult( 75 | title=item.get('title', ''), 76 | url=item.get('url', ''), 77 | content=item.get('content', ''), 78 | score=item.get('score') 79 | ) 80 | results.append(result) 81 | 82 | return results 83 | 84 | except Exception as e: 85 | print(f"搜索错误: {str(e)}") 86 | return [] 87 | 88 | 89 | # 全局搜索客户端实例 90 | _tavily_client = None 91 | 92 | 93 | def get_tavily_client() -> TavilySearch: 94 | """获取全局Tavily客户端实例""" 95 | global _tavily_client 96 | if _tavily_client is None: 97 | _tavily_client = TavilySearch() 98 | return _tavily_client 99 | 100 | 101 | def tavily_search(query: str, max_results: int = 5, include_raw_content: bool = True, 102 | timeout: int = 240, api_key: Optional[str] = None) -> List[Dict[str, Any]]: 103 | """ 104 | 便捷的Tavily搜索函数 105 | 106 | Args: 107 | query: 搜索查询 108 | max_results: 最大结果数量 109 | include_raw_content: 是否包含原始内容 110 | timeout: 超时时间(秒) 111 | api_key: Tavily API密钥,如果提供则使用此密钥,否则使用全局客户端 112 | 113 | Returns: 114 | 搜索结果字典列表,保持与原始经验贴兼容的格式 115 | """ 116 | try: 117 | if api_key: 118 | # 使用提供的API密钥创建临时客户端 119 | client = TavilySearch(api_key) 120 | else: 121 | # 使用全局客户端 122 | client = get_tavily_client() 123 | 124 | results = client.search(query, max_results, include_raw_content, timeout) 125 | 126 | # 转换为字典格式以保持兼容性 127 | return [result.to_dict() for result in results] 128 | 129 | except Exception as e: 130 | print(f"搜索功能调用错误: {str(e)}") 131 | return [] 132 | 133 | 134 | def test_search(query: str = "人工智能发展趋势 2025", max_results: int = 3): 135 | """ 136 | 测试搜索功能 137 | 138 | Args: 139 | query: 测试查询 140 | max_results: 最大结果数量 141 | """ 142 | print(f"\n=== 测试Tavily搜索功能 ===") 143 | print(f"搜索查询: {query}") 144 | print(f"最大结果数: {max_results}") 145 | 146 | try: 147 | results = tavily_search(query, max_results=max_results) 148 | 149 | if results: 150 | print(f"\n找到 {len(results)} 个结果:") 151 | for i, result in enumerate(results, 1): 152 | print(f"\n结果 {i}:") 153 | print(f"标题: {result['title']}") 154 | print(f"链接: {result['url']}") 155 | print(f"内容摘要: {result['content'][:200]}...") 156 | if result.get('score'): 157 | print(f"相关度评分: {result['score']}") 158 | else: 159 | print("未找到搜索结果") 160 | 161 | except Exception as e: 162 | print(f"搜索测试失败: {str(e)}") 163 | 164 | 165 | if __name__ == "__main__": 166 | # 运行测试 167 | test_search() 168 | -------------------------------------------------------------------------------- /src/nodes/report_structure_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | 报告结构生成节点 3 | 负责根据查询生成报告的整体结构 4 | """ 5 | 6 | import json 7 | from typing import Dict, Any, List 8 | from json.decoder import JSONDecodeError 9 | 10 | from .base_node import StateMutationNode 11 | from ..state.state import State 12 | from ..prompts import SYSTEM_PROMPT_REPORT_STRUCTURE 13 | from ..utils.text_processing import ( 14 | remove_reasoning_from_output, 15 | clean_json_tags, 16 | extract_clean_response 17 | ) 18 | 19 | 20 | class ReportStructureNode(StateMutationNode): 21 | """生成报告结构的节点""" 22 | 23 | def __init__(self, llm_client, query: str): 24 | """ 25 | 初始化报告结构节点 26 | 27 | Args: 28 | llm_client: LLM客户端 29 | query: 用户查询 30 | """ 31 | super().__init__(llm_client, "ReportStructureNode") 32 | self.query = query 33 | 34 | def validate_input(self, input_data: Any) -> bool: 35 | """验证输入数据""" 36 | return isinstance(self.query, str) and len(self.query.strip()) > 0 37 | 38 | def run(self, input_data: Any = None, **kwargs) -> List[Dict[str, str]]: 39 | """ 40 | 调用LLM生成报告结构 41 | 42 | Args: 43 | input_data: 输入数据(这里不使用,使用初始化时的query) 44 | **kwargs: 额外参数 45 | 46 | Returns: 47 | 报告结构列表 48 | """ 49 | try: 50 | self.log_info(f"正在为查询生成报告结构: {self.query}") 51 | 52 | # 调用LLM 53 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) 54 | 55 | # 处理响应 56 | processed_response = self.process_output(response) 57 | 58 | self.log_info(f"成功生成 {len(processed_response)} 个段落结构") 59 | return processed_response 60 | 61 | except Exception as e: 62 | self.log_error(f"生成报告结构失败: {str(e)}") 63 | raise e 64 | 65 | def process_output(self, output: str) -> List[Dict[str, str]]: 66 | """ 67 | 处理LLM输出,提取报告结构 68 | 69 | Args: 70 | output: LLM原始输出 71 | 72 | Returns: 73 | 处理后的报告结构列表 74 | """ 75 | try: 76 | # 清理响应文本 77 | cleaned_output = remove_reasoning_from_output(output) 78 | cleaned_output = clean_json_tags(cleaned_output) 79 | 80 | # 解析JSON 81 | try: 82 | report_structure = json.loads(cleaned_output) 83 | except JSONDecodeError: 84 | # 使用更强大的提取方法 85 | report_structure = extract_clean_response(cleaned_output) 86 | if "error" in report_structure: 87 | raise ValueError("JSON解析失败") 88 | 89 | # 验证结构 90 | if not isinstance(report_structure, list): 91 | raise ValueError("报告结构应该是一个列表") 92 | 93 | # 验证每个段落 94 | validated_structure = [] 95 | for i, paragraph in enumerate(report_structure): 96 | if not isinstance(paragraph, dict): 97 | continue 98 | 99 | title = paragraph.get("title", f"段落 {i+1}") 100 | content = paragraph.get("content", "") 101 | 102 | validated_structure.append({ 103 | "title": title, 104 | "content": content 105 | }) 106 | 107 | return validated_structure 108 | 109 | except Exception as e: 110 | self.log_error(f"处理输出失败: {str(e)}") 111 | # 返回默认结构 112 | return [ 113 | { 114 | "title": "概述", 115 | "content": f"对'{self.query}'的总体概述和背景介绍" 116 | }, 117 | { 118 | "title": "详细分析", 119 | "content": f"深入分析'{self.query}'的相关内容" 120 | } 121 | ] 122 | 123 | def mutate_state(self, input_data: Any = None, state: State = None, **kwargs) -> State: 124 | """ 125 | 将报告结构写入状态 126 | 127 | Args: 128 | input_data: 输入数据 129 | state: 当前状态,如果为None则创建新状态 130 | **kwargs: 额外参数 131 | 132 | Returns: 133 | 更新后的状态 134 | """ 135 | if state is None: 136 | state = State() 137 | 138 | try: 139 | # 生成报告结构 140 | report_structure = self.run(input_data, **kwargs) 141 | 142 | # 设置查询和报告标题 143 | state.query = self.query 144 | if not state.report_title: 145 | state.report_title = f"关于'{self.query}'的深度研究报告" 146 | 147 | # 添加段落到状态 148 | for paragraph_data in report_structure: 149 | state.add_paragraph( 150 | title=paragraph_data["title"], 151 | content=paragraph_data["content"] 152 | ) 153 | 154 | self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中") 155 | return state 156 | 157 | except Exception as e: 158 | self.log_error(f"状态更新失败: {str(e)}") 159 | raise e 160 | -------------------------------------------------------------------------------- /src/utils/text_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | 文本处理工具函数 3 | 用于清理LLM输出、解析JSON等 4 | """ 5 | 6 | import re 7 | import json 8 | from typing import Dict, Any, List 9 | from json.decoder import JSONDecodeError 10 | 11 | 12 | def clean_json_tags(text: str) -> str: 13 | """ 14 | 清理文本中的JSON标签 15 | 16 | Args: 17 | text: 原始文本 18 | 19 | Returns: 20 | 清理后的文本 21 | """ 22 | # 移除```json 和 ```标签 23 | text = re.sub(r'```json\s*', '', text) 24 | text = re.sub(r'```\s*$', '', text) 25 | text = re.sub(r'```', '', text) 26 | 27 | return text.strip() 28 | 29 | 30 | def clean_markdown_tags(text: str) -> str: 31 | """ 32 | 清理文本中的Markdown标签 33 | 34 | Args: 35 | text: 原始文本 36 | 37 | Returns: 38 | 清理后的文本 39 | """ 40 | # 移除```markdown 和 ```标签 41 | text = re.sub(r'```markdown\s*', '', text) 42 | text = re.sub(r'```\s*$', '', text) 43 | text = re.sub(r'```', '', text) 44 | 45 | return text.strip() 46 | 47 | 48 | def remove_reasoning_from_output(text: str) -> str: 49 | """ 50 | 移除输出中的推理过程文本 51 | 52 | Args: 53 | text: 原始文本 54 | 55 | Returns: 56 | 清理后的文本 57 | """ 58 | # 移除常见的推理标识 59 | patterns = [ 60 | r'(?:reasoning|推理|思考|分析)[::]\s*.*?(?=\{|\[)', # 移除推理部分 61 | r'(?:explanation|解释|说明)[::]\s*.*?(?=\{|\[)', # 移除解释部分 62 | r'^.*?(?=\{|\[)', # 移除JSON前的所有文本 63 | ] 64 | 65 | for pattern in patterns: 66 | text = re.sub(pattern, '', text, flags=re.IGNORECASE | re.DOTALL) 67 | 68 | return text.strip() 69 | 70 | 71 | def extract_clean_response(text: str) -> Dict[str, Any]: 72 | """ 73 | 提取并清理响应中的JSON内容 74 | 75 | Args: 76 | text: 原始响应文本 77 | 78 | Returns: 79 | 解析后的JSON字典 80 | """ 81 | # 清理文本 82 | cleaned_text = clean_json_tags(text) 83 | cleaned_text = remove_reasoning_from_output(cleaned_text) 84 | 85 | # 尝试直接解析 86 | try: 87 | return json.loads(cleaned_text) 88 | except JSONDecodeError: 89 | pass 90 | 91 | # 尝试查找JSON对象 92 | json_pattern = r'\{.*\}' 93 | match = re.search(json_pattern, cleaned_text, re.DOTALL) 94 | if match: 95 | try: 96 | return json.loads(match.group()) 97 | except JSONDecodeError: 98 | pass 99 | 100 | # 尝试查找JSON数组 101 | array_pattern = r'\[.*\]' 102 | match = re.search(array_pattern, cleaned_text, re.DOTALL) 103 | if match: 104 | try: 105 | return json.loads(match.group()) 106 | except JSONDecodeError: 107 | pass 108 | 109 | # 如果所有方法都失败,返回错误信息 110 | print(f"无法解析JSON响应: {cleaned_text[:200]}...") 111 | return {"error": "JSON解析失败", "raw_text": cleaned_text} 112 | 113 | 114 | def update_state_with_search_results(search_results: List[Dict[str, Any]], 115 | paragraph_index: int, state: Any) -> Any: 116 | """ 117 | 将搜索结果更新到状态中 118 | 119 | Args: 120 | search_results: 搜索结果列表 121 | paragraph_index: 段落索引 122 | state: 状态对象 123 | 124 | Returns: 125 | 更新后的状态对象 126 | """ 127 | if 0 <= paragraph_index < len(state.paragraphs): 128 | # 获取最后一次搜索的查询(假设是当前查询) 129 | current_query = "" 130 | if search_results: 131 | # 从搜索结果推断查询(这里需要改进以获取实际查询) 132 | current_query = "搜索查询" 133 | 134 | # 添加搜索结果到状态 135 | state.paragraphs[paragraph_index].research.add_search_results( 136 | current_query, search_results 137 | ) 138 | 139 | return state 140 | 141 | 142 | def validate_json_schema(data: Dict[str, Any], required_fields: List[str]) -> bool: 143 | """ 144 | 验证JSON数据是否包含必需字段 145 | 146 | Args: 147 | data: 要验证的数据 148 | required_fields: 必需字段列表 149 | 150 | Returns: 151 | 验证是否通过 152 | """ 153 | return all(field in data for field in required_fields) 154 | 155 | 156 | def truncate_content(content: str, max_length: int = 20000) -> str: 157 | """ 158 | 截断内容到指定长度 159 | 160 | Args: 161 | content: 原始内容 162 | max_length: 最大长度 163 | 164 | Returns: 165 | 截断后的内容 166 | """ 167 | if len(content) <= max_length: 168 | return content 169 | 170 | # 尝试在单词边界截断 171 | truncated = content[:max_length] 172 | last_space = truncated.rfind(' ') 173 | 174 | if last_space > max_length * 0.8: # 如果最后一个空格位置合理 175 | return truncated[:last_space] + "..." 176 | else: 177 | return truncated + "..." 178 | 179 | 180 | def format_search_results_for_prompt(search_results: List[Dict[str, Any]], 181 | max_length: int = 20000) -> List[str]: 182 | """ 183 | 格式化搜索结果用于提示词 184 | 185 | Args: 186 | search_results: 搜索结果列表 187 | max_length: 每个结果的最大长度 188 | 189 | Returns: 190 | 格式化后的内容列表 191 | """ 192 | formatted_results = [] 193 | 194 | for result in search_results: 195 | content = result.get('content', '') 196 | if content: 197 | truncated_content = truncate_content(content, max_length) 198 | formatted_results.append(truncated_content) 199 | 200 | return formatted_results 201 | -------------------------------------------------------------------------------- /src/nodes/formatting_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | 报告格式化节点 3 | 负责将最终研究结果格式化为美观的Markdown报告 4 | """ 5 | 6 | import json 7 | from typing import List, Dict, Any 8 | 9 | from .base_node import BaseNode 10 | from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING 11 | from ..utils.text_processing import ( 12 | remove_reasoning_from_output, 13 | clean_markdown_tags 14 | ) 15 | 16 | 17 | class ReportFormattingNode(BaseNode): 18 | """格式化最终报告的节点""" 19 | 20 | def __init__(self, llm_client): 21 | """ 22 | 初始化报告格式化节点 23 | 24 | Args: 25 | llm_client: LLM客户端 26 | """ 27 | super().__init__(llm_client, "ReportFormattingNode") 28 | 29 | def validate_input(self, input_data: Any) -> bool: 30 | """验证输入数据""" 31 | if isinstance(input_data, str): 32 | try: 33 | data = json.loads(input_data) 34 | return isinstance(data, list) and all( 35 | isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item 36 | for item in data 37 | ) 38 | except: 39 | return False 40 | elif isinstance(input_data, list): 41 | return all( 42 | isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item 43 | for item in input_data 44 | ) 45 | return False 46 | 47 | def run(self, input_data: Any, **kwargs) -> str: 48 | """ 49 | 调用LLM生成Markdown格式报告 50 | 51 | Args: 52 | input_data: 包含所有段落信息的列表 53 | **kwargs: 额外参数 54 | 55 | Returns: 56 | 格式化的Markdown报告 57 | """ 58 | try: 59 | if not self.validate_input(input_data): 60 | raise ValueError("输入数据格式错误,需要包含title和paragraph_latest_state的列表") 61 | 62 | # 准备输入数据 63 | if isinstance(input_data, str): 64 | message = input_data 65 | else: 66 | message = json.dumps(input_data, ensure_ascii=False) 67 | 68 | self.log_info("正在格式化最终报告") 69 | 70 | # 调用LLM 71 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message) 72 | 73 | # 处理响应 74 | processed_response = self.process_output(response) 75 | 76 | self.log_info("成功生成格式化报告") 77 | return processed_response 78 | 79 | except Exception as e: 80 | self.log_error(f"报告格式化失败: {str(e)}") 81 | raise e 82 | 83 | def process_output(self, output: str) -> str: 84 | """ 85 | 处理LLM输出,清理Markdown格式 86 | 87 | Args: 88 | output: LLM原始输出 89 | 90 | Returns: 91 | 清理后的Markdown报告 92 | """ 93 | try: 94 | # 清理响应文本 95 | cleaned_output = remove_reasoning_from_output(output) 96 | cleaned_output = clean_markdown_tags(cleaned_output) 97 | 98 | # 确保报告有基本结构 99 | if not cleaned_output.strip(): 100 | return "# 报告生成失败\n\n无法生成有效的报告内容。" 101 | 102 | # 如果没有标题,添加一个默认标题 103 | if not cleaned_output.strip().startswith('#'): 104 | cleaned_output = "# 深度研究报告\n\n" + cleaned_output 105 | 106 | return cleaned_output.strip() 107 | 108 | except Exception as e: 109 | self.log_error(f"处理输出失败: {str(e)}") 110 | return "# 报告处理失败\n\n报告格式化过程中发生错误。" 111 | 112 | def format_report_manually(self, paragraphs_data: List[Dict[str, str]], 113 | report_title: str = "深度研究报告") -> str: 114 | """ 115 | 手动格式化报告(备用方法) 116 | 117 | Args: 118 | paragraphs_data: 段落数据列表 119 | report_title: 报告标题 120 | 121 | Returns: 122 | 格式化的Markdown报告 123 | """ 124 | try: 125 | self.log_info("使用手动格式化方法") 126 | 127 | # 构建报告 128 | report_lines = [ 129 | f"# {report_title}", 130 | "", 131 | "---", 132 | "" 133 | ] 134 | 135 | # 添加各个段落 136 | for i, paragraph in enumerate(paragraphs_data, 1): 137 | title = paragraph.get("title", f"段落 {i}") 138 | content = paragraph.get("paragraph_latest_state", "") 139 | 140 | if content: 141 | report_lines.extend([ 142 | f"## {title}", 143 | "", 144 | content, 145 | "", 146 | "---", 147 | "" 148 | ]) 149 | 150 | # 添加结论 151 | if len(paragraphs_data) > 1: 152 | report_lines.extend([ 153 | "## 结论", 154 | "", 155 | "本报告通过深度搜索和研究,对相关主题进行了全面分析。" 156 | "以上各个方面的内容为理解该主题提供了重要参考。", 157 | "" 158 | ]) 159 | 160 | return "\n".join(report_lines) 161 | 162 | except Exception as e: 163 | self.log_error(f"手动格式化失败: {str(e)}") 164 | return "# 报告生成失败\n\n无法完成报告格式化。" 165 | -------------------------------------------------------------------------------- /src/prompts/prompts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep Search Agent 的所有提示词定义 3 | 包含各个阶段的系统提示词和JSON Schema定义 4 | """ 5 | 6 | import json 7 | 8 | # ===== JSON Schema 定义 ===== 9 | 10 | # 报告结构输出Schema 11 | output_schema_report_structure = { 12 | "type": "array", 13 | "items": { 14 | "type": "object", 15 | "properties": { 16 | "title": {"type": "string"}, 17 | "content": {"type": "string"} 18 | } 19 | } 20 | } 21 | 22 | # 首次搜索输入Schema 23 | input_schema_first_search = { 24 | "type": "object", 25 | "properties": { 26 | "title": {"type": "string"}, 27 | "content": {"type": "string"} 28 | } 29 | } 30 | 31 | # 首次搜索输出Schema 32 | output_schema_first_search = { 33 | "type": "object", 34 | "properties": { 35 | "search_query": {"type": "string"}, 36 | "reasoning": {"type": "string"} 37 | } 38 | } 39 | 40 | # 首次总结输入Schema 41 | input_schema_first_summary = { 42 | "type": "object", 43 | "properties": { 44 | "title": {"type": "string"}, 45 | "content": {"type": "string"}, 46 | "search_query": {"type": "string"}, 47 | "search_results": { 48 | "type": "array", 49 | "items": {"type": "string"} 50 | } 51 | } 52 | } 53 | 54 | # 首次总结输出Schema 55 | output_schema_first_summary = { 56 | "type": "object", 57 | "properties": { 58 | "paragraph_latest_state": {"type": "string"} 59 | } 60 | } 61 | 62 | # 反思输入Schema 63 | input_schema_reflection = { 64 | "type": "object", 65 | "properties": { 66 | "title": {"type": "string"}, 67 | "content": {"type": "string"}, 68 | "paragraph_latest_state": {"type": "string"} 69 | } 70 | } 71 | 72 | # 反思输出Schema 73 | output_schema_reflection = { 74 | "type": "object", 75 | "properties": { 76 | "search_query": {"type": "string"}, 77 | "reasoning": {"type": "string"} 78 | } 79 | } 80 | 81 | # 反思总结输入Schema 82 | input_schema_reflection_summary = { 83 | "type": "object", 84 | "properties": { 85 | "title": {"type": "string"}, 86 | "content": {"type": "string"}, 87 | "search_query": {"type": "string"}, 88 | "search_results": { 89 | "type": "array", 90 | "items": {"type": "string"} 91 | }, 92 | "paragraph_latest_state": {"type": "string"} 93 | } 94 | } 95 | 96 | # 反思总结输出Schema 97 | output_schema_reflection_summary = { 98 | "type": "object", 99 | "properties": { 100 | "updated_paragraph_latest_state": {"type": "string"} 101 | } 102 | } 103 | 104 | # 报告格式化输入Schema 105 | input_schema_report_formatting = { 106 | "type": "array", 107 | "items": { 108 | "type": "object", 109 | "properties": { 110 | "title": {"type": "string"}, 111 | "paragraph_latest_state": {"type": "string"} 112 | } 113 | } 114 | } 115 | 116 | # ===== 系统提示词定义 ===== 117 | 118 | # 生成报告结构的系统提示词 119 | SYSTEM_PROMPT_REPORT_STRUCTURE = f""" 120 | 你是一位深度研究助手。给定一个查询,你需要规划一个报告的结构和其中包含的段落。最多五个段落。 121 | 确保段落的排序合理有序。 122 | 一旦大纲创建完成,你将获得工具来分别为每个部分搜索网络并进行反思。 123 | 请按照以下JSON模式定义格式化输出: 124 | 125 | 126 | {json.dumps(output_schema_report_structure, indent=2, ensure_ascii=False)} 127 | 128 | 129 | 标题和内容属性将用于更深入的研究。 130 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。 131 | 只返回JSON对象,不要有解释或额外文本。 132 | """ 133 | 134 | # 每个段落第一次搜索的系统提示词 135 | SYSTEM_PROMPT_FIRST_SEARCH = f""" 136 | 你是一位深度研究助手。你将获得报告中的一个段落,其标题和预期内容将按照以下JSON模式定义提供: 137 | 138 | 139 | {json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)} 140 | 141 | 142 | 你可以使用一个网络搜索工具,该工具接受'search_query'作为参数。 143 | 你的任务是思考这个主题,并提供最佳的网络搜索查询来丰富你当前的知识。 144 | 请按照以下JSON模式定义格式化输出(文字请使用中文): 145 | 146 | 147 | {json.dumps(output_schema_first_search, indent=2, ensure_ascii=False)} 148 | 149 | 150 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。 151 | 只返回JSON对象,不要有解释或额外文本。 152 | """ 153 | 154 | # 每个段落第一次总结的系统提示词 155 | SYSTEM_PROMPT_FIRST_SUMMARY = f""" 156 | 你是一位深度研究助手。你将获得搜索查询、搜索结果以及你正在研究的报告段落,数据将按照以下JSON模式定义提供: 157 | 158 | 159 | {json.dumps(input_schema_first_summary, indent=2, ensure_ascii=False)} 160 | 161 | 162 | 你的任务是作为研究者,使用搜索结果撰写与段落主题一致的内容,并适当地组织结构以便纳入报告中。 163 | 请按照以下JSON模式定义格式化输出: 164 | 165 | 166 | {json.dumps(output_schema_first_summary, indent=2, ensure_ascii=False)} 167 | 168 | 169 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。 170 | 只返回JSON对象,不要有解释或额外文本。 171 | """ 172 | 173 | # 反思(Reflect)的系统提示词 174 | SYSTEM_PROMPT_REFLECTION = f""" 175 | 你是一位深度研究助手。你负责为研究报告构建全面的段落。你将获得段落标题、计划内容摘要,以及你已经创建的段落最新状态,所有这些都将按照以下JSON模式定义提供: 176 | 177 | 178 | {json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)} 179 | 180 | 181 | 你可以使用一个网络搜索工具,该工具接受'search_query'作为参数。 182 | 你的任务是反思段落文本的当前状态,思考是否遗漏了主题的某些关键方面,并提供最佳的网络搜索查询来丰富最新状态。 183 | 请按照以下JSON模式定义格式化输出: 184 | 185 | 186 | {json.dumps(output_schema_reflection, indent=2, ensure_ascii=False)} 187 | 188 | 189 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。 190 | 只返回JSON对象,不要有解释或额外文本。 191 | """ 192 | 193 | # 总结反思的系统提示词 194 | SYSTEM_PROMPT_REFLECTION_SUMMARY = f""" 195 | 你是一位深度研究助手。 196 | 你将获得搜索查询、搜索结果、段落标题以及你正在研究的报告段落的预期内容。 197 | 你正在迭代完善这个段落,并且段落的最新状态也会提供给你。 198 | 数据将按照以下JSON模式定义提供: 199 | 200 | 201 | {json.dumps(input_schema_reflection_summary, indent=2, ensure_ascii=False)} 202 | 203 | 204 | 你的任务是根据搜索结果和预期内容丰富段落的当前最新状态。 205 | 不要删除最新状态中的关键信息,尽量丰富它,只添加缺失的信息。 206 | 适当地组织段落结构以便纳入报告中。 207 | 请按照以下JSON模式定义格式化输出: 208 | 209 | 210 | {json.dumps(output_schema_reflection_summary, indent=2, ensure_ascii=False)} 211 | 212 | 213 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。 214 | 只返回JSON对象,不要有解释或额外文本。 215 | """ 216 | 217 | # 最终研究报告格式化的系统提示词 218 | SYSTEM_PROMPT_REPORT_FORMATTING = f""" 219 | 你是一位深度研究助手。你已经完成了研究并构建了报告中所有段落的最终版本。 220 | 你将获得以下JSON格式的数据: 221 | 222 | 223 | {json.dumps(input_schema_report_formatting, indent=2, ensure_ascii=False)} 224 | 225 | 226 | 你的任务是将报告格式化为美观的形式,并以Markdown格式返回。 227 | 如果没有结论段落,请根据其他段落的最新状态在报告末尾添加一个结论。 228 | 使用段落标题来创建报告的标题。 229 | """ 230 | -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | 配置管理模块 3 | 处理环境变量和配置参数 4 | """ 5 | 6 | import os 7 | from dataclasses import dataclass 8 | from typing import Optional 9 | 10 | 11 | @dataclass 12 | class Config: 13 | """配置类""" 14 | # API密钥 15 | deepseek_api_key: Optional[str] = None 16 | openai_api_key: Optional[str] = None 17 | tavily_api_key: Optional[str] = None 18 | 19 | # 模型配置 20 | default_llm_provider: str = "deepseek" # deepseek 或 openai 21 | deepseek_model: str = "deepseek-chat" 22 | openai_model: str = "gpt-4o-mini" 23 | 24 | # 搜索配置 25 | max_search_results: int = 3 26 | search_timeout: int = 240 27 | max_content_length: int = 20000 28 | 29 | # Agent配置 30 | max_reflections: int = 2 31 | max_paragraphs: int = 5 32 | 33 | # 输出配置 34 | output_dir: str = "reports" 35 | save_intermediate_states: bool = True 36 | 37 | def validate(self) -> bool: 38 | """验证配置""" 39 | # 检查必需的API密钥 40 | if self.default_llm_provider == "deepseek" and not self.deepseek_api_key: 41 | print("错误: DeepSeek API Key未设置") 42 | return False 43 | 44 | if self.default_llm_provider == "openai" and not self.openai_api_key: 45 | print("错误: OpenAI API Key未设置") 46 | return False 47 | 48 | if not self.tavily_api_key: 49 | print("错误: Tavily API Key未设置") 50 | return False 51 | 52 | return True 53 | 54 | @classmethod 55 | def from_file(cls, config_file: str) -> "Config": 56 | """从配置文件创建配置""" 57 | if config_file.endswith('.py'): 58 | # Python配置文件 59 | import importlib.util 60 | 61 | # 动态导入配置文件 62 | spec = importlib.util.spec_from_file_location("config", config_file) 63 | config_module = importlib.util.module_from_spec(spec) 64 | spec.loader.exec_module(config_module) 65 | 66 | return cls( 67 | deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None), 68 | openai_api_key=getattr(config_module, "OPENAI_API_KEY", None), 69 | tavily_api_key=getattr(config_module, "TAVILY_API_KEY", None), 70 | default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"), 71 | deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"), 72 | openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"), 73 | max_search_results=getattr(config_module, "SEARCH_RESULTS_PER_QUERY", 3), 74 | search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240), 75 | max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000), 76 | max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2), 77 | max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5), 78 | output_dir=getattr(config_module, "OUTPUT_DIR", "reports"), 79 | save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True) 80 | ) 81 | else: 82 | # .env格式配置文件 83 | config_dict = {} 84 | 85 | if os.path.exists(config_file): 86 | with open(config_file, 'r', encoding='utf-8') as f: 87 | for line in f: 88 | line = line.strip() 89 | if line and not line.startswith('#') and '=' in line: 90 | key, value = line.split('=', 1) 91 | config_dict[key.strip()] = value.strip() 92 | 93 | return cls( 94 | deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"), 95 | openai_api_key=config_dict.get("OPENAI_API_KEY"), 96 | tavily_api_key=config_dict.get("TAVILY_API_KEY"), 97 | default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"), 98 | deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"), 99 | openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), 100 | max_search_results=int(config_dict.get("SEARCH_RESULTS_PER_QUERY", "3")), 101 | search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")), 102 | max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")), 103 | max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")), 104 | max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")), 105 | output_dir=config_dict.get("OUTPUT_DIR", "reports"), 106 | save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true" 107 | ) 108 | 109 | 110 | def load_config(config_file: Optional[str] = None) -> Config: 111 | """ 112 | 加载配置 113 | 114 | Args: 115 | config_file: 配置文件路径,如果不指定则使用默认路径 116 | 117 | Returns: 118 | 配置对象 119 | """ 120 | # 确定配置文件路径 121 | if config_file: 122 | if not os.path.exists(config_file): 123 | raise FileNotFoundError(f"配置文件不存在: {config_file}") 124 | file_to_load = config_file 125 | else: 126 | # 尝试加载常见的配置文件 127 | for config_path in ["config.py", "config.env", ".env"]: 128 | if os.path.exists(config_path): 129 | file_to_load = config_path 130 | print(f"已找到配置文件: {config_path}") 131 | break 132 | else: 133 | raise FileNotFoundError("未找到配置文件,请创建 config.py 文件") 134 | 135 | # 创建配置对象 136 | config = Config.from_file(file_to_load) 137 | 138 | # 验证配置 139 | if not config.validate(): 140 | raise ValueError("配置验证失败,请检查配置文件中的API密钥") 141 | 142 | return config 143 | 144 | 145 | def print_config(config: Config): 146 | """打印配置信息(隐藏敏感信息)""" 147 | print("\n=== 当前配置 ===") 148 | print(f"LLM提供商: {config.default_llm_provider}") 149 | print(f"DeepSeek模型: {config.deepseek_model}") 150 | print(f"OpenAI模型: {config.openai_model}") 151 | print(f"最大搜索结果数: {config.max_search_results}") 152 | print(f"搜索超时: {config.search_timeout}秒") 153 | print(f"最大内容长度: {config.max_content_length}") 154 | print(f"最大反思次数: {config.max_reflections}") 155 | print(f"最大段落数: {config.max_paragraphs}") 156 | print(f"输出目录: {config.output_dir}") 157 | print(f"保存中间状态: {config.save_intermediate_states}") 158 | 159 | # 显示API密钥状态(不显示实际密钥) 160 | print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}") 161 | print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}") 162 | print(f"Tavily API Key: {'已设置' if config.tavily_api_key else '未设置'}") 163 | print("==================\n") 164 | -------------------------------------------------------------------------------- /src/nodes/search_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | 搜索节点实现 3 | 负责生成搜索查询和反思查询 4 | """ 5 | 6 | import json 7 | from typing import Dict, Any 8 | from json.decoder import JSONDecodeError 9 | 10 | from .base_node import BaseNode 11 | from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION 12 | from ..utils.text_processing import ( 13 | remove_reasoning_from_output, 14 | clean_json_tags, 15 | extract_clean_response 16 | ) 17 | 18 | 19 | class FirstSearchNode(BaseNode): 20 | """为段落生成首次搜索查询的节点""" 21 | 22 | def __init__(self, llm_client): 23 | """ 24 | 初始化首次搜索节点 25 | 26 | Args: 27 | llm_client: LLM客户端 28 | """ 29 | super().__init__(llm_client, "FirstSearchNode") 30 | 31 | def validate_input(self, input_data: Any) -> bool: 32 | """验证输入数据""" 33 | if isinstance(input_data, str): 34 | try: 35 | data = json.loads(input_data) 36 | return "title" in data and "content" in data 37 | except JSONDecodeError: 38 | return False 39 | elif isinstance(input_data, dict): 40 | return "title" in input_data and "content" in input_data 41 | return False 42 | 43 | def run(self, input_data: Any, **kwargs) -> Dict[str, str]: 44 | """ 45 | 调用LLM生成搜索查询和理由 46 | 47 | Args: 48 | input_data: 包含title和content的字符串或字典 49 | **kwargs: 额外参数 50 | 51 | Returns: 52 | 包含search_query和reasoning的字典 53 | """ 54 | try: 55 | if not self.validate_input(input_data): 56 | raise ValueError("输入数据格式错误,需要包含title和content字段") 57 | 58 | # 准备输入数据 59 | if isinstance(input_data, str): 60 | message = input_data 61 | else: 62 | message = json.dumps(input_data, ensure_ascii=False) 63 | 64 | self.log_info("正在生成首次搜索查询") 65 | 66 | # 调用LLM 67 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) 68 | 69 | # 处理响应 70 | processed_response = self.process_output(response) 71 | 72 | self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}") 73 | return processed_response 74 | 75 | except Exception as e: 76 | self.log_error(f"生成首次搜索查询失败: {str(e)}") 77 | raise e 78 | 79 | def process_output(self, output: str) -> Dict[str, str]: 80 | """ 81 | 处理LLM输出,提取搜索查询和推理 82 | 83 | Args: 84 | output: LLM原始输出 85 | 86 | Returns: 87 | 包含search_query和reasoning的字典 88 | """ 89 | try: 90 | # 清理响应文本 91 | cleaned_output = remove_reasoning_from_output(output) 92 | cleaned_output = clean_json_tags(cleaned_output) 93 | 94 | # 解析JSON 95 | try: 96 | result = json.loads(cleaned_output) 97 | except JSONDecodeError: 98 | # 使用更强大的提取方法 99 | result = extract_clean_response(cleaned_output) 100 | if "error" in result: 101 | raise ValueError("JSON解析失败") 102 | 103 | # 验证和清理结果 104 | search_query = result.get("search_query", "") 105 | reasoning = result.get("reasoning", "") 106 | 107 | if not search_query: 108 | raise ValueError("未找到搜索查询") 109 | 110 | return { 111 | "search_query": search_query, 112 | "reasoning": reasoning 113 | } 114 | 115 | except Exception as e: 116 | self.log_error(f"处理输出失败: {str(e)}") 117 | # 返回默认查询 118 | return { 119 | "search_query": "相关主题研究", 120 | "reasoning": "由于解析失败,使用默认搜索查询" 121 | } 122 | 123 | 124 | class ReflectionNode(BaseNode): 125 | """反思段落并生成新搜索查询的节点""" 126 | 127 | def __init__(self, llm_client): 128 | """ 129 | 初始化反思节点 130 | 131 | Args: 132 | llm_client: LLM客户端 133 | """ 134 | super().__init__(llm_client, "ReflectionNode") 135 | 136 | def validate_input(self, input_data: Any) -> bool: 137 | """验证输入数据""" 138 | if isinstance(input_data, str): 139 | try: 140 | data = json.loads(input_data) 141 | required_fields = ["title", "content", "paragraph_latest_state"] 142 | return all(field in data for field in required_fields) 143 | except JSONDecodeError: 144 | return False 145 | elif isinstance(input_data, dict): 146 | required_fields = ["title", "content", "paragraph_latest_state"] 147 | return all(field in input_data for field in required_fields) 148 | return False 149 | 150 | def run(self, input_data: Any, **kwargs) -> Dict[str, str]: 151 | """ 152 | 调用LLM反思并生成搜索查询 153 | 154 | Args: 155 | input_data: 包含title、content和paragraph_latest_state的字符串或字典 156 | **kwargs: 额外参数 157 | 158 | Returns: 159 | 包含search_query和reasoning的字典 160 | """ 161 | try: 162 | if not self.validate_input(input_data): 163 | raise ValueError("输入数据格式错误,需要包含title、content和paragraph_latest_state字段") 164 | 165 | # 准备输入数据 166 | if isinstance(input_data, str): 167 | message = input_data 168 | else: 169 | message = json.dumps(input_data, ensure_ascii=False) 170 | 171 | self.log_info("正在进行反思并生成新搜索查询") 172 | 173 | # 调用LLM 174 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) 175 | 176 | # 处理响应 177 | processed_response = self.process_output(response) 178 | 179 | self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}") 180 | return processed_response 181 | 182 | except Exception as e: 183 | self.log_error(f"反思生成搜索查询失败: {str(e)}") 184 | raise e 185 | 186 | def process_output(self, output: str) -> Dict[str, str]: 187 | """ 188 | 处理LLM输出,提取搜索查询和推理 189 | 190 | Args: 191 | output: LLM原始输出 192 | 193 | Returns: 194 | 包含search_query和reasoning的字典 195 | """ 196 | try: 197 | # 清理响应文本 198 | cleaned_output = remove_reasoning_from_output(output) 199 | cleaned_output = clean_json_tags(cleaned_output) 200 | 201 | # 解析JSON 202 | try: 203 | result = json.loads(cleaned_output) 204 | except JSONDecodeError: 205 | # 使用更强大的提取方法 206 | result = extract_clean_response(cleaned_output) 207 | if "error" in result: 208 | raise ValueError("JSON解析失败") 209 | 210 | # 验证和清理结果 211 | search_query = result.get("search_query", "") 212 | reasoning = result.get("reasoning", "") 213 | 214 | if not search_query: 215 | raise ValueError("未找到搜索查询") 216 | 217 | return { 218 | "search_query": search_query, 219 | "reasoning": reasoning 220 | } 221 | 222 | except Exception as e: 223 | self.log_error(f"处理输出失败: {str(e)}") 224 | # 返回默认查询 225 | return { 226 | "search_query": "深度研究补充信息", 227 | "reasoning": "由于解析失败,使用默认反思搜索查询" 228 | } 229 | -------------------------------------------------------------------------------- /examples/streamlit_app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Streamlit Web界面 3 | 为Deep Search Agent提供友好的Web界面 4 | """ 5 | 6 | import os 7 | import sys 8 | import streamlit as st 9 | from datetime import datetime 10 | import json 11 | 12 | # 添加src目录到Python路径 13 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) 14 | 15 | from src import DeepSearchAgent, Config 16 | 17 | 18 | def main(): 19 | """主函数""" 20 | st.set_page_config( 21 | page_title="Deep Search Agent", 22 | page_icon="🔍", 23 | layout="wide" 24 | ) 25 | 26 | st.title("Deep Search Agent") 27 | st.markdown("基于DeepSeek的无框架深度搜索AI代理") 28 | 29 | # 侧边栏配置 30 | with st.sidebar: 31 | st.header("配置") 32 | 33 | # API密钥配置 34 | st.subheader("API密钥") 35 | deepseek_key = st.text_input("DeepSeek API Key", type="password", 36 | value="") 37 | tavily_key = st.text_input("Tavily API Key", type="password", 38 | value="") 39 | 40 | # 高级配置 41 | st.subheader("高级配置") 42 | max_reflections = st.slider("反思次数", 1, 5, 2) 43 | max_search_results = st.slider("搜索结果数", 1, 10, 3) 44 | max_content_length = st.number_input("最大内容长度", 1000, 50000, 20000) 45 | 46 | # 模型选择 47 | llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"]) 48 | 49 | if llm_provider == "deepseek": 50 | model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"]) 51 | else: 52 | model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"]) 53 | openai_key = st.text_input("OpenAI API Key", type="password", 54 | value="") 55 | 56 | # 主界面 57 | col1, col2 = st.columns([2, 1]) 58 | 59 | with col1: 60 | st.header("研究查询") 61 | query = st.text_area( 62 | "请输入您要研究的问题", 63 | placeholder="例如:2025年人工智能发展趋势", 64 | height=100 65 | ) 66 | 67 | # 预设查询示例 68 | st.subheader("示例查询") 69 | example_queries = [ 70 | "2025年人工智能发展趋势", 71 | "深度学习在医疗领域的应用", 72 | "区块链技术的最新发展", 73 | "可持续能源技术趋势", 74 | "量子计算的发展现状" 75 | ] 76 | 77 | selected_example = st.selectbox("选择示例查询", ["自定义"] + example_queries) 78 | if selected_example != "自定义": 79 | query = selected_example 80 | 81 | with col2: 82 | st.header("状态信息") 83 | if 'agent' in st.session_state and hasattr(st.session_state.agent, 'state'): 84 | progress = st.session_state.agent.get_progress_summary() 85 | st.metric("总段落数", progress['total_paragraphs']) 86 | st.metric("已完成", progress['completed_paragraphs']) 87 | st.progress(progress['progress_percentage'] / 100) 88 | else: 89 | st.info("尚未开始研究") 90 | 91 | # 执行按钮 92 | col1, col2, col3 = st.columns([1, 1, 1]) 93 | with col2: 94 | start_research = st.button("开始研究", type="primary", use_container_width=True) 95 | 96 | # 验证配置 97 | if start_research: 98 | if not query.strip(): 99 | st.error("请输入研究查询") 100 | return 101 | 102 | if not deepseek_key and llm_provider == "deepseek": 103 | st.error("请提供DeepSeek API Key") 104 | return 105 | 106 | if not tavily_key: 107 | st.error("请提供Tavily API Key") 108 | return 109 | 110 | if llm_provider == "openai" and not openai_key: 111 | st.error("请提供OpenAI API Key") 112 | return 113 | 114 | # 创建配置 115 | config = Config( 116 | deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None, 117 | openai_api_key=openai_key if llm_provider == "openai" else None, 118 | tavily_api_key=tavily_key, 119 | default_llm_provider=llm_provider, 120 | deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat", 121 | openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini", 122 | max_reflections=max_reflections, 123 | max_search_results=max_search_results, 124 | max_content_length=max_content_length, 125 | output_dir="streamlit_reports" 126 | ) 127 | 128 | # 执行研究 129 | execute_research(query, config) 130 | 131 | 132 | def execute_research(query: str, config: Config): 133 | """执行研究""" 134 | try: 135 | # 创建进度条 136 | progress_bar = st.progress(0) 137 | status_text = st.empty() 138 | 139 | # 初始化Agent 140 | status_text.text("正在初始化Agent...") 141 | agent = DeepSearchAgent(config) 142 | st.session_state.agent = agent 143 | 144 | progress_bar.progress(10) 145 | 146 | # 生成报告结构 147 | status_text.text("正在生成报告结构...") 148 | agent._generate_report_structure(query) 149 | progress_bar.progress(20) 150 | 151 | # 处理段落 152 | total_paragraphs = len(agent.state.paragraphs) 153 | for i in range(total_paragraphs): 154 | status_text.text(f"正在处理段落 {i+1}/{total_paragraphs}: {agent.state.paragraphs[i].title}") 155 | 156 | # 初始搜索和总结 157 | agent._initial_search_and_summary(i) 158 | progress_value = 20 + (i + 0.5) / total_paragraphs * 60 159 | progress_bar.progress(int(progress_value)) 160 | 161 | # 反思循环 162 | agent._reflection_loop(i) 163 | agent.state.paragraphs[i].research.mark_completed() 164 | 165 | progress_value = 20 + (i + 1) / total_paragraphs * 60 166 | progress_bar.progress(int(progress_value)) 167 | 168 | # 生成最终报告 169 | status_text.text("正在生成最终报告...") 170 | final_report = agent._generate_final_report() 171 | progress_bar.progress(90) 172 | 173 | # 保存报告 174 | status_text.text("正在保存报告...") 175 | agent._save_report(final_report) 176 | progress_bar.progress(100) 177 | 178 | status_text.text("研究完成!") 179 | 180 | # 显示结果 181 | display_results(agent, final_report) 182 | 183 | except Exception as e: 184 | st.error(f"研究过程中发生错误: {str(e)}") 185 | 186 | 187 | def display_results(agent: DeepSearchAgent, final_report: str): 188 | """显示研究结果""" 189 | st.header("研究结果") 190 | 191 | # 结果标签页 192 | tab1, tab2, tab3 = st.tabs(["最终报告", "详细信息", "下载"]) 193 | 194 | with tab1: 195 | st.markdown(final_report) 196 | 197 | with tab2: 198 | # 段落详情 199 | st.subheader("段落详情") 200 | for i, paragraph in enumerate(agent.state.paragraphs): 201 | with st.expander(f"段落 {i+1}: {paragraph.title}"): 202 | st.write("**预期内容:**", paragraph.content) 203 | st.write("**最终内容:**", paragraph.research.latest_summary[:300] + "..." 204 | if len(paragraph.research.latest_summary) > 300 205 | else paragraph.research.latest_summary) 206 | st.write("**搜索次数:**", paragraph.research.get_search_count()) 207 | st.write("**反思次数:**", paragraph.research.reflection_iteration) 208 | 209 | # 搜索历史 210 | st.subheader("搜索历史") 211 | all_searches = [] 212 | for paragraph in agent.state.paragraphs: 213 | all_searches.extend(paragraph.research.search_history) 214 | 215 | if all_searches: 216 | for i, search in enumerate(all_searches): 217 | with st.expander(f"搜索 {i+1}: {search.query}"): 218 | st.write("**URL:**", search.url) 219 | st.write("**标题:**", search.title) 220 | st.write("**内容预览:**", search.content[:200] + "..." if len(search.content) > 200 else search.content) 221 | if search.score: 222 | st.write("**相关度评分:**", search.score) 223 | 224 | with tab3: 225 | # 下载选项 226 | st.subheader("下载报告") 227 | 228 | # Markdown下载 229 | st.download_button( 230 | label="下载Markdown报告", 231 | data=final_report, 232 | file_name=f"deep_search_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md", 233 | mime="text/markdown" 234 | ) 235 | 236 | # JSON状态下载 237 | state_json = agent.state.to_json() 238 | st.download_button( 239 | label="下载状态文件", 240 | data=state_json, 241 | file_name=f"deep_search_state_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", 242 | mime="application/json" 243 | ) 244 | 245 | 246 | if __name__ == "__main__": 247 | main() 248 | -------------------------------------------------------------------------------- /src/nodes/summary_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | 总结节点实现 3 | 负责根据搜索结果生成和更新段落内容 4 | """ 5 | 6 | import json 7 | from typing import Dict, Any, List 8 | from json.decoder import JSONDecodeError 9 | 10 | from .base_node import StateMutationNode 11 | from ..state.state import State 12 | from ..prompts import SYSTEM_PROMPT_FIRST_SUMMARY, SYSTEM_PROMPT_REFLECTION_SUMMARY 13 | from ..utils.text_processing import ( 14 | remove_reasoning_from_output, 15 | clean_json_tags, 16 | extract_clean_response, 17 | format_search_results_for_prompt 18 | ) 19 | 20 | 21 | class FirstSummaryNode(StateMutationNode): 22 | """根据搜索结果生成段落首次总结的节点""" 23 | 24 | def __init__(self, llm_client): 25 | """ 26 | 初始化首次总结节点 27 | 28 | Args: 29 | llm_client: LLM客户端 30 | """ 31 | super().__init__(llm_client, "FirstSummaryNode") 32 | 33 | def validate_input(self, input_data: Any) -> bool: 34 | """验证输入数据""" 35 | if isinstance(input_data, str): 36 | try: 37 | data = json.loads(input_data) 38 | required_fields = ["title", "content", "search_query", "search_results"] 39 | return all(field in data for field in required_fields) 40 | except JSONDecodeError: 41 | return False 42 | elif isinstance(input_data, dict): 43 | required_fields = ["title", "content", "search_query", "search_results"] 44 | return all(field in input_data for field in required_fields) 45 | return False 46 | 47 | def run(self, input_data: Any, **kwargs) -> str: 48 | """ 49 | 调用LLM生成段落总结 50 | 51 | Args: 52 | input_data: 包含title、content、search_query和search_results的数据 53 | **kwargs: 额外参数 54 | 55 | Returns: 56 | 段落总结内容 57 | """ 58 | try: 59 | if not self.validate_input(input_data): 60 | raise ValueError("输入数据格式错误") 61 | 62 | # 准备输入数据 63 | if isinstance(input_data, str): 64 | message = input_data 65 | else: 66 | message = json.dumps(input_data, ensure_ascii=False) 67 | 68 | self.log_info("正在生成首次段落总结") 69 | 70 | # 调用LLM 71 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message) 72 | 73 | # 处理响应 74 | processed_response = self.process_output(response) 75 | 76 | self.log_info("成功生成首次段落总结") 77 | return processed_response 78 | 79 | except Exception as e: 80 | self.log_error(f"生成首次总结失败: {str(e)}") 81 | raise e 82 | 83 | def process_output(self, output: str) -> str: 84 | """ 85 | 处理LLM输出,提取段落总结 86 | 87 | Args: 88 | output: LLM原始输出 89 | 90 | Returns: 91 | 段落总结内容 92 | """ 93 | try: 94 | # 清理响应文本 95 | cleaned_output = remove_reasoning_from_output(output) 96 | cleaned_output = clean_json_tags(cleaned_output) 97 | 98 | # 解析JSON 99 | try: 100 | result = json.loads(cleaned_output) 101 | except JSONDecodeError: 102 | # 如果不是JSON格式,直接返回清理后的文本 103 | return cleaned_output 104 | 105 | # 提取段落内容 106 | if isinstance(result, dict): 107 | paragraph_content = result.get("paragraph_latest_state", "") 108 | if paragraph_content: 109 | return paragraph_content 110 | 111 | # 如果提取失败,返回原始清理后的文本 112 | return cleaned_output 113 | 114 | except Exception as e: 115 | self.log_error(f"处理输出失败: {str(e)}") 116 | return "段落总结生成失败" 117 | 118 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: 119 | """ 120 | 更新段落的最新总结到状态 121 | 122 | Args: 123 | input_data: 输入数据 124 | state: 当前状态 125 | paragraph_index: 段落索引 126 | **kwargs: 额外参数 127 | 128 | Returns: 129 | 更新后的状态 130 | """ 131 | try: 132 | # 生成总结 133 | summary = self.run(input_data, **kwargs) 134 | 135 | # 更新状态 136 | if 0 <= paragraph_index < len(state.paragraphs): 137 | state.paragraphs[paragraph_index].research.latest_summary = summary 138 | self.log_info(f"已更新段落 {paragraph_index} 的首次总结") 139 | else: 140 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") 141 | 142 | state.update_timestamp() 143 | return state 144 | 145 | except Exception as e: 146 | self.log_error(f"状态更新失败: {str(e)}") 147 | raise e 148 | 149 | 150 | class ReflectionSummaryNode(StateMutationNode): 151 | """根据反思搜索结果更新段落总结的节点""" 152 | 153 | def __init__(self, llm_client): 154 | """ 155 | 初始化反思总结节点 156 | 157 | Args: 158 | llm_client: LLM客户端 159 | """ 160 | super().__init__(llm_client, "ReflectionSummaryNode") 161 | 162 | def validate_input(self, input_data: Any) -> bool: 163 | """验证输入数据""" 164 | if isinstance(input_data, str): 165 | try: 166 | data = json.loads(input_data) 167 | required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"] 168 | return all(field in data for field in required_fields) 169 | except JSONDecodeError: 170 | return False 171 | elif isinstance(input_data, dict): 172 | required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"] 173 | return all(field in input_data for field in required_fields) 174 | return False 175 | 176 | def run(self, input_data: Any, **kwargs) -> str: 177 | """ 178 | 调用LLM更新段落内容 179 | 180 | Args: 181 | input_data: 包含完整反思信息的数据 182 | **kwargs: 额外参数 183 | 184 | Returns: 185 | 更新后的段落内容 186 | """ 187 | try: 188 | if not self.validate_input(input_data): 189 | raise ValueError("输入数据格式错误") 190 | 191 | # 准备输入数据 192 | if isinstance(input_data, str): 193 | message = input_data 194 | else: 195 | message = json.dumps(input_data, ensure_ascii=False) 196 | 197 | self.log_info("正在生成反思总结") 198 | 199 | # 调用LLM 200 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message) 201 | 202 | # 处理响应 203 | processed_response = self.process_output(response) 204 | 205 | self.log_info("成功生成反思总结") 206 | return processed_response 207 | 208 | except Exception as e: 209 | self.log_error(f"生成反思总结失败: {str(e)}") 210 | raise e 211 | 212 | def process_output(self, output: str) -> str: 213 | """ 214 | 处理LLM输出,提取更新后的段落内容 215 | 216 | Args: 217 | output: LLM原始输出 218 | 219 | Returns: 220 | 更新后的段落内容 221 | """ 222 | try: 223 | # 清理响应文本 224 | cleaned_output = remove_reasoning_from_output(output) 225 | cleaned_output = clean_json_tags(cleaned_output) 226 | 227 | # 解析JSON 228 | try: 229 | result = json.loads(cleaned_output) 230 | except JSONDecodeError: 231 | # 如果不是JSON格式,直接返回清理后的文本 232 | return cleaned_output 233 | 234 | # 提取更新后的段落内容 235 | if isinstance(result, dict): 236 | updated_content = result.get("updated_paragraph_latest_state", "") 237 | if updated_content: 238 | return updated_content 239 | 240 | # 如果提取失败,返回原始清理后的文本 241 | return cleaned_output 242 | 243 | except Exception as e: 244 | self.log_error(f"处理输出失败: {str(e)}") 245 | return "反思总结生成失败" 246 | 247 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: 248 | """ 249 | 将更新后的总结写入状态 250 | 251 | Args: 252 | input_data: 输入数据 253 | state: 当前状态 254 | paragraph_index: 段落索引 255 | **kwargs: 额外参数 256 | 257 | Returns: 258 | 更新后的状态 259 | """ 260 | try: 261 | # 生成更新后的总结 262 | updated_summary = self.run(input_data, **kwargs) 263 | 264 | # 更新状态 265 | if 0 <= paragraph_index < len(state.paragraphs): 266 | state.paragraphs[paragraph_index].research.latest_summary = updated_summary 267 | state.paragraphs[paragraph_index].research.increment_reflection() 268 | self.log_info(f"已更新段落 {paragraph_index} 的反思总结") 269 | else: 270 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") 271 | 272 | state.update_timestamp() 273 | return state 274 | 275 | except Exception as e: 276 | self.log_error(f"状态更新失败: {str(e)}") 277 | raise e 278 | -------------------------------------------------------------------------------- /src/state/state.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep Search Agent状态管理 3 | 定义所有状态数据结构和操作方法 4 | """ 5 | 6 | from dataclasses import dataclass, field 7 | from typing import List, Dict, Any, Optional 8 | import json 9 | from datetime import datetime 10 | 11 | 12 | @dataclass 13 | class Search: 14 | """单个搜索结果的状态""" 15 | query: str = "" # 搜索查询 16 | url: str = "" # 搜索结果的链接 17 | title: str = "" # 搜索结果标题 18 | content: str = "" # 搜索返回的内容 19 | score: Optional[float] = None # 相关度评分 20 | timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) 21 | 22 | def to_dict(self) -> Dict[str, Any]: 23 | """转换为字典格式""" 24 | return { 25 | "query": self.query, 26 | "url": self.url, 27 | "title": self.title, 28 | "content": self.content, 29 | "score": self.score, 30 | "timestamp": self.timestamp 31 | } 32 | 33 | @classmethod 34 | def from_dict(cls, data: Dict[str, Any]) -> "Search": 35 | """从字典创建Search对象""" 36 | return cls( 37 | query=data.get("query", ""), 38 | url=data.get("url", ""), 39 | title=data.get("title", ""), 40 | content=data.get("content", ""), 41 | score=data.get("score"), 42 | timestamp=data.get("timestamp", datetime.now().isoformat()) 43 | ) 44 | 45 | 46 | @dataclass 47 | class Research: 48 | """段落研究过程的状态""" 49 | search_history: List[Search] = field(default_factory=list) # 搜索记录列表 50 | latest_summary: str = "" # 当前段落的最新总结 51 | reflection_iteration: int = 0 # 反思迭代次数 52 | is_completed: bool = False # 是否完成研究 53 | 54 | def add_search(self, search: Search): 55 | """添加搜索记录""" 56 | self.search_history.append(search) 57 | 58 | def add_search_results(self, query: str, results: List[Dict[str, Any]]): 59 | """批量添加搜索结果""" 60 | for result in results: 61 | search = Search( 62 | query=query, 63 | url=result.get("url", ""), 64 | title=result.get("title", ""), 65 | content=result.get("content", ""), 66 | score=result.get("score") 67 | ) 68 | self.add_search(search) 69 | 70 | def get_search_count(self) -> int: 71 | """获取搜索次数""" 72 | return len(self.search_history) 73 | 74 | def increment_reflection(self): 75 | """增加反思次数""" 76 | self.reflection_iteration += 1 77 | 78 | def mark_completed(self): 79 | """标记为完成""" 80 | self.is_completed = True 81 | 82 | def to_dict(self) -> Dict[str, Any]: 83 | """转换为字典格式""" 84 | return { 85 | "search_history": [search.to_dict() for search in self.search_history], 86 | "latest_summary": self.latest_summary, 87 | "reflection_iteration": self.reflection_iteration, 88 | "is_completed": self.is_completed 89 | } 90 | 91 | @classmethod 92 | def from_dict(cls, data: Dict[str, Any]) -> "Research": 93 | """从字典创建Research对象""" 94 | search_history = [Search.from_dict(search_data) for search_data in data.get("search_history", [])] 95 | return cls( 96 | search_history=search_history, 97 | latest_summary=data.get("latest_summary", ""), 98 | reflection_iteration=data.get("reflection_iteration", 0), 99 | is_completed=data.get("is_completed", False) 100 | ) 101 | 102 | 103 | @dataclass 104 | class Paragraph: 105 | """报告中单个段落的状态""" 106 | title: str = "" # 段落标题 107 | content: str = "" # 段落的预期内容(初始规划) 108 | research: Research = field(default_factory=Research) # 研究进度 109 | order: int = 0 # 段落顺序 110 | 111 | def is_completed(self) -> bool: 112 | """检查段落是否完成""" 113 | return self.research.is_completed and bool(self.research.latest_summary) 114 | 115 | def get_final_content(self) -> str: 116 | """获取最终内容""" 117 | return self.research.latest_summary or self.content 118 | 119 | def to_dict(self) -> Dict[str, Any]: 120 | """转换为字典格式""" 121 | return { 122 | "title": self.title, 123 | "content": self.content, 124 | "research": self.research.to_dict(), 125 | "order": self.order 126 | } 127 | 128 | @classmethod 129 | def from_dict(cls, data: Dict[str, Any]) -> "Paragraph": 130 | """从字典创建Paragraph对象""" 131 | research_data = data.get("research", {}) 132 | research = Research.from_dict(research_data) if research_data else Research() 133 | 134 | return cls( 135 | title=data.get("title", ""), 136 | content=data.get("content", ""), 137 | research=research, 138 | order=data.get("order", 0) 139 | ) 140 | 141 | 142 | @dataclass 143 | class State: 144 | """整个报告的状态""" 145 | query: str = "" # 原始查询 146 | report_title: str = "" # 报告标题 147 | paragraphs: List[Paragraph] = field(default_factory=list) # 段落列表 148 | final_report: str = "" # 最终报告内容 149 | is_completed: bool = False # 是否完成 150 | created_at: str = field(default_factory=lambda: datetime.now().isoformat()) 151 | updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) 152 | 153 | def add_paragraph(self, title: str, content: str) -> int: 154 | """ 155 | 添加段落 156 | 157 | Args: 158 | title: 段落标题 159 | content: 段落内容 160 | 161 | Returns: 162 | 段落索引 163 | """ 164 | order = len(self.paragraphs) 165 | paragraph = Paragraph(title=title, content=content, order=order) 166 | self.paragraphs.append(paragraph) 167 | self.update_timestamp() 168 | return order 169 | 170 | def get_paragraph(self, index: int) -> Optional[Paragraph]: 171 | """获取指定索引的段落""" 172 | if 0 <= index < len(self.paragraphs): 173 | return self.paragraphs[index] 174 | return None 175 | 176 | def get_completed_paragraphs_count(self) -> int: 177 | """获取已完成段落数量""" 178 | return sum(1 for p in self.paragraphs if p.is_completed()) 179 | 180 | def get_total_paragraphs_count(self) -> int: 181 | """获取总段落数量""" 182 | return len(self.paragraphs) 183 | 184 | def is_all_paragraphs_completed(self) -> bool: 185 | """检查是否所有段落都完成""" 186 | return all(p.is_completed() for p in self.paragraphs) if self.paragraphs else False 187 | 188 | def mark_completed(self): 189 | """标记整个报告为完成""" 190 | self.is_completed = True 191 | self.update_timestamp() 192 | 193 | def update_timestamp(self): 194 | """更新时间戳""" 195 | self.updated_at = datetime.now().isoformat() 196 | 197 | def get_progress_summary(self) -> Dict[str, Any]: 198 | """获取进度摘要""" 199 | completed = self.get_completed_paragraphs_count() 200 | total = self.get_total_paragraphs_count() 201 | 202 | return { 203 | "total_paragraphs": total, 204 | "completed_paragraphs": completed, 205 | "progress_percentage": (completed / total * 100) if total > 0 else 0, 206 | "is_completed": self.is_completed, 207 | "created_at": self.created_at, 208 | "updated_at": self.updated_at 209 | } 210 | 211 | def to_dict(self) -> Dict[str, Any]: 212 | """转换为字典格式""" 213 | return { 214 | "query": self.query, 215 | "report_title": self.report_title, 216 | "paragraphs": [p.to_dict() for p in self.paragraphs], 217 | "final_report": self.final_report, 218 | "is_completed": self.is_completed, 219 | "created_at": self.created_at, 220 | "updated_at": self.updated_at 221 | } 222 | 223 | def to_json(self, indent: int = 2) -> str: 224 | """转换为JSON字符串""" 225 | return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False) 226 | 227 | @classmethod 228 | def from_dict(cls, data: Dict[str, Any]) -> "State": 229 | """从字典创建State对象""" 230 | paragraphs = [Paragraph.from_dict(p_data) for p_data in data.get("paragraphs", [])] 231 | 232 | return cls( 233 | query=data.get("query", ""), 234 | report_title=data.get("report_title", ""), 235 | paragraphs=paragraphs, 236 | final_report=data.get("final_report", ""), 237 | is_completed=data.get("is_completed", False), 238 | created_at=data.get("created_at", datetime.now().isoformat()), 239 | updated_at=data.get("updated_at", datetime.now().isoformat()) 240 | ) 241 | 242 | @classmethod 243 | def from_json(cls, json_str: str) -> "State": 244 | """从JSON字符串创建State对象""" 245 | data = json.loads(json_str) 246 | return cls.from_dict(data) 247 | 248 | def save_to_file(self, filepath: str): 249 | """保存状态到文件""" 250 | with open(filepath, 'w', encoding='utf-8') as f: 251 | f.write(self.to_json()) 252 | 253 | @classmethod 254 | def load_from_file(cls, filepath: str) -> "State": 255 | """从文件加载状态""" 256 | with open(filepath, 'r', encoding='utf-8') as f: 257 | json_str = f.read() 258 | return cls.from_json(json_str) 259 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Search Agent 2 | 3 | [![Python](https://img.shields.io/badge/Python-3.9+-blue.svg)](https://python.org) 4 | [![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE) 5 | [![DeepSeek](https://img.shields.io/badge/LLM-DeepSeek-red.svg)](https://platform.deepseek.com/) 6 | [![Tavily](https://img.shields.io/badge/Search-Tavily-yellow.svg)](https://tavily.com/) 7 | 8 | 一个**无框架**的深度搜索AI代理实现,能够通过多轮搜索和反思生成高质量的研究报告。 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | ## 特性 17 | 18 | - **无框架设计**: 从零实现,不依赖LangChain等重型框架 19 | - **多LLM支持**: 支持DeepSeek、OpenAI等主流大语言模型 20 | - **智能搜索**: 集成Tavily搜索引擎,提供高质量网络搜索 21 | - **反思机制**: 多轮反思优化,确保研究深度和完整性 22 | - **状态管理**: 完整的研究过程状态跟踪和恢复 23 | - **Web界面**: Streamlit友好界面,易于使用 24 | - **Markdown输出**: 美观的Markdown格式研究报告 25 | 26 | ## 工作原理 27 | 28 | Deep Search Agent采用分阶段的研究方法: 29 | 30 | ```mermaid 31 | graph TD 32 | A[用户查询] --> B[生成报告结构] 33 | B --> C[遍历每个段落] 34 | C --> D[初始搜索] 35 | D --> E[生成初始总结] 36 | E --> F[反思循环] 37 | F --> G[反思搜索] 38 | G --> H[更新总结] 39 | H --> I{达到反思次数?} 40 | I -->|否| F 41 | I -->|是| J{所有段落完成?} 42 | J -->|否| C 43 | J -->|是| K[格式化最终报告] 44 | K --> L[输出报告] 45 | ``` 46 | 47 | ### 核心流程 48 | 49 | 1. **结构生成**: 根据查询生成报告大纲和段落结构 50 | 2. **初始研究**: 为每个段落生成搜索查询并获取相关信息 51 | 3. **初始总结**: 基于搜索结果生成段落初稿 52 | 4. **反思优化**: 多轮反思,发现遗漏并补充搜索 53 | 5. **最终整合**: 将所有段落整合为完整的Markdown报告 54 | 55 | ## 快速开始 56 | 57 | ### 1. 环境准备 58 | 59 | 确保您的系统安装了Python 3.9或更高版本: 60 | 61 | ```bash 62 | python --version 63 | ``` 64 | 65 | ### 2. 克隆项目 66 | 67 | ```bash 68 | git clone 69 | cd Demo\ DeepSearch\ Agent 70 | ``` 71 | 72 | ### 3. 安装依赖 73 | 74 | ```bash 75 | # 激活虚拟环境(推荐) 76 | conda activate pytorch_python11 # 或者使用其他虚拟环境 77 | 78 | # 安装依赖 79 | pip install -r requirements.txt 80 | ``` 81 | 82 | ### 4. 配置API密钥 83 | 84 | 项目根目录下已有`config.py`配置文件,请直接编辑此文件设置您的API密钥: 85 | 86 | ```python 87 | # Deep Search Agent 配置文件 88 | # 请在这里填入您的API密钥 89 | 90 | # DeepSeek API Key 91 | DEEPSEEK_API_KEY = "your_deepseek_api_key_here" 92 | 93 | # OpenAI API Key (可选) 94 | OPENAI_API_KEY = "your_openai_api_key_here" 95 | 96 | # Tavily搜索API Key 97 | TAVILY_API_KEY = "your_tavily_api_key_here" 98 | 99 | # 配置参数 100 | DEFAULT_LLM_PROVIDER = "deepseek" 101 | DEEPSEEK_MODEL = "deepseek-chat" 102 | OPENAI_MODEL = "gpt-4o-mini" 103 | 104 | MAX_REFLECTIONS = 2 105 | SEARCH_RESULTS_PER_QUERY = 3 106 | SEARCH_CONTENT_MAX_LENGTH = 20000 107 | OUTPUT_DIR = "reports" 108 | SAVE_INTERMEDIATE_STATES = True 109 | ``` 110 | 111 | ### 5. 开始使用 112 | 113 | 现在您可以开始使用Deep Search Agent了! 114 | 115 | ## 使用方法 116 | 117 | ### 方式一:运行示例脚本 118 | 119 | **基本使用示例**: 120 | ```bash 121 | python examples/basic_usage.py 122 | ``` 123 | 这个示例展示了最简单的使用方式,执行一个预设的研究查询并显示结果。 124 | 125 | **高级使用示例**: 126 | ```bash 127 | python examples/advanced_usage.py 128 | ``` 129 | 这个示例展示了更复杂的使用场景,包括: 130 | - 自定义配置参数 131 | - 执行多个研究任务 132 | - 状态管理和恢复 133 | - 不同模型的使用 134 | 135 | ### 方式二:Web界面 136 | 137 | 启动Streamlit Web界面: 138 | ```bash 139 | streamlit run examples/streamlit_app.py 140 | ``` 141 | Web界面无需配置文件,直接在界面中输入API密钥即可使用。 142 | 143 | ### 方式三:编程方式 144 | 145 | ```python 146 | from src import DeepSearchAgent, load_config 147 | 148 | # 加载配置 149 | config = load_config() 150 | 151 | # 创建Agent 152 | agent = DeepSearchAgent(config) 153 | 154 | # 执行研究 155 | query = "2025年人工智能发展趋势" 156 | final_report = agent.research(query, save_report=True) 157 | 158 | print(final_report) 159 | ``` 160 | 161 | ### 方式四:自定义配置(编程方式) 162 | 163 | 如果需要在代码中动态设置配置,可以使用以下方式: 164 | 165 | ```python 166 | from src import DeepSearchAgent, Config 167 | 168 | # 自定义配置 169 | config = Config( 170 | default_llm_provider="deepseek", 171 | deepseek_model="deepseek-chat", 172 | max_reflections=3, # 增加反思次数 173 | max_search_results=5, # 增加搜索结果数 174 | output_dir="my_reports" # 自定义输出目录 175 | ) 176 | 177 | # 设置API密钥 178 | config.deepseek_api_key = "your_api_key" 179 | config.tavily_api_key = "your_tavily_key" 180 | 181 | agent = DeepSearchAgent(config) 182 | ``` 183 | 184 | ## 项目结构 185 | 186 | ``` 187 | Demo DeepSearch Agent/ 188 | ├── src/ # 核心代码 189 | │ ├── llms/ # LLM调用模块 190 | │ │ ├── base.py # LLM基类 191 | │ │ ├── deepseek.py # DeepSeek实现 192 | │ │ └── openai_llm.py # OpenAI实现 193 | │ ├── nodes/ # 处理节点 194 | │ │ ├── base_node.py # 节点基类 195 | │ │ ├── report_structure_node.py # 结构生成 196 | │ │ ├── search_node.py # 搜索节点 197 | │ │ ├── summary_node.py # 总结节点 198 | │ │ └── formatting_node.py # 格式化节点 199 | │ ├── prompts/ # 提示词模块 200 | │ │ └── prompts.py # 所有提示词定义 201 | │ ├── state/ # 状态管理 202 | │ │ └── state.py # 状态数据结构 203 | │ ├── tools/ # 工具调用 204 | │ │ └── search.py # 搜索工具 205 | │ ├── utils/ # 工具函数 206 | │ │ ├── config.py # 配置管理 207 | │ │ └── text_processing.py # 文本处理 208 | │ └── agent.py # 主Agent类 209 | ├── examples/ # 使用示例 210 | │ ├── basic_usage.py # 基本使用示例 211 | │ ├── advanced_usage.py # 高级使用示例 212 | │ └── streamlit_app.py # Web界面 213 | ├── reports/ # 输出报告目录 214 | ├── requirements.txt # 依赖列表 215 | ├── config.py # 配置文件 216 | └── README.md # 项目文档 217 | ``` 218 | 219 | ## 代码结构 220 | 221 | ```mermaid 222 | graph TB 223 | subgraph "用户层" 224 | A[用户查询] 225 | B[Web界面] 226 | C[命令行接口] 227 | end 228 | 229 | subgraph "主控制层" 230 | D[DeepSearchAgent] 231 | end 232 | 233 | subgraph "处理节点层" 234 | E[ReportStructureNode
报告结构生成] 235 | F[FirstSearchNode
初始搜索] 236 | G[FirstSummaryNode
初始总结] 237 | H[ReflectionNode
反思搜索] 238 | I[ReflectionSummaryNode
反思总结] 239 | J[ReportFormattingNode
报告格式化] 240 | end 241 | 242 | subgraph "LLM层" 243 | K[DeepSeekLLM] 244 | L[OpenAILLM] 245 | M[BaseLLM抽象类] 246 | end 247 | 248 | subgraph "工具层" 249 | N[Tavily搜索] 250 | O[文本处理工具] 251 | P[配置管理] 252 | end 253 | 254 | subgraph "状态管理层" 255 | Q[State状态对象] 256 | R[Paragraph段落对象] 257 | S[Research研究对象] 258 | T[Search搜索记录] 259 | end 260 | 261 | subgraph "数据持久化" 262 | U[JSON状态文件] 263 | V[Markdown报告] 264 | W[日志文件] 265 | end 266 | 267 | A --> D 268 | B --> D 269 | C --> D 270 | 271 | D --> E 272 | D --> F 273 | D --> G 274 | D --> H 275 | D --> I 276 | D --> J 277 | 278 | E --> K 279 | E --> L 280 | F --> K 281 | F --> L 282 | G --> K 283 | G --> L 284 | H --> K 285 | H --> L 286 | I --> K 287 | I --> L 288 | J --> K 289 | J --> L 290 | 291 | K --> M 292 | L --> M 293 | 294 | F --> N 295 | H --> N 296 | 297 | D --> O 298 | D --> P 299 | 300 | D --> Q 301 | Q --> R 302 | R --> S 303 | S --> T 304 | 305 | Q --> U 306 | D --> V 307 | D --> W 308 | 309 | style A fill:#e1f5fe 310 | style D fill:#f3e5f5 311 | style E fill:#fff3e0 312 | style F fill:#fff3e0 313 | style G fill:#fff3e0 314 | style H fill:#fff3e0 315 | style I fill:#fff3e0 316 | style J fill:#fff3e0 317 | style K fill:#e8f5e8 318 | style L fill:#e8f5e8 319 | style N fill:#fce4ec 320 | style Q fill:#f1f8e9 321 | ``` 322 | 323 | ## API 参考 324 | 325 | ### DeepSearchAgent 326 | 327 | 主要的Agent类,提供完整的深度搜索功能。 328 | 329 | ```python 330 | class DeepSearchAgent: 331 | def __init__(self, config: Optional[Config] = None) 332 | def research(self, query: str, save_report: bool = True) -> str 333 | def get_progress_summary(self) -> Dict[str, Any] 334 | def load_state(self, filepath: str) 335 | def save_state(self, filepath: str) 336 | ``` 337 | 338 | ### Config 339 | 340 | 配置管理类,控制Agent的行为参数。 341 | 342 | ```python 343 | class Config: 344 | # API密钥 345 | deepseek_api_key: Optional[str] 346 | openai_api_key: Optional[str] 347 | tavily_api_key: Optional[str] 348 | 349 | # 模型配置 350 | default_llm_provider: str = "deepseek" 351 | deepseek_model: str = "deepseek-chat" 352 | openai_model: str = "gpt-4o-mini" 353 | 354 | # 搜索配置 355 | max_search_results: int = 3 356 | search_timeout: int = 240 357 | max_content_length: int = 20000 358 | 359 | # Agent配置 360 | max_reflections: int = 2 361 | max_paragraphs: int = 5 362 | ``` 363 | 364 | ## 示例 365 | 366 | ### 示例1:基本研究 367 | 368 | ```python 369 | from src import create_agent 370 | 371 | # 快速创建Agent 372 | agent = create_agent() 373 | 374 | # 执行研究 375 | report = agent.research("量子计算的发展现状") 376 | print(report) 377 | ``` 378 | 379 | ### 示例2:自定义研究参数 380 | 381 | ```python 382 | from src import DeepSearchAgent, Config 383 | 384 | config = Config( 385 | max_reflections=4, # 更深度的反思 386 | max_search_results=8, # 更多搜索结果 387 | max_paragraphs=6 # 更长的报告 388 | ) 389 | 390 | agent = DeepSearchAgent(config) 391 | report = agent.research("人工智能的伦理问题") 392 | ``` 393 | 394 | ### 示例3:状态管理 395 | 396 | ```python 397 | # 开始研究 398 | agent = DeepSearchAgent() 399 | report = agent.research("区块链技术应用") 400 | 401 | # 保存状态 402 | agent.save_state("blockchain_research.json") 403 | 404 | # 稍后恢复状态 405 | new_agent = DeepSearchAgent() 406 | new_agent.load_state("blockchain_research.json") 407 | 408 | # 检查进度 409 | progress = new_agent.get_progress_summary() 410 | print(f"研究进度: {progress['progress_percentage']}%") 411 | ``` 412 | 413 | ## 高级功能 414 | 415 | ### 多模型支持 416 | 417 | ```python 418 | # 使用DeepSeek 419 | config = Config(default_llm_provider="deepseek") 420 | 421 | # 使用OpenAI 422 | config = Config(default_llm_provider="openai", openai_model="gpt-4o") 423 | ``` 424 | 425 | ### 自定义输出 426 | 427 | ```python 428 | config = Config( 429 | output_dir="custom_reports", # 自定义输出目录 430 | save_intermediate_states=True # 保存中间状态 431 | ) 432 | ``` 433 | 434 | ## 常见问题 435 | 436 | ### Q: 支持哪些LLM? 437 | 438 | A: 目前支持: 439 | - **DeepSeek**: 推荐使用,性价比高 440 | - **OpenAI**: GPT-4o、GPT-4o-mini等 441 | - 可以通过继承`BaseLLM`类轻松添加其他模型 442 | 443 | ### Q: 如何获取API密钥? 444 | 445 | A: 446 | - **DeepSeek**: 访问 [DeepSeek平台](https://platform.deepseek.com/) 注册获取 447 | - **Tavily**: 访问 [Tavily](https://tavily.com/) 注册获取(每月1000次免费) 448 | - **OpenAI**: 访问 [OpenAI平台](https://platform.openai.com/) 获取 449 | 450 | 获取密钥后,直接编辑项目根目录的`config.py`文件填入即可。 451 | 452 | ### Q: 研究报告质量如何提升? 453 | 454 | A: 可以通过以下方式优化: 455 | - 增加`max_reflections`参数(更多反思轮次) 456 | - 增加`max_search_results`参数(更多搜索结果) 457 | - 调整`max_content_length`参数(更长的搜索内容) 458 | - 使用更强大的LLM模型 459 | 460 | ### Q: 如何自定义提示词? 461 | 462 | A: 修改`src/prompts/prompts.py`文件中的系统提示词,可以根据需要调整Agent的行为。 463 | 464 | ### Q: 支持其他搜索引擎吗? 465 | 466 | A: 当前主要支持Tavily,但可以通过修改`src/tools/search.py`添加其他搜索引擎支持。 467 | 468 | ## 贡献 469 | 470 | 欢迎贡献代码!请遵循以下步骤: 471 | 472 | 1. Fork本项目 473 | 2. 创建特性分支 (`git checkout -b feature/AmazingFeature`) 474 | 3. 提交更改 (`git commit -m 'Add some AmazingFeature'`) 475 | 4. 推送到分支 (`git push origin feature/AmazingFeature`) 476 | 5. 开启Pull Request 477 | 478 | ## 许可证 479 | 480 | 本项目采用MIT许可证 - 查看 [LICENSE](LICENSE) 文件了解详情。 481 | 482 | ## 致谢 483 | 484 | - 感谢 [DeepSeek](https://www.deepseek.com/) 提供优秀的LLM服务 485 | - 感谢 [Tavily](https://tavily.com/) 提供高质量的搜索API 486 | 487 | --- 488 | 489 | 如果这个项目对您有帮助,请给个Star! 490 | -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep Search Agent主类 3 | 整合所有模块,实现完整的深度搜索流程 4 | """ 5 | 6 | import json 7 | import os 8 | from datetime import datetime 9 | from typing import Optional, Dict, Any, List 10 | 11 | from .llms import DeepSeekLLM, OpenAILLM, BaseLLM 12 | from .nodes import ( 13 | ReportStructureNode, 14 | FirstSearchNode, 15 | ReflectionNode, 16 | FirstSummaryNode, 17 | ReflectionSummaryNode, 18 | ReportFormattingNode 19 | ) 20 | from .state import State 21 | from .tools import tavily_search 22 | from .utils import Config, load_config, format_search_results_for_prompt 23 | 24 | 25 | class DeepSearchAgent: 26 | """Deep Search Agent主类""" 27 | 28 | def __init__(self, config: Optional[Config] = None): 29 | """ 30 | 初始化Deep Search Agent 31 | 32 | Args: 33 | config: 配置对象,如果不提供则自动加载 34 | """ 35 | # 加载配置 36 | self.config = config or load_config() 37 | 38 | # 初始化LLM客户端 39 | self.llm_client = self._initialize_llm() 40 | 41 | # 初始化节点 42 | self._initialize_nodes() 43 | 44 | # 状态 45 | self.state = State() 46 | 47 | # 确保输出目录存在 48 | os.makedirs(self.config.output_dir, exist_ok=True) 49 | 50 | print(f"Deep Search Agent 已初始化") 51 | print(f"使用LLM: {self.llm_client.get_model_info()}") 52 | 53 | def _initialize_llm(self) -> BaseLLM: 54 | """初始化LLM客户端""" 55 | if self.config.default_llm_provider == "deepseek": 56 | return DeepSeekLLM( 57 | api_key=self.config.deepseek_api_key, 58 | model_name=self.config.deepseek_model 59 | ) 60 | elif self.config.default_llm_provider == "openai": 61 | return OpenAILLM( 62 | api_key=self.config.openai_api_key, 63 | model_name=self.config.openai_model 64 | ) 65 | else: 66 | raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}") 67 | 68 | def _initialize_nodes(self): 69 | """初始化处理节点""" 70 | self.first_search_node = FirstSearchNode(self.llm_client) 71 | self.reflection_node = ReflectionNode(self.llm_client) 72 | self.first_summary_node = FirstSummaryNode(self.llm_client) 73 | self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) 74 | self.report_formatting_node = ReportFormattingNode(self.llm_client) 75 | 76 | def research(self, query: str, save_report: bool = True) -> str: 77 | """ 78 | 执行深度研究 79 | 80 | Args: 81 | query: 研究查询 82 | save_report: 是否保存报告到文件 83 | 84 | Returns: 85 | 最终报告内容 86 | """ 87 | print(f"\n{'='*60}") 88 | print(f"开始深度研究: {query}") 89 | print(f"{'='*60}") 90 | 91 | try: 92 | # Step 1: 生成报告结构 93 | self._generate_report_structure(query) 94 | 95 | # Step 2: 处理每个段落 96 | self._process_paragraphs() 97 | 98 | # Step 3: 生成最终报告 99 | final_report = self._generate_final_report() 100 | 101 | # Step 4: 保存报告 102 | if save_report: 103 | self._save_report(final_report) 104 | 105 | print(f"\n{'='*60}") 106 | print("深度研究完成!") 107 | print(f"{'='*60}") 108 | 109 | return final_report 110 | 111 | except Exception as e: 112 | print(f"研究过程中发生错误: {str(e)}") 113 | raise e 114 | 115 | def _generate_report_structure(self, query: str): 116 | """生成报告结构""" 117 | print(f"\n[步骤 1] 生成报告结构...") 118 | 119 | # 创建报告结构节点 120 | report_structure_node = ReportStructureNode(self.llm_client, query) 121 | 122 | # 生成结构并更新状态 123 | self.state = report_structure_node.mutate_state(state=self.state) 124 | 125 | print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:") 126 | for i, paragraph in enumerate(self.state.paragraphs, 1): 127 | print(f" {i}. {paragraph.title}") 128 | 129 | def _process_paragraphs(self): 130 | """处理所有段落""" 131 | total_paragraphs = len(self.state.paragraphs) 132 | 133 | for i in range(total_paragraphs): 134 | print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") 135 | print("-" * 50) 136 | 137 | # 初始搜索和总结 138 | self._initial_search_and_summary(i) 139 | 140 | # 反思循环 141 | self._reflection_loop(i) 142 | 143 | # 标记段落完成 144 | self.state.paragraphs[i].research.mark_completed() 145 | 146 | progress = (i + 1) / total_paragraphs * 100 147 | print(f"段落处理完成 ({progress:.1f}%)") 148 | 149 | def _initial_search_and_summary(self, paragraph_index: int): 150 | """执行初始搜索和总结""" 151 | paragraph = self.state.paragraphs[paragraph_index] 152 | 153 | # 准备搜索输入 154 | search_input = { 155 | "title": paragraph.title, 156 | "content": paragraph.content 157 | } 158 | 159 | # 生成搜索查询 160 | print(" - 生成搜索查询...") 161 | search_output = self.first_search_node.run(search_input) 162 | search_query = search_output["search_query"] 163 | reasoning = search_output["reasoning"] 164 | 165 | print(f" - 搜索查询: {search_query}") 166 | print(f" - 推理: {reasoning}") 167 | 168 | # 执行搜索 169 | print(" - 执行网络搜索...") 170 | search_results = tavily_search( 171 | search_query, 172 | max_results=self.config.max_search_results, 173 | timeout=self.config.search_timeout, 174 | api_key=self.config.tavily_api_key 175 | ) 176 | 177 | if search_results: 178 | print(f" - 找到 {len(search_results)} 个搜索结果") 179 | for j, result in enumerate(search_results, 1): 180 | print(f" {j}. {result['title'][:50]}...") 181 | else: 182 | print(" - 未找到搜索结果") 183 | 184 | # 更新状态中的搜索历史 185 | paragraph.research.add_search_results(search_query, search_results) 186 | 187 | # 生成初始总结 188 | print(" - 生成初始总结...") 189 | summary_input = { 190 | "title": paragraph.title, 191 | "content": paragraph.content, 192 | "search_query": search_query, 193 | "search_results": format_search_results_for_prompt( 194 | search_results, self.config.max_content_length 195 | ) 196 | } 197 | 198 | # 更新状态 199 | self.state = self.first_summary_node.mutate_state( 200 | summary_input, self.state, paragraph_index 201 | ) 202 | 203 | print(" - 初始总结完成") 204 | 205 | def _reflection_loop(self, paragraph_index: int): 206 | """执行反思循环""" 207 | paragraph = self.state.paragraphs[paragraph_index] 208 | 209 | for reflection_i in range(self.config.max_reflections): 210 | print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...") 211 | 212 | # 准备反思输入 213 | reflection_input = { 214 | "title": paragraph.title, 215 | "content": paragraph.content, 216 | "paragraph_latest_state": paragraph.research.latest_summary 217 | } 218 | 219 | # 生成反思搜索查询 220 | reflection_output = self.reflection_node.run(reflection_input) 221 | search_query = reflection_output["search_query"] 222 | reasoning = reflection_output["reasoning"] 223 | 224 | print(f" 反思查询: {search_query}") 225 | print(f" 反思推理: {reasoning}") 226 | 227 | # 执行反思搜索 228 | search_results = tavily_search( 229 | search_query, 230 | max_results=self.config.max_search_results, 231 | timeout=self.config.search_timeout, 232 | api_key=self.config.tavily_api_key 233 | ) 234 | 235 | if search_results: 236 | print(f" 找到 {len(search_results)} 个反思搜索结果") 237 | 238 | # 更新搜索历史 239 | paragraph.research.add_search_results(search_query, search_results) 240 | 241 | # 生成反思总结 242 | reflection_summary_input = { 243 | "title": paragraph.title, 244 | "content": paragraph.content, 245 | "search_query": search_query, 246 | "search_results": format_search_results_for_prompt( 247 | search_results, self.config.max_content_length 248 | ), 249 | "paragraph_latest_state": paragraph.research.latest_summary 250 | } 251 | 252 | # 更新状态 253 | self.state = self.reflection_summary_node.mutate_state( 254 | reflection_summary_input, self.state, paragraph_index 255 | ) 256 | 257 | print(f" 反思 {reflection_i + 1} 完成") 258 | 259 | def _generate_final_report(self) -> str: 260 | """生成最终报告""" 261 | print(f"\n[步骤 3] 生成最终报告...") 262 | 263 | # 准备报告数据 264 | report_data = [] 265 | for paragraph in self.state.paragraphs: 266 | report_data.append({ 267 | "title": paragraph.title, 268 | "paragraph_latest_state": paragraph.research.latest_summary 269 | }) 270 | 271 | # 格式化报告 272 | try: 273 | final_report = self.report_formatting_node.run(report_data) 274 | except Exception as e: 275 | print(f"LLM格式化失败,使用备用方法: {str(e)}") 276 | final_report = self.report_formatting_node.format_report_manually( 277 | report_data, self.state.report_title 278 | ) 279 | 280 | # 更新状态 281 | self.state.final_report = final_report 282 | self.state.mark_completed() 283 | 284 | print("最终报告生成完成") 285 | return final_report 286 | 287 | def _save_report(self, report_content: str): 288 | """保存报告到文件""" 289 | # 生成文件名 290 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 291 | query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip() 292 | query_safe = query_safe.replace(' ', '_')[:30] 293 | 294 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" 295 | filepath = os.path.join(self.config.output_dir, filename) 296 | 297 | # 保存报告 298 | with open(filepath, 'w', encoding='utf-8') as f: 299 | f.write(report_content) 300 | 301 | print(f"报告已保存到: {filepath}") 302 | 303 | # 保存状态(如果配置允许) 304 | if self.config.save_intermediate_states: 305 | state_filename = f"state_{query_safe}_{timestamp}.json" 306 | state_filepath = os.path.join(self.config.output_dir, state_filename) 307 | self.state.save_to_file(state_filepath) 308 | print(f"状态已保存到: {state_filepath}") 309 | 310 | def get_progress_summary(self) -> Dict[str, Any]: 311 | """获取进度摘要""" 312 | return self.state.get_progress_summary() 313 | 314 | def load_state(self, filepath: str): 315 | """从文件加载状态""" 316 | self.state = State.load_from_file(filepath) 317 | print(f"状态已从 {filepath} 加载") 318 | 319 | def save_state(self, filepath: str): 320 | """保存状态到文件""" 321 | self.state.save_to_file(filepath) 322 | print(f"状态已保存到 {filepath}") 323 | 324 | 325 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: 326 | """ 327 | 创建Deep Search Agent实例的便捷函数 328 | 329 | Args: 330 | config_file: 配置文件路径 331 | 332 | Returns: 333 | DeepSearchAgent实例 334 | """ 335 | config = load_config(config_file) 336 | return DeepSearchAgent(config) 337 | --------------------------------------------------------------------------------