├── img
├── 1.png
├── 2.png
└── 3.png
├── requirements.txt
├── src
├── tools
│ ├── __init__.py
│ └── search.py
├── state
│ ├── __init__.py
│ └── state.py
├── llms
│ ├── __init__.py
│ ├── base.py
│ ├── openai_llm.py
│ └── deepseek.py
├── __init__.py
├── nodes
│ ├── __init__.py
│ ├── base_node.py
│ ├── report_structure_node.py
│ ├── formatting_node.py
│ ├── search_node.py
│ └── summary_node.py
├── utils
│ ├── __init__.py
│ ├── text_processing.py
│ └── config.py
├── prompts
│ ├── __init__.py
│ └── prompts.py
└── agent.py
├── config.py
├── .gitignore
├── LICENSE
├── examples
├── basic_usage.py
├── advanced_usage.py
└── streamlit_app.py
└── README.md
/img/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/666ghj/DeepSearchAgent-Demo/HEAD/img/1.png
--------------------------------------------------------------------------------
/img/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/666ghj/DeepSearchAgent-Demo/HEAD/img/2.png
--------------------------------------------------------------------------------
/img/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/666ghj/DeepSearchAgent-Demo/HEAD/img/3.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai>=1.0.0
2 | requests>=2.25.0
3 | tavily-python>=0.3.0
4 | streamlit>=1.28.0
5 | pydantic>=2.0.0
6 | rich>=13.0.0
7 |
--------------------------------------------------------------------------------
/src/tools/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 工具调用模块
3 | 提供外部工具接口,如网络搜索等
4 | """
5 |
6 | from .search import tavily_search, SearchResult
7 |
8 | __all__ = ["tavily_search", "SearchResult"]
9 |
--------------------------------------------------------------------------------
/src/state/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 状态管理模块
3 | 定义Deep Search Agent的状态数据结构
4 | """
5 |
6 | from .state import State, Paragraph, Research, Search
7 |
8 | __all__ = ["State", "Paragraph", "Research", "Search"]
9 |
--------------------------------------------------------------------------------
/src/llms/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | LLM调用模块
3 | 支持多种大语言模型的统一接口
4 | """
5 |
6 | from .base import BaseLLM
7 | from .deepseek import DeepSeekLLM
8 | from .openai_llm import OpenAILLM
9 |
10 | __all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"]
11 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Deep Search Agent
3 | 一个无框架的深度搜索AI代理实现
4 | """
5 |
6 | from .agent import DeepSearchAgent, create_agent
7 | from .utils.config import Config, load_config
8 |
9 | __version__ = "1.0.0"
10 | __author__ = "Deep Search Agent Team"
11 |
12 | __all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
13 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Deep Search Agent 配置文件
2 | # 请在这里填入您的API密钥
3 |
4 | # DeepSeek API Key
5 | DEEPSEEK_API_KEY = "your_deepseek_api_key_here"
6 |
7 | # OpenAI API Key (可选)
8 | OPENAI_API_KEY = "your_openai_api_key_here"
9 |
10 | # Tavily搜索API Key
11 | TAVILY_API_KEY = "your_tavily_api_key_here"
12 |
13 | # 配置参数
14 | DEFAULT_LLM_PROVIDER = "deepseek"
15 | DEEPSEEK_MODEL = "deepseek-chat"
16 | OPENAI_MODEL = "gpt-4o-mini"
17 |
18 | MAX_REFLECTIONS = 2
19 | SEARCH_RESULTS_PER_QUERY = 3
20 | SEARCH_CONTENT_MAX_LENGTH = 20000
21 | OUTPUT_DIR = "reports"
22 | SAVE_INTERMEDIATE_STATES = True
23 |
--------------------------------------------------------------------------------
/src/nodes/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 节点处理模块
3 | 实现Deep Search Agent的各个处理步骤
4 | """
5 |
6 | from .base_node import BaseNode
7 | from .report_structure_node import ReportStructureNode
8 | from .search_node import FirstSearchNode, ReflectionNode
9 | from .summary_node import FirstSummaryNode, ReflectionSummaryNode
10 | from .formatting_node import ReportFormattingNode
11 |
12 | __all__ = [
13 | "BaseNode",
14 | "ReportStructureNode",
15 | "FirstSearchNode",
16 | "ReflectionNode",
17 | "FirstSummaryNode",
18 | "ReflectionSummaryNode",
19 | "ReportFormattingNode"
20 | ]
21 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 工具函数模块
3 | 提供文本处理、JSON解析等辅助功能
4 | """
5 |
6 | from .text_processing import (
7 | clean_json_tags,
8 | clean_markdown_tags,
9 | remove_reasoning_from_output,
10 | extract_clean_response,
11 | update_state_with_search_results,
12 | format_search_results_for_prompt
13 | )
14 |
15 | from .config import Config, load_config
16 |
17 | __all__ = [
18 | "clean_json_tags",
19 | "clean_markdown_tags",
20 | "remove_reasoning_from_output",
21 | "extract_clean_response",
22 | "update_state_with_search_results",
23 | "format_search_results_for_prompt",
24 | "Config",
25 | "load_config"
26 | ]
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | pip-wheel-metadata/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
26 | # 虚拟环境
27 | venv/
28 | env/
29 | ENV/
30 | env.bak/
31 | venv.bak/
32 |
33 | # 环境变量文件
34 | .env
35 | *.env
36 |
37 | # IDE
38 | .vscode/
39 | .idea/
40 | *.swp
41 | *.swo
42 | *~
43 |
44 | # 日志文件
45 | *.log
46 | logs/
47 |
48 | # 临时文件
49 | *.tmp
50 | *.temp
51 | temp/
52 | tmp/
53 |
54 | # 系统文件
55 | .DS_Store
56 | Thumbs.db
57 |
58 | # 项目特定
59 | reports/
60 | custom_reports/
61 | streamlit_reports/
62 | *.json
63 | *.md
64 | !README.md
65 | !LICENSE
66 |
67 | # Streamlit
68 | .streamlit/
69 |
70 | # 测试覆盖率
71 | .coverage
72 | htmlcov/
73 | .pytest_cache/
74 |
--------------------------------------------------------------------------------
/src/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Prompt模块
3 | 定义Deep Search Agent各个阶段使用的系统提示词
4 | """
5 |
6 | from .prompts import (
7 | SYSTEM_PROMPT_REPORT_STRUCTURE,
8 | SYSTEM_PROMPT_FIRST_SEARCH,
9 | SYSTEM_PROMPT_FIRST_SUMMARY,
10 | SYSTEM_PROMPT_REFLECTION,
11 | SYSTEM_PROMPT_REFLECTION_SUMMARY,
12 | SYSTEM_PROMPT_REPORT_FORMATTING,
13 | output_schema_report_structure,
14 | output_schema_first_search,
15 | output_schema_first_summary,
16 | output_schema_reflection,
17 | output_schema_reflection_summary,
18 | input_schema_report_formatting
19 | )
20 |
21 | __all__ = [
22 | "SYSTEM_PROMPT_REPORT_STRUCTURE",
23 | "SYSTEM_PROMPT_FIRST_SEARCH",
24 | "SYSTEM_PROMPT_FIRST_SUMMARY",
25 | "SYSTEM_PROMPT_REFLECTION",
26 | "SYSTEM_PROMPT_REFLECTION_SUMMARY",
27 | "SYSTEM_PROMPT_REPORT_FORMATTING",
28 | "output_schema_report_structure",
29 | "output_schema_first_search",
30 | "output_schema_first_summary",
31 | "output_schema_reflection",
32 | "output_schema_reflection_summary",
33 | "input_schema_report_formatting"
34 | ]
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Deep Search Agent Team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/llms/base.py:
--------------------------------------------------------------------------------
1 | """
2 | LLM基础抽象类
3 | 定义所有LLM实现需要遵循的接口标准
4 | """
5 |
6 | from abc import ABC, abstractmethod
7 | from typing import Optional, Dict, Any
8 |
9 |
10 | class BaseLLM(ABC):
11 | """LLM基础抽象类"""
12 |
13 | def __init__(self, api_key: str, model_name: Optional[str] = None):
14 | """
15 | 初始化LLM客户端
16 |
17 | Args:
18 | api_key: API密钥
19 | model_name: 模型名称,如果不指定则使用默认模型
20 | """
21 | self.api_key = api_key
22 | self.model_name = model_name
23 |
24 | @abstractmethod
25 | def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
26 | """
27 | 调用LLM生成回复
28 |
29 | Args:
30 | system_prompt: 系统提示词
31 | user_prompt: 用户输入
32 | **kwargs: 其他参数,如temperature、max_tokens等
33 |
34 | Returns:
35 | LLM生成的回复文本
36 | """
37 | pass
38 |
39 | @abstractmethod
40 | def get_default_model(self) -> str:
41 | """
42 | 获取默认模型名称
43 |
44 | Returns:
45 | 默认模型名称
46 | """
47 | pass
48 |
49 | def validate_response(self, response: str) -> str:
50 | """
51 | 验证和清理响应内容
52 |
53 | Args:
54 | response: LLM原始响应
55 |
56 | Returns:
57 | 清理后的响应内容
58 | """
59 | if response is None:
60 | return ""
61 | return response.strip()
62 |
--------------------------------------------------------------------------------
/examples/basic_usage.py:
--------------------------------------------------------------------------------
1 | """
2 | 基本使用示例
3 | 演示如何使用Deep Search Agent进行基本的深度搜索
4 | """
5 |
6 | import os
7 | import sys
8 |
9 | # 添加项目根目录到Python路径
10 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
11 |
12 | from src import DeepSearchAgent, load_config
13 | from src.utils.config import print_config
14 |
15 |
16 | def basic_example():
17 | """基本使用示例"""
18 | print("=" * 60)
19 | print("Deep Search Agent - 基本使用示例")
20 | print("=" * 60)
21 |
22 | try:
23 | # 加载配置
24 | print("正在加载配置...")
25 | config = load_config()
26 | print_config(config)
27 |
28 | # 创建Agent
29 | print("正在初始化Agent...")
30 | agent = DeepSearchAgent(config)
31 |
32 | # 执行研究
33 | query = "2025年人工智能发展趋势"
34 | print(f"开始研究: {query}")
35 |
36 | final_report = agent.research(query, save_report=True)
37 |
38 | # 显示结果
39 | print("\n" + "=" * 60)
40 | print("研究完成!最终报告预览:")
41 | print("=" * 60)
42 | print(final_report[:500] + "..." if len(final_report) > 500 else final_report)
43 |
44 | # 显示进度信息
45 | progress = agent.get_progress_summary()
46 | print(f"\n进度信息:")
47 | print(f"- 总段落数: {progress['total_paragraphs']}")
48 | print(f"- 已完成段落: {progress['completed_paragraphs']}")
49 | print(f"- 完成进度: {progress['progress_percentage']:.1f}%")
50 | print(f"- 是否完成: {progress['is_completed']}")
51 |
52 | except Exception as e:
53 | print(f"示例运行失败: {str(e)}")
54 | print("请检查:")
55 | print("1. 是否安装了所有依赖:pip install -r requirements.txt")
56 | print("2. 是否设置了必要的API密钥")
57 | print("3. 网络连接是否正常")
58 | print("4. 配置文件是否正确")
59 |
60 |
61 | if __name__ == "__main__":
62 | basic_example()
63 |
--------------------------------------------------------------------------------
/src/nodes/base_node.py:
--------------------------------------------------------------------------------
1 | """
2 | 节点基类
3 | 定义所有处理节点的基础接口
4 | """
5 |
6 | from abc import ABC, abstractmethod
7 | from typing import Any, Dict, Optional
8 | from ..llms.base import BaseLLM
9 | from ..state.state import State
10 |
11 |
12 | class BaseNode(ABC):
13 | """节点基类"""
14 |
15 | def __init__(self, llm_client: BaseLLM, node_name: str = ""):
16 | """
17 | 初始化节点
18 |
19 | Args:
20 | llm_client: LLM客户端
21 | node_name: 节点名称
22 | """
23 | self.llm_client = llm_client
24 | self.node_name = node_name or self.__class__.__name__
25 |
26 | @abstractmethod
27 | def run(self, input_data: Any, **kwargs) -> Any:
28 | """
29 | 执行节点处理逻辑
30 |
31 | Args:
32 | input_data: 输入数据
33 | **kwargs: 额外参数
34 |
35 | Returns:
36 | 处理结果
37 | """
38 | pass
39 |
40 | def validate_input(self, input_data: Any) -> bool:
41 | """
42 | 验证输入数据
43 |
44 | Args:
45 | input_data: 输入数据
46 |
47 | Returns:
48 | 验证是否通过
49 | """
50 | return True
51 |
52 | def process_output(self, output: Any) -> Any:
53 | """
54 | 处理输出数据
55 |
56 | Args:
57 | output: 原始输出
58 |
59 | Returns:
60 | 处理后的输出
61 | """
62 | return output
63 |
64 | def log_info(self, message: str):
65 | """记录信息日志"""
66 | print(f"[{self.node_name}] {message}")
67 |
68 | def log_error(self, message: str):
69 | """记录错误日志"""
70 | print(f"[{self.node_name}] 错误: {message}")
71 |
72 |
73 | class StateMutationNode(BaseNode):
74 | """带状态修改功能的节点基类"""
75 |
76 | @abstractmethod
77 | def mutate_state(self, input_data: Any, state: State, **kwargs) -> State:
78 | """
79 | 修改状态
80 |
81 | Args:
82 | input_data: 输入数据
83 | state: 当前状态
84 | **kwargs: 额外参数
85 |
86 | Returns:
87 | 修改后的状态
88 | """
89 | pass
90 |
--------------------------------------------------------------------------------
/src/llms/openai_llm.py:
--------------------------------------------------------------------------------
1 | """
2 | OpenAI LLM实现
3 | 使用OpenAI API进行文本生成
4 | """
5 |
6 | import os
7 | from typing import Optional, Dict, Any
8 | from openai import OpenAI
9 | from .base import BaseLLM
10 |
11 |
12 | class OpenAILLM(BaseLLM):
13 | """OpenAI LLM实现类"""
14 |
15 | def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
16 | """
17 | 初始化OpenAI客户端
18 |
19 | Args:
20 | api_key: OpenAI API密钥,如果不提供则从环境变量读取
21 | model_name: 模型名称,默认使用gpt-4o-mini
22 | """
23 | if api_key is None:
24 | api_key = os.getenv("OPENAI_API_KEY")
25 | if not api_key:
26 | raise ValueError("OpenAI API Key未找到!请设置OPENAI_API_KEY环境变量或在初始化时提供")
27 |
28 | super().__init__(api_key, model_name)
29 |
30 | # 初始化OpenAI客户端
31 | self.client = OpenAI(api_key=self.api_key)
32 | self.default_model = model_name or self.get_default_model()
33 |
34 | def get_default_model(self) -> str:
35 | """获取默认模型名称"""
36 | return "gpt-4o-mini"
37 |
38 | def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
39 | """
40 | 调用OpenAI API生成回复
41 |
42 | Args:
43 | system_prompt: 系统提示词
44 | user_prompt: 用户输入
45 | **kwargs: 其他参数,如temperature、max_tokens等
46 |
47 | Returns:
48 | OpenAI生成的回复文本
49 | """
50 | try:
51 | # 构建消息
52 | messages = [
53 | {"role": "system", "content": system_prompt},
54 | {"role": "user", "content": user_prompt}
55 | ]
56 |
57 | # 设置默认参数
58 | params = {
59 | "model": self.default_model,
60 | "messages": messages,
61 | "temperature": kwargs.get("temperature", 0.7),
62 | "max_tokens": kwargs.get("max_tokens", 4000)
63 | }
64 |
65 | # 调用API
66 | response = self.client.chat.completions.create(**params)
67 |
68 | # 提取回复内容
69 | if response.choices and response.choices[0].message:
70 | content = response.choices[0].message.content
71 | return self.validate_response(content)
72 | else:
73 | return ""
74 |
75 | except Exception as e:
76 | print(f"OpenAI API调用错误: {str(e)}")
77 | raise e
78 |
79 | def get_model_info(self) -> Dict[str, Any]:
80 | """
81 | 获取当前模型信息
82 |
83 | Returns:
84 | 模型信息字典
85 | """
86 | return {
87 | "provider": "OpenAI",
88 | "model": self.default_model,
89 | "api_base": "https://api.openai.com"
90 | }
91 |
--------------------------------------------------------------------------------
/src/llms/deepseek.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepSeek LLM实现
3 | 使用DeepSeek API进行文本生成
4 | """
5 |
6 | import os
7 | from typing import Optional, Dict, Any
8 | from openai import OpenAI
9 | from .base import BaseLLM
10 |
11 |
12 | class DeepSeekLLM(BaseLLM):
13 | """DeepSeek LLM实现类"""
14 |
15 | def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
16 | """
17 | 初始化DeepSeek客户端
18 |
19 | Args:
20 | api_key: DeepSeek API密钥,如果不提供则从环境变量读取
21 | model_name: 模型名称,默认使用deepseek-chat
22 | """
23 | if api_key is None:
24 | api_key = os.getenv("DEEPSEEK_API_KEY")
25 | if not api_key:
26 | raise ValueError("DeepSeek API Key未找到!请设置DEEPSEEK_API_KEY环境变量或在初始化时提供")
27 |
28 | super().__init__(api_key, model_name)
29 |
30 | # 初始化OpenAI客户端,使用DeepSeek的endpoint
31 | self.client = OpenAI(
32 | api_key=self.api_key,
33 | base_url="https://api.deepseek.com"
34 | )
35 |
36 | self.default_model = model_name or self.get_default_model()
37 |
38 | def get_default_model(self) -> str:
39 | """获取默认模型名称"""
40 | return "deepseek-chat"
41 |
42 | def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
43 | """
44 | 调用DeepSeek API生成回复
45 |
46 | Args:
47 | system_prompt: 系统提示词
48 | user_prompt: 用户输入
49 | **kwargs: 其他参数,如temperature、max_tokens等
50 |
51 | Returns:
52 | DeepSeek生成的回复文本
53 | """
54 | try:
55 | # 构建消息
56 | messages = [
57 | {"role": "system", "content": system_prompt},
58 | {"role": "user", "content": user_prompt}
59 | ]
60 |
61 | # 设置默认参数
62 | params = {
63 | "model": self.default_model,
64 | "messages": messages,
65 | "temperature": kwargs.get("temperature", 0.7),
66 | "max_tokens": kwargs.get("max_tokens", 4000),
67 | "stream": False
68 | }
69 |
70 | # 调用API
71 | response = self.client.chat.completions.create(**params)
72 |
73 | # 提取回复内容
74 | if response.choices and response.choices[0].message:
75 | content = response.choices[0].message.content
76 | return self.validate_response(content)
77 | else:
78 | return ""
79 |
80 | except Exception as e:
81 | print(f"DeepSeek API调用错误: {str(e)}")
82 | raise e
83 |
84 | def get_model_info(self) -> Dict[str, Any]:
85 | """
86 | 获取当前模型信息
87 |
88 | Returns:
89 | 模型信息字典
90 | """
91 | return {
92 | "provider": "DeepSeek",
93 | "model": self.default_model,
94 | "api_base": "https://api.deepseek.com"
95 | }
96 |
--------------------------------------------------------------------------------
/examples/advanced_usage.py:
--------------------------------------------------------------------------------
1 | """
2 | 高级使用示例
3 | 演示Deep Search Agent的高级功能
4 | """
5 |
6 | import os
7 | import sys
8 |
9 | # 添加项目根目录到Python路径
10 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
11 |
12 | from src import DeepSearchAgent, Config
13 | from src.utils.config import print_config
14 |
15 |
16 | def advanced_example():
17 | """高级使用示例"""
18 | print("=" * 60)
19 | print("Deep Search Agent - 高级使用示例")
20 | print("=" * 60)
21 |
22 | try:
23 | # 自定义配置
24 | print("正在创建自定义配置...")
25 | config = Config(
26 | # 使用OpenAI而不是DeepSeek
27 | default_llm_provider="openai",
28 | openai_model="gpt-4o-mini",
29 | # 自定义搜索参数
30 | max_search_results=5, # 更多搜索结果
31 | max_reflections=3, # 更多反思次数
32 | max_content_length=15000,
33 | # 自定义输出
34 | output_dir="custom_reports",
35 | save_intermediate_states=True
36 | )
37 |
38 | # 从环境变量设置API密钥
39 | config.openai_api_key = os.getenv("OPENAI_API_KEY")
40 | config.tavily_api_key = os.getenv("TAVILY_API_KEY")
41 |
42 | if not config.validate():
43 | print("配置验证失败,请检查API密钥设置")
44 | return
45 |
46 | print_config(config)
47 |
48 | # 创建Agent
49 | print("正在初始化Agent...")
50 | agent = DeepSearchAgent(config)
51 |
52 | # 执行多个研究任务
53 | queries = [
54 | "深度学习在医疗领域的应用",
55 | "区块链技术的最新发展",
56 | "可持续能源技术趋势"
57 | ]
58 |
59 | for i, query in enumerate(queries, 1):
60 | print(f"\n{'='*60}")
61 | print(f"执行研究任务 {i}/{len(queries)}: {query}")
62 | print(f"{'='*60}")
63 |
64 | try:
65 | # 执行研究
66 | final_report = agent.research(query, save_report=True)
67 |
68 | # 保存状态(示例)
69 | state_file = f"custom_reports/state_task_{i}.json"
70 | agent.save_state(state_file)
71 |
72 | print(f"任务 {i} 完成")
73 | print(f"报告长度: {len(final_report)} 字符")
74 |
75 | # 显示进度
76 | progress = agent.get_progress_summary()
77 | print(f"完成进度: {progress['progress_percentage']:.1f}%")
78 |
79 | except Exception as e:
80 | print(f"任务 {i} 失败: {str(e)}")
81 | continue
82 |
83 | print(f"\n{'='*60}")
84 | print("所有研究任务完成!")
85 | print(f"{'='*60}")
86 |
87 | except Exception as e:
88 | print(f"高级示例运行失败: {str(e)}")
89 |
90 |
91 | def state_management_example():
92 | """状态管理示例"""
93 | print("\n" + "=" * 60)
94 | print("状态管理示例")
95 | print("=" * 60)
96 |
97 | try:
98 | # 创建配置
99 | config = Config.from_env()
100 | if not config.validate():
101 | print("配置验证失败")
102 | return
103 |
104 | # 创建Agent
105 | agent = DeepSearchAgent(config)
106 |
107 | query = "量子计算的发展现状"
108 | print(f"开始研究: {query}")
109 |
110 | # 执行研究
111 | final_report = agent.research(query)
112 |
113 | # 保存状态
114 | state_file = "custom_reports/quantum_computing_state.json"
115 | agent.save_state(state_file)
116 | print(f"状态已保存到: {state_file}")
117 |
118 | # 创建新的Agent并加载状态
119 | print("\n创建新Agent并加载状态...")
120 | new_agent = DeepSearchAgent(config)
121 | new_agent.load_state(state_file)
122 |
123 | # 检查加载的状态
124 | progress = new_agent.get_progress_summary()
125 | print("加载的状态信息:")
126 | print(f"- 查询: {new_agent.state.query}")
127 | print(f"- 报告标题: {new_agent.state.report_title}")
128 | print(f"- 段落数: {progress['total_paragraphs']}")
129 | print(f"- 完成状态: {progress['is_completed']}")
130 |
131 | except Exception as e:
132 | print(f"状态管理示例失败: {str(e)}")
133 |
134 |
135 | if __name__ == "__main__":
136 | advanced_example()
137 | state_management_example()
138 |
--------------------------------------------------------------------------------
/src/tools/search.py:
--------------------------------------------------------------------------------
1 | """
2 | 搜索工具实现
3 | 支持多种搜索引擎,主要使用Tavily搜索
4 | """
5 |
6 | import os
7 | from typing import List, Dict, Any, Optional
8 | from dataclasses import dataclass
9 | from tavily import TavilyClient
10 |
11 |
12 | @dataclass
13 | class SearchResult:
14 | """搜索结果数据类"""
15 | title: str
16 | url: str
17 | content: str
18 | score: Optional[float] = None
19 |
20 | def to_dict(self) -> Dict[str, Any]:
21 | """转换为字典格式"""
22 | return {
23 | "title": self.title,
24 | "url": self.url,
25 | "content": self.content,
26 | "score": self.score
27 | }
28 |
29 |
30 | class TavilySearch:
31 | """Tavily搜索客户端封装"""
32 |
33 | def __init__(self, api_key: Optional[str] = None):
34 | """
35 | 初始化Tavily搜索客户端
36 |
37 | Args:
38 | api_key: Tavily API密钥,如果不提供则从环境变量读取
39 | """
40 | if api_key is None:
41 | api_key = os.getenv("TAVILY_API_KEY")
42 | if not api_key:
43 | raise ValueError("Tavily API Key未找到!请设置TAVILY_API_KEY环境变量或在初始化时提供")
44 |
45 | self.client = TavilyClient(api_key=api_key)
46 |
47 | def search(self, query: str, max_results: int = 5, include_raw_content: bool = True,
48 | timeout: int = 240) -> List[SearchResult]:
49 | """
50 | 执行搜索
51 |
52 | Args:
53 | query: 搜索查询
54 | max_results: 最大结果数量
55 | include_raw_content: 是否包含原始内容
56 | timeout: 超时时间(秒)
57 |
58 | Returns:
59 | 搜索结果列表
60 | """
61 | try:
62 | # 调用Tavily API
63 | response = self.client.search(
64 | query=query,
65 | max_results=max_results,
66 | include_raw_content=include_raw_content,
67 | timeout=timeout
68 | )
69 |
70 | # 解析结果
71 | results = []
72 | if 'results' in response:
73 | for item in response['results']:
74 | result = SearchResult(
75 | title=item.get('title', ''),
76 | url=item.get('url', ''),
77 | content=item.get('content', ''),
78 | score=item.get('score')
79 | )
80 | results.append(result)
81 |
82 | return results
83 |
84 | except Exception as e:
85 | print(f"搜索错误: {str(e)}")
86 | return []
87 |
88 |
89 | # 全局搜索客户端实例
90 | _tavily_client = None
91 |
92 |
93 | def get_tavily_client() -> TavilySearch:
94 | """获取全局Tavily客户端实例"""
95 | global _tavily_client
96 | if _tavily_client is None:
97 | _tavily_client = TavilySearch()
98 | return _tavily_client
99 |
100 |
101 | def tavily_search(query: str, max_results: int = 5, include_raw_content: bool = True,
102 | timeout: int = 240, api_key: Optional[str] = None) -> List[Dict[str, Any]]:
103 | """
104 | 便捷的Tavily搜索函数
105 |
106 | Args:
107 | query: 搜索查询
108 | max_results: 最大结果数量
109 | include_raw_content: 是否包含原始内容
110 | timeout: 超时时间(秒)
111 | api_key: Tavily API密钥,如果提供则使用此密钥,否则使用全局客户端
112 |
113 | Returns:
114 | 搜索结果字典列表,保持与原始经验贴兼容的格式
115 | """
116 | try:
117 | if api_key:
118 | # 使用提供的API密钥创建临时客户端
119 | client = TavilySearch(api_key)
120 | else:
121 | # 使用全局客户端
122 | client = get_tavily_client()
123 |
124 | results = client.search(query, max_results, include_raw_content, timeout)
125 |
126 | # 转换为字典格式以保持兼容性
127 | return [result.to_dict() for result in results]
128 |
129 | except Exception as e:
130 | print(f"搜索功能调用错误: {str(e)}")
131 | return []
132 |
133 |
134 | def test_search(query: str = "人工智能发展趋势 2025", max_results: int = 3):
135 | """
136 | 测试搜索功能
137 |
138 | Args:
139 | query: 测试查询
140 | max_results: 最大结果数量
141 | """
142 | print(f"\n=== 测试Tavily搜索功能 ===")
143 | print(f"搜索查询: {query}")
144 | print(f"最大结果数: {max_results}")
145 |
146 | try:
147 | results = tavily_search(query, max_results=max_results)
148 |
149 | if results:
150 | print(f"\n找到 {len(results)} 个结果:")
151 | for i, result in enumerate(results, 1):
152 | print(f"\n结果 {i}:")
153 | print(f"标题: {result['title']}")
154 | print(f"链接: {result['url']}")
155 | print(f"内容摘要: {result['content'][:200]}...")
156 | if result.get('score'):
157 | print(f"相关度评分: {result['score']}")
158 | else:
159 | print("未找到搜索结果")
160 |
161 | except Exception as e:
162 | print(f"搜索测试失败: {str(e)}")
163 |
164 |
165 | if __name__ == "__main__":
166 | # 运行测试
167 | test_search()
168 |
--------------------------------------------------------------------------------
/src/nodes/report_structure_node.py:
--------------------------------------------------------------------------------
1 | """
2 | 报告结构生成节点
3 | 负责根据查询生成报告的整体结构
4 | """
5 |
6 | import json
7 | from typing import Dict, Any, List
8 | from json.decoder import JSONDecodeError
9 |
10 | from .base_node import StateMutationNode
11 | from ..state.state import State
12 | from ..prompts import SYSTEM_PROMPT_REPORT_STRUCTURE
13 | from ..utils.text_processing import (
14 | remove_reasoning_from_output,
15 | clean_json_tags,
16 | extract_clean_response
17 | )
18 |
19 |
20 | class ReportStructureNode(StateMutationNode):
21 | """生成报告结构的节点"""
22 |
23 | def __init__(self, llm_client, query: str):
24 | """
25 | 初始化报告结构节点
26 |
27 | Args:
28 | llm_client: LLM客户端
29 | query: 用户查询
30 | """
31 | super().__init__(llm_client, "ReportStructureNode")
32 | self.query = query
33 |
34 | def validate_input(self, input_data: Any) -> bool:
35 | """验证输入数据"""
36 | return isinstance(self.query, str) and len(self.query.strip()) > 0
37 |
38 | def run(self, input_data: Any = None, **kwargs) -> List[Dict[str, str]]:
39 | """
40 | 调用LLM生成报告结构
41 |
42 | Args:
43 | input_data: 输入数据(这里不使用,使用初始化时的query)
44 | **kwargs: 额外参数
45 |
46 | Returns:
47 | 报告结构列表
48 | """
49 | try:
50 | self.log_info(f"正在为查询生成报告结构: {self.query}")
51 |
52 | # 调用LLM
53 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
54 |
55 | # 处理响应
56 | processed_response = self.process_output(response)
57 |
58 | self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
59 | return processed_response
60 |
61 | except Exception as e:
62 | self.log_error(f"生成报告结构失败: {str(e)}")
63 | raise e
64 |
65 | def process_output(self, output: str) -> List[Dict[str, str]]:
66 | """
67 | 处理LLM输出,提取报告结构
68 |
69 | Args:
70 | output: LLM原始输出
71 |
72 | Returns:
73 | 处理后的报告结构列表
74 | """
75 | try:
76 | # 清理响应文本
77 | cleaned_output = remove_reasoning_from_output(output)
78 | cleaned_output = clean_json_tags(cleaned_output)
79 |
80 | # 解析JSON
81 | try:
82 | report_structure = json.loads(cleaned_output)
83 | except JSONDecodeError:
84 | # 使用更强大的提取方法
85 | report_structure = extract_clean_response(cleaned_output)
86 | if "error" in report_structure:
87 | raise ValueError("JSON解析失败")
88 |
89 | # 验证结构
90 | if not isinstance(report_structure, list):
91 | raise ValueError("报告结构应该是一个列表")
92 |
93 | # 验证每个段落
94 | validated_structure = []
95 | for i, paragraph in enumerate(report_structure):
96 | if not isinstance(paragraph, dict):
97 | continue
98 |
99 | title = paragraph.get("title", f"段落 {i+1}")
100 | content = paragraph.get("content", "")
101 |
102 | validated_structure.append({
103 | "title": title,
104 | "content": content
105 | })
106 |
107 | return validated_structure
108 |
109 | except Exception as e:
110 | self.log_error(f"处理输出失败: {str(e)}")
111 | # 返回默认结构
112 | return [
113 | {
114 | "title": "概述",
115 | "content": f"对'{self.query}'的总体概述和背景介绍"
116 | },
117 | {
118 | "title": "详细分析",
119 | "content": f"深入分析'{self.query}'的相关内容"
120 | }
121 | ]
122 |
123 | def mutate_state(self, input_data: Any = None, state: State = None, **kwargs) -> State:
124 | """
125 | 将报告结构写入状态
126 |
127 | Args:
128 | input_data: 输入数据
129 | state: 当前状态,如果为None则创建新状态
130 | **kwargs: 额外参数
131 |
132 | Returns:
133 | 更新后的状态
134 | """
135 | if state is None:
136 | state = State()
137 |
138 | try:
139 | # 生成报告结构
140 | report_structure = self.run(input_data, **kwargs)
141 |
142 | # 设置查询和报告标题
143 | state.query = self.query
144 | if not state.report_title:
145 | state.report_title = f"关于'{self.query}'的深度研究报告"
146 |
147 | # 添加段落到状态
148 | for paragraph_data in report_structure:
149 | state.add_paragraph(
150 | title=paragraph_data["title"],
151 | content=paragraph_data["content"]
152 | )
153 |
154 | self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
155 | return state
156 |
157 | except Exception as e:
158 | self.log_error(f"状态更新失败: {str(e)}")
159 | raise e
160 |
--------------------------------------------------------------------------------
/src/utils/text_processing.py:
--------------------------------------------------------------------------------
1 | """
2 | 文本处理工具函数
3 | 用于清理LLM输出、解析JSON等
4 | """
5 |
6 | import re
7 | import json
8 | from typing import Dict, Any, List
9 | from json.decoder import JSONDecodeError
10 |
11 |
12 | def clean_json_tags(text: str) -> str:
13 | """
14 | 清理文本中的JSON标签
15 |
16 | Args:
17 | text: 原始文本
18 |
19 | Returns:
20 | 清理后的文本
21 | """
22 | # 移除```json 和 ```标签
23 | text = re.sub(r'```json\s*', '', text)
24 | text = re.sub(r'```\s*$', '', text)
25 | text = re.sub(r'```', '', text)
26 |
27 | return text.strip()
28 |
29 |
30 | def clean_markdown_tags(text: str) -> str:
31 | """
32 | 清理文本中的Markdown标签
33 |
34 | Args:
35 | text: 原始文本
36 |
37 | Returns:
38 | 清理后的文本
39 | """
40 | # 移除```markdown 和 ```标签
41 | text = re.sub(r'```markdown\s*', '', text)
42 | text = re.sub(r'```\s*$', '', text)
43 | text = re.sub(r'```', '', text)
44 |
45 | return text.strip()
46 |
47 |
48 | def remove_reasoning_from_output(text: str) -> str:
49 | """
50 | 移除输出中的推理过程文本
51 |
52 | Args:
53 | text: 原始文本
54 |
55 | Returns:
56 | 清理后的文本
57 | """
58 | # 移除常见的推理标识
59 | patterns = [
60 | r'(?:reasoning|推理|思考|分析)[::]\s*.*?(?=\{|\[)', # 移除推理部分
61 | r'(?:explanation|解释|说明)[::]\s*.*?(?=\{|\[)', # 移除解释部分
62 | r'^.*?(?=\{|\[)', # 移除JSON前的所有文本
63 | ]
64 |
65 | for pattern in patterns:
66 | text = re.sub(pattern, '', text, flags=re.IGNORECASE | re.DOTALL)
67 |
68 | return text.strip()
69 |
70 |
71 | def extract_clean_response(text: str) -> Dict[str, Any]:
72 | """
73 | 提取并清理响应中的JSON内容
74 |
75 | Args:
76 | text: 原始响应文本
77 |
78 | Returns:
79 | 解析后的JSON字典
80 | """
81 | # 清理文本
82 | cleaned_text = clean_json_tags(text)
83 | cleaned_text = remove_reasoning_from_output(cleaned_text)
84 |
85 | # 尝试直接解析
86 | try:
87 | return json.loads(cleaned_text)
88 | except JSONDecodeError:
89 | pass
90 |
91 | # 尝试查找JSON对象
92 | json_pattern = r'\{.*\}'
93 | match = re.search(json_pattern, cleaned_text, re.DOTALL)
94 | if match:
95 | try:
96 | return json.loads(match.group())
97 | except JSONDecodeError:
98 | pass
99 |
100 | # 尝试查找JSON数组
101 | array_pattern = r'\[.*\]'
102 | match = re.search(array_pattern, cleaned_text, re.DOTALL)
103 | if match:
104 | try:
105 | return json.loads(match.group())
106 | except JSONDecodeError:
107 | pass
108 |
109 | # 如果所有方法都失败,返回错误信息
110 | print(f"无法解析JSON响应: {cleaned_text[:200]}...")
111 | return {"error": "JSON解析失败", "raw_text": cleaned_text}
112 |
113 |
114 | def update_state_with_search_results(search_results: List[Dict[str, Any]],
115 | paragraph_index: int, state: Any) -> Any:
116 | """
117 | 将搜索结果更新到状态中
118 |
119 | Args:
120 | search_results: 搜索结果列表
121 | paragraph_index: 段落索引
122 | state: 状态对象
123 |
124 | Returns:
125 | 更新后的状态对象
126 | """
127 | if 0 <= paragraph_index < len(state.paragraphs):
128 | # 获取最后一次搜索的查询(假设是当前查询)
129 | current_query = ""
130 | if search_results:
131 | # 从搜索结果推断查询(这里需要改进以获取实际查询)
132 | current_query = "搜索查询"
133 |
134 | # 添加搜索结果到状态
135 | state.paragraphs[paragraph_index].research.add_search_results(
136 | current_query, search_results
137 | )
138 |
139 | return state
140 |
141 |
142 | def validate_json_schema(data: Dict[str, Any], required_fields: List[str]) -> bool:
143 | """
144 | 验证JSON数据是否包含必需字段
145 |
146 | Args:
147 | data: 要验证的数据
148 | required_fields: 必需字段列表
149 |
150 | Returns:
151 | 验证是否通过
152 | """
153 | return all(field in data for field in required_fields)
154 |
155 |
156 | def truncate_content(content: str, max_length: int = 20000) -> str:
157 | """
158 | 截断内容到指定长度
159 |
160 | Args:
161 | content: 原始内容
162 | max_length: 最大长度
163 |
164 | Returns:
165 | 截断后的内容
166 | """
167 | if len(content) <= max_length:
168 | return content
169 |
170 | # 尝试在单词边界截断
171 | truncated = content[:max_length]
172 | last_space = truncated.rfind(' ')
173 |
174 | if last_space > max_length * 0.8: # 如果最后一个空格位置合理
175 | return truncated[:last_space] + "..."
176 | else:
177 | return truncated + "..."
178 |
179 |
180 | def format_search_results_for_prompt(search_results: List[Dict[str, Any]],
181 | max_length: int = 20000) -> List[str]:
182 | """
183 | 格式化搜索结果用于提示词
184 |
185 | Args:
186 | search_results: 搜索结果列表
187 | max_length: 每个结果的最大长度
188 |
189 | Returns:
190 | 格式化后的内容列表
191 | """
192 | formatted_results = []
193 |
194 | for result in search_results:
195 | content = result.get('content', '')
196 | if content:
197 | truncated_content = truncate_content(content, max_length)
198 | formatted_results.append(truncated_content)
199 |
200 | return formatted_results
201 |
--------------------------------------------------------------------------------
/src/nodes/formatting_node.py:
--------------------------------------------------------------------------------
1 | """
2 | 报告格式化节点
3 | 负责将最终研究结果格式化为美观的Markdown报告
4 | """
5 |
6 | import json
7 | from typing import List, Dict, Any
8 |
9 | from .base_node import BaseNode
10 | from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
11 | from ..utils.text_processing import (
12 | remove_reasoning_from_output,
13 | clean_markdown_tags
14 | )
15 |
16 |
17 | class ReportFormattingNode(BaseNode):
18 | """格式化最终报告的节点"""
19 |
20 | def __init__(self, llm_client):
21 | """
22 | 初始化报告格式化节点
23 |
24 | Args:
25 | llm_client: LLM客户端
26 | """
27 | super().__init__(llm_client, "ReportFormattingNode")
28 |
29 | def validate_input(self, input_data: Any) -> bool:
30 | """验证输入数据"""
31 | if isinstance(input_data, str):
32 | try:
33 | data = json.loads(input_data)
34 | return isinstance(data, list) and all(
35 | isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
36 | for item in data
37 | )
38 | except:
39 | return False
40 | elif isinstance(input_data, list):
41 | return all(
42 | isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
43 | for item in input_data
44 | )
45 | return False
46 |
47 | def run(self, input_data: Any, **kwargs) -> str:
48 | """
49 | 调用LLM生成Markdown格式报告
50 |
51 | Args:
52 | input_data: 包含所有段落信息的列表
53 | **kwargs: 额外参数
54 |
55 | Returns:
56 | 格式化的Markdown报告
57 | """
58 | try:
59 | if not self.validate_input(input_data):
60 | raise ValueError("输入数据格式错误,需要包含title和paragraph_latest_state的列表")
61 |
62 | # 准备输入数据
63 | if isinstance(input_data, str):
64 | message = input_data
65 | else:
66 | message = json.dumps(input_data, ensure_ascii=False)
67 |
68 | self.log_info("正在格式化最终报告")
69 |
70 | # 调用LLM
71 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message)
72 |
73 | # 处理响应
74 | processed_response = self.process_output(response)
75 |
76 | self.log_info("成功生成格式化报告")
77 | return processed_response
78 |
79 | except Exception as e:
80 | self.log_error(f"报告格式化失败: {str(e)}")
81 | raise e
82 |
83 | def process_output(self, output: str) -> str:
84 | """
85 | 处理LLM输出,清理Markdown格式
86 |
87 | Args:
88 | output: LLM原始输出
89 |
90 | Returns:
91 | 清理后的Markdown报告
92 | """
93 | try:
94 | # 清理响应文本
95 | cleaned_output = remove_reasoning_from_output(output)
96 | cleaned_output = clean_markdown_tags(cleaned_output)
97 |
98 | # 确保报告有基本结构
99 | if not cleaned_output.strip():
100 | return "# 报告生成失败\n\n无法生成有效的报告内容。"
101 |
102 | # 如果没有标题,添加一个默认标题
103 | if not cleaned_output.strip().startswith('#'):
104 | cleaned_output = "# 深度研究报告\n\n" + cleaned_output
105 |
106 | return cleaned_output.strip()
107 |
108 | except Exception as e:
109 | self.log_error(f"处理输出失败: {str(e)}")
110 | return "# 报告处理失败\n\n报告格式化过程中发生错误。"
111 |
112 | def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
113 | report_title: str = "深度研究报告") -> str:
114 | """
115 | 手动格式化报告(备用方法)
116 |
117 | Args:
118 | paragraphs_data: 段落数据列表
119 | report_title: 报告标题
120 |
121 | Returns:
122 | 格式化的Markdown报告
123 | """
124 | try:
125 | self.log_info("使用手动格式化方法")
126 |
127 | # 构建报告
128 | report_lines = [
129 | f"# {report_title}",
130 | "",
131 | "---",
132 | ""
133 | ]
134 |
135 | # 添加各个段落
136 | for i, paragraph in enumerate(paragraphs_data, 1):
137 | title = paragraph.get("title", f"段落 {i}")
138 | content = paragraph.get("paragraph_latest_state", "")
139 |
140 | if content:
141 | report_lines.extend([
142 | f"## {title}",
143 | "",
144 | content,
145 | "",
146 | "---",
147 | ""
148 | ])
149 |
150 | # 添加结论
151 | if len(paragraphs_data) > 1:
152 | report_lines.extend([
153 | "## 结论",
154 | "",
155 | "本报告通过深度搜索和研究,对相关主题进行了全面分析。"
156 | "以上各个方面的内容为理解该主题提供了重要参考。",
157 | ""
158 | ])
159 |
160 | return "\n".join(report_lines)
161 |
162 | except Exception as e:
163 | self.log_error(f"手动格式化失败: {str(e)}")
164 | return "# 报告生成失败\n\n无法完成报告格式化。"
165 |
--------------------------------------------------------------------------------
/src/prompts/prompts.py:
--------------------------------------------------------------------------------
1 | """
2 | Deep Search Agent 的所有提示词定义
3 | 包含各个阶段的系统提示词和JSON Schema定义
4 | """
5 |
6 | import json
7 |
8 | # ===== JSON Schema 定义 =====
9 |
10 | # 报告结构输出Schema
11 | output_schema_report_structure = {
12 | "type": "array",
13 | "items": {
14 | "type": "object",
15 | "properties": {
16 | "title": {"type": "string"},
17 | "content": {"type": "string"}
18 | }
19 | }
20 | }
21 |
22 | # 首次搜索输入Schema
23 | input_schema_first_search = {
24 | "type": "object",
25 | "properties": {
26 | "title": {"type": "string"},
27 | "content": {"type": "string"}
28 | }
29 | }
30 |
31 | # 首次搜索输出Schema
32 | output_schema_first_search = {
33 | "type": "object",
34 | "properties": {
35 | "search_query": {"type": "string"},
36 | "reasoning": {"type": "string"}
37 | }
38 | }
39 |
40 | # 首次总结输入Schema
41 | input_schema_first_summary = {
42 | "type": "object",
43 | "properties": {
44 | "title": {"type": "string"},
45 | "content": {"type": "string"},
46 | "search_query": {"type": "string"},
47 | "search_results": {
48 | "type": "array",
49 | "items": {"type": "string"}
50 | }
51 | }
52 | }
53 |
54 | # 首次总结输出Schema
55 | output_schema_first_summary = {
56 | "type": "object",
57 | "properties": {
58 | "paragraph_latest_state": {"type": "string"}
59 | }
60 | }
61 |
62 | # 反思输入Schema
63 | input_schema_reflection = {
64 | "type": "object",
65 | "properties": {
66 | "title": {"type": "string"},
67 | "content": {"type": "string"},
68 | "paragraph_latest_state": {"type": "string"}
69 | }
70 | }
71 |
72 | # 反思输出Schema
73 | output_schema_reflection = {
74 | "type": "object",
75 | "properties": {
76 | "search_query": {"type": "string"},
77 | "reasoning": {"type": "string"}
78 | }
79 | }
80 |
81 | # 反思总结输入Schema
82 | input_schema_reflection_summary = {
83 | "type": "object",
84 | "properties": {
85 | "title": {"type": "string"},
86 | "content": {"type": "string"},
87 | "search_query": {"type": "string"},
88 | "search_results": {
89 | "type": "array",
90 | "items": {"type": "string"}
91 | },
92 | "paragraph_latest_state": {"type": "string"}
93 | }
94 | }
95 |
96 | # 反思总结输出Schema
97 | output_schema_reflection_summary = {
98 | "type": "object",
99 | "properties": {
100 | "updated_paragraph_latest_state": {"type": "string"}
101 | }
102 | }
103 |
104 | # 报告格式化输入Schema
105 | input_schema_report_formatting = {
106 | "type": "array",
107 | "items": {
108 | "type": "object",
109 | "properties": {
110 | "title": {"type": "string"},
111 | "paragraph_latest_state": {"type": "string"}
112 | }
113 | }
114 | }
115 |
116 | # ===== 系统提示词定义 =====
117 |
118 | # 生成报告结构的系统提示词
119 | SYSTEM_PROMPT_REPORT_STRUCTURE = f"""
120 | 你是一位深度研究助手。给定一个查询,你需要规划一个报告的结构和其中包含的段落。最多五个段落。
121 | 确保段落的排序合理有序。
122 | 一旦大纲创建完成,你将获得工具来分别为每个部分搜索网络并进行反思。
123 | 请按照以下JSON模式定义格式化输出:
124 |
125 |
128 |
129 | 标题和内容属性将用于更深入的研究。
130 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。
131 | 只返回JSON对象,不要有解释或额外文本。
132 | """
133 |
134 | # 每个段落第一次搜索的系统提示词
135 | SYSTEM_PROMPT_FIRST_SEARCH = f"""
136 | 你是一位深度研究助手。你将获得报告中的一个段落,其标题和预期内容将按照以下JSON模式定义提供:
137 |
138 |
139 | {json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)}
140 |
141 |
142 | 你可以使用一个网络搜索工具,该工具接受'search_query'作为参数。
143 | 你的任务是思考这个主题,并提供最佳的网络搜索查询来丰富你当前的知识。
144 | 请按照以下JSON模式定义格式化输出(文字请使用中文):
145 |
146 |
149 |
150 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。
151 | 只返回JSON对象,不要有解释或额外文本。
152 | """
153 |
154 | # 每个段落第一次总结的系统提示词
155 | SYSTEM_PROMPT_FIRST_SUMMARY = f"""
156 | 你是一位深度研究助手。你将获得搜索查询、搜索结果以及你正在研究的报告段落,数据将按照以下JSON模式定义提供:
157 |
158 |
159 | {json.dumps(input_schema_first_summary, indent=2, ensure_ascii=False)}
160 |
161 |
162 | 你的任务是作为研究者,使用搜索结果撰写与段落主题一致的内容,并适当地组织结构以便纳入报告中。
163 | 请按照以下JSON模式定义格式化输出:
164 |
165 |
168 |
169 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。
170 | 只返回JSON对象,不要有解释或额外文本。
171 | """
172 |
173 | # 反思(Reflect)的系统提示词
174 | SYSTEM_PROMPT_REFLECTION = f"""
175 | 你是一位深度研究助手。你负责为研究报告构建全面的段落。你将获得段落标题、计划内容摘要,以及你已经创建的段落最新状态,所有这些都将按照以下JSON模式定义提供:
176 |
177 |
178 | {json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)}
179 |
180 |
181 | 你可以使用一个网络搜索工具,该工具接受'search_query'作为参数。
182 | 你的任务是反思段落文本的当前状态,思考是否遗漏了主题的某些关键方面,并提供最佳的网络搜索查询来丰富最新状态。
183 | 请按照以下JSON模式定义格式化输出:
184 |
185 |
188 |
189 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。
190 | 只返回JSON对象,不要有解释或额外文本。
191 | """
192 |
193 | # 总结反思的系统提示词
194 | SYSTEM_PROMPT_REFLECTION_SUMMARY = f"""
195 | 你是一位深度研究助手。
196 | 你将获得搜索查询、搜索结果、段落标题以及你正在研究的报告段落的预期内容。
197 | 你正在迭代完善这个段落,并且段落的最新状态也会提供给你。
198 | 数据将按照以下JSON模式定义提供:
199 |
200 |
201 | {json.dumps(input_schema_reflection_summary, indent=2, ensure_ascii=False)}
202 |
203 |
204 | 你的任务是根据搜索结果和预期内容丰富段落的当前最新状态。
205 | 不要删除最新状态中的关键信息,尽量丰富它,只添加缺失的信息。
206 | 适当地组织段落结构以便纳入报告中。
207 | 请按照以下JSON模式定义格式化输出:
208 |
209 |
212 |
213 | 确保输出是一个符合上述输出JSON模式定义的JSON对象。
214 | 只返回JSON对象,不要有解释或额外文本。
215 | """
216 |
217 | # 最终研究报告格式化的系统提示词
218 | SYSTEM_PROMPT_REPORT_FORMATTING = f"""
219 | 你是一位深度研究助手。你已经完成了研究并构建了报告中所有段落的最终版本。
220 | 你将获得以下JSON格式的数据:
221 |
222 |
223 | {json.dumps(input_schema_report_formatting, indent=2, ensure_ascii=False)}
224 |
225 |
226 | 你的任务是将报告格式化为美观的形式,并以Markdown格式返回。
227 | 如果没有结论段落,请根据其他段落的最新状态在报告末尾添加一个结论。
228 | 使用段落标题来创建报告的标题。
229 | """
230 |
--------------------------------------------------------------------------------
/src/utils/config.py:
--------------------------------------------------------------------------------
1 | """
2 | 配置管理模块
3 | 处理环境变量和配置参数
4 | """
5 |
6 | import os
7 | from dataclasses import dataclass
8 | from typing import Optional
9 |
10 |
11 | @dataclass
12 | class Config:
13 | """配置类"""
14 | # API密钥
15 | deepseek_api_key: Optional[str] = None
16 | openai_api_key: Optional[str] = None
17 | tavily_api_key: Optional[str] = None
18 |
19 | # 模型配置
20 | default_llm_provider: str = "deepseek" # deepseek 或 openai
21 | deepseek_model: str = "deepseek-chat"
22 | openai_model: str = "gpt-4o-mini"
23 |
24 | # 搜索配置
25 | max_search_results: int = 3
26 | search_timeout: int = 240
27 | max_content_length: int = 20000
28 |
29 | # Agent配置
30 | max_reflections: int = 2
31 | max_paragraphs: int = 5
32 |
33 | # 输出配置
34 | output_dir: str = "reports"
35 | save_intermediate_states: bool = True
36 |
37 | def validate(self) -> bool:
38 | """验证配置"""
39 | # 检查必需的API密钥
40 | if self.default_llm_provider == "deepseek" and not self.deepseek_api_key:
41 | print("错误: DeepSeek API Key未设置")
42 | return False
43 |
44 | if self.default_llm_provider == "openai" and not self.openai_api_key:
45 | print("错误: OpenAI API Key未设置")
46 | return False
47 |
48 | if not self.tavily_api_key:
49 | print("错误: Tavily API Key未设置")
50 | return False
51 |
52 | return True
53 |
54 | @classmethod
55 | def from_file(cls, config_file: str) -> "Config":
56 | """从配置文件创建配置"""
57 | if config_file.endswith('.py'):
58 | # Python配置文件
59 | import importlib.util
60 |
61 | # 动态导入配置文件
62 | spec = importlib.util.spec_from_file_location("config", config_file)
63 | config_module = importlib.util.module_from_spec(spec)
64 | spec.loader.exec_module(config_module)
65 |
66 | return cls(
67 | deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None),
68 | openai_api_key=getattr(config_module, "OPENAI_API_KEY", None),
69 | tavily_api_key=getattr(config_module, "TAVILY_API_KEY", None),
70 | default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"),
71 | deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"),
72 | openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"),
73 | max_search_results=getattr(config_module, "SEARCH_RESULTS_PER_QUERY", 3),
74 | search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240),
75 | max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000),
76 | max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2),
77 | max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5),
78 | output_dir=getattr(config_module, "OUTPUT_DIR", "reports"),
79 | save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True)
80 | )
81 | else:
82 | # .env格式配置文件
83 | config_dict = {}
84 |
85 | if os.path.exists(config_file):
86 | with open(config_file, 'r', encoding='utf-8') as f:
87 | for line in f:
88 | line = line.strip()
89 | if line and not line.startswith('#') and '=' in line:
90 | key, value = line.split('=', 1)
91 | config_dict[key.strip()] = value.strip()
92 |
93 | return cls(
94 | deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
95 | openai_api_key=config_dict.get("OPENAI_API_KEY"),
96 | tavily_api_key=config_dict.get("TAVILY_API_KEY"),
97 | default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
98 | deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
99 | openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
100 | max_search_results=int(config_dict.get("SEARCH_RESULTS_PER_QUERY", "3")),
101 | search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")),
102 | max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")),
103 | max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")),
104 | max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")),
105 | output_dir=config_dict.get("OUTPUT_DIR", "reports"),
106 | save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true"
107 | )
108 |
109 |
110 | def load_config(config_file: Optional[str] = None) -> Config:
111 | """
112 | 加载配置
113 |
114 | Args:
115 | config_file: 配置文件路径,如果不指定则使用默认路径
116 |
117 | Returns:
118 | 配置对象
119 | """
120 | # 确定配置文件路径
121 | if config_file:
122 | if not os.path.exists(config_file):
123 | raise FileNotFoundError(f"配置文件不存在: {config_file}")
124 | file_to_load = config_file
125 | else:
126 | # 尝试加载常见的配置文件
127 | for config_path in ["config.py", "config.env", ".env"]:
128 | if os.path.exists(config_path):
129 | file_to_load = config_path
130 | print(f"已找到配置文件: {config_path}")
131 | break
132 | else:
133 | raise FileNotFoundError("未找到配置文件,请创建 config.py 文件")
134 |
135 | # 创建配置对象
136 | config = Config.from_file(file_to_load)
137 |
138 | # 验证配置
139 | if not config.validate():
140 | raise ValueError("配置验证失败,请检查配置文件中的API密钥")
141 |
142 | return config
143 |
144 |
145 | def print_config(config: Config):
146 | """打印配置信息(隐藏敏感信息)"""
147 | print("\n=== 当前配置 ===")
148 | print(f"LLM提供商: {config.default_llm_provider}")
149 | print(f"DeepSeek模型: {config.deepseek_model}")
150 | print(f"OpenAI模型: {config.openai_model}")
151 | print(f"最大搜索结果数: {config.max_search_results}")
152 | print(f"搜索超时: {config.search_timeout}秒")
153 | print(f"最大内容长度: {config.max_content_length}")
154 | print(f"最大反思次数: {config.max_reflections}")
155 | print(f"最大段落数: {config.max_paragraphs}")
156 | print(f"输出目录: {config.output_dir}")
157 | print(f"保存中间状态: {config.save_intermediate_states}")
158 |
159 | # 显示API密钥状态(不显示实际密钥)
160 | print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}")
161 | print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}")
162 | print(f"Tavily API Key: {'已设置' if config.tavily_api_key else '未设置'}")
163 | print("==================\n")
164 |
--------------------------------------------------------------------------------
/src/nodes/search_node.py:
--------------------------------------------------------------------------------
1 | """
2 | 搜索节点实现
3 | 负责生成搜索查询和反思查询
4 | """
5 |
6 | import json
7 | from typing import Dict, Any
8 | from json.decoder import JSONDecodeError
9 |
10 | from .base_node import BaseNode
11 | from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
12 | from ..utils.text_processing import (
13 | remove_reasoning_from_output,
14 | clean_json_tags,
15 | extract_clean_response
16 | )
17 |
18 |
19 | class FirstSearchNode(BaseNode):
20 | """为段落生成首次搜索查询的节点"""
21 |
22 | def __init__(self, llm_client):
23 | """
24 | 初始化首次搜索节点
25 |
26 | Args:
27 | llm_client: LLM客户端
28 | """
29 | super().__init__(llm_client, "FirstSearchNode")
30 |
31 | def validate_input(self, input_data: Any) -> bool:
32 | """验证输入数据"""
33 | if isinstance(input_data, str):
34 | try:
35 | data = json.loads(input_data)
36 | return "title" in data and "content" in data
37 | except JSONDecodeError:
38 | return False
39 | elif isinstance(input_data, dict):
40 | return "title" in input_data and "content" in input_data
41 | return False
42 |
43 | def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
44 | """
45 | 调用LLM生成搜索查询和理由
46 |
47 | Args:
48 | input_data: 包含title和content的字符串或字典
49 | **kwargs: 额外参数
50 |
51 | Returns:
52 | 包含search_query和reasoning的字典
53 | """
54 | try:
55 | if not self.validate_input(input_data):
56 | raise ValueError("输入数据格式错误,需要包含title和content字段")
57 |
58 | # 准备输入数据
59 | if isinstance(input_data, str):
60 | message = input_data
61 | else:
62 | message = json.dumps(input_data, ensure_ascii=False)
63 |
64 | self.log_info("正在生成首次搜索查询")
65 |
66 | # 调用LLM
67 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
68 |
69 | # 处理响应
70 | processed_response = self.process_output(response)
71 |
72 | self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
73 | return processed_response
74 |
75 | except Exception as e:
76 | self.log_error(f"生成首次搜索查询失败: {str(e)}")
77 | raise e
78 |
79 | def process_output(self, output: str) -> Dict[str, str]:
80 | """
81 | 处理LLM输出,提取搜索查询和推理
82 |
83 | Args:
84 | output: LLM原始输出
85 |
86 | Returns:
87 | 包含search_query和reasoning的字典
88 | """
89 | try:
90 | # 清理响应文本
91 | cleaned_output = remove_reasoning_from_output(output)
92 | cleaned_output = clean_json_tags(cleaned_output)
93 |
94 | # 解析JSON
95 | try:
96 | result = json.loads(cleaned_output)
97 | except JSONDecodeError:
98 | # 使用更强大的提取方法
99 | result = extract_clean_response(cleaned_output)
100 | if "error" in result:
101 | raise ValueError("JSON解析失败")
102 |
103 | # 验证和清理结果
104 | search_query = result.get("search_query", "")
105 | reasoning = result.get("reasoning", "")
106 |
107 | if not search_query:
108 | raise ValueError("未找到搜索查询")
109 |
110 | return {
111 | "search_query": search_query,
112 | "reasoning": reasoning
113 | }
114 |
115 | except Exception as e:
116 | self.log_error(f"处理输出失败: {str(e)}")
117 | # 返回默认查询
118 | return {
119 | "search_query": "相关主题研究",
120 | "reasoning": "由于解析失败,使用默认搜索查询"
121 | }
122 |
123 |
124 | class ReflectionNode(BaseNode):
125 | """反思段落并生成新搜索查询的节点"""
126 |
127 | def __init__(self, llm_client):
128 | """
129 | 初始化反思节点
130 |
131 | Args:
132 | llm_client: LLM客户端
133 | """
134 | super().__init__(llm_client, "ReflectionNode")
135 |
136 | def validate_input(self, input_data: Any) -> bool:
137 | """验证输入数据"""
138 | if isinstance(input_data, str):
139 | try:
140 | data = json.loads(input_data)
141 | required_fields = ["title", "content", "paragraph_latest_state"]
142 | return all(field in data for field in required_fields)
143 | except JSONDecodeError:
144 | return False
145 | elif isinstance(input_data, dict):
146 | required_fields = ["title", "content", "paragraph_latest_state"]
147 | return all(field in input_data for field in required_fields)
148 | return False
149 |
150 | def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
151 | """
152 | 调用LLM反思并生成搜索查询
153 |
154 | Args:
155 | input_data: 包含title、content和paragraph_latest_state的字符串或字典
156 | **kwargs: 额外参数
157 |
158 | Returns:
159 | 包含search_query和reasoning的字典
160 | """
161 | try:
162 | if not self.validate_input(input_data):
163 | raise ValueError("输入数据格式错误,需要包含title、content和paragraph_latest_state字段")
164 |
165 | # 准备输入数据
166 | if isinstance(input_data, str):
167 | message = input_data
168 | else:
169 | message = json.dumps(input_data, ensure_ascii=False)
170 |
171 | self.log_info("正在进行反思并生成新搜索查询")
172 |
173 | # 调用LLM
174 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
175 |
176 | # 处理响应
177 | processed_response = self.process_output(response)
178 |
179 | self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
180 | return processed_response
181 |
182 | except Exception as e:
183 | self.log_error(f"反思生成搜索查询失败: {str(e)}")
184 | raise e
185 |
186 | def process_output(self, output: str) -> Dict[str, str]:
187 | """
188 | 处理LLM输出,提取搜索查询和推理
189 |
190 | Args:
191 | output: LLM原始输出
192 |
193 | Returns:
194 | 包含search_query和reasoning的字典
195 | """
196 | try:
197 | # 清理响应文本
198 | cleaned_output = remove_reasoning_from_output(output)
199 | cleaned_output = clean_json_tags(cleaned_output)
200 |
201 | # 解析JSON
202 | try:
203 | result = json.loads(cleaned_output)
204 | except JSONDecodeError:
205 | # 使用更强大的提取方法
206 | result = extract_clean_response(cleaned_output)
207 | if "error" in result:
208 | raise ValueError("JSON解析失败")
209 |
210 | # 验证和清理结果
211 | search_query = result.get("search_query", "")
212 | reasoning = result.get("reasoning", "")
213 |
214 | if not search_query:
215 | raise ValueError("未找到搜索查询")
216 |
217 | return {
218 | "search_query": search_query,
219 | "reasoning": reasoning
220 | }
221 |
222 | except Exception as e:
223 | self.log_error(f"处理输出失败: {str(e)}")
224 | # 返回默认查询
225 | return {
226 | "search_query": "深度研究补充信息",
227 | "reasoning": "由于解析失败,使用默认反思搜索查询"
228 | }
229 |
--------------------------------------------------------------------------------
/examples/streamlit_app.py:
--------------------------------------------------------------------------------
1 | """
2 | Streamlit Web界面
3 | 为Deep Search Agent提供友好的Web界面
4 | """
5 |
6 | import os
7 | import sys
8 | import streamlit as st
9 | from datetime import datetime
10 | import json
11 |
12 | # 添加src目录到Python路径
13 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
14 |
15 | from src import DeepSearchAgent, Config
16 |
17 |
18 | def main():
19 | """主函数"""
20 | st.set_page_config(
21 | page_title="Deep Search Agent",
22 | page_icon="🔍",
23 | layout="wide"
24 | )
25 |
26 | st.title("Deep Search Agent")
27 | st.markdown("基于DeepSeek的无框架深度搜索AI代理")
28 |
29 | # 侧边栏配置
30 | with st.sidebar:
31 | st.header("配置")
32 |
33 | # API密钥配置
34 | st.subheader("API密钥")
35 | deepseek_key = st.text_input("DeepSeek API Key", type="password",
36 | value="")
37 | tavily_key = st.text_input("Tavily API Key", type="password",
38 | value="")
39 |
40 | # 高级配置
41 | st.subheader("高级配置")
42 | max_reflections = st.slider("反思次数", 1, 5, 2)
43 | max_search_results = st.slider("搜索结果数", 1, 10, 3)
44 | max_content_length = st.number_input("最大内容长度", 1000, 50000, 20000)
45 |
46 | # 模型选择
47 | llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"])
48 |
49 | if llm_provider == "deepseek":
50 | model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"])
51 | else:
52 | model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"])
53 | openai_key = st.text_input("OpenAI API Key", type="password",
54 | value="")
55 |
56 | # 主界面
57 | col1, col2 = st.columns([2, 1])
58 |
59 | with col1:
60 | st.header("研究查询")
61 | query = st.text_area(
62 | "请输入您要研究的问题",
63 | placeholder="例如:2025年人工智能发展趋势",
64 | height=100
65 | )
66 |
67 | # 预设查询示例
68 | st.subheader("示例查询")
69 | example_queries = [
70 | "2025年人工智能发展趋势",
71 | "深度学习在医疗领域的应用",
72 | "区块链技术的最新发展",
73 | "可持续能源技术趋势",
74 | "量子计算的发展现状"
75 | ]
76 |
77 | selected_example = st.selectbox("选择示例查询", ["自定义"] + example_queries)
78 | if selected_example != "自定义":
79 | query = selected_example
80 |
81 | with col2:
82 | st.header("状态信息")
83 | if 'agent' in st.session_state and hasattr(st.session_state.agent, 'state'):
84 | progress = st.session_state.agent.get_progress_summary()
85 | st.metric("总段落数", progress['total_paragraphs'])
86 | st.metric("已完成", progress['completed_paragraphs'])
87 | st.progress(progress['progress_percentage'] / 100)
88 | else:
89 | st.info("尚未开始研究")
90 |
91 | # 执行按钮
92 | col1, col2, col3 = st.columns([1, 1, 1])
93 | with col2:
94 | start_research = st.button("开始研究", type="primary", use_container_width=True)
95 |
96 | # 验证配置
97 | if start_research:
98 | if not query.strip():
99 | st.error("请输入研究查询")
100 | return
101 |
102 | if not deepseek_key and llm_provider == "deepseek":
103 | st.error("请提供DeepSeek API Key")
104 | return
105 |
106 | if not tavily_key:
107 | st.error("请提供Tavily API Key")
108 | return
109 |
110 | if llm_provider == "openai" and not openai_key:
111 | st.error("请提供OpenAI API Key")
112 | return
113 |
114 | # 创建配置
115 | config = Config(
116 | deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None,
117 | openai_api_key=openai_key if llm_provider == "openai" else None,
118 | tavily_api_key=tavily_key,
119 | default_llm_provider=llm_provider,
120 | deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat",
121 | openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini",
122 | max_reflections=max_reflections,
123 | max_search_results=max_search_results,
124 | max_content_length=max_content_length,
125 | output_dir="streamlit_reports"
126 | )
127 |
128 | # 执行研究
129 | execute_research(query, config)
130 |
131 |
132 | def execute_research(query: str, config: Config):
133 | """执行研究"""
134 | try:
135 | # 创建进度条
136 | progress_bar = st.progress(0)
137 | status_text = st.empty()
138 |
139 | # 初始化Agent
140 | status_text.text("正在初始化Agent...")
141 | agent = DeepSearchAgent(config)
142 | st.session_state.agent = agent
143 |
144 | progress_bar.progress(10)
145 |
146 | # 生成报告结构
147 | status_text.text("正在生成报告结构...")
148 | agent._generate_report_structure(query)
149 | progress_bar.progress(20)
150 |
151 | # 处理段落
152 | total_paragraphs = len(agent.state.paragraphs)
153 | for i in range(total_paragraphs):
154 | status_text.text(f"正在处理段落 {i+1}/{total_paragraphs}: {agent.state.paragraphs[i].title}")
155 |
156 | # 初始搜索和总结
157 | agent._initial_search_and_summary(i)
158 | progress_value = 20 + (i + 0.5) / total_paragraphs * 60
159 | progress_bar.progress(int(progress_value))
160 |
161 | # 反思循环
162 | agent._reflection_loop(i)
163 | agent.state.paragraphs[i].research.mark_completed()
164 |
165 | progress_value = 20 + (i + 1) / total_paragraphs * 60
166 | progress_bar.progress(int(progress_value))
167 |
168 | # 生成最终报告
169 | status_text.text("正在生成最终报告...")
170 | final_report = agent._generate_final_report()
171 | progress_bar.progress(90)
172 |
173 | # 保存报告
174 | status_text.text("正在保存报告...")
175 | agent._save_report(final_report)
176 | progress_bar.progress(100)
177 |
178 | status_text.text("研究完成!")
179 |
180 | # 显示结果
181 | display_results(agent, final_report)
182 |
183 | except Exception as e:
184 | st.error(f"研究过程中发生错误: {str(e)}")
185 |
186 |
187 | def display_results(agent: DeepSearchAgent, final_report: str):
188 | """显示研究结果"""
189 | st.header("研究结果")
190 |
191 | # 结果标签页
192 | tab1, tab2, tab3 = st.tabs(["最终报告", "详细信息", "下载"])
193 |
194 | with tab1:
195 | st.markdown(final_report)
196 |
197 | with tab2:
198 | # 段落详情
199 | st.subheader("段落详情")
200 | for i, paragraph in enumerate(agent.state.paragraphs):
201 | with st.expander(f"段落 {i+1}: {paragraph.title}"):
202 | st.write("**预期内容:**", paragraph.content)
203 | st.write("**最终内容:**", paragraph.research.latest_summary[:300] + "..."
204 | if len(paragraph.research.latest_summary) > 300
205 | else paragraph.research.latest_summary)
206 | st.write("**搜索次数:**", paragraph.research.get_search_count())
207 | st.write("**反思次数:**", paragraph.research.reflection_iteration)
208 |
209 | # 搜索历史
210 | st.subheader("搜索历史")
211 | all_searches = []
212 | for paragraph in agent.state.paragraphs:
213 | all_searches.extend(paragraph.research.search_history)
214 |
215 | if all_searches:
216 | for i, search in enumerate(all_searches):
217 | with st.expander(f"搜索 {i+1}: {search.query}"):
218 | st.write("**URL:**", search.url)
219 | st.write("**标题:**", search.title)
220 | st.write("**内容预览:**", search.content[:200] + "..." if len(search.content) > 200 else search.content)
221 | if search.score:
222 | st.write("**相关度评分:**", search.score)
223 |
224 | with tab3:
225 | # 下载选项
226 | st.subheader("下载报告")
227 |
228 | # Markdown下载
229 | st.download_button(
230 | label="下载Markdown报告",
231 | data=final_report,
232 | file_name=f"deep_search_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md",
233 | mime="text/markdown"
234 | )
235 |
236 | # JSON状态下载
237 | state_json = agent.state.to_json()
238 | st.download_button(
239 | label="下载状态文件",
240 | data=state_json,
241 | file_name=f"deep_search_state_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
242 | mime="application/json"
243 | )
244 |
245 |
246 | if __name__ == "__main__":
247 | main()
248 |
--------------------------------------------------------------------------------
/src/nodes/summary_node.py:
--------------------------------------------------------------------------------
1 | """
2 | 总结节点实现
3 | 负责根据搜索结果生成和更新段落内容
4 | """
5 |
6 | import json
7 | from typing import Dict, Any, List
8 | from json.decoder import JSONDecodeError
9 |
10 | from .base_node import StateMutationNode
11 | from ..state.state import State
12 | from ..prompts import SYSTEM_PROMPT_FIRST_SUMMARY, SYSTEM_PROMPT_REFLECTION_SUMMARY
13 | from ..utils.text_processing import (
14 | remove_reasoning_from_output,
15 | clean_json_tags,
16 | extract_clean_response,
17 | format_search_results_for_prompt
18 | )
19 |
20 |
21 | class FirstSummaryNode(StateMutationNode):
22 | """根据搜索结果生成段落首次总结的节点"""
23 |
24 | def __init__(self, llm_client):
25 | """
26 | 初始化首次总结节点
27 |
28 | Args:
29 | llm_client: LLM客户端
30 | """
31 | super().__init__(llm_client, "FirstSummaryNode")
32 |
33 | def validate_input(self, input_data: Any) -> bool:
34 | """验证输入数据"""
35 | if isinstance(input_data, str):
36 | try:
37 | data = json.loads(input_data)
38 | required_fields = ["title", "content", "search_query", "search_results"]
39 | return all(field in data for field in required_fields)
40 | except JSONDecodeError:
41 | return False
42 | elif isinstance(input_data, dict):
43 | required_fields = ["title", "content", "search_query", "search_results"]
44 | return all(field in input_data for field in required_fields)
45 | return False
46 |
47 | def run(self, input_data: Any, **kwargs) -> str:
48 | """
49 | 调用LLM生成段落总结
50 |
51 | Args:
52 | input_data: 包含title、content、search_query和search_results的数据
53 | **kwargs: 额外参数
54 |
55 | Returns:
56 | 段落总结内容
57 | """
58 | try:
59 | if not self.validate_input(input_data):
60 | raise ValueError("输入数据格式错误")
61 |
62 | # 准备输入数据
63 | if isinstance(input_data, str):
64 | message = input_data
65 | else:
66 | message = json.dumps(input_data, ensure_ascii=False)
67 |
68 | self.log_info("正在生成首次段落总结")
69 |
70 | # 调用LLM
71 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
72 |
73 | # 处理响应
74 | processed_response = self.process_output(response)
75 |
76 | self.log_info("成功生成首次段落总结")
77 | return processed_response
78 |
79 | except Exception as e:
80 | self.log_error(f"生成首次总结失败: {str(e)}")
81 | raise e
82 |
83 | def process_output(self, output: str) -> str:
84 | """
85 | 处理LLM输出,提取段落总结
86 |
87 | Args:
88 | output: LLM原始输出
89 |
90 | Returns:
91 | 段落总结内容
92 | """
93 | try:
94 | # 清理响应文本
95 | cleaned_output = remove_reasoning_from_output(output)
96 | cleaned_output = clean_json_tags(cleaned_output)
97 |
98 | # 解析JSON
99 | try:
100 | result = json.loads(cleaned_output)
101 | except JSONDecodeError:
102 | # 如果不是JSON格式,直接返回清理后的文本
103 | return cleaned_output
104 |
105 | # 提取段落内容
106 | if isinstance(result, dict):
107 | paragraph_content = result.get("paragraph_latest_state", "")
108 | if paragraph_content:
109 | return paragraph_content
110 |
111 | # 如果提取失败,返回原始清理后的文本
112 | return cleaned_output
113 |
114 | except Exception as e:
115 | self.log_error(f"处理输出失败: {str(e)}")
116 | return "段落总结生成失败"
117 |
118 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
119 | """
120 | 更新段落的最新总结到状态
121 |
122 | Args:
123 | input_data: 输入数据
124 | state: 当前状态
125 | paragraph_index: 段落索引
126 | **kwargs: 额外参数
127 |
128 | Returns:
129 | 更新后的状态
130 | """
131 | try:
132 | # 生成总结
133 | summary = self.run(input_data, **kwargs)
134 |
135 | # 更新状态
136 | if 0 <= paragraph_index < len(state.paragraphs):
137 | state.paragraphs[paragraph_index].research.latest_summary = summary
138 | self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
139 | else:
140 | raise ValueError(f"段落索引 {paragraph_index} 超出范围")
141 |
142 | state.update_timestamp()
143 | return state
144 |
145 | except Exception as e:
146 | self.log_error(f"状态更新失败: {str(e)}")
147 | raise e
148 |
149 |
150 | class ReflectionSummaryNode(StateMutationNode):
151 | """根据反思搜索结果更新段落总结的节点"""
152 |
153 | def __init__(self, llm_client):
154 | """
155 | 初始化反思总结节点
156 |
157 | Args:
158 | llm_client: LLM客户端
159 | """
160 | super().__init__(llm_client, "ReflectionSummaryNode")
161 |
162 | def validate_input(self, input_data: Any) -> bool:
163 | """验证输入数据"""
164 | if isinstance(input_data, str):
165 | try:
166 | data = json.loads(input_data)
167 | required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"]
168 | return all(field in data for field in required_fields)
169 | except JSONDecodeError:
170 | return False
171 | elif isinstance(input_data, dict):
172 | required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"]
173 | return all(field in input_data for field in required_fields)
174 | return False
175 |
176 | def run(self, input_data: Any, **kwargs) -> str:
177 | """
178 | 调用LLM更新段落内容
179 |
180 | Args:
181 | input_data: 包含完整反思信息的数据
182 | **kwargs: 额外参数
183 |
184 | Returns:
185 | 更新后的段落内容
186 | """
187 | try:
188 | if not self.validate_input(input_data):
189 | raise ValueError("输入数据格式错误")
190 |
191 | # 准备输入数据
192 | if isinstance(input_data, str):
193 | message = input_data
194 | else:
195 | message = json.dumps(input_data, ensure_ascii=False)
196 |
197 | self.log_info("正在生成反思总结")
198 |
199 | # 调用LLM
200 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
201 |
202 | # 处理响应
203 | processed_response = self.process_output(response)
204 |
205 | self.log_info("成功生成反思总结")
206 | return processed_response
207 |
208 | except Exception as e:
209 | self.log_error(f"生成反思总结失败: {str(e)}")
210 | raise e
211 |
212 | def process_output(self, output: str) -> str:
213 | """
214 | 处理LLM输出,提取更新后的段落内容
215 |
216 | Args:
217 | output: LLM原始输出
218 |
219 | Returns:
220 | 更新后的段落内容
221 | """
222 | try:
223 | # 清理响应文本
224 | cleaned_output = remove_reasoning_from_output(output)
225 | cleaned_output = clean_json_tags(cleaned_output)
226 |
227 | # 解析JSON
228 | try:
229 | result = json.loads(cleaned_output)
230 | except JSONDecodeError:
231 | # 如果不是JSON格式,直接返回清理后的文本
232 | return cleaned_output
233 |
234 | # 提取更新后的段落内容
235 | if isinstance(result, dict):
236 | updated_content = result.get("updated_paragraph_latest_state", "")
237 | if updated_content:
238 | return updated_content
239 |
240 | # 如果提取失败,返回原始清理后的文本
241 | return cleaned_output
242 |
243 | except Exception as e:
244 | self.log_error(f"处理输出失败: {str(e)}")
245 | return "反思总结生成失败"
246 |
247 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
248 | """
249 | 将更新后的总结写入状态
250 |
251 | Args:
252 | input_data: 输入数据
253 | state: 当前状态
254 | paragraph_index: 段落索引
255 | **kwargs: 额外参数
256 |
257 | Returns:
258 | 更新后的状态
259 | """
260 | try:
261 | # 生成更新后的总结
262 | updated_summary = self.run(input_data, **kwargs)
263 |
264 | # 更新状态
265 | if 0 <= paragraph_index < len(state.paragraphs):
266 | state.paragraphs[paragraph_index].research.latest_summary = updated_summary
267 | state.paragraphs[paragraph_index].research.increment_reflection()
268 | self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
269 | else:
270 | raise ValueError(f"段落索引 {paragraph_index} 超出范围")
271 |
272 | state.update_timestamp()
273 | return state
274 |
275 | except Exception as e:
276 | self.log_error(f"状态更新失败: {str(e)}")
277 | raise e
278 |
--------------------------------------------------------------------------------
/src/state/state.py:
--------------------------------------------------------------------------------
1 | """
2 | Deep Search Agent状态管理
3 | 定义所有状态数据结构和操作方法
4 | """
5 |
6 | from dataclasses import dataclass, field
7 | from typing import List, Dict, Any, Optional
8 | import json
9 | from datetime import datetime
10 |
11 |
12 | @dataclass
13 | class Search:
14 | """单个搜索结果的状态"""
15 | query: str = "" # 搜索查询
16 | url: str = "" # 搜索结果的链接
17 | title: str = "" # 搜索结果标题
18 | content: str = "" # 搜索返回的内容
19 | score: Optional[float] = None # 相关度评分
20 | timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
21 |
22 | def to_dict(self) -> Dict[str, Any]:
23 | """转换为字典格式"""
24 | return {
25 | "query": self.query,
26 | "url": self.url,
27 | "title": self.title,
28 | "content": self.content,
29 | "score": self.score,
30 | "timestamp": self.timestamp
31 | }
32 |
33 | @classmethod
34 | def from_dict(cls, data: Dict[str, Any]) -> "Search":
35 | """从字典创建Search对象"""
36 | return cls(
37 | query=data.get("query", ""),
38 | url=data.get("url", ""),
39 | title=data.get("title", ""),
40 | content=data.get("content", ""),
41 | score=data.get("score"),
42 | timestamp=data.get("timestamp", datetime.now().isoformat())
43 | )
44 |
45 |
46 | @dataclass
47 | class Research:
48 | """段落研究过程的状态"""
49 | search_history: List[Search] = field(default_factory=list) # 搜索记录列表
50 | latest_summary: str = "" # 当前段落的最新总结
51 | reflection_iteration: int = 0 # 反思迭代次数
52 | is_completed: bool = False # 是否完成研究
53 |
54 | def add_search(self, search: Search):
55 | """添加搜索记录"""
56 | self.search_history.append(search)
57 |
58 | def add_search_results(self, query: str, results: List[Dict[str, Any]]):
59 | """批量添加搜索结果"""
60 | for result in results:
61 | search = Search(
62 | query=query,
63 | url=result.get("url", ""),
64 | title=result.get("title", ""),
65 | content=result.get("content", ""),
66 | score=result.get("score")
67 | )
68 | self.add_search(search)
69 |
70 | def get_search_count(self) -> int:
71 | """获取搜索次数"""
72 | return len(self.search_history)
73 |
74 | def increment_reflection(self):
75 | """增加反思次数"""
76 | self.reflection_iteration += 1
77 |
78 | def mark_completed(self):
79 | """标记为完成"""
80 | self.is_completed = True
81 |
82 | def to_dict(self) -> Dict[str, Any]:
83 | """转换为字典格式"""
84 | return {
85 | "search_history": [search.to_dict() for search in self.search_history],
86 | "latest_summary": self.latest_summary,
87 | "reflection_iteration": self.reflection_iteration,
88 | "is_completed": self.is_completed
89 | }
90 |
91 | @classmethod
92 | def from_dict(cls, data: Dict[str, Any]) -> "Research":
93 | """从字典创建Research对象"""
94 | search_history = [Search.from_dict(search_data) for search_data in data.get("search_history", [])]
95 | return cls(
96 | search_history=search_history,
97 | latest_summary=data.get("latest_summary", ""),
98 | reflection_iteration=data.get("reflection_iteration", 0),
99 | is_completed=data.get("is_completed", False)
100 | )
101 |
102 |
103 | @dataclass
104 | class Paragraph:
105 | """报告中单个段落的状态"""
106 | title: str = "" # 段落标题
107 | content: str = "" # 段落的预期内容(初始规划)
108 | research: Research = field(default_factory=Research) # 研究进度
109 | order: int = 0 # 段落顺序
110 |
111 | def is_completed(self) -> bool:
112 | """检查段落是否完成"""
113 | return self.research.is_completed and bool(self.research.latest_summary)
114 |
115 | def get_final_content(self) -> str:
116 | """获取最终内容"""
117 | return self.research.latest_summary or self.content
118 |
119 | def to_dict(self) -> Dict[str, Any]:
120 | """转换为字典格式"""
121 | return {
122 | "title": self.title,
123 | "content": self.content,
124 | "research": self.research.to_dict(),
125 | "order": self.order
126 | }
127 |
128 | @classmethod
129 | def from_dict(cls, data: Dict[str, Any]) -> "Paragraph":
130 | """从字典创建Paragraph对象"""
131 | research_data = data.get("research", {})
132 | research = Research.from_dict(research_data) if research_data else Research()
133 |
134 | return cls(
135 | title=data.get("title", ""),
136 | content=data.get("content", ""),
137 | research=research,
138 | order=data.get("order", 0)
139 | )
140 |
141 |
142 | @dataclass
143 | class State:
144 | """整个报告的状态"""
145 | query: str = "" # 原始查询
146 | report_title: str = "" # 报告标题
147 | paragraphs: List[Paragraph] = field(default_factory=list) # 段落列表
148 | final_report: str = "" # 最终报告内容
149 | is_completed: bool = False # 是否完成
150 | created_at: str = field(default_factory=lambda: datetime.now().isoformat())
151 | updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
152 |
153 | def add_paragraph(self, title: str, content: str) -> int:
154 | """
155 | 添加段落
156 |
157 | Args:
158 | title: 段落标题
159 | content: 段落内容
160 |
161 | Returns:
162 | 段落索引
163 | """
164 | order = len(self.paragraphs)
165 | paragraph = Paragraph(title=title, content=content, order=order)
166 | self.paragraphs.append(paragraph)
167 | self.update_timestamp()
168 | return order
169 |
170 | def get_paragraph(self, index: int) -> Optional[Paragraph]:
171 | """获取指定索引的段落"""
172 | if 0 <= index < len(self.paragraphs):
173 | return self.paragraphs[index]
174 | return None
175 |
176 | def get_completed_paragraphs_count(self) -> int:
177 | """获取已完成段落数量"""
178 | return sum(1 for p in self.paragraphs if p.is_completed())
179 |
180 | def get_total_paragraphs_count(self) -> int:
181 | """获取总段落数量"""
182 | return len(self.paragraphs)
183 |
184 | def is_all_paragraphs_completed(self) -> bool:
185 | """检查是否所有段落都完成"""
186 | return all(p.is_completed() for p in self.paragraphs) if self.paragraphs else False
187 |
188 | def mark_completed(self):
189 | """标记整个报告为完成"""
190 | self.is_completed = True
191 | self.update_timestamp()
192 |
193 | def update_timestamp(self):
194 | """更新时间戳"""
195 | self.updated_at = datetime.now().isoformat()
196 |
197 | def get_progress_summary(self) -> Dict[str, Any]:
198 | """获取进度摘要"""
199 | completed = self.get_completed_paragraphs_count()
200 | total = self.get_total_paragraphs_count()
201 |
202 | return {
203 | "total_paragraphs": total,
204 | "completed_paragraphs": completed,
205 | "progress_percentage": (completed / total * 100) if total > 0 else 0,
206 | "is_completed": self.is_completed,
207 | "created_at": self.created_at,
208 | "updated_at": self.updated_at
209 | }
210 |
211 | def to_dict(self) -> Dict[str, Any]:
212 | """转换为字典格式"""
213 | return {
214 | "query": self.query,
215 | "report_title": self.report_title,
216 | "paragraphs": [p.to_dict() for p in self.paragraphs],
217 | "final_report": self.final_report,
218 | "is_completed": self.is_completed,
219 | "created_at": self.created_at,
220 | "updated_at": self.updated_at
221 | }
222 |
223 | def to_json(self, indent: int = 2) -> str:
224 | """转换为JSON字符串"""
225 | return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
226 |
227 | @classmethod
228 | def from_dict(cls, data: Dict[str, Any]) -> "State":
229 | """从字典创建State对象"""
230 | paragraphs = [Paragraph.from_dict(p_data) for p_data in data.get("paragraphs", [])]
231 |
232 | return cls(
233 | query=data.get("query", ""),
234 | report_title=data.get("report_title", ""),
235 | paragraphs=paragraphs,
236 | final_report=data.get("final_report", ""),
237 | is_completed=data.get("is_completed", False),
238 | created_at=data.get("created_at", datetime.now().isoformat()),
239 | updated_at=data.get("updated_at", datetime.now().isoformat())
240 | )
241 |
242 | @classmethod
243 | def from_json(cls, json_str: str) -> "State":
244 | """从JSON字符串创建State对象"""
245 | data = json.loads(json_str)
246 | return cls.from_dict(data)
247 |
248 | def save_to_file(self, filepath: str):
249 | """保存状态到文件"""
250 | with open(filepath, 'w', encoding='utf-8') as f:
251 | f.write(self.to_json())
252 |
253 | @classmethod
254 | def load_from_file(cls, filepath: str) -> "State":
255 | """从文件加载状态"""
256 | with open(filepath, 'r', encoding='utf-8') as f:
257 | json_str = f.read()
258 | return cls.from_json(json_str)
259 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Search Agent
2 |
3 | [](https://python.org)
4 | [](LICENSE)
5 | [](https://platform.deepseek.com/)
6 | [](https://tavily.com/)
7 |
8 | 一个**无框架**的深度搜索AI代理实现,能够通过多轮搜索和反思生成高质量的研究报告。
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## 特性
17 |
18 | - **无框架设计**: 从零实现,不依赖LangChain等重型框架
19 | - **多LLM支持**: 支持DeepSeek、OpenAI等主流大语言模型
20 | - **智能搜索**: 集成Tavily搜索引擎,提供高质量网络搜索
21 | - **反思机制**: 多轮反思优化,确保研究深度和完整性
22 | - **状态管理**: 完整的研究过程状态跟踪和恢复
23 | - **Web界面**: Streamlit友好界面,易于使用
24 | - **Markdown输出**: 美观的Markdown格式研究报告
25 |
26 | ## 工作原理
27 |
28 | Deep Search Agent采用分阶段的研究方法:
29 |
30 | ```mermaid
31 | graph TD
32 | A[用户查询] --> B[生成报告结构]
33 | B --> C[遍历每个段落]
34 | C --> D[初始搜索]
35 | D --> E[生成初始总结]
36 | E --> F[反思循环]
37 | F --> G[反思搜索]
38 | G --> H[更新总结]
39 | H --> I{达到反思次数?}
40 | I -->|否| F
41 | I -->|是| J{所有段落完成?}
42 | J -->|否| C
43 | J -->|是| K[格式化最终报告]
44 | K --> L[输出报告]
45 | ```
46 |
47 | ### 核心流程
48 |
49 | 1. **结构生成**: 根据查询生成报告大纲和段落结构
50 | 2. **初始研究**: 为每个段落生成搜索查询并获取相关信息
51 | 3. **初始总结**: 基于搜索结果生成段落初稿
52 | 4. **反思优化**: 多轮反思,发现遗漏并补充搜索
53 | 5. **最终整合**: 将所有段落整合为完整的Markdown报告
54 |
55 | ## 快速开始
56 |
57 | ### 1. 环境准备
58 |
59 | 确保您的系统安装了Python 3.9或更高版本:
60 |
61 | ```bash
62 | python --version
63 | ```
64 |
65 | ### 2. 克隆项目
66 |
67 | ```bash
68 | git clone
69 | cd Demo\ DeepSearch\ Agent
70 | ```
71 |
72 | ### 3. 安装依赖
73 |
74 | ```bash
75 | # 激活虚拟环境(推荐)
76 | conda activate pytorch_python11 # 或者使用其他虚拟环境
77 |
78 | # 安装依赖
79 | pip install -r requirements.txt
80 | ```
81 |
82 | ### 4. 配置API密钥
83 |
84 | 项目根目录下已有`config.py`配置文件,请直接编辑此文件设置您的API密钥:
85 |
86 | ```python
87 | # Deep Search Agent 配置文件
88 | # 请在这里填入您的API密钥
89 |
90 | # DeepSeek API Key
91 | DEEPSEEK_API_KEY = "your_deepseek_api_key_here"
92 |
93 | # OpenAI API Key (可选)
94 | OPENAI_API_KEY = "your_openai_api_key_here"
95 |
96 | # Tavily搜索API Key
97 | TAVILY_API_KEY = "your_tavily_api_key_here"
98 |
99 | # 配置参数
100 | DEFAULT_LLM_PROVIDER = "deepseek"
101 | DEEPSEEK_MODEL = "deepseek-chat"
102 | OPENAI_MODEL = "gpt-4o-mini"
103 |
104 | MAX_REFLECTIONS = 2
105 | SEARCH_RESULTS_PER_QUERY = 3
106 | SEARCH_CONTENT_MAX_LENGTH = 20000
107 | OUTPUT_DIR = "reports"
108 | SAVE_INTERMEDIATE_STATES = True
109 | ```
110 |
111 | ### 5. 开始使用
112 |
113 | 现在您可以开始使用Deep Search Agent了!
114 |
115 | ## 使用方法
116 |
117 | ### 方式一:运行示例脚本
118 |
119 | **基本使用示例**:
120 | ```bash
121 | python examples/basic_usage.py
122 | ```
123 | 这个示例展示了最简单的使用方式,执行一个预设的研究查询并显示结果。
124 |
125 | **高级使用示例**:
126 | ```bash
127 | python examples/advanced_usage.py
128 | ```
129 | 这个示例展示了更复杂的使用场景,包括:
130 | - 自定义配置参数
131 | - 执行多个研究任务
132 | - 状态管理和恢复
133 | - 不同模型的使用
134 |
135 | ### 方式二:Web界面
136 |
137 | 启动Streamlit Web界面:
138 | ```bash
139 | streamlit run examples/streamlit_app.py
140 | ```
141 | Web界面无需配置文件,直接在界面中输入API密钥即可使用。
142 |
143 | ### 方式三:编程方式
144 |
145 | ```python
146 | from src import DeepSearchAgent, load_config
147 |
148 | # 加载配置
149 | config = load_config()
150 |
151 | # 创建Agent
152 | agent = DeepSearchAgent(config)
153 |
154 | # 执行研究
155 | query = "2025年人工智能发展趋势"
156 | final_report = agent.research(query, save_report=True)
157 |
158 | print(final_report)
159 | ```
160 |
161 | ### 方式四:自定义配置(编程方式)
162 |
163 | 如果需要在代码中动态设置配置,可以使用以下方式:
164 |
165 | ```python
166 | from src import DeepSearchAgent, Config
167 |
168 | # 自定义配置
169 | config = Config(
170 | default_llm_provider="deepseek",
171 | deepseek_model="deepseek-chat",
172 | max_reflections=3, # 增加反思次数
173 | max_search_results=5, # 增加搜索结果数
174 | output_dir="my_reports" # 自定义输出目录
175 | )
176 |
177 | # 设置API密钥
178 | config.deepseek_api_key = "your_api_key"
179 | config.tavily_api_key = "your_tavily_key"
180 |
181 | agent = DeepSearchAgent(config)
182 | ```
183 |
184 | ## 项目结构
185 |
186 | ```
187 | Demo DeepSearch Agent/
188 | ├── src/ # 核心代码
189 | │ ├── llms/ # LLM调用模块
190 | │ │ ├── base.py # LLM基类
191 | │ │ ├── deepseek.py # DeepSeek实现
192 | │ │ └── openai_llm.py # OpenAI实现
193 | │ ├── nodes/ # 处理节点
194 | │ │ ├── base_node.py # 节点基类
195 | │ │ ├── report_structure_node.py # 结构生成
196 | │ │ ├── search_node.py # 搜索节点
197 | │ │ ├── summary_node.py # 总结节点
198 | │ │ └── formatting_node.py # 格式化节点
199 | │ ├── prompts/ # 提示词模块
200 | │ │ └── prompts.py # 所有提示词定义
201 | │ ├── state/ # 状态管理
202 | │ │ └── state.py # 状态数据结构
203 | │ ├── tools/ # 工具调用
204 | │ │ └── search.py # 搜索工具
205 | │ ├── utils/ # 工具函数
206 | │ │ ├── config.py # 配置管理
207 | │ │ └── text_processing.py # 文本处理
208 | │ └── agent.py # 主Agent类
209 | ├── examples/ # 使用示例
210 | │ ├── basic_usage.py # 基本使用示例
211 | │ ├── advanced_usage.py # 高级使用示例
212 | │ └── streamlit_app.py # Web界面
213 | ├── reports/ # 输出报告目录
214 | ├── requirements.txt # 依赖列表
215 | ├── config.py # 配置文件
216 | └── README.md # 项目文档
217 | ```
218 |
219 | ## 代码结构
220 |
221 | ```mermaid
222 | graph TB
223 | subgraph "用户层"
224 | A[用户查询]
225 | B[Web界面]
226 | C[命令行接口]
227 | end
228 |
229 | subgraph "主控制层"
230 | D[DeepSearchAgent]
231 | end
232 |
233 | subgraph "处理节点层"
234 | E[ReportStructureNode
报告结构生成]
235 | F[FirstSearchNode
初始搜索]
236 | G[FirstSummaryNode
初始总结]
237 | H[ReflectionNode
反思搜索]
238 | I[ReflectionSummaryNode
反思总结]
239 | J[ReportFormattingNode
报告格式化]
240 | end
241 |
242 | subgraph "LLM层"
243 | K[DeepSeekLLM]
244 | L[OpenAILLM]
245 | M[BaseLLM抽象类]
246 | end
247 |
248 | subgraph "工具层"
249 | N[Tavily搜索]
250 | O[文本处理工具]
251 | P[配置管理]
252 | end
253 |
254 | subgraph "状态管理层"
255 | Q[State状态对象]
256 | R[Paragraph段落对象]
257 | S[Research研究对象]
258 | T[Search搜索记录]
259 | end
260 |
261 | subgraph "数据持久化"
262 | U[JSON状态文件]
263 | V[Markdown报告]
264 | W[日志文件]
265 | end
266 |
267 | A --> D
268 | B --> D
269 | C --> D
270 |
271 | D --> E
272 | D --> F
273 | D --> G
274 | D --> H
275 | D --> I
276 | D --> J
277 |
278 | E --> K
279 | E --> L
280 | F --> K
281 | F --> L
282 | G --> K
283 | G --> L
284 | H --> K
285 | H --> L
286 | I --> K
287 | I --> L
288 | J --> K
289 | J --> L
290 |
291 | K --> M
292 | L --> M
293 |
294 | F --> N
295 | H --> N
296 |
297 | D --> O
298 | D --> P
299 |
300 | D --> Q
301 | Q --> R
302 | R --> S
303 | S --> T
304 |
305 | Q --> U
306 | D --> V
307 | D --> W
308 |
309 | style A fill:#e1f5fe
310 | style D fill:#f3e5f5
311 | style E fill:#fff3e0
312 | style F fill:#fff3e0
313 | style G fill:#fff3e0
314 | style H fill:#fff3e0
315 | style I fill:#fff3e0
316 | style J fill:#fff3e0
317 | style K fill:#e8f5e8
318 | style L fill:#e8f5e8
319 | style N fill:#fce4ec
320 | style Q fill:#f1f8e9
321 | ```
322 |
323 | ## API 参考
324 |
325 | ### DeepSearchAgent
326 |
327 | 主要的Agent类,提供完整的深度搜索功能。
328 |
329 | ```python
330 | class DeepSearchAgent:
331 | def __init__(self, config: Optional[Config] = None)
332 | def research(self, query: str, save_report: bool = True) -> str
333 | def get_progress_summary(self) -> Dict[str, Any]
334 | def load_state(self, filepath: str)
335 | def save_state(self, filepath: str)
336 | ```
337 |
338 | ### Config
339 |
340 | 配置管理类,控制Agent的行为参数。
341 |
342 | ```python
343 | class Config:
344 | # API密钥
345 | deepseek_api_key: Optional[str]
346 | openai_api_key: Optional[str]
347 | tavily_api_key: Optional[str]
348 |
349 | # 模型配置
350 | default_llm_provider: str = "deepseek"
351 | deepseek_model: str = "deepseek-chat"
352 | openai_model: str = "gpt-4o-mini"
353 |
354 | # 搜索配置
355 | max_search_results: int = 3
356 | search_timeout: int = 240
357 | max_content_length: int = 20000
358 |
359 | # Agent配置
360 | max_reflections: int = 2
361 | max_paragraphs: int = 5
362 | ```
363 |
364 | ## 示例
365 |
366 | ### 示例1:基本研究
367 |
368 | ```python
369 | from src import create_agent
370 |
371 | # 快速创建Agent
372 | agent = create_agent()
373 |
374 | # 执行研究
375 | report = agent.research("量子计算的发展现状")
376 | print(report)
377 | ```
378 |
379 | ### 示例2:自定义研究参数
380 |
381 | ```python
382 | from src import DeepSearchAgent, Config
383 |
384 | config = Config(
385 | max_reflections=4, # 更深度的反思
386 | max_search_results=8, # 更多搜索结果
387 | max_paragraphs=6 # 更长的报告
388 | )
389 |
390 | agent = DeepSearchAgent(config)
391 | report = agent.research("人工智能的伦理问题")
392 | ```
393 |
394 | ### 示例3:状态管理
395 |
396 | ```python
397 | # 开始研究
398 | agent = DeepSearchAgent()
399 | report = agent.research("区块链技术应用")
400 |
401 | # 保存状态
402 | agent.save_state("blockchain_research.json")
403 |
404 | # 稍后恢复状态
405 | new_agent = DeepSearchAgent()
406 | new_agent.load_state("blockchain_research.json")
407 |
408 | # 检查进度
409 | progress = new_agent.get_progress_summary()
410 | print(f"研究进度: {progress['progress_percentage']}%")
411 | ```
412 |
413 | ## 高级功能
414 |
415 | ### 多模型支持
416 |
417 | ```python
418 | # 使用DeepSeek
419 | config = Config(default_llm_provider="deepseek")
420 |
421 | # 使用OpenAI
422 | config = Config(default_llm_provider="openai", openai_model="gpt-4o")
423 | ```
424 |
425 | ### 自定义输出
426 |
427 | ```python
428 | config = Config(
429 | output_dir="custom_reports", # 自定义输出目录
430 | save_intermediate_states=True # 保存中间状态
431 | )
432 | ```
433 |
434 | ## 常见问题
435 |
436 | ### Q: 支持哪些LLM?
437 |
438 | A: 目前支持:
439 | - **DeepSeek**: 推荐使用,性价比高
440 | - **OpenAI**: GPT-4o、GPT-4o-mini等
441 | - 可以通过继承`BaseLLM`类轻松添加其他模型
442 |
443 | ### Q: 如何获取API密钥?
444 |
445 | A:
446 | - **DeepSeek**: 访问 [DeepSeek平台](https://platform.deepseek.com/) 注册获取
447 | - **Tavily**: 访问 [Tavily](https://tavily.com/) 注册获取(每月1000次免费)
448 | - **OpenAI**: 访问 [OpenAI平台](https://platform.openai.com/) 获取
449 |
450 | 获取密钥后,直接编辑项目根目录的`config.py`文件填入即可。
451 |
452 | ### Q: 研究报告质量如何提升?
453 |
454 | A: 可以通过以下方式优化:
455 | - 增加`max_reflections`参数(更多反思轮次)
456 | - 增加`max_search_results`参数(更多搜索结果)
457 | - 调整`max_content_length`参数(更长的搜索内容)
458 | - 使用更强大的LLM模型
459 |
460 | ### Q: 如何自定义提示词?
461 |
462 | A: 修改`src/prompts/prompts.py`文件中的系统提示词,可以根据需要调整Agent的行为。
463 |
464 | ### Q: 支持其他搜索引擎吗?
465 |
466 | A: 当前主要支持Tavily,但可以通过修改`src/tools/search.py`添加其他搜索引擎支持。
467 |
468 | ## 贡献
469 |
470 | 欢迎贡献代码!请遵循以下步骤:
471 |
472 | 1. Fork本项目
473 | 2. 创建特性分支 (`git checkout -b feature/AmazingFeature`)
474 | 3. 提交更改 (`git commit -m 'Add some AmazingFeature'`)
475 | 4. 推送到分支 (`git push origin feature/AmazingFeature`)
476 | 5. 开启Pull Request
477 |
478 | ## 许可证
479 |
480 | 本项目采用MIT许可证 - 查看 [LICENSE](LICENSE) 文件了解详情。
481 |
482 | ## 致谢
483 |
484 | - 感谢 [DeepSeek](https://www.deepseek.com/) 提供优秀的LLM服务
485 | - 感谢 [Tavily](https://tavily.com/) 提供高质量的搜索API
486 |
487 | ---
488 |
489 | 如果这个项目对您有帮助,请给个Star!
490 |
--------------------------------------------------------------------------------
/src/agent.py:
--------------------------------------------------------------------------------
1 | """
2 | Deep Search Agent主类
3 | 整合所有模块,实现完整的深度搜索流程
4 | """
5 |
6 | import json
7 | import os
8 | from datetime import datetime
9 | from typing import Optional, Dict, Any, List
10 |
11 | from .llms import DeepSeekLLM, OpenAILLM, BaseLLM
12 | from .nodes import (
13 | ReportStructureNode,
14 | FirstSearchNode,
15 | ReflectionNode,
16 | FirstSummaryNode,
17 | ReflectionSummaryNode,
18 | ReportFormattingNode
19 | )
20 | from .state import State
21 | from .tools import tavily_search
22 | from .utils import Config, load_config, format_search_results_for_prompt
23 |
24 |
25 | class DeepSearchAgent:
26 | """Deep Search Agent主类"""
27 |
28 | def __init__(self, config: Optional[Config] = None):
29 | """
30 | 初始化Deep Search Agent
31 |
32 | Args:
33 | config: 配置对象,如果不提供则自动加载
34 | """
35 | # 加载配置
36 | self.config = config or load_config()
37 |
38 | # 初始化LLM客户端
39 | self.llm_client = self._initialize_llm()
40 |
41 | # 初始化节点
42 | self._initialize_nodes()
43 |
44 | # 状态
45 | self.state = State()
46 |
47 | # 确保输出目录存在
48 | os.makedirs(self.config.output_dir, exist_ok=True)
49 |
50 | print(f"Deep Search Agent 已初始化")
51 | print(f"使用LLM: {self.llm_client.get_model_info()}")
52 |
53 | def _initialize_llm(self) -> BaseLLM:
54 | """初始化LLM客户端"""
55 | if self.config.default_llm_provider == "deepseek":
56 | return DeepSeekLLM(
57 | api_key=self.config.deepseek_api_key,
58 | model_name=self.config.deepseek_model
59 | )
60 | elif self.config.default_llm_provider == "openai":
61 | return OpenAILLM(
62 | api_key=self.config.openai_api_key,
63 | model_name=self.config.openai_model
64 | )
65 | else:
66 | raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}")
67 |
68 | def _initialize_nodes(self):
69 | """初始化处理节点"""
70 | self.first_search_node = FirstSearchNode(self.llm_client)
71 | self.reflection_node = ReflectionNode(self.llm_client)
72 | self.first_summary_node = FirstSummaryNode(self.llm_client)
73 | self.reflection_summary_node = ReflectionSummaryNode(self.llm_client)
74 | self.report_formatting_node = ReportFormattingNode(self.llm_client)
75 |
76 | def research(self, query: str, save_report: bool = True) -> str:
77 | """
78 | 执行深度研究
79 |
80 | Args:
81 | query: 研究查询
82 | save_report: 是否保存报告到文件
83 |
84 | Returns:
85 | 最终报告内容
86 | """
87 | print(f"\n{'='*60}")
88 | print(f"开始深度研究: {query}")
89 | print(f"{'='*60}")
90 |
91 | try:
92 | # Step 1: 生成报告结构
93 | self._generate_report_structure(query)
94 |
95 | # Step 2: 处理每个段落
96 | self._process_paragraphs()
97 |
98 | # Step 3: 生成最终报告
99 | final_report = self._generate_final_report()
100 |
101 | # Step 4: 保存报告
102 | if save_report:
103 | self._save_report(final_report)
104 |
105 | print(f"\n{'='*60}")
106 | print("深度研究完成!")
107 | print(f"{'='*60}")
108 |
109 | return final_report
110 |
111 | except Exception as e:
112 | print(f"研究过程中发生错误: {str(e)}")
113 | raise e
114 |
115 | def _generate_report_structure(self, query: str):
116 | """生成报告结构"""
117 | print(f"\n[步骤 1] 生成报告结构...")
118 |
119 | # 创建报告结构节点
120 | report_structure_node = ReportStructureNode(self.llm_client, query)
121 |
122 | # 生成结构并更新状态
123 | self.state = report_structure_node.mutate_state(state=self.state)
124 |
125 | print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
126 | for i, paragraph in enumerate(self.state.paragraphs, 1):
127 | print(f" {i}. {paragraph.title}")
128 |
129 | def _process_paragraphs(self):
130 | """处理所有段落"""
131 | total_paragraphs = len(self.state.paragraphs)
132 |
133 | for i in range(total_paragraphs):
134 | print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
135 | print("-" * 50)
136 |
137 | # 初始搜索和总结
138 | self._initial_search_and_summary(i)
139 |
140 | # 反思循环
141 | self._reflection_loop(i)
142 |
143 | # 标记段落完成
144 | self.state.paragraphs[i].research.mark_completed()
145 |
146 | progress = (i + 1) / total_paragraphs * 100
147 | print(f"段落处理完成 ({progress:.1f}%)")
148 |
149 | def _initial_search_and_summary(self, paragraph_index: int):
150 | """执行初始搜索和总结"""
151 | paragraph = self.state.paragraphs[paragraph_index]
152 |
153 | # 准备搜索输入
154 | search_input = {
155 | "title": paragraph.title,
156 | "content": paragraph.content
157 | }
158 |
159 | # 生成搜索查询
160 | print(" - 生成搜索查询...")
161 | search_output = self.first_search_node.run(search_input)
162 | search_query = search_output["search_query"]
163 | reasoning = search_output["reasoning"]
164 |
165 | print(f" - 搜索查询: {search_query}")
166 | print(f" - 推理: {reasoning}")
167 |
168 | # 执行搜索
169 | print(" - 执行网络搜索...")
170 | search_results = tavily_search(
171 | search_query,
172 | max_results=self.config.max_search_results,
173 | timeout=self.config.search_timeout,
174 | api_key=self.config.tavily_api_key
175 | )
176 |
177 | if search_results:
178 | print(f" - 找到 {len(search_results)} 个搜索结果")
179 | for j, result in enumerate(search_results, 1):
180 | print(f" {j}. {result['title'][:50]}...")
181 | else:
182 | print(" - 未找到搜索结果")
183 |
184 | # 更新状态中的搜索历史
185 | paragraph.research.add_search_results(search_query, search_results)
186 |
187 | # 生成初始总结
188 | print(" - 生成初始总结...")
189 | summary_input = {
190 | "title": paragraph.title,
191 | "content": paragraph.content,
192 | "search_query": search_query,
193 | "search_results": format_search_results_for_prompt(
194 | search_results, self.config.max_content_length
195 | )
196 | }
197 |
198 | # 更新状态
199 | self.state = self.first_summary_node.mutate_state(
200 | summary_input, self.state, paragraph_index
201 | )
202 |
203 | print(" - 初始总结完成")
204 |
205 | def _reflection_loop(self, paragraph_index: int):
206 | """执行反思循环"""
207 | paragraph = self.state.paragraphs[paragraph_index]
208 |
209 | for reflection_i in range(self.config.max_reflections):
210 | print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
211 |
212 | # 准备反思输入
213 | reflection_input = {
214 | "title": paragraph.title,
215 | "content": paragraph.content,
216 | "paragraph_latest_state": paragraph.research.latest_summary
217 | }
218 |
219 | # 生成反思搜索查询
220 | reflection_output = self.reflection_node.run(reflection_input)
221 | search_query = reflection_output["search_query"]
222 | reasoning = reflection_output["reasoning"]
223 |
224 | print(f" 反思查询: {search_query}")
225 | print(f" 反思推理: {reasoning}")
226 |
227 | # 执行反思搜索
228 | search_results = tavily_search(
229 | search_query,
230 | max_results=self.config.max_search_results,
231 | timeout=self.config.search_timeout,
232 | api_key=self.config.tavily_api_key
233 | )
234 |
235 | if search_results:
236 | print(f" 找到 {len(search_results)} 个反思搜索结果")
237 |
238 | # 更新搜索历史
239 | paragraph.research.add_search_results(search_query, search_results)
240 |
241 | # 生成反思总结
242 | reflection_summary_input = {
243 | "title": paragraph.title,
244 | "content": paragraph.content,
245 | "search_query": search_query,
246 | "search_results": format_search_results_for_prompt(
247 | search_results, self.config.max_content_length
248 | ),
249 | "paragraph_latest_state": paragraph.research.latest_summary
250 | }
251 |
252 | # 更新状态
253 | self.state = self.reflection_summary_node.mutate_state(
254 | reflection_summary_input, self.state, paragraph_index
255 | )
256 |
257 | print(f" 反思 {reflection_i + 1} 完成")
258 |
259 | def _generate_final_report(self) -> str:
260 | """生成最终报告"""
261 | print(f"\n[步骤 3] 生成最终报告...")
262 |
263 | # 准备报告数据
264 | report_data = []
265 | for paragraph in self.state.paragraphs:
266 | report_data.append({
267 | "title": paragraph.title,
268 | "paragraph_latest_state": paragraph.research.latest_summary
269 | })
270 |
271 | # 格式化报告
272 | try:
273 | final_report = self.report_formatting_node.run(report_data)
274 | except Exception as e:
275 | print(f"LLM格式化失败,使用备用方法: {str(e)}")
276 | final_report = self.report_formatting_node.format_report_manually(
277 | report_data, self.state.report_title
278 | )
279 |
280 | # 更新状态
281 | self.state.final_report = final_report
282 | self.state.mark_completed()
283 |
284 | print("最终报告生成完成")
285 | return final_report
286 |
287 | def _save_report(self, report_content: str):
288 | """保存报告到文件"""
289 | # 生成文件名
290 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
291 | query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip()
292 | query_safe = query_safe.replace(' ', '_')[:30]
293 |
294 | filename = f"deep_search_report_{query_safe}_{timestamp}.md"
295 | filepath = os.path.join(self.config.output_dir, filename)
296 |
297 | # 保存报告
298 | with open(filepath, 'w', encoding='utf-8') as f:
299 | f.write(report_content)
300 |
301 | print(f"报告已保存到: {filepath}")
302 |
303 | # 保存状态(如果配置允许)
304 | if self.config.save_intermediate_states:
305 | state_filename = f"state_{query_safe}_{timestamp}.json"
306 | state_filepath = os.path.join(self.config.output_dir, state_filename)
307 | self.state.save_to_file(state_filepath)
308 | print(f"状态已保存到: {state_filepath}")
309 |
310 | def get_progress_summary(self) -> Dict[str, Any]:
311 | """获取进度摘要"""
312 | return self.state.get_progress_summary()
313 |
314 | def load_state(self, filepath: str):
315 | """从文件加载状态"""
316 | self.state = State.load_from_file(filepath)
317 | print(f"状态已从 {filepath} 加载")
318 |
319 | def save_state(self, filepath: str):
320 | """保存状态到文件"""
321 | self.state.save_to_file(filepath)
322 | print(f"状态已保存到 {filepath}")
323 |
324 |
325 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
326 | """
327 | 创建Deep Search Agent实例的便捷函数
328 |
329 | Args:
330 | config_file: 配置文件路径
331 |
332 | Returns:
333 | DeepSearchAgent实例
334 | """
335 | config = load_config(config_file)
336 | return DeepSearchAgent(config)
337 |
--------------------------------------------------------------------------------