├── .github
├── stale.yml
└── workflows
│ └── ubuntu.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── config.example.yaml
├── deep_research_demo.py
├── docs
├── gradio.png
├── logo.png
├── project.md
├── report.png
├── wechat.jpeg
└── wechat_group.jpg
├── main.py
├── requirements.txt
├── src
├── __init__.py
├── config.py
├── deep_research.py
├── gradio_chat.py
├── model_utils.py
├── mp_search_client.py
├── prompts.py
├── providers.py
├── search_utils.py
├── serper_client.py
└── tavily_client.py
└── tests
├── __init__.py
├── test_providers.py
└── test_serper_client.py
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 60
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 7
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: wontfix
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs. Thank you
15 | for your contributions.(由于长期不活动,机器人自动关闭此问题,如果需要欢迎提问)
16 | # Comment to post when closing a stale issue. Set to `false` to disable
17 | closeComment: false
--------------------------------------------------------------------------------
/.github/workflows/ubuntu.yml:
--------------------------------------------------------------------------------
1 | on:
2 | workflow_dispatch: # Manually running a workflow
3 | name: Linux build
4 | jobs:
5 | test-ubuntu:
6 | runs-on: ubuntu-latest
7 | strategy:
8 | fail-fast: false
9 | matrix:
10 | # python-version: [ 3.7, 3.8, 3.9 ]
11 | python-version: [ 3.12 ]
12 | steps:
13 | - uses: actions/checkout@v2
14 | - name: Cache pip
15 | uses: actions/cache@v2
16 | if: startsWith(runner.os, 'Linux')
17 | with:
18 | path: ~/.cache/pip
19 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
20 | restore-keys: |
21 | ${{ runner.os }}-pip-
22 | - name: Cache huggingface models
23 | uses: actions/cache@v2
24 | with:
25 | path: ~/.cache/huggingface
26 | key: ${{ runner.os }}-huggingface-
27 | - name: Cache text2vec models
28 | uses: actions/cache@v2
29 | with:
30 | path: ~/.text2vec
31 | key: ${{ runner.os }}-text2vec-
32 | - name: Set up Python
33 | uses: actions/setup-python@v2
34 | with:
35 | python-version: ${{ matrix.python-version }}
36 | - name: Install torch
37 | run: |
38 | python -m pip install --upgrade pip
39 | pip install Cython
40 | pip install torch
41 | - name: Install from pypi
42 | run: |
43 | pip install -U agentica
44 | python -c "import agentica; print(agentica.__version__)"
45 | pip uninstall -y agentica
46 | - name: Install dependencies
47 | run: |
48 | pip install -r requirements.txt
49 | pip install .
50 | pip install pytest
51 | - name: PKG-TEST
52 | run: |
53 | python -m pytest
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | We are happy to accept your contributions to make this repo better and more awesome! To avoid unnecessary work on either
4 | side, please stick to the following process:
5 |
6 | 1. Check if there is already [an issue](https://github.com/shibing624/deep-research/issues) for your concern.
7 | 2. If there is not, open a new one to start a discussion. We hate to close finished PRs!
8 | 3. If we decide your concern needs code changes, we would be happy to accept a pull request. Please consider the
9 | commit guidelines below.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
6 |
7 | -----------------
8 |
9 | # Open Deep Research (Python)
10 | [](CONTRIBUTING.md)
11 | [](LICENSE)
12 | [](requirements.txt)
13 | [](https://github.com/shibing624/deep-research/issues)
14 | [](#Contact)
15 |
16 |
17 | **Deep Research**: Python implementation of AI-powered research assistant that performs iterative, deep research on any topic by combining search engines, web scraping, and large language models.
18 |
19 |
20 | ## Features
21 |
22 | - **深度搜索**:智能生成检索计划并按计划执行并行检索
23 | - **智能查询生成**:基于初始问题和已获取的信息自动生成后续查询
24 | - **多种输出格式**:支持简洁回答和详细报告两种输出模式
25 | - **多语言支持**:完全支持中文输入和输出
26 | - **自动上下文管理**:智能控制传递给LLM的上下文长度,防止token限制错误
27 | - **可配置澄清流程**:可选择跳过澄清环节,直接进行研究
28 | - **多种使用方式**:
29 | - 命令行界面
30 | - Gradio 网页界面(支持流式输出CoT)
31 | - Python 模块直接调用
32 |
33 | ## Demo
34 | - Official demo: https://deepresearch.mulanai.com
35 |
36 | ## Setup
37 |
38 | 1. Clone the repository:
39 |
40 | ```bash
41 | git clone https://github.com/shibing624/deep-research.git
42 | ```
43 | 2. Install dependencies:
44 |
45 | ```bash
46 | pip install -r requirements.txt
47 | ```
48 |
49 | 3. Create a configuration file:
50 |
51 | ```bash
52 | # Copy the example configuration
53 | cp config.example.yaml config.yaml
54 | ```
55 |
56 | The configuration file allows you to set:
57 | - API keys for OpenAI, Tavily, and Serper
58 | - Model preferences
59 | - Search engine options (Serper, MP Search, and Tavily)
60 |
61 |
62 | ### Key Configuration Options
63 |
64 | - **context_size**: Controls the maximum size of context sent to the LLM. The system will automatically truncate longer contexts to prevent token limit errors while preserving as much relevant information as possible. Default is `128000`.
65 |
66 | - **enable_clarification**: When set to `False`, the system will skip the clarification step and proceed directly with research. This is useful for straightforward queries where clarification might be unnecessary. Default is `False`.
67 |
68 | - **search_source**: Choose your preferred search provider. Default is `tavily`.
69 |
70 | - **enable_refine_search_result**: When enabled, the system will refine search results for better relevance. Default is `False`.
71 |
72 | ## Search Engines
73 |
74 | Deep Research supports multiple search engines:
75 |
76 | 1. **Serper** (Default): Uses Google search results via Serper.dev API
77 | 2. **MP Search**: An alternative search provider
78 | 3. **Tavily**: A specialized AI-optimized search engine
79 |
80 | To use Tavily search:
81 | 1. Get an API key from [Tavily](https://tavily.com)
82 | 2. Add it to your config.yaml:
83 | ```yaml
84 | tavily:
85 | api_key: "your-tavily-api-key" # Use the token without 'Bearer' prefix
86 | base_url: "https://api.tavily.com/search"
87 | ```
88 | Note: For Tavily, provide just the API token (e.g., "tvly-dev-xxx") without the "Bearer" prefix.
89 |
90 | 3. Set Tavily as your search source in the Gradio interface or in your config.yaml:
91 | ```yaml
92 | research:
93 | search_source: "tavily"
94 | ```
95 |
96 | The Tavily search engine provides high-quality, AI-optimized search results and may include:
97 | - Ranked search results with relevance scores
98 | - Follow-up questions (when available)
99 | - Direct answers for certain queries (when available)
100 |
101 | ## Usage
102 |
103 | ### Command Line Interface
104 |
105 | The main.py script provides several ways to use the research assistant:
106 |
107 | ```bash
108 | # Show help
109 | python main.py --help
110 |
111 | # Run research directly from command line
112 | python main.py research "中国2024年经济情况分析"
113 |
114 | # Launch the Gradio demo interface
115 | python main.py demo
116 |
117 | # Use a specific configuration file
118 | python main.py --config my-config.yaml research "Your query"
119 | ```
120 |
121 | ### Demo Script
122 |
123 | 运行演示脚本,查看完整的研究流程:
124 |
125 | ```bash
126 | python deep_research_demo.py
127 | ```
128 |
129 | 这将执行一个示例研究,生成 query: {中国历史上最伟大的发明是什么?} 的详细报告,并保存到文件[report.md](https://github.com/shibing624/deep-research/blob/main/report.md)中。
130 |
131 | output:
132 |
133 | 
134 |
135 | ### Gradio Demo
136 |
137 | For a user-friendly interface, run the Gradio demo:
138 |
139 | ```bash
140 | python main.py demo
141 | ```
142 |
143 | This will start a web interface where you can enter your research query, adjust parameters, and view results.
144 |
145 | 
146 |
147 | ### Python Module
148 |
149 | Or use the module directly:
150 |
151 | ```python
152 | import asyncio
153 | from src.deep_research import deep_research_stream
154 |
155 |
156 | async def run_research():
157 | # 运行研究
158 | async for result in deep_research_stream(
159 | query="特斯拉股票走势分析",
160 | history_context="",
161 | ):
162 | # 如果研究完成,保存报告
163 | if result.get("stage") == "completed":
164 | report = result.get("final_report", "")
165 | print(report)
166 | break
167 |
168 |
169 | if __name__ == "__main__":
170 | asyncio.run(run_research())
171 | ```
172 |
173 | Note: Since asynchronous functions are used, you need to use `asyncio.run()` or use `await` in an asynchronous context.
174 |
175 | ## 后续计划
176 |
177 | - 添加更多搜索引擎支持
178 | - 改进查询生成策略
179 | - 增强结果可视化
180 | - 支持更多大语言模型
181 | - 添加文档嵌入和向量搜索功能
182 |
183 | ## Contact
184 |
185 | - Issue(建议)
186 | :[](https://github.com/shibing624/deep-research/issues)
187 | - 邮件我:xuming: xuming624@qq.com
188 | - 微信我: 加我*微信号:xuming624, 备注:姓名-公司-NLP* 进NLP交流群。
189 |
190 |
191 |
192 | ## Citation
193 |
194 | 如果你在研究中使用了`deep-research`,请按如下格式引用:
195 |
196 | APA:
197 |
198 | ```
199 | Xu, M. deep-research: Deep Research with LLM (Version 0.0.1) [Computer software]. https://github.com/shibing624/deep-research
200 | ```
201 |
202 | BibTeX:
203 |
204 | ```
205 | @misc{Xu_deep_research,
206 | title={deep-research: Deep Research with LLM},
207 | author={Xu Ming},
208 | year={2025},
209 | howpublished={\url{https://github.com/shibing624/deep-research}},
210 | }
211 | ```
212 |
213 | ## License
214 |
215 | 授权协议为 [The Apache License 2.0](/LICENSE),可免费用做商业用途。请在产品说明中附加`deep-research`的链接和授权协议。
216 | ## Contribute
217 |
218 | 项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:
219 |
220 | - 在`tests`添加相应的单元测试
221 | - 使用`python -m pytest`来运行所有单元测试,确保所有单测都是通过的
222 |
223 | 之后即可提交PR。
224 |
225 | ## Acknowledgements
226 |
227 | - [dzhng/deep-research](https://github.com/dzhng/deep-research)
228 |
229 | Thanks for their great work!
--------------------------------------------------------------------------------
/config.example.yaml:
--------------------------------------------------------------------------------
1 | # Deep Research Configuration
2 |
3 | # OpenAI settings
4 | openai:
5 | api_key: "sk-xxx" # Set your OpenAI API key here
6 | base_url: "https://api.openai.com/v1"
7 | model: "gpt-4o-mini" # Default model, can be "gpt-4o", "gpt-4o-mini"
8 |
9 | # Report LLM settings
10 | report_llm:
11 | api_key: "sk-xxx" # Set your OpenAI API key here
12 | base_url: "https://api.openai.com/v1"
13 | model: "gpt-4o"
14 |
15 | # Serper settings
16 | serper:
17 | api_key: "xxx" # Set your Serper API key here
18 | base_url: "https://google.serper.dev/search"
19 |
20 | # Tavily settings
21 | tavily:
22 | api_key: "tvly-dev-xxx" # Use the token without 'Bearer' prefix
23 | base_url: "https://api.tavily.com/search"
24 |
25 | # Research settings
26 | research:
27 | concurrency_limit: 3 # 并发搜索数量
28 | context_size: 64000 # 传入LLM的最大文本长度(token数)
29 | search_source: "tavily" # 搜索引擎
30 | max_results_per_query: 3 # 每个搜索结果的最大数量
31 | enable_refine_search_result: False # 是否需要精简搜索结果,如果为True,将会对搜索结果提取关键片段;如果是False,直接用搜索结果原文
32 | enable_next_plan: False # 是否需要下一步计划,如果为True,将会总结搜索结果并提供下一步计划,为生成报告提供参考;如果是False,不要下一步计划的分析
33 | enable_clarification: False # 是否需要澄清问题,如果为True, 将会先澄清问题;如果是False,将会跳过澄清环节,直接进行搜索
34 |
--------------------------------------------------------------------------------
/deep_research_demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import asyncio
3 | from loguru import logger
4 |
5 | from src.config import get_config
6 | from src.deep_research import deep_research_stream
7 |
8 |
9 | async def run_demo():
10 | """
11 | 演示如何使用 deep_research 模块进行研究
12 | """
13 | # 加载配置
14 | config = get_config()
15 |
16 | # 定义研究问题
17 | query = "中国元朝的货币制度改革的影响意义?"
18 |
19 | # 定义进度回调函数
20 | def progress_callback(progress):
21 | current = progress.get("currentQuery", "")
22 | logger.info(f"进度: 当前查询: {current}")
23 |
24 | logger.info(f"开始研究: {query}")
25 | async for result in deep_research_stream(
26 | query=query,
27 | on_progress=progress_callback,
28 | user_clarifications={'all': 'skip'}, # 使用特殊标记跳过澄清
29 | history_context="" # 添加空的history_context
30 | ):
31 | # 处理状态更新
32 | if result.get("status_update"):
33 | logger.info(f"状态更新: {result['status_update']}")
34 |
35 | # 如果研究完成,获取最终报告
36 | if result.get("stage") == "completed":
37 | final_report = result.get("final_report", "")
38 | with open("report.md", "w", encoding="utf-8") as f:
39 | f.write(final_report)
40 | logger.info("报告已保存到 report.md")
41 | print("\n" + "=" * 50)
42 | print(f"研究问题: {query}")
43 | print("=" * 50)
44 | print("\n研究报告:")
45 | print(final_report)
46 | print("\n" + "=" * 50)
47 |
48 | break
49 |
50 |
51 | if __name__ == "__main__":
52 | # 运行异步演示
53 | asyncio.run(run_demo())
54 |
--------------------------------------------------------------------------------
/docs/gradio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/gradio.png
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/logo.png
--------------------------------------------------------------------------------
/docs/project.md:
--------------------------------------------------------------------------------
1 | # Deep Research 项目分析
2 |
3 | ## 项目概述
4 |
5 | Deep Research是一个强大的研究助手工具,能够通过多阶段搜索、分析和总结来帮助用户获取全面的研究信息。通过整合大语言模型与搜索引擎,实现了自动化的研究过程,包括问题澄清、研究计划生成、多步骤搜索和最终报告生成。
6 |
7 | ## 当前优点
8 |
9 | 1. **智能查询澄清**
10 | - 自动识别查询不明确的情况,生成澄清问题
11 | - 对用户回答进行处理,提炼出更明确的查询
12 | - 设计了缓存机制避免重复调用API,优化性能
13 |
14 | 2. **多层次研究流程**
15 | - 动态生成研究计划,根据查询复杂性调整步骤数量
16 | - 支持广度和深度参数调整,适应不同研究需求
17 | - 并行处理多个查询,提高效率
18 |
19 | 3. **透明的研究过程**
20 | - 实时显示研究进度和发现
21 | - 清晰展示研究计划和执行步骤
22 | - 记录所有数据来源,确保可追溯性
23 |
24 | 4. **优化的用户界面**
25 | - Gradio界面简洁易用
26 | - 支持实时反馈和状态更新
27 | - 结构化展示研究结果和中间过程
28 |
29 | 5. **灵活的搜索源**
30 | - 支持多种搜索提供商(serper, tavily, mp_search)
31 | - 允许用户选择偏好的搜索引擎
32 |
33 | ## 当前缺点和局限性
34 |
35 | 1. **错误处理机制不完善**
36 | - 搜索失败或API限制时的恢复策略有限
37 | - 异常处理较为简单,缺乏优雅的降级方案
38 |
39 | 2. **用户交互有待加强**
40 | - 澄清问题环节的交互体验不够流畅
41 | - 缺乏对用户输入的实时验证
42 | - 无法让用户中途干预或调整研究方向
43 |
44 | 3. **性能和资源优化**
45 | - 大型查询可能导致较高的API成本
46 | - 对于复杂查询,响应时间较长
47 | - 缓存和记忆机制有限,不能很好地复用之前的研究结果
48 |
49 | 4. **工具链整合不足**
50 | - 缺乏与其他研究工具的集成能力
51 | - 无法处理非文本内容(图表、数据集等)
52 | - 缺少本地知识库支持
53 |
54 | 5. **评估和验证**
55 | - 缺乏对搜索结果质量的自我评估机制
56 | - 不能充分验证信息的准确性和时效性
57 | - 缺少对矛盾信息的处理机制
58 |
59 | ## 建议改进方向
60 |
61 | 1. **增强错误恢复机制**
62 | - 实现更健壮的错误处理和重试逻辑
63 | - 添加备用搜索源自动切换功能
64 | - 提供更详细的错误诊断和解决建议
65 |
66 | 2. **改进用户体验**
67 | - 支持研究过程中的用户干预
68 | - 增加更多交互选项,如点赞/否决特定发现
69 | - 提供简洁模式和详细模式切换
70 | - 优化移动设备支持
71 |
72 | 3. **扩展功能和集成**
73 | - 添加更多搜索提供商
74 | - 实现本地知识库和文档集成
75 | - 支持导出研究结果到不同格式(PDF、Word、Markdown等)
76 | - 增加多媒体内容处理能力
77 |
78 | 4. **性能优化**
79 | - 实现更智能的缓存机制
80 | - 优化并行请求管理
81 | - 添加低成本模式,减少API调用
82 | - 实现增量更新研究,避免重复工作
83 |
84 | 5. **增强验证能力**
85 | - 添加信息交叉验证机制
86 | - 实现时效性检测,标记可能过时的信息
87 | - 增加来源评级系统,优先考虑高质量来源
88 | - 支持对矛盾信息的比较和分析
89 |
90 | 6. **扩展语言支持**
91 | - 增强多语言处理能力
92 | - 优化非英文查询的澄清和处理
93 | - 支持语言切换和翻译功能
94 |
95 | 7. **社区和开发者支持**
96 | - 完善API文档
97 | - 提供更多自定义和扩展点
98 | - 简化部署流程
99 | - 创建插件系统,支持社区贡献
100 |
101 | 8. **应用场景扩展**
102 | - 开发专注于特定领域的研究模式(学术、新闻、医疗等)
103 | - 添加长期研究项目支持,包括进度跟踪和定期更新
104 | - 实现团队协作功能,支持多人共同研究
105 |
106 | ## 结论
107 |
108 | Deep Research项目展现了强大的潜力,通过结合大语言模型和搜索引擎,为用户提供了自动化、深入的研究能力。尽管目前存在一些局限性,但通过持续改进,该项目可以发展成为一个更加全面、可靠且易用的研究助手工具。重点应放在提高研究质量、增强用户体验以及扩展功能集成上,使其能够满足各种复杂的研究需求。
--------------------------------------------------------------------------------
/docs/report.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/report.png
--------------------------------------------------------------------------------
/docs/wechat.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/wechat.jpeg
--------------------------------------------------------------------------------
/docs/wechat_group.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/wechat_group.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import argparse
7 | import asyncio
8 | from src.config import load_config, get_config
9 | from src.deep_research import deep_research_stream
10 | from loguru import logger
11 |
12 |
13 | async def run_research(args):
14 | # 加载配置
15 | config = get_config()
16 |
17 | # 定义研究问题
18 | query = args.query
19 |
20 | # 定义进度回调函数
21 | def progress_callback(progress):
22 | current = progress.get("currentQuery", "")
23 |
24 | logger.info(f"进度: 当前查询: {current}")
25 |
26 | # 运行研究
27 | logger.info(f"开始研究: {query}")
28 |
29 | # 使用流式研究,通过user_clarifications跳过澄清步骤
30 | async for result in deep_research_stream(
31 | query=query,
32 | on_progress=progress_callback,
33 | user_clarifications={'all': 'skip'}, # 使用特殊标记跳过澄清
34 | history_context="" # 添加空的history_context
35 | ):
36 | # 处理状态更新
37 | if result.get("status_update"):
38 | logger.info(f"状态更新: {result['status_update']}")
39 |
40 | # 如果研究完成,获取最终报告
41 | if result.get("stage") == "completed":
42 | learnings = result.get("learnings", [])
43 | visited_urls = result.get("visitedUrls", [])
44 | final_report = result.get("final_report", "")
45 |
46 | logger.info(f"研究完成! 发现 {len(learnings)} 条学习内容和 {len(visited_urls)} 个来源。")
47 |
48 | # 保存结果到文件
49 | with open("report.md", "w", encoding="utf-8") as f:
50 | f.write(final_report)
51 |
52 | logger.info("报告已保存到 report.md")
53 |
54 | print("\n" + "=" * 50)
55 | print(f"研究问题: {query}")
56 | print("=" * 50)
57 | print("\n研究报告:")
58 | print(final_report)
59 | print("\n" + "=" * 50)
60 |
61 | break
62 |
63 |
64 | def main():
65 | """Main entry point with argument parsing"""
66 | parser = argparse.ArgumentParser(
67 | description="Deep Research - AI-powered research assistant"
68 | )
69 |
70 | # Add config file argument
71 | parser.add_argument(
72 | "--config", type=str,
73 | help="Path to YAML configuration file"
74 | )
75 |
76 | subparsers = parser.add_subparsers(dest="command", help="Command to run")
77 |
78 | # demo command
79 | demo_parser = subparsers.add_parser("demo", help="Run the gradio demo server")
80 | demo_parser.add_argument(
81 | "--host", type=str, default='0.0.0.0', help="Host ip"
82 | )
83 |
84 | # Research command
85 | research_parser = subparsers.add_parser("research", help="Run research directly")
86 | research_parser.add_argument(
87 | "query", type=str,
88 | help="Research query"
89 | )
90 | args = parser.parse_args()
91 |
92 | # Load configuration
93 | if args.config:
94 | load_config(args.config)
95 |
96 | # Execute command
97 | if args.command == "research":
98 | asyncio.run(run_research(args))
99 | elif args.command == "demo":
100 | from src.gradio_chat import run_gradio_demo
101 | run_gradio_demo()
102 | else:
103 | # Default to showing help if no command specified
104 | parser.print_help()
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | loguru
2 | openai
3 | httpx
4 | gradio==5.22.0
5 | pyyaml
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/src/__init__.py
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import os
7 | import yaml
8 | from typing import Dict, Any, Optional
9 | from loguru import logger
10 |
11 | # Default configuration
12 | DEFAULT_CONFIG = {
13 | "openai": {
14 | "api_key": None,
15 | "base_url": "https://api.openai.com/v1",
16 | "model": "o3-mini"
17 | },
18 | "serper": {
19 | "api_key": None,
20 | "base_url": "https://google.serper.dev/search"
21 | },
22 | "tavily": {
23 | "api_key": None,
24 | "base_url": "https://api.tavily.com/search"
25 | },
26 | "research": {
27 | "concurrency_limit": 1,
28 | "context_size": 128000,
29 | "search_source": "serper",
30 | "max_results_per_query": 5,
31 | "enable_refine_search_result": False,
32 | "enable_next_plan": False
33 | }
34 | }
35 |
36 | # Global configuration object
37 | _config = None
38 |
39 |
40 | def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
41 | """
42 | Load configuration from YAML file with fallback to environment variables
43 |
44 | Args:
45 | config_path: Path to YAML configuration file
46 |
47 | Returns:
48 | Dict containing configuration
49 | """
50 | global _config
51 |
52 | # Start with default configuration
53 | config = DEFAULT_CONFIG.copy()
54 |
55 | # Try to load from specified config file
56 | if config_path and os.path.exists(config_path):
57 | try:
58 | with open(config_path, 'r') as f:
59 | yaml_config = yaml.safe_load(f)
60 | if yaml_config:
61 | # Merge with defaults (deep merge would be better but this is simple)
62 | for section, values in yaml_config.items():
63 | if section in config and isinstance(values, dict):
64 | config[section].update(values)
65 | else:
66 | config[section] = values
67 | logger.info(f"Loaded configuration from {config_path}")
68 | except Exception as e:
69 | logger.warning(f"Error loading config from {config_path}: {str(e)}")
70 | else:
71 | # If no config file specified, look for config.yaml in the current directory
72 | default_path = './config.yaml'
73 | if os.path.exists(default_path):
74 | return load_config(default_path)
75 | else:
76 | logger.info("No configuration file found, using defaults")
77 |
78 | # Store the config globally
79 | _config = config
80 | return config
81 |
82 |
83 | def get_config() -> Dict[str, Any]:
84 | """
85 | Get the current configuration
86 |
87 | Returns:
88 | Dict containing configuration
89 | """
90 | global _config
91 | if _config is None:
92 | return load_config()
93 | return _config
94 |
--------------------------------------------------------------------------------
/src/deep_research.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 |
6 | Deep research functionality for comprehensive research.
7 | """
8 |
9 | import asyncio
10 | import json
11 | import traceback
12 | import inspect
13 | import platform
14 | from typing import Optional, Callable, Dict, List, Any, Union, Tuple, AsyncGenerator
15 | from loguru import logger
16 | from datetime import datetime
17 |
18 | from .config import get_config
19 | from .providers import get_search_provider
20 | from .prompts import (
21 | SHOULD_CLARIFY_QUERY_PROMPT,
22 | FOLLOW_UP_QUESTIONS_PROMPT,
23 | PROCESS_NO_CLARIFICATIONS_PROMPT,
24 | PROCESS_CLARIFICATIONS_PROMPT,
25 | RESEARCH_PLAN_PROMPT,
26 | EXTRACT_SEARCH_RESULTS_SYSTEM_PROMPT,
27 | EXTRACT_SEARCH_RESULTS_PROMPT,
28 | RESEARCH_SUMMARY_PROMPT,
29 | FINAL_REPORT_SYSTEM_PROMPT,
30 | FINAL_REPORT_PROMPT,
31 | FINAL_ANSWER_PROMPT,
32 | )
33 | from .model_utils import generate_completion, generate_json_completion
34 | from .search_utils import search_with_query
35 |
36 |
37 | def get_current_date():
38 | """Return the current date in ISO format."""
39 | return datetime.now().isoformat()
40 |
41 |
42 | def add_event_loop_policy():
43 | """Add event loop policy for Windows if needed."""
44 | if platform.system() == "Windows":
45 | try:
46 | # Set event loop policy for Windows
47 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
48 | except Exception as e:
49 | print(f"Error setting event loop policy: {e}")
50 |
51 |
52 | def limit_context_size(text: str, max_size: int) -> str:
53 | """
54 | Limit the context size to prevent LLM token limit errors.
55 |
56 | Args:
57 | text: Text to limit
58 | max_size: Approximate maximum character count (rough estimate for tokens)
59 |
60 | Returns:
61 | Limited text
62 | """
63 | # Simple character-based truncation (rough approximation)
64 | # On average, 1 token is roughly 4 characters for English text
65 | # For Chinese, the ratio is 2
66 | char_limit = max_size * 2
67 |
68 | if len(text) <= char_limit:
69 | return text
70 |
71 | logger.warning(f"Truncating context from {len(text)} chars to ~{char_limit} chars")
72 |
73 | # For JSON strings, try to preserve structure
74 | if text.startswith('{') and text.endswith('}'):
75 | try:
76 | # Try to parse as JSON
77 | data = json.loads(text)
78 | # If it's a list of items, truncate the list
79 | if isinstance(data, list):
80 | # Calculate approx chars per item
81 | if len(data) > 0:
82 | chars_per_item = len(text) / len(data)
83 | items_to_keep = int(char_limit / chars_per_item)
84 | items_to_keep = max(1, min(items_to_keep, len(data)))
85 | return json.dumps(data[:items_to_keep], ensure_ascii=False)
86 | # If it contains a list property, truncate that
87 | for key, value in data.items():
88 | if isinstance(value, list) and len(value) > 0:
89 | chars_per_item = len(json.dumps(value, ensure_ascii=False)) / len(value)
90 | items_to_keep = int(char_limit / 2 / chars_per_item) # Use half for lists
91 | items_to_keep = max(1, min(items_to_keep, len(value)))
92 | data[key] = value[:items_to_keep]
93 | return json.dumps(data, ensure_ascii=False)
94 | except:
95 | # Not valid JSON, use simple truncation
96 | pass
97 |
98 | # Simple truncation with indicator
99 | return text[:char_limit - 50] + "... [content truncated due to token limit]"
100 |
101 |
102 | async def should_clarify_query(query: str, history_context: str = '') -> bool:
103 | """
104 | Use the language model to determine if a query needs clarification.
105 | """
106 | try:
107 | prompt = SHOULD_CLARIFY_QUERY_PROMPT.format(
108 | query=query,
109 | history_context=history_context,
110 | current_date=get_current_date()
111 | )
112 | result = await generate_completion(prompt, temperature=0)
113 | logger.debug(f"query: {query}, should clarify query result: {result}")
114 |
115 | # 检查结果是否包含肯定回答(yes/y)
116 | needs_clarification = "yes" in result.lower() or 'y' in result.lower()
117 | return needs_clarification
118 | except Exception as e:
119 | logger.error(f"error: {str(e)}")
120 | return True
121 |
122 |
123 | async def generate_followup_questions(query: str, history_context: str = '') -> Dict[str, Any]:
124 | """
125 | Generate clarifying follow-up questions for the given query.
126 |
127 | Args:
128 | query: The user's research query
129 | history_context: Chat history context
130 |
131 | Returns:
132 | Dict containing whether clarification is needed and questions
133 | """
134 | try:
135 | # Format the prompt
136 | prompt = FOLLOW_UP_QUESTIONS_PROMPT.format(
137 | query=query,
138 | history_context=history_context,
139 | current_date=get_current_date() # 添加当前日期
140 | )
141 |
142 | # Generate the followup questions
143 | result = await generate_json_completion(prompt, temperature=0.7)
144 |
145 | # Ensure the expected structure
146 | if "needs_clarification" not in result:
147 | result["needs_clarification"] = False
148 |
149 | if "questions" not in result:
150 | result["questions"] = []
151 |
152 | logger.debug(f"Follow-up questions: {result}")
153 | return result
154 |
155 | except Exception as e:
156 | logger.error(f"Error generating followup questions: {str(e)}")
157 | return {
158 | "needs_clarification": False,
159 | "questions": []
160 | }
161 |
162 |
163 | async def process_clarifications(
164 | query: str,
165 | user_responses: Dict[str, str],
166 | all_questions: List[Dict[str, Any]],
167 | history_context: str = ''
168 | ) -> Dict[str, Any]:
169 | """
170 | Process user responses to clarification questions and refine the query.
171 |
172 | Args:
173 | query: The original query
174 | user_responses: Dict mapping question keys to user responses
175 | all_questions: List of all questions that were asked
176 | history_context: Chat history context
177 |
178 | Returns:
179 | Dict with refined query and other information
180 | """
181 | try:
182 | # Format the questions and responses for the prompt
183 | clarifications = []
184 | unanswered = []
185 |
186 | for question in all_questions:
187 | key = question.get("key", "")
188 | question_text = question.get("question", "")
189 | default = question.get("default", "")
190 |
191 | if key in user_responses and user_responses[key]:
192 | clarifications.append(f"Q: {question_text}\nA: {user_responses[key]}")
193 | else:
194 | clarifications.append(f"Q: {question_text}\nA: [Not answered by user]")
195 | unanswered.append(f"Q: {question_text}\nDefault: {default}")
196 |
197 | # Check if all questions were unanswered - if so, modify the prompt
198 | all_unanswered = len(unanswered) == len(all_questions) and len(all_questions) > 0
199 |
200 | # Format the prompt
201 | if all_unanswered:
202 | prompt = PROCESS_NO_CLARIFICATIONS_PROMPT.format(
203 | query=query,
204 | unanswered_questions="\n\n".join(unanswered) if unanswered else "None",
205 | history_context=history_context,
206 | current_date=get_current_date()
207 | )
208 | else:
209 | prompt = PROCESS_CLARIFICATIONS_PROMPT.format(
210 | query=query,
211 | clarifications="\n\n".join(clarifications),
212 | unanswered_questions="\n\n".join(unanswered) if unanswered else "None",
213 | history_context=history_context,
214 | current_date=get_current_date()
215 | )
216 |
217 | # Generate the clarifications processing
218 | result = await generate_json_completion(prompt, temperature=0.7)
219 |
220 | # Ensure the expected structure
221 | if "refined_query" not in result:
222 | result["refined_query"] = query
223 |
224 | if "assumptions" not in result:
225 | result["assumptions"] = []
226 |
227 | if "requires_search" not in result:
228 | result["requires_search"] = True
229 |
230 | if "direct_answer" not in result:
231 | result["direct_answer"] = ""
232 |
233 | logger.debug(f"Processed clarifications: {result}")
234 | return result
235 |
236 | except Exception as e:
237 | logger.error(f"Error processing clarifications: {str(e)}")
238 | return {
239 | "refined_query": query,
240 | "assumptions": [],
241 | "requires_search": True,
242 | "direct_answer": ""
243 | }
244 |
245 |
246 | async def generate_research_plan(query: str, history_context: str = "") -> Dict[str, Any]:
247 | """
248 | Generate a research plan with a variable number of steps based on the query complexity.
249 |
250 | Args:
251 | query: The research query
252 | history_context: Chat history context
253 |
254 | Returns:
255 | Dict containing the research plan with steps
256 | """
257 | try:
258 | # Format the prompt
259 | prompt = RESEARCH_PLAN_PROMPT.format(
260 | query=query,
261 | history_context=history_context,
262 | current_date=get_current_date() # 添加当前日期
263 | )
264 |
265 | # Generate the research plan
266 | result = await generate_json_completion(prompt, temperature=0.7)
267 |
268 | # Ensure the expected structure
269 | if "steps" not in result or not result["steps"]:
270 | # Create a default single-step plan if none was provided
271 | result["steps"] = [{
272 | "step_id": 1,
273 | "description": "Research the query",
274 | "search_queries": [query],
275 | "goal": "Find information about the query"
276 | }]
277 |
278 | if "assessments" not in result:
279 | result["assessments"] = "No complexity assessment provided"
280 |
281 | logger.debug(f"Research plan: {result}")
282 | return result
283 |
284 | except Exception as e:
285 | logger.error(f"Error generating research plan: {str(e)}")
286 | # Return a simple default plan
287 | return {
288 | "assessments": "Error occurred, using default plan",
289 | "steps": [{
290 | "step_id": 1,
291 | "description": "Research the query",
292 | "search_queries": [query],
293 | "goal": "Find information about the query"
294 | }]
295 | }
296 |
297 |
298 | async def extract_search_results(query: str, search_results: str) -> str:
299 | """
300 | Extract search results for a query.
301 |
302 | Args:
303 | query: The search query
304 | search_results: Formatted search results text
305 |
306 | Returns:
307 | extracted search results with detailed content and relevance information
308 | """
309 | try:
310 | # Get context size limit from config
311 | config = get_config()
312 | context_size = config.get("research", {}).get("context_size", 128000)
313 |
314 | # Limit search results size
315 | limited_search_results = limit_context_size(search_results, context_size // 2)
316 |
317 | # Format the prompt
318 | prompt = EXTRACT_SEARCH_RESULTS_PROMPT.format(
319 | query=query,
320 | search_results=limited_search_results,
321 | current_date=get_current_date()
322 | )
323 |
324 | # Generate the extracted_contents
325 | extracted_contents = await generate_json_completion(
326 | prompt=prompt,
327 | system_message=EXTRACT_SEARCH_RESULTS_SYSTEM_PROMPT,
328 | temperature=0
329 | )
330 |
331 | # Process and enrich the extracted content
332 | if "extracted_infos" in extracted_contents:
333 | # Make sure all entries have a relevance field (for backward compatibility)
334 | for info in extracted_contents["extracted_infos"]:
335 | if "relevance" not in info:
336 | info["relevance"] = "与查询相关的信息"
337 |
338 | # Convert to string for storage and transfer
339 | extracted_contents_str = json.dumps(extracted_contents, ensure_ascii=False)
340 | return extracted_contents_str
341 |
342 | except Exception as e:
343 | logger.error(f"Error extract search results: {str(e)}")
344 | return f"Error extract results for '{query}': {str(e)}"
345 |
346 |
347 | async def write_final_report_stream(query: str, context: str,
348 | history_context: str = '') -> AsyncGenerator[str, None]:
349 | """
350 | Streaming version of write_final_report that yields chunks of the report.
351 |
352 | Args:
353 | query: The original research query
354 | context: List of key learnings/facts discovered with their sources
355 | history_context: str
356 |
357 | Yields:
358 | Chunks of the final report
359 | """
360 | # Get context size limit from config
361 | config = get_config()
362 | context_size = config.get("research", {}).get("context_size", 128000)
363 |
364 | # Limit context sizes
365 | limited_context = limit_context_size(context, context_size // 2)
366 | limited_history = limit_context_size(history_context, context_size // 4)
367 | system_message = FINAL_REPORT_SYSTEM_PROMPT.format(
368 | current_date=get_current_date()
369 | )
370 | formatted_prompt = FINAL_REPORT_PROMPT.format(
371 | query=query,
372 | context=limited_context,
373 | history_context=limited_history
374 | )
375 |
376 | response_generator = await generate_completion(
377 | prompt=formatted_prompt,
378 | system_message=system_message,
379 | temperature=0.7,
380 | stream=True,
381 | is_report=True
382 | )
383 |
384 | # Stream the response chunks
385 | async for chunk in response_generator:
386 | yield chunk
387 |
388 |
389 | async def write_final_report(query: str, context: str, history_context: str = '') -> str:
390 | """
391 | Generate a final research report based on learnings and sources.
392 |
393 | Args:
394 | query: The original research query
395 | context: List of key learnings/facts discovered with their sources
396 | history_context: chat history
397 |
398 | Returns:
399 | A formatted research report
400 | """
401 | # Get context size limit from config
402 | config = get_config()
403 | context_size = config.get("research", {}).get("context_size", 128000)
404 |
405 | # Limit context sizes
406 | limited_context = limit_context_size(context, context_size // 2)
407 | limited_history = limit_context_size(history_context, context_size // 4)
408 |
409 | formatted_prompt = FINAL_REPORT_PROMPT.format(
410 | query=query,
411 | context=limited_context,
412 | history_context=limited_history,
413 | )
414 |
415 | system_message = FINAL_REPORT_SYSTEM_PROMPT.format(
416 | current_date=get_current_date()
417 | )
418 |
419 | report = await generate_completion(
420 | prompt=formatted_prompt,
421 | system_message=system_message,
422 | temperature=0.7,
423 | is_report=True
424 | )
425 |
426 | return report
427 |
428 |
429 | async def write_final_answer(query: str, context: str, history_context: str = '') -> str:
430 | """
431 | Generate a final concise answer based on the research.
432 |
433 | Args:
434 | query: The original research query
435 | context: List of key learnings/facts discovered with their sources
436 | history_context: chat history
437 |
438 | Returns:
439 | A concise answer to the query
440 | """
441 | # Get context size limit from config
442 | config = get_config()
443 | context_size = config.get("research", {}).get("context_size", 128000)
444 |
445 | # Limit context sizes
446 | limited_context = limit_context_size(context, context_size // 2)
447 | limited_history = limit_context_size(history_context, context_size // 4)
448 |
449 | formatted_prompt = FINAL_ANSWER_PROMPT.format(
450 | query=query,
451 | context=limited_context,
452 | history_context=limited_history,
453 | )
454 |
455 | system_message = FINAL_REPORT_SYSTEM_PROMPT.format(
456 | current_date=get_current_date()
457 | )
458 |
459 | answer = await generate_completion(
460 | prompt=formatted_prompt,
461 | system_message=system_message,
462 | temperature=0.7,
463 | is_report=True
464 | )
465 |
466 | return answer
467 |
468 |
469 | def write_final_report_sync(query: str, context: str, history_context: str = '') -> str:
470 | """
471 | Synchronous wrapper for the write_final_report function.
472 |
473 | Args:
474 | query: The original research query
475 | context: List of key learnings/facts discovered with their sources
476 | history_context: chat history
477 |
478 | Returns:
479 | A formatted research report
480 | """
481 | # Add event loop policy for Windows if needed
482 | add_event_loop_policy()
483 |
484 | # Run the async function in the event loop
485 | try:
486 | loop = asyncio.get_event_loop()
487 | except RuntimeError:
488 | loop = asyncio.new_event_loop()
489 | asyncio.set_event_loop(loop)
490 |
491 | return loop.run_until_complete(
492 | write_final_report(
493 | query=query,
494 | context=context,
495 | history_context=history_context
496 | )
497 | )
498 |
499 |
500 | def write_final_answer_sync(query: str, context: str, history_context: str = '') -> str:
501 | """
502 | Synchronous wrapper for the write_final_answer function.
503 |
504 | Args:
505 | query: The original research query
506 | context: List of key learnings/facts discovered with their sources
507 | history_context: chat history
508 |
509 | Returns:
510 | A concise answer to the query
511 | """
512 | # Add event loop policy for Windows if needed
513 | add_event_loop_policy()
514 |
515 | # Run the async function in the event loop
516 | try:
517 | loop = asyncio.get_event_loop()
518 | except RuntimeError:
519 | loop = asyncio.new_event_loop()
520 | asyncio.set_event_loop(loop)
521 |
522 | return loop.run_until_complete(
523 | write_final_answer(
524 | query=query,
525 | context=context,
526 | history_context=history_context
527 | )
528 | )
529 |
530 |
531 | async def research_step(
532 | query: str,
533 | config: Dict[str, Any],
534 | on_progress: Optional[Callable] = None,
535 | search_provider=None,
536 | ) -> Dict[str, Any]:
537 | """
538 | Perform a single step of the research process.
539 |
540 | Args:
541 | query: The query to research
542 | config: Configuration dictionary
543 | on_progress: Optional callback for progress updates
544 | search_provider: Search provider instance
545 |
546 | Returns:
547 | Dict with research results
548 | """
549 | if search_provider is None:
550 | search_provider = get_search_provider()
551 |
552 | # Progress update
553 | if on_progress:
554 | progress_data = {
555 | "currentQuery": query
556 | }
557 |
558 | # Check if on_progress is a coroutine function
559 | if inspect.iscoroutinefunction(on_progress):
560 | await on_progress(progress_data)
561 | else:
562 | on_progress(progress_data)
563 |
564 | # Get the search results
565 | search_result = await search_with_query(query, config, search_provider)
566 | search_results_text = search_result["summary"]
567 | urls = search_result["urls"]
568 |
569 | enable_refine_search_result = config.get("research", {}).get("enable_refine_search_result", False)
570 | if enable_refine_search_result:
571 | extracted_content = await extract_search_results(query, search_results_text)
572 | else:
573 | extracted_content = search_results_text
574 |
575 | return {
576 | "extracted_content": extracted_content,
577 | "urls": urls,
578 | }
579 |
580 |
581 | async def deep_research_stream(
582 | query: str,
583 | on_progress: Optional[Callable] = None,
584 | user_clarifications: Dict[str, str] = None,
585 | search_source: Optional[str] = None,
586 | history_context: Optional[str] = None,
587 | enable_clarification: bool = False,
588 | ) -> AsyncGenerator[Dict[str, Any], None]:
589 | """
590 | Streaming version of deep research that yields partial results.
591 |
592 | Args:
593 | query: The research query
594 | on_progress: Optional callback function for progress updates
595 | user_clarifications: User responses to clarification questions
596 | search_source: Optional search provider to use
597 | history_context: history chat content
598 | enable_clarification: Whether to use the clarification step
599 |
600 | Yields:
601 | Dict with partial research results and status updates
602 | """
603 | # Load configuration
604 | config = get_config()
605 | logger.debug(f"query: {query}, config: {config}")
606 |
607 | # Initialize tracking variables
608 | visited_urls = []
609 | all_learnings = []
610 |
611 | # Initialize search provider
612 | search_provider = get_search_provider(search_source=search_source)
613 |
614 | try:
615 | # Step 1: Yield initial status
616 | yield {
617 | "status_update": f"开始研究查询: '{query}'...",
618 | "learnings": all_learnings,
619 | "visitedUrls": visited_urls,
620 | "current_query": query,
621 | "stage": "initial"
622 | }
623 |
624 | # Step 1.5: 先判断是否需要生成澄清问题
625 | needs_clarification = False
626 |
627 | # 如果没有提供用户澄清,且不跳过澄清环节,则判断是否需要澄清
628 | if not user_clarifications and enable_clarification:
629 | yield {
630 | "status_update": f"分析查询是否需要澄清...",
631 | "learnings": all_learnings,
632 | "visitedUrls": visited_urls,
633 | "current_query": query,
634 | "stage": "analyzing_query"
635 | }
636 |
637 | # 修复重复调用问题 - 使用缓存变量存储结果
638 | _clarification_cache_key = f"should_clarify_{query}"
639 | if _clarification_cache_key not in globals():
640 | needs_clarification = await should_clarify_query(query, history_context)
641 | globals()[_clarification_cache_key] = needs_clarification
642 | else:
643 | needs_clarification = globals()[_clarification_cache_key]
644 | logger.debug(f"使用缓存的澄清结果: {query}, result: {needs_clarification}")
645 |
646 | if not needs_clarification:
647 | yield {
648 | "status_update": f"查询已足够清晰,跳过澄清步骤",
649 | "learnings": all_learnings,
650 | "visitedUrls": visited_urls,
651 | "current_query": query,
652 | "stage": "clarification_skipped"
653 | }
654 | elif not enable_clarification:
655 | # 如果配置为跳过澄清环节,直接显示状态
656 | yield {
657 | "status_update": f"配置为跳过澄清环节,直接开始研究",
658 | "learnings": all_learnings,
659 | "visitedUrls": visited_urls,
660 | "current_query": query,
661 | "stage": "clarification_skipped"
662 | }
663 |
664 | # 如果LLM判断需要澄清,或者已经有用户澄清,且未配置跳过澄清环节,则继续生成或处理澄清问题
665 | questions = []
666 | if (needs_clarification or user_clarifications) and enable_clarification:
667 | # Step 2: Generate clarification questions if needed
668 | if needs_clarification:
669 | yield {
670 | "status_update": f"生成澄清问题...",
671 | "learnings": all_learnings,
672 | "visitedUrls": visited_urls,
673 | "current_query": query,
674 | "stage": "generating_questions"
675 | }
676 |
677 | followup_result = await generate_followup_questions(query, history_context)
678 | questions = followup_result.get("questions", [])
679 |
680 | if questions:
681 | # If clarification is needed, update status
682 | yield {
683 | "status_update": f"查询需要澄清,生成了 {len(questions)} 个问题",
684 | "learnings": all_learnings,
685 | "visitedUrls": visited_urls,
686 | "current_query": query,
687 | "questions": questions,
688 | "stage": "clarification_needed"
689 | }
690 |
691 | # If we don't have user responses yet, wait for them
692 | if not user_clarifications:
693 | yield {
694 | "status_update": "等待用户回答澄清问题...",
695 | "learnings": all_learnings,
696 | "visitedUrls": visited_urls,
697 | "current_query": query,
698 | "questions": questions,
699 | "awaiting_clarification": True,
700 | "stage": "awaiting_clarification"
701 | }
702 | return
703 |
704 | # Step 3: Process user clarifications if provided
705 | refined_query = query
706 | user_responses = user_clarifications or {}
707 |
708 | if questions and user_clarifications and enable_clarification:
709 | # Track which questions were answered vs. which use defaults
710 | answered_questions = []
711 | unanswered_questions = []
712 |
713 | for q in questions:
714 | key = q.get("key", "")
715 | question_text = q.get("question", "")
716 | if key in user_responses and user_responses[key]:
717 | answered_questions.append(question_text)
718 | else:
719 | unanswered_questions.append(question_text)
720 |
721 | yield {
722 | "status_update": f"处理用户的澄清回答 ({len(answered_questions)}/{len(questions)} 已回答)",
723 | "learnings": all_learnings,
724 | "visitedUrls": visited_urls,
725 | "current_query": query,
726 | "answered_questions": answered_questions,
727 | "unanswered_questions": unanswered_questions,
728 | "stage": "processing_clarifications"
729 | }
730 |
731 | # Process the clarifications
732 | clarification_result = await process_clarifications(query, user_responses, questions, history_context)
733 | refined_query = clarification_result.get("refined_query", query)
734 |
735 | yield {
736 | "status_update": f"查询已优化: '{refined_query}'",
737 | "learnings": all_learnings,
738 | "visitedUrls": visited_urls,
739 | "original_query": query,
740 | "current_query": refined_query,
741 | "assumptions": clarification_result.get("assumptions", []),
742 | "stage": "query_refined"
743 | }
744 |
745 | # Check if this is a simple query that can be answered directly
746 | if not clarification_result.get("requires_search", True):
747 | direct_answer = clarification_result.get("direct_answer", "")
748 | if direct_answer:
749 | yield {
750 | "status_update": "查询可以直接回答,无需搜索",
751 | "requires_search": False,
752 | "direct_answer": direct_answer,
753 | "final_report": direct_answer, # 直接使用direct_answer作为最终报告
754 | "learnings": ["直接回答: " + direct_answer],
755 | "visitedUrls": [],
756 | "stage": "completed"
757 | }
758 | return # 不再需要继续执行查询流程
759 |
760 | # Step 4: Generate research plan with variable steps
761 | yield {
762 | "status_update": f"为查询生成研究计划: '{refined_query}'",
763 | "learnings": all_learnings,
764 | "visitedUrls": visited_urls,
765 | "current_query": refined_query,
766 | "stage": "planning"
767 | }
768 |
769 | plan_result = await generate_research_plan(refined_query, history_context)
770 | steps = plan_result.get("steps", [])
771 |
772 | yield {
773 | "status_update": f"生成了 {len(steps)} 步研究计划: {plan_result.get('assessments', '无评估')}",
774 | "learnings": all_learnings,
775 | "visitedUrls": visited_urls,
776 | "current_query": refined_query,
777 | "research_plan": steps,
778 | "stage": "plan_generated"
779 | }
780 |
781 | # Track step summaries for the final analysis
782 | step_summaries = []
783 |
784 | # Iterate through each step in the research plan
785 | for step_idx, step in enumerate(steps):
786 | step_id = step.get("step_id", step_idx + 1)
787 | description = step.get("description", f"Research step {step_id}")
788 | search_queries = step.get("search_queries", [refined_query])
789 |
790 | yield {
791 | "status_update": f"开始研究步骤 {step_id}/{len(steps)}: {description}",
792 | "learnings": all_learnings,
793 | "visitedUrls": visited_urls,
794 | "current_query": refined_query,
795 | "search_queries": search_queries,
796 | "current_step": step,
797 | "progress": {
798 | "current_step": step_id,
799 | "total_steps": len(steps)
800 | },
801 | "stage": "step_started"
802 | }
803 |
804 | # Create a queue of queries to process for this step
805 | step_urls = []
806 | step_learnings = []
807 |
808 | current_queries = search_queries.copy()
809 | # Research these queries concurrently
810 | yield {
811 | "status_update": f"步骤 {step_id}/{len(steps)}: 并行研究 {len(current_queries)} 个查询",
812 | "learnings": all_learnings,
813 | "visitedUrls": visited_urls,
814 | "current_queries": current_queries,
815 | "progress": {
816 | "current_step": step_id,
817 | "total_steps": len(steps),
818 | "processed_queries": len(current_queries),
819 | },
820 | "stage": "processing_queries"
821 | }
822 |
823 | # Process each query in the current batch
824 | research_tasks = []
825 | for current_query in current_queries:
826 | task = research_step(
827 | query=current_query,
828 | config=config,
829 | on_progress=on_progress,
830 | search_provider=search_provider,
831 | )
832 | research_tasks.append(task)
833 |
834 | # Execute tasks with concurrency
835 | results = await asyncio.gather(*research_tasks)
836 |
837 | # Process the results
838 | for result in results:
839 | # Update tracking variables
840 | urls = result["urls"]
841 | content = result["extracted_content"]
842 | step_urls.extend(urls)
843 | step_learnings.append(content)
844 |
845 | # Format learnings and URLs for display
846 | formatted_learnings = []
847 | for i, learning in enumerate(step_learnings):
848 | formatted_learnings.append(f"[{i + 1}] {learning}")
849 |
850 | formatted_urls = []
851 | for i, url in enumerate(step_urls):
852 | formatted_urls.append(f"[{i + 1}] {url}")
853 |
854 | # Truncate longer learnings for display
855 | new_learnings = [str(i)[:400] for i in step_learnings]
856 | yield {
857 | "status_update": f"步骤 {step_id}/{len(steps)}: 发现 {len(new_learnings)} 个新见解",
858 | "learnings": all_learnings + step_learnings,
859 | "visitedUrls": visited_urls + step_urls,
860 | "new_learnings": new_learnings,
861 | "formatted_new_learnings": formatted_learnings,
862 | "new_urls": step_urls,
863 | "formatted_new_urls": formatted_urls,
864 | "progress": {
865 | "current_step": step_id,
866 | "total_steps": len(steps),
867 | "processed_queries": len(current_queries)
868 | },
869 | "stage": "insights_found"
870 | }
871 |
872 | # Update visited URLs and learnings
873 | visited_urls.extend(step_urls)
874 | all_learnings.extend(step_learnings)
875 |
876 | # Save step summary
877 | step_summaries.append({
878 | "step_id": step_id,
879 | "description": description,
880 | "learnings": step_learnings,
881 | "urls": step_urls
882 | })
883 |
884 | # Format all step learnings for display
885 | formatted_step_learnings = []
886 | for i, learning in enumerate(step_learnings):
887 | formatted_step_learnings.append(f"[{i + 1}] {learning}")
888 |
889 | # Step completion update
890 | yield {
891 | "status_update": f"完成研究步骤 {step_id}/{len(steps)}: {description},获得 {len(step_learnings)} 个见解",
892 | "learnings": all_learnings,
893 | "visitedUrls": visited_urls,
894 | "step_learnings": step_learnings,
895 | "formatted_step_learnings": formatted_step_learnings,
896 | "step_urls": step_urls,
897 | "progress": {
898 | "current_step": step_id,
899 | "total_steps": len(steps),
900 | "completed": True
901 | },
902 | "stage": "step_completed"
903 | }
904 |
905 | enable_next_plan = config.get("research", {}).get("enable_next_plan", False)
906 | if enable_next_plan:
907 | # Perform final analysis
908 | yield {
909 | "status_update": "分析所有已收集的信息...",
910 | "learnings": all_learnings,
911 | "visitedUrls": visited_urls,
912 | "stage": "final_analysis"
913 | }
914 |
915 | steps_summary_text = "\n\n".join([
916 | f"步骤 {s['step_id']}: {s['description']}\n发现: {json.dumps(s['learnings'], ensure_ascii=False)}"
917 | for s in step_summaries
918 | ])
919 |
920 | future_research = RESEARCH_SUMMARY_PROMPT.format(
921 | query=refined_query,
922 | steps_summary=steps_summary_text,
923 | current_date=get_current_date()
924 | )
925 |
926 | future_research_result = await generate_json_completion(future_research, temperature=0.7)
927 |
928 | # Add final findings to learnings
929 | findings = []
930 | if "findings" in future_research_result:
931 | all_learnings.extend(future_research_result["findings"])
932 | findings = future_research_result["findings"]
933 |
934 | # Format final findings
935 | formatted_findings = []
936 | for i, finding in enumerate(findings):
937 | formatted_findings.append(f"[{i + 1}] {finding}")
938 |
939 | yield {
940 | "status_update": f"分析完成,得出 {len(findings)} 个主要发现",
941 | "learnings": all_learnings,
942 | "visitedUrls": visited_urls,
943 | "final_findings": findings,
944 | "formatted_final_findings": formatted_findings,
945 | "gaps": future_research_result.get("gaps", []),
946 | "recommendations": future_research_result.get("recommendations", []),
947 | "stage": "analysis_completed"
948 | }
949 | else:
950 | future_research_result = ""
951 | # No need to modify all_learnings when skipping summary
952 |
953 | yield {
954 | "status_update": "生成详细研究报告...",
955 | "learnings": all_learnings,
956 | "visitedUrls": visited_urls,
957 | "stage": "generating_report"
958 | }
959 |
960 | # Generate report (non-streaming for now)
961 | context = str(all_learnings) + '\n\n' + str(future_research_result)
962 | final_report = await write_final_report(refined_query, context, history_context)
963 |
964 | # Return compiled results
965 | yield {
966 | "status_update": "研究完成!",
967 | "query": refined_query,
968 | "originalQuery": query,
969 | "learnings": all_learnings,
970 | "visitedUrls": list(set(visited_urls)),
971 | "summary": future_research_result,
972 | "final_report": final_report,
973 | "stage": "completed"
974 | }
975 |
976 | except Exception as e:
977 | logger.error(f"Error in deep research stream: {str(e)}")
978 | logger.error(traceback.format_exc())
979 |
980 | yield {
981 | "status_update": f"错误: {str(e)}",
982 | "error": str(e),
983 | "learnings": all_learnings,
984 | "visitedUrls": visited_urls,
985 | "stage": "error"
986 | }
987 |
--------------------------------------------------------------------------------
/src/gradio_chat.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 |
6 | A simplified Gradio demo for Deep Research with basic conversation interface.
7 | """
8 |
9 | import time
10 | import gradio as gr
11 | from loguru import logger
12 | from .config import get_config
13 | from .deep_research import (
14 | deep_research_stream,
15 | generate_followup_questions,
16 | process_clarifications,
17 | write_final_report,
18 | should_clarify_query
19 | )
20 |
21 | # Load configuration
22 | config = get_config()
23 |
24 |
25 | def run_gradio_demo():
26 | """Run a modern Gradio demo for Deep Research using ChatInterface"""
27 | enable_clarification = config.get("research", {}).get("enable_clarification", False)
28 | search_source = config.get("research", {}).get("search_source", "tavily")
29 |
30 | # Conversation state (shared across functions)
31 | conversation_state = {
32 | "current_query": "",
33 | "needs_clarification": False,
34 | "questions": [],
35 | "waiting_for_clarification": False,
36 | "clarification_answers": {},
37 | "last_status": "", # 跟踪最后一个状态更新
38 | "history": [], # 保存当前聊天历史
39 | "search_source": search_source, # 存储搜索提供商
40 | "enable_clarification": enable_clarification # 存储是否启用澄清
41 | }
42 |
43 | async def research_with_thinking(message, history):
44 | """处理查询,展示研究过程并返回结果"""
45 | if not message:
46 | yield history # 空消息,直接返回
47 | return # 无值返回是允许的
48 |
49 | # 重置状态,确保多次查询之间状态不混淆
50 | conversation_state["last_status"] = ""
51 | conversation_state["current_query"] = ""
52 | conversation_state["questions"] = []
53 |
54 | # 判断是否是澄清回答,如果是则不重置waiting_for_clarification
55 | if not conversation_state["waiting_for_clarification"]:
56 | conversation_state["waiting_for_clarification"] = False
57 | conversation_state["clarification_answers"] = {}
58 |
59 | logger.debug(
60 | f"Starting research, message: {message}, search_source: {search_source}, enable_clarification: {enable_clarification}")
61 |
62 | # 构建历史上下文 - 直接使用history即可
63 | history_context = ''
64 | for msg in history:
65 | if isinstance(msg, dict) and msg.get("role") == "user":
66 | q = 'Q:' + msg.get("content", "") + '\n'
67 | history_context += q
68 |
69 | # 3. 检查是否是对澄清问题的回答
70 | if conversation_state["waiting_for_clarification"]:
71 | async for msg in handle_clarification_answer(message, history, history_context):
72 | yield msg
73 | return # 无值返回是允许的
74 |
75 | # 4. 创建研究过程消息并将其添加到历史记录
76 | messages = []
77 | messages.append({"role": "assistant", "content": "正在进行研究...", "metadata": {"title": "研究过程"}})
78 | yield messages
79 |
80 | # 5. 处理澄清环节
81 | if not enable_clarification:
82 | messages[-1]["content"] = "跳过澄清环节,直接开始研究..."
83 | yield messages
84 | else:
85 | # 分析查询是否需要澄清
86 | messages[-1]["content"] = "分析查询需求中..."
87 | yield messages
88 |
89 | needs_clarification = await should_clarify_query(message, history_context)
90 | if needs_clarification:
91 | messages[-1]["content"] = "生成澄清问题..."
92 | yield messages
93 |
94 | followup_result = await generate_followup_questions(message, history_context)
95 | questions = followup_result.get("questions", [])
96 |
97 | if questions:
98 | # 保存问题和状态
99 | conversation_state["current_query"] = message
100 | conversation_state["questions"] = questions
101 | conversation_state["waiting_for_clarification"] = True
102 |
103 | # 显示问题给用户
104 | clarification_msg = "请回答以下问题,帮助我更好地理解您的查询:"
105 | for i, q in enumerate(questions, 1):
106 | clarification_msg += f"\n{i}. {q.get('question', '')}"
107 |
108 | # 替换研究过程消息为澄清问题
109 | messages[-1] = {"role": "assistant", "content": clarification_msg}
110 | yield messages
111 | return # 等待用户回答
112 | else:
113 | messages[-1]["content"] = "无法生成有效的澄清问题,继续研究..."
114 | yield messages
115 | else:
116 | messages[-1]["content"] = "查询已足够清晰,开始研究..."
117 | yield messages
118 |
119 | # 6. 开始搜索
120 | messages[-1]["content"] = f"使用 {search_source} 搜索相关信息中..."
121 | yield messages
122 |
123 | # 7. 执行研究过程
124 | final_result = None
125 | research_log = []
126 |
127 | async for partial_result in deep_research_stream(
128 | query=message,
129 | search_source=search_source,
130 | history_context=history_context,
131 | enable_clarification=enable_clarification
132 | ):
133 | # 更新研究进度
134 | if partial_result.get("status_update"):
135 | status = partial_result.get("status_update")
136 | stage = partial_result.get("stage", "")
137 |
138 | # 检查状态是否有变化
139 | if status != conversation_state["last_status"]:
140 | conversation_state["last_status"] = status
141 |
142 | # 更新研究进度消息
143 | timestamp = time.strftime('%H:%M:%S')
144 | status_line = f"[{timestamp}] {status}"
145 | research_log.append(status_line)
146 |
147 | # 显示当前研究计划步骤
148 | if partial_result.get("current_step"):
149 | current_step = partial_result.get("current_step")
150 | step_id = current_step.get("step_id", "")
151 | description = current_step.get("description", "")
152 | step_line = f"当前步骤 {step_id}: {description}"
153 | research_log.append(step_line)
154 |
155 | # 显示当前查询
156 | if partial_result.get("current_queries"):
157 | queries = partial_result.get("current_queries")
158 | queries_lines = ["**当前并行查询**:"]
159 | for i, q in enumerate(queries, 1):
160 | queries_lines.append(f"{i}. {q}")
161 | research_log.append("\n".join(queries_lines))
162 |
163 | # 对于特定阶段,添加更多信息
164 | if stage == "plan_generated" and partial_result.get("research_plan"):
165 | research_plan = partial_result.get("research_plan")
166 | plan_lines = ["**研究计划**:"]
167 | for i, step in enumerate(research_plan):
168 | step_id = step.get("step_id", i + 1)
169 | description = step.get("description", "")
170 | plan_lines.append(f"步骤 {step_id}: {description}")
171 | research_log.append("\n".join(plan_lines))
172 |
173 | # 添加阶段详细信息
174 | if stage == "insights_found" and partial_result.get("formatted_new_learnings"):
175 | if partial_result.get("formatted_new_urls") and len(
176 | partial_result.get("formatted_new_urls")) > 0:
177 | research_log.append("\n**来源**:\n" + "\n".join(
178 | partial_result.get("formatted_new_urls", [])[:3]))
179 |
180 | elif stage == "step_completed" and partial_result.get("formatted_step_learnings"):
181 | research_log.append("\n**步骤总结**:\n" + "\n".join(
182 | partial_result.get("formatted_step_learnings", [])))
183 |
184 | elif stage == "analysis_completed" and partial_result.get("formatted_final_findings"):
185 | research_log.append("\n**主要发现**:\n" + "\n".join(
186 | partial_result.get("formatted_final_findings", [])))
187 |
188 | if partial_result.get("gaps"):
189 | research_log.append("\n\n**研究空白**:\n- " + "\n- ".join(partial_result.get("gaps", [])))
190 |
191 | # 合并所有日志并更新研究过程消息
192 | messages[-1]["content"] = "\n\n".join(research_log)
193 | yield messages
194 |
195 | # 保存最后一个结果用于生成报告
196 | final_result = partial_result
197 |
198 | # 如果有最终报告,跳出循环
199 | if "final_report" in partial_result:
200 | break
201 |
202 | # 8. 生成报告
203 | if final_result:
204 | # 如果直接在结果中有final_report,直接使用
205 | if "final_report" in final_result:
206 | report = final_result["final_report"]
207 | # 标记研究过程消息已完成
208 | research_process = messages[-1]["content"]
209 | messages[-1]["content"] = "研究完成,报告已生成。\n\n" + research_process
210 | yield messages
211 | else:
212 | # 否则,使用收集到的信息生成报告
213 | research_process = messages[-1]["content"]
214 | messages[-1]["content"] = "正在整合研究结果并生成报告...\n\n" + research_process
215 | yield messages
216 |
217 | learnings = final_result.get("learnings", [])
218 |
219 | try:
220 | report = await write_final_report(
221 | query=message,
222 | context=str(learnings),
223 | history_context=history_context
224 | )
225 | # 确保report不为None
226 | if report is None:
227 | report = "抱歉,无法生成研究报告。"
228 | logger.error(f"write_final_report returned None for query: {message}")
229 | except Exception as e:
230 | report = f"生成报告时出错: {str(e)}"
231 | logger.error(f"Error in write_final_report: {str(e)}")
232 |
233 | # 保留研究过程信息
234 | messages[-1]["content"] = "研究完成,报告已生成。\n\n" + research_process
235 | yield messages
236 |
237 | # 添加最终报告消息,但保留研究过程消息
238 | messages.append({"role": "assistant", "content": report})
239 | yield messages
240 | else:
241 | messages.append(
242 | {"role": "assistant", "content": "抱歉,我无法为您的查询生成研究报告。请尝试其他问题或稍后再试。"})
243 | yield messages
244 |
245 | async def handle_clarification_answer(message, history, history_context):
246 | """处理用户对澄清问题的回答"""
247 | # 重置等待标志
248 | conversation_state["waiting_for_clarification"] = False
249 |
250 | # 获取原始查询和问题
251 | query = conversation_state["current_query"]
252 | questions = conversation_state["questions"]
253 |
254 | # 重置状态,确保多次查询之间状态不混淆
255 | conversation_state["last_status"] = ""
256 |
257 | # 1. 创建消息列表并添加研究过程消息
258 | messages = []
259 | messages.append({"role": "assistant", "content": "解析您的澄清回答...", "metadata": {"title": "研究过程"}})
260 | yield messages
261 |
262 | # 2. 解析用户回答
263 | lines = [line.strip() for line in message.split('\n') if line.strip()]
264 | if len(lines) < len(questions):
265 | # 尝试逗号分隔
266 | if ',' in message:
267 | lines = [ans.strip() for ans in message.split(',')]
268 |
269 | # 3. 创建响应字典
270 | user_responses = {}
271 | for i, q in enumerate(questions):
272 | key = q.get("key", f"q{i}")
273 | if i < len(lines) and lines[i]:
274 | user_responses[key] = lines[i]
275 |
276 | # 4. 处理澄清内容
277 | messages[-1]["content"] = "处理您的澄清内容..."
278 | yield messages
279 |
280 | # 5. 处理澄清并优化查询
281 | clarification_result = await process_clarifications(
282 | query=query,
283 | user_responses=user_responses,
284 | all_questions=questions,
285 | history_context=history_context
286 | )
287 |
288 | # 6. 获取优化后的查询
289 | refined_query = clarification_result.get("refined_query", query)
290 | messages[-1]["content"] = f"已优化查询: {refined_query}"
291 | yield messages
292 |
293 | # 7. 检查是否可以直接回答
294 | if not clarification_result.get("requires_search", True) and clarification_result.get("direct_answer"):
295 | direct_answer = clarification_result.get("direct_answer", "")
296 |
297 | # 保留研究过程消息,并添加直接回答
298 | research_process = messages[-1]["content"]
299 | messages[-1]["content"] = "提供直接回答,无需搜索。\n\n" + research_process
300 | yield messages
301 |
302 | # 添加最终回答,但保留研究过程
303 | messages.append({"role": "assistant", "content": direct_answer})
304 | yield messages
305 |
306 | # 8. 开始搜索
307 | messages[-1]["content"] = "基于您的澄清开始搜索信息..."
308 | yield messages
309 |
310 | # 9. 执行研究过程
311 | final_result = None
312 | research_log = []
313 |
314 | async for partial_result in deep_research_stream(
315 | query=refined_query,
316 | user_clarifications=user_responses,
317 | search_source=search_source,
318 | history_context=history_context
319 | ):
320 | # 更新研究进度
321 | if partial_result.get("status_update"):
322 | status = partial_result.get("status_update")
323 | stage = partial_result.get("stage", "")
324 |
325 | # 检查状态是否有变化
326 | if status != conversation_state["last_status"]:
327 | conversation_state["last_status"] = status
328 |
329 | # 更新研究进度消息
330 | timestamp = time.strftime('%H:%M:%S')
331 | status_line = f"[{timestamp}] {status}"
332 | research_log.append(status_line)
333 |
334 | # 显示当前研究计划步骤
335 | if partial_result.get("current_step"):
336 | current_step = partial_result.get("current_step")
337 | step_id = current_step.get("step_id", "")
338 | description = current_step.get("description", "")
339 | step_line = f"当前步骤 {step_id}: {description}"
340 | research_log.append(step_line)
341 |
342 | # 显示当前查询
343 | if partial_result.get("current_queries"):
344 | queries = partial_result.get("current_queries")
345 | queries_lines = ["当前并行查询:"]
346 | for i, q in enumerate(queries, 1):
347 | queries_lines.append(f"{i}. {q}")
348 | research_log.append("\n".join(queries_lines))
349 |
350 | # 对于特定阶段,添加更多信息
351 | if stage == "plan_generated" and partial_result.get("research_plan"):
352 | research_plan = partial_result.get("research_plan")
353 | plan_lines = ["研究计划:"]
354 | for i, step in enumerate(research_plan):
355 | step_id = step.get("step_id", i + 1)
356 | description = step.get("description", "")
357 | plan_lines.append(f"步骤 {step_id}: {description}")
358 | research_log.append("\n".join(plan_lines))
359 |
360 | # 添加阶段详细信息
361 | if stage == "insights_found" and partial_result.get("formatted_new_learnings"):
362 | if partial_result.get("formatted_new_urls") and len(
363 | partial_result.get("formatted_new_urls")) > 0:
364 | research_log.append("\n**来源**:\n" + "\n".join(
365 | partial_result.get("formatted_new_urls", [])[:3]))
366 |
367 | elif stage == "step_completed" and partial_result.get("formatted_step_learnings"):
368 | research_log.append("\n**步骤总结**:\n" + "\n".join(
369 | partial_result.get("formatted_step_learnings", [])))
370 |
371 | elif stage == "analysis_completed" and partial_result.get("formatted_final_findings"):
372 | research_log.append("\n**主要发现**:\n" + "\n".join(
373 | partial_result.get("formatted_final_findings", [])))
374 |
375 | if partial_result.get("gaps"):
376 | research_log.append("\n\n**研究空白**:\n- " + "\n- ".join(partial_result.get("gaps", [])))
377 |
378 | # 合并所有日志并更新研究过程消息
379 | messages[-1]["content"] = "\n\n".join(research_log)
380 | yield messages
381 |
382 | # 保存最后一个结果用于生成报告
383 | final_result = partial_result
384 |
385 | # 如果有最终报告,跳出循环
386 | if "final_report" in partial_result:
387 | break
388 |
389 | # 10. 生成报告
390 | if final_result:
391 | # 如果直接在结果中有final_report,直接使用
392 | if "final_report" in final_result:
393 | report = final_result["final_report"]
394 | # 标记研究过程消息已完成
395 | research_process = messages[-1]["content"]
396 | messages[-1]["content"] = "研究完成,报告已生成。\n\n" + research_process
397 | yield messages
398 | else:
399 | # 否则,使用收集到的信息生成报告
400 | research_process = messages[-1]["content"]
401 | messages[-1]["content"] = "正在整合研究结果并生成报告...\n\n" + research_process
402 | yield messages
403 |
404 | learnings = final_result.get("learnings", [])
405 |
406 | try:
407 | report = await write_final_report(
408 | query=refined_query,
409 | context=str(learnings),
410 | history_context=history_context
411 | )
412 | # 确保report不为None
413 | if report is None:
414 | report = "抱歉,无法生成研究报告。"
415 | logger.error(f"returned None for query: {refined_query}")
416 | except Exception as e:
417 | report = f"生成报告时出错: {str(e)}"
418 | logger.error(f"Error in write_final_report: {str(e)}")
419 |
420 | # 保留研究过程信息
421 | messages[-1]["content"] = "研究完成,报告已生成。\n\n" + research_process
422 | yield messages
423 |
424 | # 添加最终报告消息,但保留研究过程消息
425 | messages.append({"role": "assistant", "content": report})
426 | yield messages
427 | else:
428 | messages.append(
429 | {"role": "assistant", "content": "抱歉,我无法为您的查询生成研究报告。请尝试其他问题或稍后再试。"})
430 | yield messages
431 |
432 | # 创建 ChatInterface
433 | demo = gr.ChatInterface(
434 | research_with_thinking,
435 | type='messages',
436 | title="🔍 Deep Research",
437 | description="使用此工具进行深度研究,我将搜索互联网为您找到回答。Powered by [Deep Research](https://github.com/shibing624/deep-research) Made with ❤️ by [shibing624](https://github.com/shibing624)",
438 | examples=[
439 | ["特斯拉股票的最新行情?"],
440 | ["介绍一下最近的人工智能技术发展趋势"],
441 | ["中国2024年GDP增长了多少?"],
442 | ["Explain the differences between supervised and unsupervised machine learning."]
443 | ]
444 | )
445 |
446 | # 启动界面
447 | demo.queue()
448 | demo.launch(server_name="0.0.0.0", share=False, server_port=7860, show_api=False)
449 |
450 |
451 | if __name__ == "__main__":
452 | run_gradio_demo()
453 |
--------------------------------------------------------------------------------
/src/model_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 |
6 | Utility functions for interacting with language models.
7 | """
8 |
9 | import json
10 | from typing import Dict, Any, List, Optional, AsyncGenerator, Union
11 | from loguru import logger
12 |
13 | from .providers import get_model
14 |
15 |
16 | async def generate_completion(
17 | prompt: str,
18 | system_message: str = "You are a helpful AI research assistant.",
19 | temperature: float = 0.7,
20 | json_response: bool = False,
21 | stream: bool = False,
22 | is_report: bool = False
23 | ) -> Union[str, AsyncGenerator[str, None]]:
24 | """
25 | Generate a completion from the language model.
26 |
27 | Args:
28 | prompt: The prompt to send to the model
29 | system_message: The system message to use
30 | temperature: The temperature to use for generation
31 | json_response: Whether to request a JSON response
32 | stream: Whether to stream the response
33 | is_report: Whether to use the report model configuration
34 |
35 | Returns:
36 | The model's response as a string or an async generator of response chunks
37 | """
38 | model_config = get_model(is_report=is_report)
39 |
40 | messages = [
41 | {"role": "system", "content": system_message},
42 | {"role": "user", "content": prompt}
43 | ]
44 |
45 | # Prepare the request parameters
46 | request_params = {
47 | "model": model_config["model"],
48 | "messages": messages,
49 | "temperature": temperature,
50 | }
51 |
52 | # Add response format if JSON is requested
53 | if json_response:
54 | request_params["response_format"] = {"type": "json_object"}
55 |
56 | # Add streaming parameter if requested
57 | if stream:
58 | request_params["stream"] = True
59 |
60 | try:
61 | # Make the API call
62 | response = await model_config["async_client"].chat.completions.create(**request_params)
63 |
64 | if stream:
65 | # Return an async generator for streaming responses
66 | async def response_generator():
67 | collected_chunks = []
68 | async for chunk in response:
69 | if chunk.choices and chunk.choices[0].delta.content:
70 | content = chunk.choices[0].delta.content
71 | collected_chunks.append(content)
72 | yield content
73 |
74 | # If no chunks were yielded, yield the empty string
75 | if not collected_chunks:
76 | yield ""
77 |
78 | return response_generator()
79 | else:
80 | # Return the full response for non-streaming
81 | res = response.choices[0].message.content
82 | logger.debug(f"prompt: {prompt}\n\nGenerated completion: {res}")
83 | return res
84 |
85 | except Exception as e:
86 | logger.error(f"Error generating completion: {str(e)}")
87 | if stream:
88 | # Return an empty generator for streaming
89 | async def error_generator():
90 | yield f"Error: {str(e)}"
91 |
92 | return error_generator()
93 | else:
94 | # Return an error message for non-streaming
95 | return f"Error: {str(e)}"
96 |
97 |
98 | async def generate_json_completion(
99 | prompt: str,
100 | system_message: str = "You are a helpful AI research assistant.",
101 | temperature: float = 0.7
102 | ) -> Dict[str, Any]:
103 | """
104 | Generate a JSON completion from the language model.
105 |
106 | Args:
107 | prompt: The prompt to send to the model
108 | system_message: The system message to use
109 | temperature: The temperature to use for generation
110 |
111 | Returns:
112 | The model's response parsed as a JSON object
113 | """
114 | response_text = ""
115 | try:
116 | response_text = await generate_completion(
117 | prompt=prompt,
118 | system_message=system_message,
119 | temperature=temperature,
120 | json_response=True
121 | )
122 |
123 | # Parse the JSON response
124 | result = json.loads(response_text)
125 | return result
126 |
127 | except json.JSONDecodeError as e:
128 | logger.error(f"Error parsing JSON response: {str(e)}")
129 | logger.error(f"Response text: {response_text}")
130 | return {}
131 |
132 | except Exception as e:
133 | logger.error(f"Error in generate_json_completion: {str(e)}")
134 | return {}
135 |
--------------------------------------------------------------------------------
/src/mp_search_client.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import httpx
7 | import aiohttp
8 | from typing import Dict, Any, Optional
9 | import json
10 | from loguru import logger
11 |
12 | from .config import get_config
13 |
14 |
15 | class MPSearchClient:
16 | """Client for MP Search API."""
17 |
18 | def __init__(self):
19 | config = get_config()
20 | self.api_key = config.get("mp_search", {}).get("api_key", "")
21 | self.base_url = config.get("mp_search", {}).get("base_url", "https://api.mpsrch.com/v1/search")
22 | self.forward_service = config.get("mp_search", {}).get("forward_service", "hyaide-application-1111")
23 | self.client = httpx.Client(timeout=30.0)
24 |
25 | def search_sync(self, query: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
26 | """
27 | Perform a search using MP Search API.
28 |
29 | Args:
30 | query: Search query
31 | options: Additional options for the search
32 |
33 | Returns:
34 | Dict containing search results
35 | """
36 | if not self.api_key:
37 | raise ValueError("MP Search API key not configured")
38 |
39 | if options is None:
40 | options = {}
41 |
42 | # Default payload
43 | payload = {
44 | "query": query,
45 | "forward_service": self.forward_service,
46 | "query_id": options.get("query_id", f"qid_{hash(query)}"),
47 | "stream": options.get("stream", False)
48 | }
49 |
50 | headers = {
51 | "Authorization": f"Bearer {self.api_key}"
52 | }
53 |
54 | try:
55 | logger.debug(f"Searching with MP Search API: {query}")
56 | response = self.client.post(
57 | self.base_url,
58 | headers=headers,
59 | json=payload
60 | )
61 | response.raise_for_status()
62 | result = response.json()
63 |
64 | # Transform the result to match the expected format
65 | transformed_result = self._transform_result(result, query)
66 | logger.debug(f"Transformed result: {transformed_result}")
67 | return transformed_result
68 |
69 | except Exception as e:
70 | logger.error(f"Error searching with MP Search API: {str(e)}")
71 | raise
72 |
73 | async def search(self, query: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
74 | """
75 | Perform an async search using MP Search API.
76 |
77 | Args:
78 | query: Search query
79 | options: Additional options for the search
80 |
81 | Returns:
82 | Dict containing search results
83 | """
84 | if not self.api_key:
85 | raise ValueError("MP Search API key not configured")
86 |
87 | if options is None:
88 | options = {}
89 |
90 | # Default payload
91 | payload = {
92 | "query": query,
93 | "forward_service": self.forward_service,
94 | "query_id": options.get("query_id", f"qid_{hash(query)}"),
95 | "stream": options.get("stream", False)
96 | }
97 |
98 | headers = {
99 | "Authorization": f"Bearer {self.api_key}"
100 | }
101 |
102 | try:
103 | logger.debug(f"Searching with MP Search API: {query}")
104 | async with aiohttp.ClientSession() as session:
105 | async with session.post(
106 | self.base_url,
107 | headers=headers,
108 | json=payload
109 | ) as response:
110 | response.raise_for_status()
111 |
112 | text_content = await response.text()
113 | try:
114 | # Try to parse as JSON anyway
115 | result = json.loads(text_content)
116 | except json.JSONDecodeError:
117 | logger.error(f"Error parsing JSON response: {text_content}")
118 | result = ''
119 |
120 | # Transform the result to match the expected format
121 | transformed_result = self._transform_result(result, query)
122 | logger.debug(f"Transformed result: {transformed_result}")
123 | return transformed_result
124 |
125 | except Exception as e:
126 | logger.error(f"Error searching with MP Search API: {str(e)}")
127 | raise
128 |
129 | def _transform_result(self, mp_result: Dict[str, Any], query: str) -> Dict[str, Any]:
130 | """
131 | Transform MP Search API result to match the expected format.
132 |
133 | Args:
134 | mp_result: Raw result from MP Search API
135 | query: Original search query
136 |
137 | Returns:
138 | Transformed result in compatible format
139 | """
140 | transformed_data = []
141 | logger.debug(f"Transforming MP Search result: {mp_result}")
142 | try:
143 | # 直接使用result字段作为内容
144 | if "result" in mp_result and mp_result["result"]:
145 | res = None
146 | if isinstance(mp_result["result"], str):
147 | res = json.loads(mp_result["result"])
148 | for item in res:
149 | if "value" in item and item['value']:
150 | val = item['value']
151 | title, content, dt, url = val.split(" ||| ")
152 | transformed_item = {
153 | "url": url,
154 | "title": title,
155 | "content": str(content)[:4000],
156 | "source": "mp_search"
157 | }
158 | transformed_data.append(transformed_item)
159 | except Exception as e:
160 | logger.error(f"Error transforming MP Search result: {str(e)}")
161 | # Add a fallback item with the error
162 | transformed_item = {
163 | "url": "",
164 | "title": "MP Search Error",
165 | "content": f"Error processing result: {str(e)}\nRaw result: {mp_result}",
166 | "source": "mp_search"
167 | }
168 | transformed_data.append(transformed_item)
169 |
170 | return {
171 | "query": query,
172 | "data": transformed_data
173 | }
174 |
--------------------------------------------------------------------------------
/src/prompts.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 |
6 | Prompts used for deep research functionality.
7 | """
8 |
9 | SHOULD_CLARIFY_QUERY_PROMPT = """
10 | 请判断以下查询是否需要澄清问题。
11 | 一个好的查询应该明确、具体且包含足够的上下文。
12 | 如果查询模糊、缺少重要上下文、过于宽泛或包含多个可能的解释,则需要澄清。
13 |
14 | 对话历史:
15 | ```
16 | {history_context}
17 | ```
18 |
19 | 查询是: ```{query}```
20 |
21 | 当前日期是{current_date}。
22 |
23 | 请只回答 "yes" 或 "no"。如果查询已经足够清晰,请回答"no"。
24 | """
25 |
26 | # Prompt for generating follow-up questions
27 | FOLLOW_UP_QUESTIONS_PROMPT = """
28 | You are an expert researcher and I need your help to generate clarifying questions for a given research query.
29 |
30 | chat history:
31 | ```
32 | {history_context}
33 | ```
34 |
35 | The query is: ```{query}```
36 |
37 | Based on this query, please generate clarifying questions that would help you better understand what the user is looking for.
38 | For effective questions:
39 | 1. Identify ambiguous terms or concepts that need clarification
40 | 2. Ask about the scope or timeframe of interest
41 | 3. Check if there are specific aspects the user is most interested in
42 | 4. Consider what background information might be helpful
43 | 5. Ask about intended use of the information (academic, personal interest, decision-making, etc.)
44 |
45 | - User's query is written in Chinese, 需要用中文输出.
46 | - 当前日期是{current_date}。
47 |
48 | Format your response as a valid JSON object with the following structure:
49 | {{
50 | "needs_clarification": true/false (boolean indicating if clarification questions are needed),
51 | "questions": [
52 | {{
53 | "key": "specific_key_1",
54 | "question": "The clarifying question text",
55 | "default": "A reasonable default answer if the user doesn't provide one"
56 | }},
57 | ... additional questions ...
58 | ]
59 | }}
60 |
61 | If the query seems clear enough and doesn't require clarification, return "needs_clarification": false with an empty questions array.
62 | For simple factual queries or clear requests, clarification is usually not needed.
63 | """
64 |
65 | # Prompt for processing clarifications
66 | PROCESS_CLARIFICATIONS_PROMPT = """
67 | I'm reviewing a user query with clarification questions and their responses.
68 |
69 | Chat history: ```
70 | {history_context}
71 | ```
72 |
73 | Original query: ```{query}```
74 |
75 | Clarification questions and responses:
76 | ```
77 | {clarifications}
78 | ```
79 |
80 | Questions that were not answered:
81 | ```
82 | {unanswered_questions}
83 | ```
84 |
85 | Based on this information, please:
86 | 1. Summarize the original query with the additional context provided by the clarifications
87 | 2. For questions that were not answered, use reasonable default assumptions and clearly state what you're assuming
88 | 3. Identify if this is a simple factual query that doesn't require search
89 | - User's query is written in Chinese, 需要用中文输出.
90 | - 当前日期是{current_date}。
91 |
92 | Format your response as a valid JSON object with the following structure:
93 | {{
94 | "refined_query": "The refined and clarified query",
95 | "assumptions": ["List of assumptions made for unanswered questions"],
96 | "requires_search": true/false (boolean indicating if this query needs web search or can be answered directly),
97 | "direct_answer": "If requires_search is false, provide a direct answer here, otherwise empty string"
98 | }}
99 | """
100 |
101 | # Prompt for no clarifications needed
102 | PROCESS_NO_CLARIFICATIONS_PROMPT = """
103 | I'm reviewing a user query where they chose not to provide any clarifications.
104 |
105 | Chat history:
106 | ```
107 | {history_context}
108 | ```
109 |
110 | Original query: ```{query}```
111 |
112 | The user was asked the following clarification questions but chose not to answer any:
113 | ```
114 | {unanswered_questions}
115 | ```
116 |
117 | Since the user didn't provide any clarifications, please:
118 | 1. Analyze the original query as comprehensively as possible
119 | 2. Make reasonable assumptions for all ambiguous aspects
120 | 3. Determine if this is a simple factual query that doesn't require search
121 | 4. If possible, provide a direct answer along with the refined query
122 | - User's query is written in Chinese, 需要用中文输出.
123 | - 当前日期是{current_date}。
124 |
125 | Format your response as a valid JSON object with the following structure:
126 | {{
127 | "refined_query": "The refined query with all possible considerations",
128 | "assumptions": ["List of all assumptions made"],
129 | "requires_search": true/false (boolean indicating if this query needs web search or can be answered directly),
130 | "direct_answer": "If requires_search is false, provide a comprehensive direct answer here, otherwise empty string"
131 | }}
132 |
133 | Since the user chose not to provide clarifications, be as thorough and comprehensive as possible in your analysis and answer.
134 | """
135 |
136 | # Prompt for generating research plan
137 | RESEARCH_PLAN_PROMPT = """
138 | You are an expert researcher creating a flexible research plan for a given query.
139 |
140 | Chat history:
141 | ```
142 | {history_context}
143 | ```
144 |
145 | QUERY: ```{query}```
146 |
147 | Please analyze this query and create an appropriate research plan. The number of steps should vary based on complexity:
148 | - For simple questions, you might need only 1 steps
149 | - For moderately complex questions, 2 steps may be appropriate
150 | - For very complex questions, 3 or more steps may be needed
151 | - User's query is written in Chinese, 需要用中文输出.
152 | - 当前日期是{current_date}。
153 |
154 | Consider:
155 | 1. The complexity of the query
156 | 2. Whether multiple angles of research are needed
157 | 3. If the topic requires exploration of causes, effects, comparisons, or historical context
158 | 4. If the topic is controversial and needs different perspectives
159 |
160 | Format your response as a valid JSON object with the following structure:
161 | {{
162 | "assessments": "Brief assessment of query complexity and reasoning",
163 | "steps": [
164 | {{
165 | "step_id": 1,
166 | "description": "Description of this research step",
167 | "search_queries": ["search query 1", "search query 2", ...],
168 | "goal": "What this step aims to discover"
169 | }},
170 | ... additional steps as needed ...
171 | ]
172 | }}
173 |
174 | Make each step logical and focused on a specific aspect of the research. Steps should build on each other,
175 | and search queries should be specific and effective for web search.
176 | """
177 |
178 | # Prompt for extract search results
179 | EXTRACT_SEARCH_RESULTS_SYSTEM_PROMPT = "You are an expert in extracting the most relevant and detailed information from search results."
180 | EXTRACT_SEARCH_RESULTS_PROMPT = """
181 | User query: ```{query}```
182 |
183 | search result(Webpage Content):
184 | ```
185 | {search_results}
186 | ```
187 |
188 | - User's query is written in Chinese, 需要用中文输出.
189 | - 当前日期是{current_date}。
190 |
191 | 作为信息提取专家,请从网页内容中提取与用户查询最相关的核心片段。需要提取的内容要求:
192 | 1. 包含具体的细节、数据、定义和重要论点,不要使用笼统的总结替代原始的详细内容
193 | 2. 保留原文中的关键事实、数字、日期和引用
194 | 3. 提取完整的相关段落,而不仅仅是简短的摘要
195 | 4. 特别关注可以直接回答用户查询的内容
196 | 5. 如果内容包含表格或列表中的重要信息,请完整保留这些结构化数据
197 |
198 | Output your response in the following JSON format:
199 | {{
200 | "extracted_infos": [
201 | {{
202 | "info": "核心片段1,包含详细内容、数据和定义等",
203 | "url": "url 1",
204 | "relevance": "解释这段内容与查询的相关性"
205 | }},
206 | {{
207 | "info": "核心片段2,包含详细内容、数据和定义等",
208 | "url": "url 2",
209 | "relevance": "解释这段内容与查询的相关性"
210 | }},
211 | ...
212 | ]
213 | }}
214 |
215 | - info: 保留原文格式的关键信息片段,包含详细内容而非简单摘要
216 | - url: 信息来源的网页URL
217 | - relevance: 简要说明这段内容如何回答了用户的查询
218 | """
219 |
220 | # Prompt for final research summary
221 | RESEARCH_SUMMARY_PROMPT = """
222 | Based on our research, we've explored the query: ```{query}```
223 |
224 | Research Summary by Step:
225 | ```
226 | {steps_summary}
227 | ```
228 |
229 | Please analyze this information and provide:
230 | 1. A set of key findings that answer the main query
231 | 2. Identification of any areas where the research is lacking or more information is needed
232 | - User's query is written in Chinese, 需要用中文输出.
233 | - 当前日期是{current_date}。
234 |
235 | Format your response as a valid JSON object with:
236 | {{
237 | "findings": [{{"finding": "finding 1", "url": "cite url 1"}}, {{"finding": "finding 2", "url": "cite url 2"}}, ...], (key conclusions from the research, and the cite url)
238 | "gaps": ["gap 1", "gap 2", ...], (areas where more research is needed)
239 | "recommendations": ["recommendation 1", "recommendation 2", ...] (suggestions for further research directions)
240 | }}
241 | """
242 |
243 | FINAL_REPORT_SYSTEM_PROMPT = """You are an expert researcher. Follow these instructions when responding:
244 | - You may be asked to research subjects that is after your knowledge cutoff, assume the user is right when presented with news.
245 | - The user is a highly experienced analyst, no need to simplify it, be as detailed as possible and make sure your response is correct.
246 | - Be highly organized.
247 | - Suggest solutions that I didn't think about.
248 | - Be proactive and anticipate my needs.
249 | - Treat me as an expert in all subject matter.
250 | - Mistakes erode my trust, so be accurate and thorough.
251 | - Provide detailed explanations, I'm comfortable with lots of detail.
252 | - Value good arguments over authorities, the source is irrelevant.
253 | - Consider new technologies and contrarian ideas, not just the conventional wisdom.
254 | - User's query is written in Chinese, 需要用中文输出.
255 | - 当前日期是{current_date}。
256 | """
257 |
258 | # Prompt for final report
259 | FINAL_REPORT_PROMPT = """
260 | I've been researching the following query: ```{query}```
261 |
262 | Please write a comprehensive research report on this topic.
263 | The report should be well-structured with headings, subheadings, and a conclusion.
264 |
265 | [要求]:
266 | - 输出markdown格式的回答。
267 | - [context]是参考资料,回答中需要包含引用来源,格式为 [cite](url) ,其中url是实际的链接。
268 | - 除代码、专名外,你必须使用与问题相同语言回答。
269 |
270 | Chat history:
271 | ```
272 | {history_context}
273 | ```
274 |
275 | [context]:
276 | ```
277 | {context}
278 | ```
279 | """
280 |
281 | # Prompt for final answer
282 | FINAL_ANSWER_PROMPT = """
283 | I've been researching the following query: ```{query}```
284 |
285 | 详细、专业回答用户的query。
286 |
287 | [要求]:
288 | - 输出markdown格式的回答。
289 | - [context]是参考资料,回答中需要包含引用来源,格式为 [cite](url) ,其中url是实际的链接。
290 | - 除代码、专名外,你必须使用与问题相同语言回答。
291 |
292 | Chat history:
293 | ```
294 | {history_context}
295 | ```
296 |
297 | [context]:
298 | ```
299 | {context}
300 | ```
301 | """
302 |
--------------------------------------------------------------------------------
/src/providers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | from typing import Dict, Any, Optional
7 | import openai
8 | from .config import get_config
9 |
10 |
11 | def get_model(is_report: bool = False) -> Dict[str, Any]:
12 | """
13 | Get model configuration including client and model name.
14 |
15 | :param is_report: Whether to get the model configuration for a report
16 |
17 | Returns:
18 | Dict containing model configuration
19 | """
20 | config = get_config()
21 | if is_report:
22 | report_config = config.get("report_llm", {})
23 | api_key = report_config.get("api_key", "")
24 | model = report_config.get("model", "gpt-4o")
25 | base_url = report_config.get("base_url", None)
26 | else:
27 | openai_config = config.get("openai", {})
28 | api_key = openai_config.get("api_key", "")
29 | model = openai_config.get("model", "gpt-4o-mini")
30 | base_url = openai_config.get("base_url", None)
31 |
32 | # Initialize OpenAI client
33 | client_args = {"api_key": api_key}
34 | if base_url:
35 | client_args["base_url"] = base_url
36 |
37 | client = openai.OpenAI(**client_args)
38 | async_client = openai.AsyncOpenAI(**client_args)
39 |
40 | return {
41 | "client": client,
42 | "async_client": async_client,
43 | "model": model
44 | }
45 |
46 |
47 | def get_search_provider(search_source=None):
48 | """
49 | Get the appropriate search provider based on configuration.
50 |
51 | Returns:
52 | An instance of the search provider class
53 | """
54 | if search_source is None:
55 | config = get_config()
56 | search_source = config.get("research", {}).get("search_source", "serper")
57 |
58 | if search_source == "mp_search":
59 | from .mp_search_client import MPSearchClient
60 | return MPSearchClient()
61 | elif search_source == "tavily":
62 | from .tavily_client import TavilyClient
63 | return TavilyClient()
64 | else: # Default to serper
65 | from .serper_client import SerperClient
66 | return SerperClient()
67 |
--------------------------------------------------------------------------------
/src/search_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | Utility functions for handling search operations, including concurrent searches.
6 | """
7 |
8 | import asyncio
9 | from typing import Dict, Any, List, Optional, Callable
10 | from loguru import logger
11 |
12 | from .config import get_config
13 | from .providers import get_search_provider
14 |
15 |
16 | async def search_with_query(
17 | query: str,
18 | config: Dict[str, Any] = None,
19 | search_provider=None
20 | ) -> Dict[str, Any]:
21 | """
22 | Search results for a query.
23 |
24 | Args:
25 | query: The query to search for
26 | config: Configuration dictionary
27 | search_provider: Search provider instance
28 |
29 | Returns:
30 | Dict with summary and URLs
31 | """
32 | if config is None:
33 | config = get_config()
34 |
35 | if search_provider is None:
36 | search_provider = get_search_provider()
37 |
38 | try:
39 | # Perform the search
40 | results = await search_provider.search(query)
41 |
42 | # For web search, ensure URLs are captured
43 | if hasattr(search_provider, 'get_organic_urls'):
44 | urls = search_provider.get_organic_urls()
45 | else:
46 | urls = []
47 |
48 | # Default to 2 results if not specified in config
49 | max_results = config.get("research", {}).get("max_results_per_query", 2)
50 |
51 | # Ensure results is a list before slicing
52 | if isinstance(results, list):
53 | results_to_process = results[:max_results]
54 | elif isinstance(results, dict) and "data" in results and isinstance(results["data"], list):
55 | # Some search providers return a dict with a "data" field containing the results
56 | results_to_process = results["data"][:max_results]
57 | else:
58 | # If results is not a list or a dict with a "data" field, use an empty list
59 | results_to_process = []
60 | logger.warning(f"Unexpected search results format: {type(results)}")
61 |
62 | search_results_text = results_to_process if results_to_process else "No search results found."
63 | # Return the formatted results and URLs
64 | return {
65 | "summary": search_results_text,
66 | "urls": urls,
67 | "raw_results": results_to_process
68 | }
69 |
70 | except Exception as e:
71 | logger.error(f"Error in search: {str(e)}")
72 | return {
73 | "summary": f"Error searching for '{query}': {str(e)}",
74 | "urls": [],
75 | "raw_results": []
76 | }
77 |
78 |
79 | async def concurrent_search(
80 | queries: List[str],
81 | config: Dict[str, Any] = None,
82 | search_provider=None
83 | ) -> List[Dict[str, Any]]:
84 | """
85 | Perform concurrent searches for multiple queries.
86 |
87 | Args:
88 | queries: List of queries to search for
89 | config: Configuration dictionary
90 | search_provider: Search provider instance
91 |
92 | Returns:
93 | List of search results
94 | """
95 | if config is None:
96 | config = get_config()
97 |
98 | if search_provider is None:
99 | search_provider = get_search_provider()
100 |
101 | # Get the concurrency limit from config
102 | concurrency_limit = config.get("research", {}).get("concurrency_limit", 1)
103 |
104 | # Create a semaphore to limit concurrency
105 | semaphore = asyncio.Semaphore(concurrency_limit)
106 |
107 | async def search_with_semaphore(query: str) -> Dict[str, Any]:
108 | """Perform a search with semaphore-based concurrency control."""
109 | async with semaphore:
110 | return await search_with_query(query, config, search_provider)
111 |
112 | # Create tasks for all queries
113 | tasks = [search_with_semaphore(query) for query in queries]
114 |
115 | # Run all tasks concurrently and gather results
116 | results = await asyncio.gather(*tasks)
117 |
118 | return results
119 |
--------------------------------------------------------------------------------
/src/serper_client.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import os
7 | import json
8 | import aiohttp
9 | import httpx
10 | from typing import Dict, List, Any, Union, Optional
11 | from loguru import logger
12 |
13 | from .config import get_config
14 |
15 |
16 | class SerperClient:
17 | """Client for the Serper.dev API to perform web searches."""
18 |
19 | def __init__(self):
20 | config = get_config()
21 | self.api_key = config.get("serper", {}).get("api_key", "")
22 |
23 | if not self.api_key:
24 | logger.warning("No Serper API key found. Searches will fail.")
25 |
26 | self.api_url = "https://google.serper.dev/search"
27 | self.headers = {
28 | "X-API-KEY": self.api_key,
29 | "Content-Type": "application/json"
30 | }
31 | self.organic_urls = []
32 |
33 | def search_sync(self, query: str) -> List[Dict[str, Any]]:
34 | """
35 | Perform a search using the Serper API.
36 |
37 | Args:
38 | query: Search query
39 |
40 | Returns:
41 | List of search result items
42 | """
43 | try:
44 | payload = json.dumps({
45 | "q": query
46 | })
47 |
48 | response = httpx.post(self.api_url, headers=self.headers, data=payload)
49 | response.raise_for_status()
50 |
51 | result = response.json()
52 | self.organic_urls = self._extract_urls(result)
53 |
54 | # Format the results for consumption
55 | formatted_results = self._format_results(result)
56 | return formatted_results
57 |
58 | except Exception as e:
59 | logger.error(f"Error searching with Serper: {str(e)}")
60 | return []
61 |
62 | async def search(self, query: str) -> List[Dict[str, Any]]:
63 | """
64 | Perform an async search using the Serper API.
65 |
66 | Args:
67 | query: Search query
68 |
69 | Returns:
70 | List of search result items
71 | """
72 | try:
73 | payload = json.dumps({
74 | "q": query
75 | })
76 |
77 | async with aiohttp.ClientSession() as session:
78 | async with session.post(self.api_url, headers=self.headers, data=payload) as response:
79 | result = await response.json()
80 |
81 | if response.status != 200:
82 | logger.error(f"Serper API error: {result}")
83 | return []
84 |
85 | self.organic_urls = self._extract_urls(result)
86 |
87 | # Format the results for consumption
88 | formatted_results = self._format_results(result)
89 | return formatted_results
90 |
91 | except Exception as e:
92 | logger.error(f"Error searching with Serper: {str(e)}")
93 | return []
94 |
95 | def _extract_urls(self, result: Dict[str, Any]) -> List[str]:
96 | """
97 | Extract URLs from search results.
98 |
99 | Args:
100 | result: Search result object
101 |
102 | Returns:
103 | List of URLs
104 | """
105 | urls = []
106 |
107 | # Extract organic results
108 | if "organic" in result:
109 | for item in result["organic"]:
110 | if "link" in item:
111 | urls.append(item["link"])
112 |
113 | return urls
114 |
115 | def _format_results(self, result: Dict[str, Any]) -> List[Dict[str, Any]]:
116 | """
117 | Format search results into a standardized format.
118 |
119 | Args:
120 | result: Search result object
121 |
122 | Returns:
123 | List of formatted result items
124 | """
125 | formatted_results = []
126 |
127 | # Format organic results
128 | if "organic" in result:
129 | for item in result["organic"]:
130 | formatted_item = {
131 | "title": item.get("title", ""),
132 | "url": item.get("link", ""),
133 | "snippet": item.get("snippet", ""),
134 | "content": f"{item.get('title', '')} - {item.get('snippet', '')}"
135 | }
136 | formatted_results.append(formatted_item)
137 |
138 | # Include featured snippet if available
139 | if "answerBox" in result:
140 | answer_box = result["answerBox"]
141 | formatted_item = {
142 | "title": answer_box.get("title", "Featured Snippet"),
143 | "url": answer_box.get("link", ""),
144 | "snippet": answer_box.get("snippet", ""),
145 | "content": answer_box.get("answer", answer_box.get("snippet", ""))
146 | }
147 | formatted_results.insert(0, formatted_item)
148 |
149 | return formatted_results
150 |
151 | def get_organic_urls(self) -> List[str]:
152 | """
153 | Get the URLs of organic search results from the last search.
154 |
155 | Returns:
156 | List of URLs
157 | """
158 | return self.organic_urls
159 |
--------------------------------------------------------------------------------
/src/tavily_client.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 |
6 | Client for Tavily Search API.
7 |
8 | curl -X POST https://api.tavily.com/search \\
9 | -H 'Content-Type: application/json' \\
10 | -H 'Authorization: Bearer tvly-dev-xxx' \\
11 | -d '{
12 | "query": "good"
13 | }'
14 |
15 | Example response:
16 | {
17 | "query": "good",
18 | "follow_up_questions": null,
19 | "answer": null,
20 | "images": [],
21 | "results": [
22 | {
23 | "title": "GOOD Definition & Meaning - Merriam-Webster",
24 | "url": "https://www.merriam-webster.com/dictionary/good",
25 | "content": "good a good many of us good : something that is good...",
26 | "score": 0.7283819,
27 | "raw_content": null
28 | },
29 | ...
30 | ],
31 | "response_time": 1.38
32 | }
33 | """
34 | import httpx
35 | import aiohttp
36 | from typing import Dict, Any, Optional
37 | from loguru import logger
38 |
39 | from .config import get_config
40 |
41 |
42 | class TavilyClient:
43 | """Client for Tavily Search API."""
44 |
45 | def __init__(self):
46 | config = get_config()
47 | self.api_key = config.get("tavily", {}).get("api_key")
48 | self.base_url = config.get("tavily", {}).get("base_url", "https://api.tavily.com/search")
49 | self.client = httpx.Client(timeout=30.0)
50 |
51 | def search_sync(self, query: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
52 | """
53 | Perform a search using Tavily API.
54 |
55 | Args:
56 | query: Search query
57 | options: Additional options for the search
58 |
59 | Returns:
60 | Dict containing search results
61 | """
62 | if not self.api_key:
63 | raise ValueError("Tavily API key not configured")
64 |
65 | if options is None:
66 | options = {}
67 |
68 | # Default payload for Tavily API
69 | payload = {
70 | "query": query
71 | }
72 |
73 | # Add optional parameters if provided
74 | if "search_depth" in options:
75 | payload["search_depth"] = options.get("search_depth")
76 | if "include_domains" in options and options["include_domains"]:
77 | payload["include_domains"] = options.get("include_domains")
78 | if "exclude_domains" in options and options["exclude_domains"]:
79 | payload["exclude_domains"] = options.get("exclude_domains")
80 | if "max_results" in options:
81 | payload["max_results"] = options.get("max_results")
82 |
83 | # Use Authorization Bearer format as per the official API example
84 | headers = {
85 | "Content-Type": "application/json",
86 | "Authorization": f"Bearer {self.api_key}"
87 | }
88 |
89 | try:
90 | logger.debug(f"Searching with Tavily API: {query}")
91 | response = self.client.post(
92 | self.base_url,
93 | headers=headers,
94 | json=payload
95 | )
96 | response.raise_for_status()
97 | result = response.json()
98 |
99 | # Transform the result to match the expected format
100 | transformed_result = self._transform_result(result, query)
101 | return transformed_result
102 |
103 | except Exception as e:
104 | logger.error(f"Error searching with Tavily API: {str(e)}")
105 | raise
106 |
107 | async def search(self, query: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
108 | """
109 | Perform an async search using Tavily API.
110 |
111 | Args:
112 | query: Search query
113 | options: Additional options for the search
114 |
115 | Returns:
116 | Dict containing search results
117 | """
118 | if not self.api_key:
119 | raise ValueError("Tavily API key not configured")
120 |
121 | if options is None:
122 | options = {}
123 |
124 | # Default payload for Tavily API
125 | payload = {
126 | "query": query
127 | }
128 |
129 | # Add optional parameters if provided
130 | if "search_depth" in options:
131 | payload["search_depth"] = options.get("search_depth")
132 | if "include_domains" in options and options["include_domains"]:
133 | payload["include_domains"] = options.get("include_domains")
134 | if "exclude_domains" in options and options["exclude_domains"]:
135 | payload["exclude_domains"] = options.get("exclude_domains")
136 | if "max_results" in options:
137 | payload["max_results"] = options.get("max_results")
138 |
139 | # Use Authorization Bearer format as per the official API example
140 | headers = {
141 | "Content-Type": "application/json",
142 | "Authorization": f"Bearer {self.api_key}"
143 | }
144 |
145 | try:
146 | logger.debug(f"Searching with Tavily API: {query}")
147 | async with aiohttp.ClientSession() as session:
148 | async with session.post(
149 | self.base_url,
150 | headers=headers,
151 | json=payload
152 | ) as response:
153 | response.raise_for_status()
154 | result = await response.json()
155 |
156 | # Transform the result to match the expected format
157 | transformed_result = self._transform_result(result, query)
158 | return transformed_result
159 |
160 | except Exception as e:
161 | logger.error(f"Error searching with Tavily API: {str(e)}")
162 | raise
163 |
164 | def _transform_result(self, tavily_result: Dict[str, Any], query: str) -> Dict[str, Any]:
165 | """
166 | Transform Tavily API result to match the expected format.
167 |
168 | Args:
169 | tavily_result: Raw result from Tavily API
170 | query: Original search query
171 |
172 | Returns:
173 | Transformed result in compatible format
174 | """
175 | transformed_data = []
176 |
177 | # Process results from Tavily API
178 | if "results" in tavily_result and isinstance(tavily_result["results"], list):
179 | for item in tavily_result["results"]:
180 | content = ""
181 |
182 | # Extract content from the result
183 | if "content" in item and item["content"]:
184 | content += item["content"] + "\n\n"
185 |
186 | # Add score information if available
187 | if "score" in item:
188 | content += f"Relevance Score: {item['score']}\n\n"
189 |
190 | transformed_item = {
191 | "url": item.get("url", ""),
192 | "title": item.get("title", ""),
193 | "content": content.strip(),
194 | "source": "tavily"
195 | }
196 | transformed_data.append(transformed_item)
197 |
198 | # If answer is provided by Tavily, add it as a special result
199 | if "answer" in tavily_result and tavily_result["answer"]:
200 | transformed_item = {
201 | "url": "",
202 | "title": "Tavily Direct Answer",
203 | "content": tavily_result["answer"],
204 | "source": "tavily_answer"
205 | }
206 | transformed_data.append(transformed_item)
207 |
208 | # If follow-up questions are provided, add them as a special result
209 | if "follow_up_questions" in tavily_result and tavily_result["follow_up_questions"]:
210 | follow_up_content = "Suggested follow-up questions:\n\n"
211 | for question in tavily_result["follow_up_questions"]:
212 | follow_up_content += f"- {question}\n"
213 |
214 | transformed_item = {
215 | "url": "",
216 | "title": "Suggested Follow-up Questions",
217 | "content": follow_up_content,
218 | "source": "tavily_follow_up"
219 | }
220 | transformed_data.append(transformed_item)
221 |
222 | # If no results were found or processing failed
223 | if not transformed_data:
224 | transformed_item = {
225 | "url": "",
226 | "title": f"Tavily Search Result for: {query}",
227 | "content": "No results found or could not process the search response.",
228 | "source": "tavily"
229 | }
230 | transformed_data.append(transformed_item)
231 |
232 | # Add response time if available
233 | if "response_time" in tavily_result:
234 | logger.debug(f"Tavily search completed in {tavily_result['response_time']} seconds")
235 |
236 | return {
237 | "query": query,
238 | "data": transformed_data
239 | }
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Tests package initialization
--------------------------------------------------------------------------------
/tests/test_providers.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import patch, MagicMock
3 | from src.providers import get_model, trim_prompt
4 |
5 |
6 | class TestProviders(unittest.TestCase):
7 | """测试 providers 模块功能"""
8 |
9 | @patch('src.providers.get_config')
10 | @patch('src.providers.openai_client')
11 | def test_get_model(self, mock_openai_client, mock_get_config):
12 | """测试获取模型配置"""
13 | # 模拟配置
14 | mock_get_config.return_value = {
15 | "openai": {
16 | "model": "test-model"
17 | }
18 | }
19 |
20 | # 模拟 OpenAI 客户端
21 | mock_openai_client.__bool__.return_value = True
22 |
23 | # 获取模型配置
24 | model_config = get_model()
25 |
26 | # 验证结果
27 | self.assertEqual(model_config["client"], mock_openai_client)
28 | self.assertEqual(model_config["model"], "test-model")
29 |
30 | def test_trim_prompt_short(self):
31 | """测试短提示不需要裁剪"""
32 | prompt = "This is a short prompt"
33 | result = trim_prompt(prompt, context_size=1000)
34 | self.assertEqual(result, prompt)
35 |
36 | @patch('src.providers.encoder.encode')
37 | def test_trim_prompt_long(self, mock_encode):
38 | """测试长提示需要裁剪"""
39 | # 模拟编码器返回超长 token 数
40 | mock_encode.return_value = [0] * 2000
41 |
42 | prompt = "This is a very long prompt" * 100
43 | result = trim_prompt(prompt, context_size=1000)
44 |
45 | # 验证结果被裁剪
46 | self.assertLess(len(result), len(prompt))
47 |
48 |
49 | if __name__ == '__main__':
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/tests/test_serper_client.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import patch, MagicMock
3 | from src.serper_client import SerperClient
4 |
5 |
6 | class TestSerperClient(unittest.TestCase):
7 | """测试 Serper 搜索客户端"""
8 |
9 | @patch('src.serper_client.get_config')
10 | def setUp(self, mock_get_config):
11 | """测试前准备"""
12 | # 模拟配置
13 | mock_get_config.return_value = {
14 | "serper": {
15 | "api_key": "test_api_key",
16 | "base_url": "https://test.serper.dev/search"
17 | }
18 | }
19 |
20 | # 创建客户端实例
21 | self.client = SerperClient()
22 |
23 | def test_init(self):
24 | """测试初始化"""
25 | self.assertEqual(self.client.api_key, "test_api_key")
26 | self.assertEqual(self.client.base_url, "https://test.serper.dev/search")
27 |
28 | @patch('httpx.Client.post')
29 | def test_search(self, mock_post):
30 | """测试搜索功能"""
31 | # 模拟响应
32 | mock_response = MagicMock()
33 | mock_response.status_code = 200
34 | mock_response.json.return_value = {
35 | "organic": [
36 | {
37 | "title": "Test Result",
38 | "link": "https://example.com",
39 | "snippet": "This is a test result"
40 | }
41 | ]
42 | }
43 | mock_post.return_value = mock_response
44 |
45 | # 执行搜索
46 | result = self.client.search("test query")
47 |
48 | # 验证请求
49 | mock_post.assert_called_once()
50 | args, kwargs = mock_post.call_args
51 | self.assertEqual(kwargs["headers"]["X-API-KEY"], "test_api_key")
52 | self.assertEqual(kwargs["json"]["q"], "test query")
53 |
54 | # 验证结果转换
55 | self.assertEqual(result["query"], "test query")
56 | self.assertEqual(len(result["data"]), 1)
57 | self.assertEqual(result["data"][0]["title"], "Test Result")
58 | self.assertEqual(result["data"][0]["url"], "https://example.com")
59 | self.assertEqual(result["data"][0]["content"], "This is a test result")
60 |
61 | @patch('httpx.Client.post')
62 | def test_search_with_options(self, mock_post):
63 | """测试带选项的搜索"""
64 | # 模拟响应
65 | mock_response = MagicMock()
66 | mock_response.status_code = 200
67 | mock_response.json.return_value = {"organic": []}
68 | mock_post.return_value = mock_response
69 |
70 | # 执行搜索
71 | self.client.search("test query", {"gl": "us", "num": 5})
72 |
73 | # 验证请求选项
74 | args, kwargs = mock_post.call_args
75 | self.assertEqual(kwargs["json"]["gl"], "us")
76 | self.assertEqual(kwargs["json"]["num"], 5)
77 |
78 |
79 | if __name__ == '__main__':
80 | unittest.main()
81 |
--------------------------------------------------------------------------------