├── .github ├── stale.yml └── workflows │ └── ubuntu.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config.example.yaml ├── deep_research_demo.py ├── docs ├── gradio.png ├── logo.png ├── project.md ├── report.png ├── wechat.jpeg └── wechat_group.jpg ├── main.py ├── requirements.txt ├── src ├── __init__.py ├── config.py ├── deep_research.py ├── gradio_chat.py ├── model_utils.py ├── mp_search_client.py ├── prompts.py ├── providers.py ├── search_utils.py ├── serper_client.py └── tavily_client.py └── tests ├── __init__.py ├── test_providers.py └── test_serper_client.py /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions.(由于长期不活动,机器人自动关闭此问题,如果需要欢迎提问) 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false -------------------------------------------------------------------------------- /.github/workflows/ubuntu.yml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_dispatch: # Manually running a workflow 3 | name: Linux build 4 | jobs: 5 | test-ubuntu: 6 | runs-on: ubuntu-latest 7 | strategy: 8 | fail-fast: false 9 | matrix: 10 | # python-version: [ 3.7, 3.8, 3.9 ] 11 | python-version: [ 3.12 ] 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Cache pip 15 | uses: actions/cache@v2 16 | if: startsWith(runner.os, 'Linux') 17 | with: 18 | path: ~/.cache/pip 19 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 20 | restore-keys: | 21 | ${{ runner.os }}-pip- 22 | - name: Cache huggingface models 23 | uses: actions/cache@v2 24 | with: 25 | path: ~/.cache/huggingface 26 | key: ${{ runner.os }}-huggingface- 27 | - name: Cache text2vec models 28 | uses: actions/cache@v2 29 | with: 30 | path: ~/.text2vec 31 | key: ${{ runner.os }}-text2vec- 32 | - name: Set up Python 33 | uses: actions/setup-python@v2 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | - name: Install torch 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install Cython 40 | pip install torch 41 | - name: Install from pypi 42 | run: | 43 | pip install -U agentica 44 | python -c "import agentica; print(agentica.__version__)" 45 | pip uninstall -y agentica 46 | - name: Install dependencies 47 | run: | 48 | pip install -r requirements.txt 49 | pip install . 50 | pip install pytest 51 | - name: PKG-TEST 52 | run: | 53 | python -m pytest -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We are happy to accept your contributions to make this repo better and more awesome! To avoid unnecessary work on either 4 | side, please stick to the following process: 5 | 6 | 1. Check if there is already [an issue](https://github.com/shibing624/deep-research/issues) for your concern. 7 | 2. If there is not, open a new one to start a discussion. We hate to close finished PRs! 8 | 3. If we decide your concern needs code changes, we would be happy to accept a pull request. Please consider the 9 | commit guidelines below. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Logo 4 | 5 |
6 | 7 | ----------------- 8 | 9 | # Open Deep Research (Python) 10 | [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) 11 | [![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) 12 | [![python_version](https://img.shields.io/badge/Python-3.8%2B-green.svg)](requirements.txt) 13 | [![GitHub issues](https://img.shields.io/github/issues/shibing624/deep-research.svg)](https://github.com/shibing624/deep-research/issues) 14 | [![Wechat Group](https://img.shields.io/badge/wechat-group-green.svg?logo=wechat)](#Contact) 15 | 16 | 17 | **Deep Research**: Python implementation of AI-powered research assistant that performs iterative, deep research on any topic by combining search engines, web scraping, and large language models. 18 | 19 | 20 | ## Features 21 | 22 | - **深度搜索**:智能生成检索计划并按计划执行并行检索 23 | - **智能查询生成**:基于初始问题和已获取的信息自动生成后续查询 24 | - **多种输出格式**:支持简洁回答和详细报告两种输出模式 25 | - **多语言支持**:完全支持中文输入和输出 26 | - **自动上下文管理**:智能控制传递给LLM的上下文长度,防止token限制错误 27 | - **可配置澄清流程**:可选择跳过澄清环节,直接进行研究 28 | - **多种使用方式**: 29 | - 命令行界面 30 | - Gradio 网页界面(支持流式输出CoT) 31 | - Python 模块直接调用 32 | 33 | ## Demo 34 | - Official demo: https://deepresearch.mulanai.com 35 | 36 | ## Setup 37 | 38 | 1. Clone the repository: 39 | 40 | ```bash 41 | git clone https://github.com/shibing624/deep-research.git 42 | ``` 43 | 2. Install dependencies: 44 | 45 | ```bash 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | 3. Create a configuration file: 50 | 51 | ```bash 52 | # Copy the example configuration 53 | cp config.example.yaml config.yaml 54 | ``` 55 | 56 | The configuration file allows you to set: 57 | - API keys for OpenAI, Tavily, and Serper 58 | - Model preferences 59 | - Search engine options (Serper, MP Search, and Tavily) 60 | 61 | 62 | ### Key Configuration Options 63 | 64 | - **context_size**: Controls the maximum size of context sent to the LLM. The system will automatically truncate longer contexts to prevent token limit errors while preserving as much relevant information as possible. Default is `128000`. 65 | 66 | - **enable_clarification**: When set to `False`, the system will skip the clarification step and proceed directly with research. This is useful for straightforward queries where clarification might be unnecessary. Default is `False`. 67 | 68 | - **search_source**: Choose your preferred search provider. Default is `tavily`. 69 | 70 | - **enable_refine_search_result**: When enabled, the system will refine search results for better relevance. Default is `False`. 71 | 72 | ## Search Engines 73 | 74 | Deep Research supports multiple search engines: 75 | 76 | 1. **Serper** (Default): Uses Google search results via Serper.dev API 77 | 2. **MP Search**: An alternative search provider 78 | 3. **Tavily**: A specialized AI-optimized search engine 79 | 80 | To use Tavily search: 81 | 1. Get an API key from [Tavily](https://tavily.com) 82 | 2. Add it to your config.yaml: 83 | ```yaml 84 | tavily: 85 | api_key: "your-tavily-api-key" # Use the token without 'Bearer' prefix 86 | base_url: "https://api.tavily.com/search" 87 | ``` 88 | Note: For Tavily, provide just the API token (e.g., "tvly-dev-xxx") without the "Bearer" prefix. 89 | 90 | 3. Set Tavily as your search source in the Gradio interface or in your config.yaml: 91 | ```yaml 92 | research: 93 | search_source: "tavily" 94 | ``` 95 | 96 | The Tavily search engine provides high-quality, AI-optimized search results and may include: 97 | - Ranked search results with relevance scores 98 | - Follow-up questions (when available) 99 | - Direct answers for certain queries (when available) 100 | 101 | ## Usage 102 | 103 | ### Command Line Interface 104 | 105 | The main.py script provides several ways to use the research assistant: 106 | 107 | ```bash 108 | # Show help 109 | python main.py --help 110 | 111 | # Run research directly from command line 112 | python main.py research "中国2024年经济情况分析" 113 | 114 | # Launch the Gradio demo interface 115 | python main.py demo 116 | 117 | # Use a specific configuration file 118 | python main.py --config my-config.yaml research "Your query" 119 | ``` 120 | 121 | ### Demo Script 122 | 123 | 运行演示脚本,查看完整的研究流程: 124 | 125 | ```bash 126 | python deep_research_demo.py 127 | ``` 128 | 129 | 这将执行一个示例研究,生成 query: {中国历史上最伟大的发明是什么?} 的详细报告,并保存到文件[report.md](https://github.com/shibing624/deep-research/blob/main/report.md)中。 130 | 131 | output: 132 | 133 | ![report](https://github.com/shibing624/deep-research/blob/main/docs/report.png) 134 | 135 | ### Gradio Demo 136 | 137 | For a user-friendly interface, run the Gradio demo: 138 | 139 | ```bash 140 | python main.py demo 141 | ``` 142 | 143 | This will start a web interface where you can enter your research query, adjust parameters, and view results. 144 | 145 | ![gradio](https://github.com/shibing624/deep-research/blob/main/docs/gradio.png) 146 | 147 | ### Python Module 148 | 149 | Or use the module directly: 150 | 151 | ```python 152 | import asyncio 153 | from src.deep_research import deep_research_stream 154 | 155 | 156 | async def run_research(): 157 | # 运行研究 158 | async for result in deep_research_stream( 159 | query="特斯拉股票走势分析", 160 | history_context="", 161 | ): 162 | # 如果研究完成,保存报告 163 | if result.get("stage") == "completed": 164 | report = result.get("final_report", "") 165 | print(report) 166 | break 167 | 168 | 169 | if __name__ == "__main__": 170 | asyncio.run(run_research()) 171 | ``` 172 | 173 | Note: Since asynchronous functions are used, you need to use `asyncio.run()` or use `await` in an asynchronous context. 174 | 175 | ## 后续计划 176 | 177 | - 添加更多搜索引擎支持 178 | - 改进查询生成策略 179 | - 增强结果可视化 180 | - 支持更多大语言模型 181 | - 添加文档嵌入和向量搜索功能 182 | 183 | ## Contact 184 | 185 | - Issue(建议) 186 | :[![GitHub issues](https://img.shields.io/github/issues/shibing624/deep-research.svg)](https://github.com/shibing624/deep-research/issues) 187 | - 邮件我:xuming: xuming624@qq.com 188 | - 微信我: 加我*微信号:xuming624, 备注:姓名-公司-NLP* 进NLP交流群。 189 | 190 | 191 | 192 | ## Citation 193 | 194 | 如果你在研究中使用了`deep-research`,请按如下格式引用: 195 | 196 | APA: 197 | 198 | ``` 199 | Xu, M. deep-research: Deep Research with LLM (Version 0.0.1) [Computer software]. https://github.com/shibing624/deep-research 200 | ``` 201 | 202 | BibTeX: 203 | 204 | ``` 205 | @misc{Xu_deep_research, 206 | title={deep-research: Deep Research with LLM}, 207 | author={Xu Ming}, 208 | year={2025}, 209 | howpublished={\url{https://github.com/shibing624/deep-research}}, 210 | } 211 | ``` 212 | 213 | ## License 214 | 215 | 授权协议为 [The Apache License 2.0](/LICENSE),可免费用做商业用途。请在产品说明中附加`deep-research`的链接和授权协议。 216 | ## Contribute 217 | 218 | 项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点: 219 | 220 | - 在`tests`添加相应的单元测试 221 | - 使用`python -m pytest`来运行所有单元测试,确保所有单测都是通过的 222 | 223 | 之后即可提交PR。 224 | 225 | ## Acknowledgements 226 | 227 | - [dzhng/deep-research](https://github.com/dzhng/deep-research) 228 | 229 | Thanks for their great work! -------------------------------------------------------------------------------- /config.example.yaml: -------------------------------------------------------------------------------- 1 | # Deep Research Configuration 2 | 3 | # OpenAI settings 4 | openai: 5 | api_key: "sk-xxx" # Set your OpenAI API key here 6 | base_url: "https://api.openai.com/v1" 7 | model: "gpt-4o-mini" # Default model, can be "gpt-4o", "gpt-4o-mini" 8 | 9 | # Report LLM settings 10 | report_llm: 11 | api_key: "sk-xxx" # Set your OpenAI API key here 12 | base_url: "https://api.openai.com/v1" 13 | model: "gpt-4o" 14 | 15 | # Serper settings 16 | serper: 17 | api_key: "xxx" # Set your Serper API key here 18 | base_url: "https://google.serper.dev/search" 19 | 20 | # Tavily settings 21 | tavily: 22 | api_key: "tvly-dev-xxx" # Use the token without 'Bearer' prefix 23 | base_url: "https://api.tavily.com/search" 24 | 25 | # Research settings 26 | research: 27 | concurrency_limit: 3 # 并发搜索数量 28 | context_size: 64000 # 传入LLM的最大文本长度(token数) 29 | search_source: "tavily" # 搜索引擎 30 | max_results_per_query: 3 # 每个搜索结果的最大数量 31 | enable_refine_search_result: False # 是否需要精简搜索结果,如果为True,将会对搜索结果提取关键片段;如果是False,直接用搜索结果原文 32 | enable_next_plan: False # 是否需要下一步计划,如果为True,将会总结搜索结果并提供下一步计划,为生成报告提供参考;如果是False,不要下一步计划的分析 33 | enable_clarification: False # 是否需要澄清问题,如果为True, 将会先澄清问题;如果是False,将会跳过澄清环节,直接进行搜索 34 | -------------------------------------------------------------------------------- /deep_research_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import asyncio 3 | from loguru import logger 4 | 5 | from src.config import get_config 6 | from src.deep_research import deep_research_stream 7 | 8 | 9 | async def run_demo(): 10 | """ 11 | 演示如何使用 deep_research 模块进行研究 12 | """ 13 | # 加载配置 14 | config = get_config() 15 | 16 | # 定义研究问题 17 | query = "中国元朝的货币制度改革的影响意义?" 18 | 19 | # 定义进度回调函数 20 | def progress_callback(progress): 21 | current = progress.get("currentQuery", "") 22 | logger.info(f"进度: 当前查询: {current}") 23 | 24 | logger.info(f"开始研究: {query}") 25 | async for result in deep_research_stream( 26 | query=query, 27 | on_progress=progress_callback, 28 | user_clarifications={'all': 'skip'}, # 使用特殊标记跳过澄清 29 | history_context="" # 添加空的history_context 30 | ): 31 | # 处理状态更新 32 | if result.get("status_update"): 33 | logger.info(f"状态更新: {result['status_update']}") 34 | 35 | # 如果研究完成,获取最终报告 36 | if result.get("stage") == "completed": 37 | final_report = result.get("final_report", "") 38 | with open("report.md", "w", encoding="utf-8") as f: 39 | f.write(final_report) 40 | logger.info("报告已保存到 report.md") 41 | print("\n" + "=" * 50) 42 | print(f"研究问题: {query}") 43 | print("=" * 50) 44 | print("\n研究报告:") 45 | print(final_report) 46 | print("\n" + "=" * 50) 47 | 48 | break 49 | 50 | 51 | if __name__ == "__main__": 52 | # 运行异步演示 53 | asyncio.run(run_demo()) 54 | -------------------------------------------------------------------------------- /docs/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/gradio.png -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/logo.png -------------------------------------------------------------------------------- /docs/project.md: -------------------------------------------------------------------------------- 1 | # Deep Research 项目分析 2 | 3 | ## 项目概述 4 | 5 | Deep Research是一个强大的研究助手工具,能够通过多阶段搜索、分析和总结来帮助用户获取全面的研究信息。通过整合大语言模型与搜索引擎,实现了自动化的研究过程,包括问题澄清、研究计划生成、多步骤搜索和最终报告生成。 6 | 7 | ## 当前优点 8 | 9 | 1. **智能查询澄清** 10 | - 自动识别查询不明确的情况,生成澄清问题 11 | - 对用户回答进行处理,提炼出更明确的查询 12 | - 设计了缓存机制避免重复调用API,优化性能 13 | 14 | 2. **多层次研究流程** 15 | - 动态生成研究计划,根据查询复杂性调整步骤数量 16 | - 支持广度和深度参数调整,适应不同研究需求 17 | - 并行处理多个查询,提高效率 18 | 19 | 3. **透明的研究过程** 20 | - 实时显示研究进度和发现 21 | - 清晰展示研究计划和执行步骤 22 | - 记录所有数据来源,确保可追溯性 23 | 24 | 4. **优化的用户界面** 25 | - Gradio界面简洁易用 26 | - 支持实时反馈和状态更新 27 | - 结构化展示研究结果和中间过程 28 | 29 | 5. **灵活的搜索源** 30 | - 支持多种搜索提供商(serper, tavily, mp_search) 31 | - 允许用户选择偏好的搜索引擎 32 | 33 | ## 当前缺点和局限性 34 | 35 | 1. **错误处理机制不完善** 36 | - 搜索失败或API限制时的恢复策略有限 37 | - 异常处理较为简单,缺乏优雅的降级方案 38 | 39 | 2. **用户交互有待加强** 40 | - 澄清问题环节的交互体验不够流畅 41 | - 缺乏对用户输入的实时验证 42 | - 无法让用户中途干预或调整研究方向 43 | 44 | 3. **性能和资源优化** 45 | - 大型查询可能导致较高的API成本 46 | - 对于复杂查询,响应时间较长 47 | - 缓存和记忆机制有限,不能很好地复用之前的研究结果 48 | 49 | 4. **工具链整合不足** 50 | - 缺乏与其他研究工具的集成能力 51 | - 无法处理非文本内容(图表、数据集等) 52 | - 缺少本地知识库支持 53 | 54 | 5. **评估和验证** 55 | - 缺乏对搜索结果质量的自我评估机制 56 | - 不能充分验证信息的准确性和时效性 57 | - 缺少对矛盾信息的处理机制 58 | 59 | ## 建议改进方向 60 | 61 | 1. **增强错误恢复机制** 62 | - 实现更健壮的错误处理和重试逻辑 63 | - 添加备用搜索源自动切换功能 64 | - 提供更详细的错误诊断和解决建议 65 | 66 | 2. **改进用户体验** 67 | - 支持研究过程中的用户干预 68 | - 增加更多交互选项,如点赞/否决特定发现 69 | - 提供简洁模式和详细模式切换 70 | - 优化移动设备支持 71 | 72 | 3. **扩展功能和集成** 73 | - 添加更多搜索提供商 74 | - 实现本地知识库和文档集成 75 | - 支持导出研究结果到不同格式(PDF、Word、Markdown等) 76 | - 增加多媒体内容处理能力 77 | 78 | 4. **性能优化** 79 | - 实现更智能的缓存机制 80 | - 优化并行请求管理 81 | - 添加低成本模式,减少API调用 82 | - 实现增量更新研究,避免重复工作 83 | 84 | 5. **增强验证能力** 85 | - 添加信息交叉验证机制 86 | - 实现时效性检测,标记可能过时的信息 87 | - 增加来源评级系统,优先考虑高质量来源 88 | - 支持对矛盾信息的比较和分析 89 | 90 | 6. **扩展语言支持** 91 | - 增强多语言处理能力 92 | - 优化非英文查询的澄清和处理 93 | - 支持语言切换和翻译功能 94 | 95 | 7. **社区和开发者支持** 96 | - 完善API文档 97 | - 提供更多自定义和扩展点 98 | - 简化部署流程 99 | - 创建插件系统,支持社区贡献 100 | 101 | 8. **应用场景扩展** 102 | - 开发专注于特定领域的研究模式(学术、新闻、医疗等) 103 | - 添加长期研究项目支持,包括进度跟踪和定期更新 104 | - 实现团队协作功能,支持多人共同研究 105 | 106 | ## 结论 107 | 108 | Deep Research项目展现了强大的潜力,通过结合大语言模型和搜索引擎,为用户提供了自动化、深入的研究能力。尽管目前存在一些局限性,但通过持续改进,该项目可以发展成为一个更加全面、可靠且易用的研究助手工具。重点应放在提高研究质量、增强用户体验以及扩展功能集成上,使其能够满足各种复杂的研究需求。 -------------------------------------------------------------------------------- /docs/report.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/report.png -------------------------------------------------------------------------------- /docs/wechat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/wechat.jpeg -------------------------------------------------------------------------------- /docs/wechat_group.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/docs/wechat_group.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import argparse 7 | import asyncio 8 | from src.config import load_config, get_config 9 | from src.deep_research import deep_research_stream 10 | from loguru import logger 11 | 12 | 13 | async def run_research(args): 14 | # 加载配置 15 | config = get_config() 16 | 17 | # 定义研究问题 18 | query = args.query 19 | 20 | # 定义进度回调函数 21 | def progress_callback(progress): 22 | current = progress.get("currentQuery", "") 23 | 24 | logger.info(f"进度: 当前查询: {current}") 25 | 26 | # 运行研究 27 | logger.info(f"开始研究: {query}") 28 | 29 | # 使用流式研究,通过user_clarifications跳过澄清步骤 30 | async for result in deep_research_stream( 31 | query=query, 32 | on_progress=progress_callback, 33 | user_clarifications={'all': 'skip'}, # 使用特殊标记跳过澄清 34 | history_context="" # 添加空的history_context 35 | ): 36 | # 处理状态更新 37 | if result.get("status_update"): 38 | logger.info(f"状态更新: {result['status_update']}") 39 | 40 | # 如果研究完成,获取最终报告 41 | if result.get("stage") == "completed": 42 | learnings = result.get("learnings", []) 43 | visited_urls = result.get("visitedUrls", []) 44 | final_report = result.get("final_report", "") 45 | 46 | logger.info(f"研究完成! 发现 {len(learnings)} 条学习内容和 {len(visited_urls)} 个来源。") 47 | 48 | # 保存结果到文件 49 | with open("report.md", "w", encoding="utf-8") as f: 50 | f.write(final_report) 51 | 52 | logger.info("报告已保存到 report.md") 53 | 54 | print("\n" + "=" * 50) 55 | print(f"研究问题: {query}") 56 | print("=" * 50) 57 | print("\n研究报告:") 58 | print(final_report) 59 | print("\n" + "=" * 50) 60 | 61 | break 62 | 63 | 64 | def main(): 65 | """Main entry point with argument parsing""" 66 | parser = argparse.ArgumentParser( 67 | description="Deep Research - AI-powered research assistant" 68 | ) 69 | 70 | # Add config file argument 71 | parser.add_argument( 72 | "--config", type=str, 73 | help="Path to YAML configuration file" 74 | ) 75 | 76 | subparsers = parser.add_subparsers(dest="command", help="Command to run") 77 | 78 | # demo command 79 | demo_parser = subparsers.add_parser("demo", help="Run the gradio demo server") 80 | demo_parser.add_argument( 81 | "--host", type=str, default='0.0.0.0', help="Host ip" 82 | ) 83 | 84 | # Research command 85 | research_parser = subparsers.add_parser("research", help="Run research directly") 86 | research_parser.add_argument( 87 | "query", type=str, 88 | help="Research query" 89 | ) 90 | args = parser.parse_args() 91 | 92 | # Load configuration 93 | if args.config: 94 | load_config(args.config) 95 | 96 | # Execute command 97 | if args.command == "research": 98 | asyncio.run(run_research(args)) 99 | elif args.command == "demo": 100 | from src.gradio_chat import run_gradio_demo 101 | run_gradio_demo() 102 | else: 103 | # Default to showing help if no command specified 104 | parser.print_help() 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru 2 | openai 3 | httpx 4 | gradio==5.22.0 5 | pyyaml -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/deep-research/01e649e5158adca07cb68b109ea484803f7d368e/src/__init__.py -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import os 7 | import yaml 8 | from typing import Dict, Any, Optional 9 | from loguru import logger 10 | 11 | # Default configuration 12 | DEFAULT_CONFIG = { 13 | "openai": { 14 | "api_key": None, 15 | "base_url": "https://api.openai.com/v1", 16 | "model": "o3-mini" 17 | }, 18 | "serper": { 19 | "api_key": None, 20 | "base_url": "https://google.serper.dev/search" 21 | }, 22 | "tavily": { 23 | "api_key": None, 24 | "base_url": "https://api.tavily.com/search" 25 | }, 26 | "research": { 27 | "concurrency_limit": 1, 28 | "context_size": 128000, 29 | "search_source": "serper", 30 | "max_results_per_query": 5, 31 | "enable_refine_search_result": False, 32 | "enable_next_plan": False 33 | } 34 | } 35 | 36 | # Global configuration object 37 | _config = None 38 | 39 | 40 | def load_config(config_path: Optional[str] = None) -> Dict[str, Any]: 41 | """ 42 | Load configuration from YAML file with fallback to environment variables 43 | 44 | Args: 45 | config_path: Path to YAML configuration file 46 | 47 | Returns: 48 | Dict containing configuration 49 | """ 50 | global _config 51 | 52 | # Start with default configuration 53 | config = DEFAULT_CONFIG.copy() 54 | 55 | # Try to load from specified config file 56 | if config_path and os.path.exists(config_path): 57 | try: 58 | with open(config_path, 'r') as f: 59 | yaml_config = yaml.safe_load(f) 60 | if yaml_config: 61 | # Merge with defaults (deep merge would be better but this is simple) 62 | for section, values in yaml_config.items(): 63 | if section in config and isinstance(values, dict): 64 | config[section].update(values) 65 | else: 66 | config[section] = values 67 | logger.info(f"Loaded configuration from {config_path}") 68 | except Exception as e: 69 | logger.warning(f"Error loading config from {config_path}: {str(e)}") 70 | else: 71 | # If no config file specified, look for config.yaml in the current directory 72 | default_path = './config.yaml' 73 | if os.path.exists(default_path): 74 | return load_config(default_path) 75 | else: 76 | logger.info("No configuration file found, using defaults") 77 | 78 | # Store the config globally 79 | _config = config 80 | return config 81 | 82 | 83 | def get_config() -> Dict[str, Any]: 84 | """ 85 | Get the current configuration 86 | 87 | Returns: 88 | Dict containing configuration 89 | """ 90 | global _config 91 | if _config is None: 92 | return load_config() 93 | return _config 94 | -------------------------------------------------------------------------------- /src/deep_research.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | 6 | Deep research functionality for comprehensive research. 7 | """ 8 | 9 | import asyncio 10 | import json 11 | import traceback 12 | import inspect 13 | import platform 14 | from typing import Optional, Callable, Dict, List, Any, Union, Tuple, AsyncGenerator 15 | from loguru import logger 16 | from datetime import datetime 17 | 18 | from .config import get_config 19 | from .providers import get_search_provider 20 | from .prompts import ( 21 | SHOULD_CLARIFY_QUERY_PROMPT, 22 | FOLLOW_UP_QUESTIONS_PROMPT, 23 | PROCESS_NO_CLARIFICATIONS_PROMPT, 24 | PROCESS_CLARIFICATIONS_PROMPT, 25 | RESEARCH_PLAN_PROMPT, 26 | EXTRACT_SEARCH_RESULTS_SYSTEM_PROMPT, 27 | EXTRACT_SEARCH_RESULTS_PROMPT, 28 | RESEARCH_SUMMARY_PROMPT, 29 | FINAL_REPORT_SYSTEM_PROMPT, 30 | FINAL_REPORT_PROMPT, 31 | FINAL_ANSWER_PROMPT, 32 | ) 33 | from .model_utils import generate_completion, generate_json_completion 34 | from .search_utils import search_with_query 35 | 36 | 37 | def get_current_date(): 38 | """Return the current date in ISO format.""" 39 | return datetime.now().isoformat() 40 | 41 | 42 | def add_event_loop_policy(): 43 | """Add event loop policy for Windows if needed.""" 44 | if platform.system() == "Windows": 45 | try: 46 | # Set event loop policy for Windows 47 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) 48 | except Exception as e: 49 | print(f"Error setting event loop policy: {e}") 50 | 51 | 52 | def limit_context_size(text: str, max_size: int) -> str: 53 | """ 54 | Limit the context size to prevent LLM token limit errors. 55 | 56 | Args: 57 | text: Text to limit 58 | max_size: Approximate maximum character count (rough estimate for tokens) 59 | 60 | Returns: 61 | Limited text 62 | """ 63 | # Simple character-based truncation (rough approximation) 64 | # On average, 1 token is roughly 4 characters for English text 65 | # For Chinese, the ratio is 2 66 | char_limit = max_size * 2 67 | 68 | if len(text) <= char_limit: 69 | return text 70 | 71 | logger.warning(f"Truncating context from {len(text)} chars to ~{char_limit} chars") 72 | 73 | # For JSON strings, try to preserve structure 74 | if text.startswith('{') and text.endswith('}'): 75 | try: 76 | # Try to parse as JSON 77 | data = json.loads(text) 78 | # If it's a list of items, truncate the list 79 | if isinstance(data, list): 80 | # Calculate approx chars per item 81 | if len(data) > 0: 82 | chars_per_item = len(text) / len(data) 83 | items_to_keep = int(char_limit / chars_per_item) 84 | items_to_keep = max(1, min(items_to_keep, len(data))) 85 | return json.dumps(data[:items_to_keep], ensure_ascii=False) 86 | # If it contains a list property, truncate that 87 | for key, value in data.items(): 88 | if isinstance(value, list) and len(value) > 0: 89 | chars_per_item = len(json.dumps(value, ensure_ascii=False)) / len(value) 90 | items_to_keep = int(char_limit / 2 / chars_per_item) # Use half for lists 91 | items_to_keep = max(1, min(items_to_keep, len(value))) 92 | data[key] = value[:items_to_keep] 93 | return json.dumps(data, ensure_ascii=False) 94 | except: 95 | # Not valid JSON, use simple truncation 96 | pass 97 | 98 | # Simple truncation with indicator 99 | return text[:char_limit - 50] + "... [content truncated due to token limit]" 100 | 101 | 102 | async def should_clarify_query(query: str, history_context: str = '') -> bool: 103 | """ 104 | Use the language model to determine if a query needs clarification. 105 | """ 106 | try: 107 | prompt = SHOULD_CLARIFY_QUERY_PROMPT.format( 108 | query=query, 109 | history_context=history_context, 110 | current_date=get_current_date() 111 | ) 112 | result = await generate_completion(prompt, temperature=0) 113 | logger.debug(f"query: {query}, should clarify query result: {result}") 114 | 115 | # 检查结果是否包含肯定回答(yes/y) 116 | needs_clarification = "yes" in result.lower() or 'y' in result.lower() 117 | return needs_clarification 118 | except Exception as e: 119 | logger.error(f"error: {str(e)}") 120 | return True 121 | 122 | 123 | async def generate_followup_questions(query: str, history_context: str = '') -> Dict[str, Any]: 124 | """ 125 | Generate clarifying follow-up questions for the given query. 126 | 127 | Args: 128 | query: The user's research query 129 | history_context: Chat history context 130 | 131 | Returns: 132 | Dict containing whether clarification is needed and questions 133 | """ 134 | try: 135 | # Format the prompt 136 | prompt = FOLLOW_UP_QUESTIONS_PROMPT.format( 137 | query=query, 138 | history_context=history_context, 139 | current_date=get_current_date() # 添加当前日期 140 | ) 141 | 142 | # Generate the followup questions 143 | result = await generate_json_completion(prompt, temperature=0.7) 144 | 145 | # Ensure the expected structure 146 | if "needs_clarification" not in result: 147 | result["needs_clarification"] = False 148 | 149 | if "questions" not in result: 150 | result["questions"] = [] 151 | 152 | logger.debug(f"Follow-up questions: {result}") 153 | return result 154 | 155 | except Exception as e: 156 | logger.error(f"Error generating followup questions: {str(e)}") 157 | return { 158 | "needs_clarification": False, 159 | "questions": [] 160 | } 161 | 162 | 163 | async def process_clarifications( 164 | query: str, 165 | user_responses: Dict[str, str], 166 | all_questions: List[Dict[str, Any]], 167 | history_context: str = '' 168 | ) -> Dict[str, Any]: 169 | """ 170 | Process user responses to clarification questions and refine the query. 171 | 172 | Args: 173 | query: The original query 174 | user_responses: Dict mapping question keys to user responses 175 | all_questions: List of all questions that were asked 176 | history_context: Chat history context 177 | 178 | Returns: 179 | Dict with refined query and other information 180 | """ 181 | try: 182 | # Format the questions and responses for the prompt 183 | clarifications = [] 184 | unanswered = [] 185 | 186 | for question in all_questions: 187 | key = question.get("key", "") 188 | question_text = question.get("question", "") 189 | default = question.get("default", "") 190 | 191 | if key in user_responses and user_responses[key]: 192 | clarifications.append(f"Q: {question_text}\nA: {user_responses[key]}") 193 | else: 194 | clarifications.append(f"Q: {question_text}\nA: [Not answered by user]") 195 | unanswered.append(f"Q: {question_text}\nDefault: {default}") 196 | 197 | # Check if all questions were unanswered - if so, modify the prompt 198 | all_unanswered = len(unanswered) == len(all_questions) and len(all_questions) > 0 199 | 200 | # Format the prompt 201 | if all_unanswered: 202 | prompt = PROCESS_NO_CLARIFICATIONS_PROMPT.format( 203 | query=query, 204 | unanswered_questions="\n\n".join(unanswered) if unanswered else "None", 205 | history_context=history_context, 206 | current_date=get_current_date() 207 | ) 208 | else: 209 | prompt = PROCESS_CLARIFICATIONS_PROMPT.format( 210 | query=query, 211 | clarifications="\n\n".join(clarifications), 212 | unanswered_questions="\n\n".join(unanswered) if unanswered else "None", 213 | history_context=history_context, 214 | current_date=get_current_date() 215 | ) 216 | 217 | # Generate the clarifications processing 218 | result = await generate_json_completion(prompt, temperature=0.7) 219 | 220 | # Ensure the expected structure 221 | if "refined_query" not in result: 222 | result["refined_query"] = query 223 | 224 | if "assumptions" not in result: 225 | result["assumptions"] = [] 226 | 227 | if "requires_search" not in result: 228 | result["requires_search"] = True 229 | 230 | if "direct_answer" not in result: 231 | result["direct_answer"] = "" 232 | 233 | logger.debug(f"Processed clarifications: {result}") 234 | return result 235 | 236 | except Exception as e: 237 | logger.error(f"Error processing clarifications: {str(e)}") 238 | return { 239 | "refined_query": query, 240 | "assumptions": [], 241 | "requires_search": True, 242 | "direct_answer": "" 243 | } 244 | 245 | 246 | async def generate_research_plan(query: str, history_context: str = "") -> Dict[str, Any]: 247 | """ 248 | Generate a research plan with a variable number of steps based on the query complexity. 249 | 250 | Args: 251 | query: The research query 252 | history_context: Chat history context 253 | 254 | Returns: 255 | Dict containing the research plan with steps 256 | """ 257 | try: 258 | # Format the prompt 259 | prompt = RESEARCH_PLAN_PROMPT.format( 260 | query=query, 261 | history_context=history_context, 262 | current_date=get_current_date() # 添加当前日期 263 | ) 264 | 265 | # Generate the research plan 266 | result = await generate_json_completion(prompt, temperature=0.7) 267 | 268 | # Ensure the expected structure 269 | if "steps" not in result or not result["steps"]: 270 | # Create a default single-step plan if none was provided 271 | result["steps"] = [{ 272 | "step_id": 1, 273 | "description": "Research the query", 274 | "search_queries": [query], 275 | "goal": "Find information about the query" 276 | }] 277 | 278 | if "assessments" not in result: 279 | result["assessments"] = "No complexity assessment provided" 280 | 281 | logger.debug(f"Research plan: {result}") 282 | return result 283 | 284 | except Exception as e: 285 | logger.error(f"Error generating research plan: {str(e)}") 286 | # Return a simple default plan 287 | return { 288 | "assessments": "Error occurred, using default plan", 289 | "steps": [{ 290 | "step_id": 1, 291 | "description": "Research the query", 292 | "search_queries": [query], 293 | "goal": "Find information about the query" 294 | }] 295 | } 296 | 297 | 298 | async def extract_search_results(query: str, search_results: str) -> str: 299 | """ 300 | Extract search results for a query. 301 | 302 | Args: 303 | query: The search query 304 | search_results: Formatted search results text 305 | 306 | Returns: 307 | extracted search results with detailed content and relevance information 308 | """ 309 | try: 310 | # Get context size limit from config 311 | config = get_config() 312 | context_size = config.get("research", {}).get("context_size", 128000) 313 | 314 | # Limit search results size 315 | limited_search_results = limit_context_size(search_results, context_size // 2) 316 | 317 | # Format the prompt 318 | prompt = EXTRACT_SEARCH_RESULTS_PROMPT.format( 319 | query=query, 320 | search_results=limited_search_results, 321 | current_date=get_current_date() 322 | ) 323 | 324 | # Generate the extracted_contents 325 | extracted_contents = await generate_json_completion( 326 | prompt=prompt, 327 | system_message=EXTRACT_SEARCH_RESULTS_SYSTEM_PROMPT, 328 | temperature=0 329 | ) 330 | 331 | # Process and enrich the extracted content 332 | if "extracted_infos" in extracted_contents: 333 | # Make sure all entries have a relevance field (for backward compatibility) 334 | for info in extracted_contents["extracted_infos"]: 335 | if "relevance" not in info: 336 | info["relevance"] = "与查询相关的信息" 337 | 338 | # Convert to string for storage and transfer 339 | extracted_contents_str = json.dumps(extracted_contents, ensure_ascii=False) 340 | return extracted_contents_str 341 | 342 | except Exception as e: 343 | logger.error(f"Error extract search results: {str(e)}") 344 | return f"Error extract results for '{query}': {str(e)}" 345 | 346 | 347 | async def write_final_report_stream(query: str, context: str, 348 | history_context: str = '') -> AsyncGenerator[str, None]: 349 | """ 350 | Streaming version of write_final_report that yields chunks of the report. 351 | 352 | Args: 353 | query: The original research query 354 | context: List of key learnings/facts discovered with their sources 355 | history_context: str 356 | 357 | Yields: 358 | Chunks of the final report 359 | """ 360 | # Get context size limit from config 361 | config = get_config() 362 | context_size = config.get("research", {}).get("context_size", 128000) 363 | 364 | # Limit context sizes 365 | limited_context = limit_context_size(context, context_size // 2) 366 | limited_history = limit_context_size(history_context, context_size // 4) 367 | system_message = FINAL_REPORT_SYSTEM_PROMPT.format( 368 | current_date=get_current_date() 369 | ) 370 | formatted_prompt = FINAL_REPORT_PROMPT.format( 371 | query=query, 372 | context=limited_context, 373 | history_context=limited_history 374 | ) 375 | 376 | response_generator = await generate_completion( 377 | prompt=formatted_prompt, 378 | system_message=system_message, 379 | temperature=0.7, 380 | stream=True, 381 | is_report=True 382 | ) 383 | 384 | # Stream the response chunks 385 | async for chunk in response_generator: 386 | yield chunk 387 | 388 | 389 | async def write_final_report(query: str, context: str, history_context: str = '') -> str: 390 | """ 391 | Generate a final research report based on learnings and sources. 392 | 393 | Args: 394 | query: The original research query 395 | context: List of key learnings/facts discovered with their sources 396 | history_context: chat history 397 | 398 | Returns: 399 | A formatted research report 400 | """ 401 | # Get context size limit from config 402 | config = get_config() 403 | context_size = config.get("research", {}).get("context_size", 128000) 404 | 405 | # Limit context sizes 406 | limited_context = limit_context_size(context, context_size // 2) 407 | limited_history = limit_context_size(history_context, context_size // 4) 408 | 409 | formatted_prompt = FINAL_REPORT_PROMPT.format( 410 | query=query, 411 | context=limited_context, 412 | history_context=limited_history, 413 | ) 414 | 415 | system_message = FINAL_REPORT_SYSTEM_PROMPT.format( 416 | current_date=get_current_date() 417 | ) 418 | 419 | report = await generate_completion( 420 | prompt=formatted_prompt, 421 | system_message=system_message, 422 | temperature=0.7, 423 | is_report=True 424 | ) 425 | 426 | return report 427 | 428 | 429 | async def write_final_answer(query: str, context: str, history_context: str = '') -> str: 430 | """ 431 | Generate a final concise answer based on the research. 432 | 433 | Args: 434 | query: The original research query 435 | context: List of key learnings/facts discovered with their sources 436 | history_context: chat history 437 | 438 | Returns: 439 | A concise answer to the query 440 | """ 441 | # Get context size limit from config 442 | config = get_config() 443 | context_size = config.get("research", {}).get("context_size", 128000) 444 | 445 | # Limit context sizes 446 | limited_context = limit_context_size(context, context_size // 2) 447 | limited_history = limit_context_size(history_context, context_size // 4) 448 | 449 | formatted_prompt = FINAL_ANSWER_PROMPT.format( 450 | query=query, 451 | context=limited_context, 452 | history_context=limited_history, 453 | ) 454 | 455 | system_message = FINAL_REPORT_SYSTEM_PROMPT.format( 456 | current_date=get_current_date() 457 | ) 458 | 459 | answer = await generate_completion( 460 | prompt=formatted_prompt, 461 | system_message=system_message, 462 | temperature=0.7, 463 | is_report=True 464 | ) 465 | 466 | return answer 467 | 468 | 469 | def write_final_report_sync(query: str, context: str, history_context: str = '') -> str: 470 | """ 471 | Synchronous wrapper for the write_final_report function. 472 | 473 | Args: 474 | query: The original research query 475 | context: List of key learnings/facts discovered with their sources 476 | history_context: chat history 477 | 478 | Returns: 479 | A formatted research report 480 | """ 481 | # Add event loop policy for Windows if needed 482 | add_event_loop_policy() 483 | 484 | # Run the async function in the event loop 485 | try: 486 | loop = asyncio.get_event_loop() 487 | except RuntimeError: 488 | loop = asyncio.new_event_loop() 489 | asyncio.set_event_loop(loop) 490 | 491 | return loop.run_until_complete( 492 | write_final_report( 493 | query=query, 494 | context=context, 495 | history_context=history_context 496 | ) 497 | ) 498 | 499 | 500 | def write_final_answer_sync(query: str, context: str, history_context: str = '') -> str: 501 | """ 502 | Synchronous wrapper for the write_final_answer function. 503 | 504 | Args: 505 | query: The original research query 506 | context: List of key learnings/facts discovered with their sources 507 | history_context: chat history 508 | 509 | Returns: 510 | A concise answer to the query 511 | """ 512 | # Add event loop policy for Windows if needed 513 | add_event_loop_policy() 514 | 515 | # Run the async function in the event loop 516 | try: 517 | loop = asyncio.get_event_loop() 518 | except RuntimeError: 519 | loop = asyncio.new_event_loop() 520 | asyncio.set_event_loop(loop) 521 | 522 | return loop.run_until_complete( 523 | write_final_answer( 524 | query=query, 525 | context=context, 526 | history_context=history_context 527 | ) 528 | ) 529 | 530 | 531 | async def research_step( 532 | query: str, 533 | config: Dict[str, Any], 534 | on_progress: Optional[Callable] = None, 535 | search_provider=None, 536 | ) -> Dict[str, Any]: 537 | """ 538 | Perform a single step of the research process. 539 | 540 | Args: 541 | query: The query to research 542 | config: Configuration dictionary 543 | on_progress: Optional callback for progress updates 544 | search_provider: Search provider instance 545 | 546 | Returns: 547 | Dict with research results 548 | """ 549 | if search_provider is None: 550 | search_provider = get_search_provider() 551 | 552 | # Progress update 553 | if on_progress: 554 | progress_data = { 555 | "currentQuery": query 556 | } 557 | 558 | # Check if on_progress is a coroutine function 559 | if inspect.iscoroutinefunction(on_progress): 560 | await on_progress(progress_data) 561 | else: 562 | on_progress(progress_data) 563 | 564 | # Get the search results 565 | search_result = await search_with_query(query, config, search_provider) 566 | search_results_text = search_result["summary"] 567 | urls = search_result["urls"] 568 | 569 | enable_refine_search_result = config.get("research", {}).get("enable_refine_search_result", False) 570 | if enable_refine_search_result: 571 | extracted_content = await extract_search_results(query, search_results_text) 572 | else: 573 | extracted_content = search_results_text 574 | 575 | return { 576 | "extracted_content": extracted_content, 577 | "urls": urls, 578 | } 579 | 580 | 581 | async def deep_research_stream( 582 | query: str, 583 | on_progress: Optional[Callable] = None, 584 | user_clarifications: Dict[str, str] = None, 585 | search_source: Optional[str] = None, 586 | history_context: Optional[str] = None, 587 | enable_clarification: bool = False, 588 | ) -> AsyncGenerator[Dict[str, Any], None]: 589 | """ 590 | Streaming version of deep research that yields partial results. 591 | 592 | Args: 593 | query: The research query 594 | on_progress: Optional callback function for progress updates 595 | user_clarifications: User responses to clarification questions 596 | search_source: Optional search provider to use 597 | history_context: history chat content 598 | enable_clarification: Whether to use the clarification step 599 | 600 | Yields: 601 | Dict with partial research results and status updates 602 | """ 603 | # Load configuration 604 | config = get_config() 605 | logger.debug(f"query: {query}, config: {config}") 606 | 607 | # Initialize tracking variables 608 | visited_urls = [] 609 | all_learnings = [] 610 | 611 | # Initialize search provider 612 | search_provider = get_search_provider(search_source=search_source) 613 | 614 | try: 615 | # Step 1: Yield initial status 616 | yield { 617 | "status_update": f"开始研究查询: '{query}'...", 618 | "learnings": all_learnings, 619 | "visitedUrls": visited_urls, 620 | "current_query": query, 621 | "stage": "initial" 622 | } 623 | 624 | # Step 1.5: 先判断是否需要生成澄清问题 625 | needs_clarification = False 626 | 627 | # 如果没有提供用户澄清,且不跳过澄清环节,则判断是否需要澄清 628 | if not user_clarifications and enable_clarification: 629 | yield { 630 | "status_update": f"分析查询是否需要澄清...", 631 | "learnings": all_learnings, 632 | "visitedUrls": visited_urls, 633 | "current_query": query, 634 | "stage": "analyzing_query" 635 | } 636 | 637 | # 修复重复调用问题 - 使用缓存变量存储结果 638 | _clarification_cache_key = f"should_clarify_{query}" 639 | if _clarification_cache_key not in globals(): 640 | needs_clarification = await should_clarify_query(query, history_context) 641 | globals()[_clarification_cache_key] = needs_clarification 642 | else: 643 | needs_clarification = globals()[_clarification_cache_key] 644 | logger.debug(f"使用缓存的澄清结果: {query}, result: {needs_clarification}") 645 | 646 | if not needs_clarification: 647 | yield { 648 | "status_update": f"查询已足够清晰,跳过澄清步骤", 649 | "learnings": all_learnings, 650 | "visitedUrls": visited_urls, 651 | "current_query": query, 652 | "stage": "clarification_skipped" 653 | } 654 | elif not enable_clarification: 655 | # 如果配置为跳过澄清环节,直接显示状态 656 | yield { 657 | "status_update": f"配置为跳过澄清环节,直接开始研究", 658 | "learnings": all_learnings, 659 | "visitedUrls": visited_urls, 660 | "current_query": query, 661 | "stage": "clarification_skipped" 662 | } 663 | 664 | # 如果LLM判断需要澄清,或者已经有用户澄清,且未配置跳过澄清环节,则继续生成或处理澄清问题 665 | questions = [] 666 | if (needs_clarification or user_clarifications) and enable_clarification: 667 | # Step 2: Generate clarification questions if needed 668 | if needs_clarification: 669 | yield { 670 | "status_update": f"生成澄清问题...", 671 | "learnings": all_learnings, 672 | "visitedUrls": visited_urls, 673 | "current_query": query, 674 | "stage": "generating_questions" 675 | } 676 | 677 | followup_result = await generate_followup_questions(query, history_context) 678 | questions = followup_result.get("questions", []) 679 | 680 | if questions: 681 | # If clarification is needed, update status 682 | yield { 683 | "status_update": f"查询需要澄清,生成了 {len(questions)} 个问题", 684 | "learnings": all_learnings, 685 | "visitedUrls": visited_urls, 686 | "current_query": query, 687 | "questions": questions, 688 | "stage": "clarification_needed" 689 | } 690 | 691 | # If we don't have user responses yet, wait for them 692 | if not user_clarifications: 693 | yield { 694 | "status_update": "等待用户回答澄清问题...", 695 | "learnings": all_learnings, 696 | "visitedUrls": visited_urls, 697 | "current_query": query, 698 | "questions": questions, 699 | "awaiting_clarification": True, 700 | "stage": "awaiting_clarification" 701 | } 702 | return 703 | 704 | # Step 3: Process user clarifications if provided 705 | refined_query = query 706 | user_responses = user_clarifications or {} 707 | 708 | if questions and user_clarifications and enable_clarification: 709 | # Track which questions were answered vs. which use defaults 710 | answered_questions = [] 711 | unanswered_questions = [] 712 | 713 | for q in questions: 714 | key = q.get("key", "") 715 | question_text = q.get("question", "") 716 | if key in user_responses and user_responses[key]: 717 | answered_questions.append(question_text) 718 | else: 719 | unanswered_questions.append(question_text) 720 | 721 | yield { 722 | "status_update": f"处理用户的澄清回答 ({len(answered_questions)}/{len(questions)} 已回答)", 723 | "learnings": all_learnings, 724 | "visitedUrls": visited_urls, 725 | "current_query": query, 726 | "answered_questions": answered_questions, 727 | "unanswered_questions": unanswered_questions, 728 | "stage": "processing_clarifications" 729 | } 730 | 731 | # Process the clarifications 732 | clarification_result = await process_clarifications(query, user_responses, questions, history_context) 733 | refined_query = clarification_result.get("refined_query", query) 734 | 735 | yield { 736 | "status_update": f"查询已优化: '{refined_query}'", 737 | "learnings": all_learnings, 738 | "visitedUrls": visited_urls, 739 | "original_query": query, 740 | "current_query": refined_query, 741 | "assumptions": clarification_result.get("assumptions", []), 742 | "stage": "query_refined" 743 | } 744 | 745 | # Check if this is a simple query that can be answered directly 746 | if not clarification_result.get("requires_search", True): 747 | direct_answer = clarification_result.get("direct_answer", "") 748 | if direct_answer: 749 | yield { 750 | "status_update": "查询可以直接回答,无需搜索", 751 | "requires_search": False, 752 | "direct_answer": direct_answer, 753 | "final_report": direct_answer, # 直接使用direct_answer作为最终报告 754 | "learnings": ["直接回答: " + direct_answer], 755 | "visitedUrls": [], 756 | "stage": "completed" 757 | } 758 | return # 不再需要继续执行查询流程 759 | 760 | # Step 4: Generate research plan with variable steps 761 | yield { 762 | "status_update": f"为查询生成研究计划: '{refined_query}'", 763 | "learnings": all_learnings, 764 | "visitedUrls": visited_urls, 765 | "current_query": refined_query, 766 | "stage": "planning" 767 | } 768 | 769 | plan_result = await generate_research_plan(refined_query, history_context) 770 | steps = plan_result.get("steps", []) 771 | 772 | yield { 773 | "status_update": f"生成了 {len(steps)} 步研究计划: {plan_result.get('assessments', '无评估')}", 774 | "learnings": all_learnings, 775 | "visitedUrls": visited_urls, 776 | "current_query": refined_query, 777 | "research_plan": steps, 778 | "stage": "plan_generated" 779 | } 780 | 781 | # Track step summaries for the final analysis 782 | step_summaries = [] 783 | 784 | # Iterate through each step in the research plan 785 | for step_idx, step in enumerate(steps): 786 | step_id = step.get("step_id", step_idx + 1) 787 | description = step.get("description", f"Research step {step_id}") 788 | search_queries = step.get("search_queries", [refined_query]) 789 | 790 | yield { 791 | "status_update": f"开始研究步骤 {step_id}/{len(steps)}: {description}", 792 | "learnings": all_learnings, 793 | "visitedUrls": visited_urls, 794 | "current_query": refined_query, 795 | "search_queries": search_queries, 796 | "current_step": step, 797 | "progress": { 798 | "current_step": step_id, 799 | "total_steps": len(steps) 800 | }, 801 | "stage": "step_started" 802 | } 803 | 804 | # Create a queue of queries to process for this step 805 | step_urls = [] 806 | step_learnings = [] 807 | 808 | current_queries = search_queries.copy() 809 | # Research these queries concurrently 810 | yield { 811 | "status_update": f"步骤 {step_id}/{len(steps)}: 并行研究 {len(current_queries)} 个查询", 812 | "learnings": all_learnings, 813 | "visitedUrls": visited_urls, 814 | "current_queries": current_queries, 815 | "progress": { 816 | "current_step": step_id, 817 | "total_steps": len(steps), 818 | "processed_queries": len(current_queries), 819 | }, 820 | "stage": "processing_queries" 821 | } 822 | 823 | # Process each query in the current batch 824 | research_tasks = [] 825 | for current_query in current_queries: 826 | task = research_step( 827 | query=current_query, 828 | config=config, 829 | on_progress=on_progress, 830 | search_provider=search_provider, 831 | ) 832 | research_tasks.append(task) 833 | 834 | # Execute tasks with concurrency 835 | results = await asyncio.gather(*research_tasks) 836 | 837 | # Process the results 838 | for result in results: 839 | # Update tracking variables 840 | urls = result["urls"] 841 | content = result["extracted_content"] 842 | step_urls.extend(urls) 843 | step_learnings.append(content) 844 | 845 | # Format learnings and URLs for display 846 | formatted_learnings = [] 847 | for i, learning in enumerate(step_learnings): 848 | formatted_learnings.append(f"[{i + 1}] {learning}") 849 | 850 | formatted_urls = [] 851 | for i, url in enumerate(step_urls): 852 | formatted_urls.append(f"[{i + 1}] {url}") 853 | 854 | # Truncate longer learnings for display 855 | new_learnings = [str(i)[:400] for i in step_learnings] 856 | yield { 857 | "status_update": f"步骤 {step_id}/{len(steps)}: 发现 {len(new_learnings)} 个新见解", 858 | "learnings": all_learnings + step_learnings, 859 | "visitedUrls": visited_urls + step_urls, 860 | "new_learnings": new_learnings, 861 | "formatted_new_learnings": formatted_learnings, 862 | "new_urls": step_urls, 863 | "formatted_new_urls": formatted_urls, 864 | "progress": { 865 | "current_step": step_id, 866 | "total_steps": len(steps), 867 | "processed_queries": len(current_queries) 868 | }, 869 | "stage": "insights_found" 870 | } 871 | 872 | # Update visited URLs and learnings 873 | visited_urls.extend(step_urls) 874 | all_learnings.extend(step_learnings) 875 | 876 | # Save step summary 877 | step_summaries.append({ 878 | "step_id": step_id, 879 | "description": description, 880 | "learnings": step_learnings, 881 | "urls": step_urls 882 | }) 883 | 884 | # Format all step learnings for display 885 | formatted_step_learnings = [] 886 | for i, learning in enumerate(step_learnings): 887 | formatted_step_learnings.append(f"[{i + 1}] {learning}") 888 | 889 | # Step completion update 890 | yield { 891 | "status_update": f"完成研究步骤 {step_id}/{len(steps)}: {description},获得 {len(step_learnings)} 个见解", 892 | "learnings": all_learnings, 893 | "visitedUrls": visited_urls, 894 | "step_learnings": step_learnings, 895 | "formatted_step_learnings": formatted_step_learnings, 896 | "step_urls": step_urls, 897 | "progress": { 898 | "current_step": step_id, 899 | "total_steps": len(steps), 900 | "completed": True 901 | }, 902 | "stage": "step_completed" 903 | } 904 | 905 | enable_next_plan = config.get("research", {}).get("enable_next_plan", False) 906 | if enable_next_plan: 907 | # Perform final analysis 908 | yield { 909 | "status_update": "分析所有已收集的信息...", 910 | "learnings": all_learnings, 911 | "visitedUrls": visited_urls, 912 | "stage": "final_analysis" 913 | } 914 | 915 | steps_summary_text = "\n\n".join([ 916 | f"步骤 {s['step_id']}: {s['description']}\n发现: {json.dumps(s['learnings'], ensure_ascii=False)}" 917 | for s in step_summaries 918 | ]) 919 | 920 | future_research = RESEARCH_SUMMARY_PROMPT.format( 921 | query=refined_query, 922 | steps_summary=steps_summary_text, 923 | current_date=get_current_date() 924 | ) 925 | 926 | future_research_result = await generate_json_completion(future_research, temperature=0.7) 927 | 928 | # Add final findings to learnings 929 | findings = [] 930 | if "findings" in future_research_result: 931 | all_learnings.extend(future_research_result["findings"]) 932 | findings = future_research_result["findings"] 933 | 934 | # Format final findings 935 | formatted_findings = [] 936 | for i, finding in enumerate(findings): 937 | formatted_findings.append(f"[{i + 1}] {finding}") 938 | 939 | yield { 940 | "status_update": f"分析完成,得出 {len(findings)} 个主要发现", 941 | "learnings": all_learnings, 942 | "visitedUrls": visited_urls, 943 | "final_findings": findings, 944 | "formatted_final_findings": formatted_findings, 945 | "gaps": future_research_result.get("gaps", []), 946 | "recommendations": future_research_result.get("recommendations", []), 947 | "stage": "analysis_completed" 948 | } 949 | else: 950 | future_research_result = "" 951 | # No need to modify all_learnings when skipping summary 952 | 953 | yield { 954 | "status_update": "生成详细研究报告...", 955 | "learnings": all_learnings, 956 | "visitedUrls": visited_urls, 957 | "stage": "generating_report" 958 | } 959 | 960 | # Generate report (non-streaming for now) 961 | context = str(all_learnings) + '\n\n' + str(future_research_result) 962 | final_report = await write_final_report(refined_query, context, history_context) 963 | 964 | # Return compiled results 965 | yield { 966 | "status_update": "研究完成!", 967 | "query": refined_query, 968 | "originalQuery": query, 969 | "learnings": all_learnings, 970 | "visitedUrls": list(set(visited_urls)), 971 | "summary": future_research_result, 972 | "final_report": final_report, 973 | "stage": "completed" 974 | } 975 | 976 | except Exception as e: 977 | logger.error(f"Error in deep research stream: {str(e)}") 978 | logger.error(traceback.format_exc()) 979 | 980 | yield { 981 | "status_update": f"错误: {str(e)}", 982 | "error": str(e), 983 | "learnings": all_learnings, 984 | "visitedUrls": visited_urls, 985 | "stage": "error" 986 | } 987 | -------------------------------------------------------------------------------- /src/gradio_chat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | 6 | A simplified Gradio demo for Deep Research with basic conversation interface. 7 | """ 8 | 9 | import time 10 | import gradio as gr 11 | from loguru import logger 12 | from .config import get_config 13 | from .deep_research import ( 14 | deep_research_stream, 15 | generate_followup_questions, 16 | process_clarifications, 17 | write_final_report, 18 | should_clarify_query 19 | ) 20 | 21 | # Load configuration 22 | config = get_config() 23 | 24 | 25 | def run_gradio_demo(): 26 | """Run a modern Gradio demo for Deep Research using ChatInterface""" 27 | enable_clarification = config.get("research", {}).get("enable_clarification", False) 28 | search_source = config.get("research", {}).get("search_source", "tavily") 29 | 30 | # Conversation state (shared across functions) 31 | conversation_state = { 32 | "current_query": "", 33 | "needs_clarification": False, 34 | "questions": [], 35 | "waiting_for_clarification": False, 36 | "clarification_answers": {}, 37 | "last_status": "", # 跟踪最后一个状态更新 38 | "history": [], # 保存当前聊天历史 39 | "search_source": search_source, # 存储搜索提供商 40 | "enable_clarification": enable_clarification # 存储是否启用澄清 41 | } 42 | 43 | async def research_with_thinking(message, history): 44 | """处理查询,展示研究过程并返回结果""" 45 | if not message: 46 | yield history # 空消息,直接返回 47 | return # 无值返回是允许的 48 | 49 | # 重置状态,确保多次查询之间状态不混淆 50 | conversation_state["last_status"] = "" 51 | conversation_state["current_query"] = "" 52 | conversation_state["questions"] = [] 53 | 54 | # 判断是否是澄清回答,如果是则不重置waiting_for_clarification 55 | if not conversation_state["waiting_for_clarification"]: 56 | conversation_state["waiting_for_clarification"] = False 57 | conversation_state["clarification_answers"] = {} 58 | 59 | logger.debug( 60 | f"Starting research, message: {message}, search_source: {search_source}, enable_clarification: {enable_clarification}") 61 | 62 | # 构建历史上下文 - 直接使用history即可 63 | history_context = '' 64 | for msg in history: 65 | if isinstance(msg, dict) and msg.get("role") == "user": 66 | q = 'Q:' + msg.get("content", "") + '\n' 67 | history_context += q 68 | 69 | # 3. 检查是否是对澄清问题的回答 70 | if conversation_state["waiting_for_clarification"]: 71 | async for msg in handle_clarification_answer(message, history, history_context): 72 | yield msg 73 | return # 无值返回是允许的 74 | 75 | # 4. 创建研究过程消息并将其添加到历史记录 76 | messages = [] 77 | messages.append({"role": "assistant", "content": "正在进行研究...", "metadata": {"title": "研究过程"}}) 78 | yield messages 79 | 80 | # 5. 处理澄清环节 81 | if not enable_clarification: 82 | messages[-1]["content"] = "跳过澄清环节,直接开始研究..." 83 | yield messages 84 | else: 85 | # 分析查询是否需要澄清 86 | messages[-1]["content"] = "分析查询需求中..." 87 | yield messages 88 | 89 | needs_clarification = await should_clarify_query(message, history_context) 90 | if needs_clarification: 91 | messages[-1]["content"] = "生成澄清问题..." 92 | yield messages 93 | 94 | followup_result = await generate_followup_questions(message, history_context) 95 | questions = followup_result.get("questions", []) 96 | 97 | if questions: 98 | # 保存问题和状态 99 | conversation_state["current_query"] = message 100 | conversation_state["questions"] = questions 101 | conversation_state["waiting_for_clarification"] = True 102 | 103 | # 显示问题给用户 104 | clarification_msg = "请回答以下问题,帮助我更好地理解您的查询:" 105 | for i, q in enumerate(questions, 1): 106 | clarification_msg += f"\n{i}. {q.get('question', '')}" 107 | 108 | # 替换研究过程消息为澄清问题 109 | messages[-1] = {"role": "assistant", "content": clarification_msg} 110 | yield messages 111 | return # 等待用户回答 112 | else: 113 | messages[-1]["content"] = "无法生成有效的澄清问题,继续研究..." 114 | yield messages 115 | else: 116 | messages[-1]["content"] = "查询已足够清晰,开始研究..." 117 | yield messages 118 | 119 | # 6. 开始搜索 120 | messages[-1]["content"] = f"使用 {search_source} 搜索相关信息中..." 121 | yield messages 122 | 123 | # 7. 执行研究过程 124 | final_result = None 125 | research_log = [] 126 | 127 | async for partial_result in deep_research_stream( 128 | query=message, 129 | search_source=search_source, 130 | history_context=history_context, 131 | enable_clarification=enable_clarification 132 | ): 133 | # 更新研究进度 134 | if partial_result.get("status_update"): 135 | status = partial_result.get("status_update") 136 | stage = partial_result.get("stage", "") 137 | 138 | # 检查状态是否有变化 139 | if status != conversation_state["last_status"]: 140 | conversation_state["last_status"] = status 141 | 142 | # 更新研究进度消息 143 | timestamp = time.strftime('%H:%M:%S') 144 | status_line = f"[{timestamp}] {status}" 145 | research_log.append(status_line) 146 | 147 | # 显示当前研究计划步骤 148 | if partial_result.get("current_step"): 149 | current_step = partial_result.get("current_step") 150 | step_id = current_step.get("step_id", "") 151 | description = current_step.get("description", "") 152 | step_line = f"当前步骤 {step_id}: {description}" 153 | research_log.append(step_line) 154 | 155 | # 显示当前查询 156 | if partial_result.get("current_queries"): 157 | queries = partial_result.get("current_queries") 158 | queries_lines = ["**当前并行查询**:"] 159 | for i, q in enumerate(queries, 1): 160 | queries_lines.append(f"{i}. {q}") 161 | research_log.append("\n".join(queries_lines)) 162 | 163 | # 对于特定阶段,添加更多信息 164 | if stage == "plan_generated" and partial_result.get("research_plan"): 165 | research_plan = partial_result.get("research_plan") 166 | plan_lines = ["**研究计划**:"] 167 | for i, step in enumerate(research_plan): 168 | step_id = step.get("step_id", i + 1) 169 | description = step.get("description", "") 170 | plan_lines.append(f"步骤 {step_id}: {description}") 171 | research_log.append("\n".join(plan_lines)) 172 | 173 | # 添加阶段详细信息 174 | if stage == "insights_found" and partial_result.get("formatted_new_learnings"): 175 | if partial_result.get("formatted_new_urls") and len( 176 | partial_result.get("formatted_new_urls")) > 0: 177 | research_log.append("\n**来源**:\n" + "\n".join( 178 | partial_result.get("formatted_new_urls", [])[:3])) 179 | 180 | elif stage == "step_completed" and partial_result.get("formatted_step_learnings"): 181 | research_log.append("\n**步骤总结**:\n" + "\n".join( 182 | partial_result.get("formatted_step_learnings", []))) 183 | 184 | elif stage == "analysis_completed" and partial_result.get("formatted_final_findings"): 185 | research_log.append("\n**主要发现**:\n" + "\n".join( 186 | partial_result.get("formatted_final_findings", []))) 187 | 188 | if partial_result.get("gaps"): 189 | research_log.append("\n\n**研究空白**:\n- " + "\n- ".join(partial_result.get("gaps", []))) 190 | 191 | # 合并所有日志并更新研究过程消息 192 | messages[-1]["content"] = "\n\n".join(research_log) 193 | yield messages 194 | 195 | # 保存最后一个结果用于生成报告 196 | final_result = partial_result 197 | 198 | # 如果有最终报告,跳出循环 199 | if "final_report" in partial_result: 200 | break 201 | 202 | # 8. 生成报告 203 | if final_result: 204 | # 如果直接在结果中有final_report,直接使用 205 | if "final_report" in final_result: 206 | report = final_result["final_report"] 207 | # 标记研究过程消息已完成 208 | research_process = messages[-1]["content"] 209 | messages[-1]["content"] = "研究完成,报告已生成。\n\n" + research_process 210 | yield messages 211 | else: 212 | # 否则,使用收集到的信息生成报告 213 | research_process = messages[-1]["content"] 214 | messages[-1]["content"] = "正在整合研究结果并生成报告...\n\n" + research_process 215 | yield messages 216 | 217 | learnings = final_result.get("learnings", []) 218 | 219 | try: 220 | report = await write_final_report( 221 | query=message, 222 | context=str(learnings), 223 | history_context=history_context 224 | ) 225 | # 确保report不为None 226 | if report is None: 227 | report = "抱歉,无法生成研究报告。" 228 | logger.error(f"write_final_report returned None for query: {message}") 229 | except Exception as e: 230 | report = f"生成报告时出错: {str(e)}" 231 | logger.error(f"Error in write_final_report: {str(e)}") 232 | 233 | # 保留研究过程信息 234 | messages[-1]["content"] = "研究完成,报告已生成。\n\n" + research_process 235 | yield messages 236 | 237 | # 添加最终报告消息,但保留研究过程消息 238 | messages.append({"role": "assistant", "content": report}) 239 | yield messages 240 | else: 241 | messages.append( 242 | {"role": "assistant", "content": "抱歉,我无法为您的查询生成研究报告。请尝试其他问题或稍后再试。"}) 243 | yield messages 244 | 245 | async def handle_clarification_answer(message, history, history_context): 246 | """处理用户对澄清问题的回答""" 247 | # 重置等待标志 248 | conversation_state["waiting_for_clarification"] = False 249 | 250 | # 获取原始查询和问题 251 | query = conversation_state["current_query"] 252 | questions = conversation_state["questions"] 253 | 254 | # 重置状态,确保多次查询之间状态不混淆 255 | conversation_state["last_status"] = "" 256 | 257 | # 1. 创建消息列表并添加研究过程消息 258 | messages = [] 259 | messages.append({"role": "assistant", "content": "解析您的澄清回答...", "metadata": {"title": "研究过程"}}) 260 | yield messages 261 | 262 | # 2. 解析用户回答 263 | lines = [line.strip() for line in message.split('\n') if line.strip()] 264 | if len(lines) < len(questions): 265 | # 尝试逗号分隔 266 | if ',' in message: 267 | lines = [ans.strip() for ans in message.split(',')] 268 | 269 | # 3. 创建响应字典 270 | user_responses = {} 271 | for i, q in enumerate(questions): 272 | key = q.get("key", f"q{i}") 273 | if i < len(lines) and lines[i]: 274 | user_responses[key] = lines[i] 275 | 276 | # 4. 处理澄清内容 277 | messages[-1]["content"] = "处理您的澄清内容..." 278 | yield messages 279 | 280 | # 5. 处理澄清并优化查询 281 | clarification_result = await process_clarifications( 282 | query=query, 283 | user_responses=user_responses, 284 | all_questions=questions, 285 | history_context=history_context 286 | ) 287 | 288 | # 6. 获取优化后的查询 289 | refined_query = clarification_result.get("refined_query", query) 290 | messages[-1]["content"] = f"已优化查询: {refined_query}" 291 | yield messages 292 | 293 | # 7. 检查是否可以直接回答 294 | if not clarification_result.get("requires_search", True) and clarification_result.get("direct_answer"): 295 | direct_answer = clarification_result.get("direct_answer", "") 296 | 297 | # 保留研究过程消息,并添加直接回答 298 | research_process = messages[-1]["content"] 299 | messages[-1]["content"] = "提供直接回答,无需搜索。\n\n" + research_process 300 | yield messages 301 | 302 | # 添加最终回答,但保留研究过程 303 | messages.append({"role": "assistant", "content": direct_answer}) 304 | yield messages 305 | 306 | # 8. 开始搜索 307 | messages[-1]["content"] = "基于您的澄清开始搜索信息..." 308 | yield messages 309 | 310 | # 9. 执行研究过程 311 | final_result = None 312 | research_log = [] 313 | 314 | async for partial_result in deep_research_stream( 315 | query=refined_query, 316 | user_clarifications=user_responses, 317 | search_source=search_source, 318 | history_context=history_context 319 | ): 320 | # 更新研究进度 321 | if partial_result.get("status_update"): 322 | status = partial_result.get("status_update") 323 | stage = partial_result.get("stage", "") 324 | 325 | # 检查状态是否有变化 326 | if status != conversation_state["last_status"]: 327 | conversation_state["last_status"] = status 328 | 329 | # 更新研究进度消息 330 | timestamp = time.strftime('%H:%M:%S') 331 | status_line = f"[{timestamp}] {status}" 332 | research_log.append(status_line) 333 | 334 | # 显示当前研究计划步骤 335 | if partial_result.get("current_step"): 336 | current_step = partial_result.get("current_step") 337 | step_id = current_step.get("step_id", "") 338 | description = current_step.get("description", "") 339 | step_line = f"当前步骤 {step_id}: {description}" 340 | research_log.append(step_line) 341 | 342 | # 显示当前查询 343 | if partial_result.get("current_queries"): 344 | queries = partial_result.get("current_queries") 345 | queries_lines = ["当前并行查询:"] 346 | for i, q in enumerate(queries, 1): 347 | queries_lines.append(f"{i}. {q}") 348 | research_log.append("\n".join(queries_lines)) 349 | 350 | # 对于特定阶段,添加更多信息 351 | if stage == "plan_generated" and partial_result.get("research_plan"): 352 | research_plan = partial_result.get("research_plan") 353 | plan_lines = ["研究计划:"] 354 | for i, step in enumerate(research_plan): 355 | step_id = step.get("step_id", i + 1) 356 | description = step.get("description", "") 357 | plan_lines.append(f"步骤 {step_id}: {description}") 358 | research_log.append("\n".join(plan_lines)) 359 | 360 | # 添加阶段详细信息 361 | if stage == "insights_found" and partial_result.get("formatted_new_learnings"): 362 | if partial_result.get("formatted_new_urls") and len( 363 | partial_result.get("formatted_new_urls")) > 0: 364 | research_log.append("\n**来源**:\n" + "\n".join( 365 | partial_result.get("formatted_new_urls", [])[:3])) 366 | 367 | elif stage == "step_completed" and partial_result.get("formatted_step_learnings"): 368 | research_log.append("\n**步骤总结**:\n" + "\n".join( 369 | partial_result.get("formatted_step_learnings", []))) 370 | 371 | elif stage == "analysis_completed" and partial_result.get("formatted_final_findings"): 372 | research_log.append("\n**主要发现**:\n" + "\n".join( 373 | partial_result.get("formatted_final_findings", []))) 374 | 375 | if partial_result.get("gaps"): 376 | research_log.append("\n\n**研究空白**:\n- " + "\n- ".join(partial_result.get("gaps", []))) 377 | 378 | # 合并所有日志并更新研究过程消息 379 | messages[-1]["content"] = "\n\n".join(research_log) 380 | yield messages 381 | 382 | # 保存最后一个结果用于生成报告 383 | final_result = partial_result 384 | 385 | # 如果有最终报告,跳出循环 386 | if "final_report" in partial_result: 387 | break 388 | 389 | # 10. 生成报告 390 | if final_result: 391 | # 如果直接在结果中有final_report,直接使用 392 | if "final_report" in final_result: 393 | report = final_result["final_report"] 394 | # 标记研究过程消息已完成 395 | research_process = messages[-1]["content"] 396 | messages[-1]["content"] = "研究完成,报告已生成。\n\n" + research_process 397 | yield messages 398 | else: 399 | # 否则,使用收集到的信息生成报告 400 | research_process = messages[-1]["content"] 401 | messages[-1]["content"] = "正在整合研究结果并生成报告...\n\n" + research_process 402 | yield messages 403 | 404 | learnings = final_result.get("learnings", []) 405 | 406 | try: 407 | report = await write_final_report( 408 | query=refined_query, 409 | context=str(learnings), 410 | history_context=history_context 411 | ) 412 | # 确保report不为None 413 | if report is None: 414 | report = "抱歉,无法生成研究报告。" 415 | logger.error(f"returned None for query: {refined_query}") 416 | except Exception as e: 417 | report = f"生成报告时出错: {str(e)}" 418 | logger.error(f"Error in write_final_report: {str(e)}") 419 | 420 | # 保留研究过程信息 421 | messages[-1]["content"] = "研究完成,报告已生成。\n\n" + research_process 422 | yield messages 423 | 424 | # 添加最终报告消息,但保留研究过程消息 425 | messages.append({"role": "assistant", "content": report}) 426 | yield messages 427 | else: 428 | messages.append( 429 | {"role": "assistant", "content": "抱歉,我无法为您的查询生成研究报告。请尝试其他问题或稍后再试。"}) 430 | yield messages 431 | 432 | # 创建 ChatInterface 433 | demo = gr.ChatInterface( 434 | research_with_thinking, 435 | type='messages', 436 | title="🔍 Deep Research", 437 | description="使用此工具进行深度研究,我将搜索互联网为您找到回答。Powered by [Deep Research](https://github.com/shibing624/deep-research) Made with ❤️ by [shibing624](https://github.com/shibing624)", 438 | examples=[ 439 | ["特斯拉股票的最新行情?"], 440 | ["介绍一下最近的人工智能技术发展趋势"], 441 | ["中国2024年GDP增长了多少?"], 442 | ["Explain the differences between supervised and unsupervised machine learning."] 443 | ] 444 | ) 445 | 446 | # 启动界面 447 | demo.queue() 448 | demo.launch(server_name="0.0.0.0", share=False, server_port=7860, show_api=False) 449 | 450 | 451 | if __name__ == "__main__": 452 | run_gradio_demo() 453 | -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | 6 | Utility functions for interacting with language models. 7 | """ 8 | 9 | import json 10 | from typing import Dict, Any, List, Optional, AsyncGenerator, Union 11 | from loguru import logger 12 | 13 | from .providers import get_model 14 | 15 | 16 | async def generate_completion( 17 | prompt: str, 18 | system_message: str = "You are a helpful AI research assistant.", 19 | temperature: float = 0.7, 20 | json_response: bool = False, 21 | stream: bool = False, 22 | is_report: bool = False 23 | ) -> Union[str, AsyncGenerator[str, None]]: 24 | """ 25 | Generate a completion from the language model. 26 | 27 | Args: 28 | prompt: The prompt to send to the model 29 | system_message: The system message to use 30 | temperature: The temperature to use for generation 31 | json_response: Whether to request a JSON response 32 | stream: Whether to stream the response 33 | is_report: Whether to use the report model configuration 34 | 35 | Returns: 36 | The model's response as a string or an async generator of response chunks 37 | """ 38 | model_config = get_model(is_report=is_report) 39 | 40 | messages = [ 41 | {"role": "system", "content": system_message}, 42 | {"role": "user", "content": prompt} 43 | ] 44 | 45 | # Prepare the request parameters 46 | request_params = { 47 | "model": model_config["model"], 48 | "messages": messages, 49 | "temperature": temperature, 50 | } 51 | 52 | # Add response format if JSON is requested 53 | if json_response: 54 | request_params["response_format"] = {"type": "json_object"} 55 | 56 | # Add streaming parameter if requested 57 | if stream: 58 | request_params["stream"] = True 59 | 60 | try: 61 | # Make the API call 62 | response = await model_config["async_client"].chat.completions.create(**request_params) 63 | 64 | if stream: 65 | # Return an async generator for streaming responses 66 | async def response_generator(): 67 | collected_chunks = [] 68 | async for chunk in response: 69 | if chunk.choices and chunk.choices[0].delta.content: 70 | content = chunk.choices[0].delta.content 71 | collected_chunks.append(content) 72 | yield content 73 | 74 | # If no chunks were yielded, yield the empty string 75 | if not collected_chunks: 76 | yield "" 77 | 78 | return response_generator() 79 | else: 80 | # Return the full response for non-streaming 81 | res = response.choices[0].message.content 82 | logger.debug(f"prompt: {prompt}\n\nGenerated completion: {res}") 83 | return res 84 | 85 | except Exception as e: 86 | logger.error(f"Error generating completion: {str(e)}") 87 | if stream: 88 | # Return an empty generator for streaming 89 | async def error_generator(): 90 | yield f"Error: {str(e)}" 91 | 92 | return error_generator() 93 | else: 94 | # Return an error message for non-streaming 95 | return f"Error: {str(e)}" 96 | 97 | 98 | async def generate_json_completion( 99 | prompt: str, 100 | system_message: str = "You are a helpful AI research assistant.", 101 | temperature: float = 0.7 102 | ) -> Dict[str, Any]: 103 | """ 104 | Generate a JSON completion from the language model. 105 | 106 | Args: 107 | prompt: The prompt to send to the model 108 | system_message: The system message to use 109 | temperature: The temperature to use for generation 110 | 111 | Returns: 112 | The model's response parsed as a JSON object 113 | """ 114 | response_text = "" 115 | try: 116 | response_text = await generate_completion( 117 | prompt=prompt, 118 | system_message=system_message, 119 | temperature=temperature, 120 | json_response=True 121 | ) 122 | 123 | # Parse the JSON response 124 | result = json.loads(response_text) 125 | return result 126 | 127 | except json.JSONDecodeError as e: 128 | logger.error(f"Error parsing JSON response: {str(e)}") 129 | logger.error(f"Response text: {response_text}") 130 | return {} 131 | 132 | except Exception as e: 133 | logger.error(f"Error in generate_json_completion: {str(e)}") 134 | return {} 135 | -------------------------------------------------------------------------------- /src/mp_search_client.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import httpx 7 | import aiohttp 8 | from typing import Dict, Any, Optional 9 | import json 10 | from loguru import logger 11 | 12 | from .config import get_config 13 | 14 | 15 | class MPSearchClient: 16 | """Client for MP Search API.""" 17 | 18 | def __init__(self): 19 | config = get_config() 20 | self.api_key = config.get("mp_search", {}).get("api_key", "") 21 | self.base_url = config.get("mp_search", {}).get("base_url", "https://api.mpsrch.com/v1/search") 22 | self.forward_service = config.get("mp_search", {}).get("forward_service", "hyaide-application-1111") 23 | self.client = httpx.Client(timeout=30.0) 24 | 25 | def search_sync(self, query: str, options: Dict[str, Any] = None) -> Dict[str, Any]: 26 | """ 27 | Perform a search using MP Search API. 28 | 29 | Args: 30 | query: Search query 31 | options: Additional options for the search 32 | 33 | Returns: 34 | Dict containing search results 35 | """ 36 | if not self.api_key: 37 | raise ValueError("MP Search API key not configured") 38 | 39 | if options is None: 40 | options = {} 41 | 42 | # Default payload 43 | payload = { 44 | "query": query, 45 | "forward_service": self.forward_service, 46 | "query_id": options.get("query_id", f"qid_{hash(query)}"), 47 | "stream": options.get("stream", False) 48 | } 49 | 50 | headers = { 51 | "Authorization": f"Bearer {self.api_key}" 52 | } 53 | 54 | try: 55 | logger.debug(f"Searching with MP Search API: {query}") 56 | response = self.client.post( 57 | self.base_url, 58 | headers=headers, 59 | json=payload 60 | ) 61 | response.raise_for_status() 62 | result = response.json() 63 | 64 | # Transform the result to match the expected format 65 | transformed_result = self._transform_result(result, query) 66 | logger.debug(f"Transformed result: {transformed_result}") 67 | return transformed_result 68 | 69 | except Exception as e: 70 | logger.error(f"Error searching with MP Search API: {str(e)}") 71 | raise 72 | 73 | async def search(self, query: str, options: Dict[str, Any] = None) -> Dict[str, Any]: 74 | """ 75 | Perform an async search using MP Search API. 76 | 77 | Args: 78 | query: Search query 79 | options: Additional options for the search 80 | 81 | Returns: 82 | Dict containing search results 83 | """ 84 | if not self.api_key: 85 | raise ValueError("MP Search API key not configured") 86 | 87 | if options is None: 88 | options = {} 89 | 90 | # Default payload 91 | payload = { 92 | "query": query, 93 | "forward_service": self.forward_service, 94 | "query_id": options.get("query_id", f"qid_{hash(query)}"), 95 | "stream": options.get("stream", False) 96 | } 97 | 98 | headers = { 99 | "Authorization": f"Bearer {self.api_key}" 100 | } 101 | 102 | try: 103 | logger.debug(f"Searching with MP Search API: {query}") 104 | async with aiohttp.ClientSession() as session: 105 | async with session.post( 106 | self.base_url, 107 | headers=headers, 108 | json=payload 109 | ) as response: 110 | response.raise_for_status() 111 | 112 | text_content = await response.text() 113 | try: 114 | # Try to parse as JSON anyway 115 | result = json.loads(text_content) 116 | except json.JSONDecodeError: 117 | logger.error(f"Error parsing JSON response: {text_content}") 118 | result = '' 119 | 120 | # Transform the result to match the expected format 121 | transformed_result = self._transform_result(result, query) 122 | logger.debug(f"Transformed result: {transformed_result}") 123 | return transformed_result 124 | 125 | except Exception as e: 126 | logger.error(f"Error searching with MP Search API: {str(e)}") 127 | raise 128 | 129 | def _transform_result(self, mp_result: Dict[str, Any], query: str) -> Dict[str, Any]: 130 | """ 131 | Transform MP Search API result to match the expected format. 132 | 133 | Args: 134 | mp_result: Raw result from MP Search API 135 | query: Original search query 136 | 137 | Returns: 138 | Transformed result in compatible format 139 | """ 140 | transformed_data = [] 141 | logger.debug(f"Transforming MP Search result: {mp_result}") 142 | try: 143 | # 直接使用result字段作为内容 144 | if "result" in mp_result and mp_result["result"]: 145 | res = None 146 | if isinstance(mp_result["result"], str): 147 | res = json.loads(mp_result["result"]) 148 | for item in res: 149 | if "value" in item and item['value']: 150 | val = item['value'] 151 | title, content, dt, url = val.split(" ||| ") 152 | transformed_item = { 153 | "url": url, 154 | "title": title, 155 | "content": str(content)[:4000], 156 | "source": "mp_search" 157 | } 158 | transformed_data.append(transformed_item) 159 | except Exception as e: 160 | logger.error(f"Error transforming MP Search result: {str(e)}") 161 | # Add a fallback item with the error 162 | transformed_item = { 163 | "url": "", 164 | "title": "MP Search Error", 165 | "content": f"Error processing result: {str(e)}\nRaw result: {mp_result}", 166 | "source": "mp_search" 167 | } 168 | transformed_data.append(transformed_item) 169 | 170 | return { 171 | "query": query, 172 | "data": transformed_data 173 | } 174 | -------------------------------------------------------------------------------- /src/prompts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | 6 | Prompts used for deep research functionality. 7 | """ 8 | 9 | SHOULD_CLARIFY_QUERY_PROMPT = """ 10 | 请判断以下查询是否需要澄清问题。 11 | 一个好的查询应该明确、具体且包含足够的上下文。 12 | 如果查询模糊、缺少重要上下文、过于宽泛或包含多个可能的解释,则需要澄清。 13 | 14 | 对话历史: 15 | ``` 16 | {history_context} 17 | ``` 18 | 19 | 查询是: ```{query}``` 20 | 21 | 当前日期是{current_date}。 22 | 23 | 请只回答 "yes" 或 "no"。如果查询已经足够清晰,请回答"no"。 24 | """ 25 | 26 | # Prompt for generating follow-up questions 27 | FOLLOW_UP_QUESTIONS_PROMPT = """ 28 | You are an expert researcher and I need your help to generate clarifying questions for a given research query. 29 | 30 | chat history: 31 | ``` 32 | {history_context} 33 | ``` 34 | 35 | The query is: ```{query}``` 36 | 37 | Based on this query, please generate clarifying questions that would help you better understand what the user is looking for. 38 | For effective questions: 39 | 1. Identify ambiguous terms or concepts that need clarification 40 | 2. Ask about the scope or timeframe of interest 41 | 3. Check if there are specific aspects the user is most interested in 42 | 4. Consider what background information might be helpful 43 | 5. Ask about intended use of the information (academic, personal interest, decision-making, etc.) 44 | 45 | - User's query is written in Chinese, 需要用中文输出. 46 | - 当前日期是{current_date}。 47 | 48 | Format your response as a valid JSON object with the following structure: 49 | {{ 50 | "needs_clarification": true/false (boolean indicating if clarification questions are needed), 51 | "questions": [ 52 | {{ 53 | "key": "specific_key_1", 54 | "question": "The clarifying question text", 55 | "default": "A reasonable default answer if the user doesn't provide one" 56 | }}, 57 | ... additional questions ... 58 | ] 59 | }} 60 | 61 | If the query seems clear enough and doesn't require clarification, return "needs_clarification": false with an empty questions array. 62 | For simple factual queries or clear requests, clarification is usually not needed. 63 | """ 64 | 65 | # Prompt for processing clarifications 66 | PROCESS_CLARIFICATIONS_PROMPT = """ 67 | I'm reviewing a user query with clarification questions and their responses. 68 | 69 | Chat history: ``` 70 | {history_context} 71 | ``` 72 | 73 | Original query: ```{query}``` 74 | 75 | Clarification questions and responses: 76 | ``` 77 | {clarifications} 78 | ``` 79 | 80 | Questions that were not answered: 81 | ``` 82 | {unanswered_questions} 83 | ``` 84 | 85 | Based on this information, please: 86 | 1. Summarize the original query with the additional context provided by the clarifications 87 | 2. For questions that were not answered, use reasonable default assumptions and clearly state what you're assuming 88 | 3. Identify if this is a simple factual query that doesn't require search 89 | - User's query is written in Chinese, 需要用中文输出. 90 | - 当前日期是{current_date}。 91 | 92 | Format your response as a valid JSON object with the following structure: 93 | {{ 94 | "refined_query": "The refined and clarified query", 95 | "assumptions": ["List of assumptions made for unanswered questions"], 96 | "requires_search": true/false (boolean indicating if this query needs web search or can be answered directly), 97 | "direct_answer": "If requires_search is false, provide a direct answer here, otherwise empty string" 98 | }} 99 | """ 100 | 101 | # Prompt for no clarifications needed 102 | PROCESS_NO_CLARIFICATIONS_PROMPT = """ 103 | I'm reviewing a user query where they chose not to provide any clarifications. 104 | 105 | Chat history: 106 | ``` 107 | {history_context} 108 | ``` 109 | 110 | Original query: ```{query}``` 111 | 112 | The user was asked the following clarification questions but chose not to answer any: 113 | ``` 114 | {unanswered_questions} 115 | ``` 116 | 117 | Since the user didn't provide any clarifications, please: 118 | 1. Analyze the original query as comprehensively as possible 119 | 2. Make reasonable assumptions for all ambiguous aspects 120 | 3. Determine if this is a simple factual query that doesn't require search 121 | 4. If possible, provide a direct answer along with the refined query 122 | - User's query is written in Chinese, 需要用中文输出. 123 | - 当前日期是{current_date}。 124 | 125 | Format your response as a valid JSON object with the following structure: 126 | {{ 127 | "refined_query": "The refined query with all possible considerations", 128 | "assumptions": ["List of all assumptions made"], 129 | "requires_search": true/false (boolean indicating if this query needs web search or can be answered directly), 130 | "direct_answer": "If requires_search is false, provide a comprehensive direct answer here, otherwise empty string" 131 | }} 132 | 133 | Since the user chose not to provide clarifications, be as thorough and comprehensive as possible in your analysis and answer. 134 | """ 135 | 136 | # Prompt for generating research plan 137 | RESEARCH_PLAN_PROMPT = """ 138 | You are an expert researcher creating a flexible research plan for a given query. 139 | 140 | Chat history: 141 | ``` 142 | {history_context} 143 | ``` 144 | 145 | QUERY: ```{query}``` 146 | 147 | Please analyze this query and create an appropriate research plan. The number of steps should vary based on complexity: 148 | - For simple questions, you might need only 1 steps 149 | - For moderately complex questions, 2 steps may be appropriate 150 | - For very complex questions, 3 or more steps may be needed 151 | - User's query is written in Chinese, 需要用中文输出. 152 | - 当前日期是{current_date}。 153 | 154 | Consider: 155 | 1. The complexity of the query 156 | 2. Whether multiple angles of research are needed 157 | 3. If the topic requires exploration of causes, effects, comparisons, or historical context 158 | 4. If the topic is controversial and needs different perspectives 159 | 160 | Format your response as a valid JSON object with the following structure: 161 | {{ 162 | "assessments": "Brief assessment of query complexity and reasoning", 163 | "steps": [ 164 | {{ 165 | "step_id": 1, 166 | "description": "Description of this research step", 167 | "search_queries": ["search query 1", "search query 2", ...], 168 | "goal": "What this step aims to discover" 169 | }}, 170 | ... additional steps as needed ... 171 | ] 172 | }} 173 | 174 | Make each step logical and focused on a specific aspect of the research. Steps should build on each other, 175 | and search queries should be specific and effective for web search. 176 | """ 177 | 178 | # Prompt for extract search results 179 | EXTRACT_SEARCH_RESULTS_SYSTEM_PROMPT = "You are an expert in extracting the most relevant and detailed information from search results." 180 | EXTRACT_SEARCH_RESULTS_PROMPT = """ 181 | User query: ```{query}``` 182 | 183 | search result(Webpage Content): 184 | ``` 185 | {search_results} 186 | ``` 187 | 188 | - User's query is written in Chinese, 需要用中文输出. 189 | - 当前日期是{current_date}。 190 | 191 | 作为信息提取专家,请从网页内容中提取与用户查询最相关的核心片段。需要提取的内容要求: 192 | 1. 包含具体的细节、数据、定义和重要论点,不要使用笼统的总结替代原始的详细内容 193 | 2. 保留原文中的关键事实、数字、日期和引用 194 | 3. 提取完整的相关段落,而不仅仅是简短的摘要 195 | 4. 特别关注可以直接回答用户查询的内容 196 | 5. 如果内容包含表格或列表中的重要信息,请完整保留这些结构化数据 197 | 198 | Output your response in the following JSON format: 199 | {{ 200 | "extracted_infos": [ 201 | {{ 202 | "info": "核心片段1,包含详细内容、数据和定义等", 203 | "url": "url 1", 204 | "relevance": "解释这段内容与查询的相关性" 205 | }}, 206 | {{ 207 | "info": "核心片段2,包含详细内容、数据和定义等", 208 | "url": "url 2", 209 | "relevance": "解释这段内容与查询的相关性" 210 | }}, 211 | ... 212 | ] 213 | }} 214 | 215 | - info: 保留原文格式的关键信息片段,包含详细内容而非简单摘要 216 | - url: 信息来源的网页URL 217 | - relevance: 简要说明这段内容如何回答了用户的查询 218 | """ 219 | 220 | # Prompt for final research summary 221 | RESEARCH_SUMMARY_PROMPT = """ 222 | Based on our research, we've explored the query: ```{query}``` 223 | 224 | Research Summary by Step: 225 | ``` 226 | {steps_summary} 227 | ``` 228 | 229 | Please analyze this information and provide: 230 | 1. A set of key findings that answer the main query 231 | 2. Identification of any areas where the research is lacking or more information is needed 232 | - User's query is written in Chinese, 需要用中文输出. 233 | - 当前日期是{current_date}。 234 | 235 | Format your response as a valid JSON object with: 236 | {{ 237 | "findings": [{{"finding": "finding 1", "url": "cite url 1"}}, {{"finding": "finding 2", "url": "cite url 2"}}, ...], (key conclusions from the research, and the cite url) 238 | "gaps": ["gap 1", "gap 2", ...], (areas where more research is needed) 239 | "recommendations": ["recommendation 1", "recommendation 2", ...] (suggestions for further research directions) 240 | }} 241 | """ 242 | 243 | FINAL_REPORT_SYSTEM_PROMPT = """You are an expert researcher. Follow these instructions when responding: 244 | - You may be asked to research subjects that is after your knowledge cutoff, assume the user is right when presented with news. 245 | - The user is a highly experienced analyst, no need to simplify it, be as detailed as possible and make sure your response is correct. 246 | - Be highly organized. 247 | - Suggest solutions that I didn't think about. 248 | - Be proactive and anticipate my needs. 249 | - Treat me as an expert in all subject matter. 250 | - Mistakes erode my trust, so be accurate and thorough. 251 | - Provide detailed explanations, I'm comfortable with lots of detail. 252 | - Value good arguments over authorities, the source is irrelevant. 253 | - Consider new technologies and contrarian ideas, not just the conventional wisdom. 254 | - User's query is written in Chinese, 需要用中文输出. 255 | - 当前日期是{current_date}。 256 | """ 257 | 258 | # Prompt for final report 259 | FINAL_REPORT_PROMPT = """ 260 | I've been researching the following query: ```{query}``` 261 | 262 | Please write a comprehensive research report on this topic. 263 | The report should be well-structured with headings, subheadings, and a conclusion. 264 | 265 | [要求]: 266 | - 输出markdown格式的回答。 267 | - [context]是参考资料,回答中需要包含引用来源,格式为 [cite](url) ,其中url是实际的链接。 268 | - 除代码、专名外,你必须使用与问题相同语言回答。 269 | 270 | Chat history: 271 | ``` 272 | {history_context} 273 | ``` 274 | 275 | [context]: 276 | ``` 277 | {context} 278 | ``` 279 | """ 280 | 281 | # Prompt for final answer 282 | FINAL_ANSWER_PROMPT = """ 283 | I've been researching the following query: ```{query}``` 284 | 285 | 详细、专业回答用户的query。 286 | 287 | [要求]: 288 | - 输出markdown格式的回答。 289 | - [context]是参考资料,回答中需要包含引用来源,格式为 [cite](url) ,其中url是实际的链接。 290 | - 除代码、专名外,你必须使用与问题相同语言回答。 291 | 292 | Chat history: 293 | ``` 294 | {history_context} 295 | ``` 296 | 297 | [context]: 298 | ``` 299 | {context} 300 | ``` 301 | """ 302 | -------------------------------------------------------------------------------- /src/providers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | from typing import Dict, Any, Optional 7 | import openai 8 | from .config import get_config 9 | 10 | 11 | def get_model(is_report: bool = False) -> Dict[str, Any]: 12 | """ 13 | Get model configuration including client and model name. 14 | 15 | :param is_report: Whether to get the model configuration for a report 16 | 17 | Returns: 18 | Dict containing model configuration 19 | """ 20 | config = get_config() 21 | if is_report: 22 | report_config = config.get("report_llm", {}) 23 | api_key = report_config.get("api_key", "") 24 | model = report_config.get("model", "gpt-4o") 25 | base_url = report_config.get("base_url", None) 26 | else: 27 | openai_config = config.get("openai", {}) 28 | api_key = openai_config.get("api_key", "") 29 | model = openai_config.get("model", "gpt-4o-mini") 30 | base_url = openai_config.get("base_url", None) 31 | 32 | # Initialize OpenAI client 33 | client_args = {"api_key": api_key} 34 | if base_url: 35 | client_args["base_url"] = base_url 36 | 37 | client = openai.OpenAI(**client_args) 38 | async_client = openai.AsyncOpenAI(**client_args) 39 | 40 | return { 41 | "client": client, 42 | "async_client": async_client, 43 | "model": model 44 | } 45 | 46 | 47 | def get_search_provider(search_source=None): 48 | """ 49 | Get the appropriate search provider based on configuration. 50 | 51 | Returns: 52 | An instance of the search provider class 53 | """ 54 | if search_source is None: 55 | config = get_config() 56 | search_source = config.get("research", {}).get("search_source", "serper") 57 | 58 | if search_source == "mp_search": 59 | from .mp_search_client import MPSearchClient 60 | return MPSearchClient() 61 | elif search_source == "tavily": 62 | from .tavily_client import TavilyClient 63 | return TavilyClient() 64 | else: # Default to serper 65 | from .serper_client import SerperClient 66 | return SerperClient() 67 | -------------------------------------------------------------------------------- /src/search_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | Utility functions for handling search operations, including concurrent searches. 6 | """ 7 | 8 | import asyncio 9 | from typing import Dict, Any, List, Optional, Callable 10 | from loguru import logger 11 | 12 | from .config import get_config 13 | from .providers import get_search_provider 14 | 15 | 16 | async def search_with_query( 17 | query: str, 18 | config: Dict[str, Any] = None, 19 | search_provider=None 20 | ) -> Dict[str, Any]: 21 | """ 22 | Search results for a query. 23 | 24 | Args: 25 | query: The query to search for 26 | config: Configuration dictionary 27 | search_provider: Search provider instance 28 | 29 | Returns: 30 | Dict with summary and URLs 31 | """ 32 | if config is None: 33 | config = get_config() 34 | 35 | if search_provider is None: 36 | search_provider = get_search_provider() 37 | 38 | try: 39 | # Perform the search 40 | results = await search_provider.search(query) 41 | 42 | # For web search, ensure URLs are captured 43 | if hasattr(search_provider, 'get_organic_urls'): 44 | urls = search_provider.get_organic_urls() 45 | else: 46 | urls = [] 47 | 48 | # Default to 2 results if not specified in config 49 | max_results = config.get("research", {}).get("max_results_per_query", 2) 50 | 51 | # Ensure results is a list before slicing 52 | if isinstance(results, list): 53 | results_to_process = results[:max_results] 54 | elif isinstance(results, dict) and "data" in results and isinstance(results["data"], list): 55 | # Some search providers return a dict with a "data" field containing the results 56 | results_to_process = results["data"][:max_results] 57 | else: 58 | # If results is not a list or a dict with a "data" field, use an empty list 59 | results_to_process = [] 60 | logger.warning(f"Unexpected search results format: {type(results)}") 61 | 62 | search_results_text = results_to_process if results_to_process else "No search results found." 63 | # Return the formatted results and URLs 64 | return { 65 | "summary": search_results_text, 66 | "urls": urls, 67 | "raw_results": results_to_process 68 | } 69 | 70 | except Exception as e: 71 | logger.error(f"Error in search: {str(e)}") 72 | return { 73 | "summary": f"Error searching for '{query}': {str(e)}", 74 | "urls": [], 75 | "raw_results": [] 76 | } 77 | 78 | 79 | async def concurrent_search( 80 | queries: List[str], 81 | config: Dict[str, Any] = None, 82 | search_provider=None 83 | ) -> List[Dict[str, Any]]: 84 | """ 85 | Perform concurrent searches for multiple queries. 86 | 87 | Args: 88 | queries: List of queries to search for 89 | config: Configuration dictionary 90 | search_provider: Search provider instance 91 | 92 | Returns: 93 | List of search results 94 | """ 95 | if config is None: 96 | config = get_config() 97 | 98 | if search_provider is None: 99 | search_provider = get_search_provider() 100 | 101 | # Get the concurrency limit from config 102 | concurrency_limit = config.get("research", {}).get("concurrency_limit", 1) 103 | 104 | # Create a semaphore to limit concurrency 105 | semaphore = asyncio.Semaphore(concurrency_limit) 106 | 107 | async def search_with_semaphore(query: str) -> Dict[str, Any]: 108 | """Perform a search with semaphore-based concurrency control.""" 109 | async with semaphore: 110 | return await search_with_query(query, config, search_provider) 111 | 112 | # Create tasks for all queries 113 | tasks = [search_with_semaphore(query) for query in queries] 114 | 115 | # Run all tasks concurrently and gather results 116 | results = await asyncio.gather(*tasks) 117 | 118 | return results 119 | -------------------------------------------------------------------------------- /src/serper_client.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import os 7 | import json 8 | import aiohttp 9 | import httpx 10 | from typing import Dict, List, Any, Union, Optional 11 | from loguru import logger 12 | 13 | from .config import get_config 14 | 15 | 16 | class SerperClient: 17 | """Client for the Serper.dev API to perform web searches.""" 18 | 19 | def __init__(self): 20 | config = get_config() 21 | self.api_key = config.get("serper", {}).get("api_key", "") 22 | 23 | if not self.api_key: 24 | logger.warning("No Serper API key found. Searches will fail.") 25 | 26 | self.api_url = "https://google.serper.dev/search" 27 | self.headers = { 28 | "X-API-KEY": self.api_key, 29 | "Content-Type": "application/json" 30 | } 31 | self.organic_urls = [] 32 | 33 | def search_sync(self, query: str) -> List[Dict[str, Any]]: 34 | """ 35 | Perform a search using the Serper API. 36 | 37 | Args: 38 | query: Search query 39 | 40 | Returns: 41 | List of search result items 42 | """ 43 | try: 44 | payload = json.dumps({ 45 | "q": query 46 | }) 47 | 48 | response = httpx.post(self.api_url, headers=self.headers, data=payload) 49 | response.raise_for_status() 50 | 51 | result = response.json() 52 | self.organic_urls = self._extract_urls(result) 53 | 54 | # Format the results for consumption 55 | formatted_results = self._format_results(result) 56 | return formatted_results 57 | 58 | except Exception as e: 59 | logger.error(f"Error searching with Serper: {str(e)}") 60 | return [] 61 | 62 | async def search(self, query: str) -> List[Dict[str, Any]]: 63 | """ 64 | Perform an async search using the Serper API. 65 | 66 | Args: 67 | query: Search query 68 | 69 | Returns: 70 | List of search result items 71 | """ 72 | try: 73 | payload = json.dumps({ 74 | "q": query 75 | }) 76 | 77 | async with aiohttp.ClientSession() as session: 78 | async with session.post(self.api_url, headers=self.headers, data=payload) as response: 79 | result = await response.json() 80 | 81 | if response.status != 200: 82 | logger.error(f"Serper API error: {result}") 83 | return [] 84 | 85 | self.organic_urls = self._extract_urls(result) 86 | 87 | # Format the results for consumption 88 | formatted_results = self._format_results(result) 89 | return formatted_results 90 | 91 | except Exception as e: 92 | logger.error(f"Error searching with Serper: {str(e)}") 93 | return [] 94 | 95 | def _extract_urls(self, result: Dict[str, Any]) -> List[str]: 96 | """ 97 | Extract URLs from search results. 98 | 99 | Args: 100 | result: Search result object 101 | 102 | Returns: 103 | List of URLs 104 | """ 105 | urls = [] 106 | 107 | # Extract organic results 108 | if "organic" in result: 109 | for item in result["organic"]: 110 | if "link" in item: 111 | urls.append(item["link"]) 112 | 113 | return urls 114 | 115 | def _format_results(self, result: Dict[str, Any]) -> List[Dict[str, Any]]: 116 | """ 117 | Format search results into a standardized format. 118 | 119 | Args: 120 | result: Search result object 121 | 122 | Returns: 123 | List of formatted result items 124 | """ 125 | formatted_results = [] 126 | 127 | # Format organic results 128 | if "organic" in result: 129 | for item in result["organic"]: 130 | formatted_item = { 131 | "title": item.get("title", ""), 132 | "url": item.get("link", ""), 133 | "snippet": item.get("snippet", ""), 134 | "content": f"{item.get('title', '')} - {item.get('snippet', '')}" 135 | } 136 | formatted_results.append(formatted_item) 137 | 138 | # Include featured snippet if available 139 | if "answerBox" in result: 140 | answer_box = result["answerBox"] 141 | formatted_item = { 142 | "title": answer_box.get("title", "Featured Snippet"), 143 | "url": answer_box.get("link", ""), 144 | "snippet": answer_box.get("snippet", ""), 145 | "content": answer_box.get("answer", answer_box.get("snippet", "")) 146 | } 147 | formatted_results.insert(0, formatted_item) 148 | 149 | return formatted_results 150 | 151 | def get_organic_urls(self) -> List[str]: 152 | """ 153 | Get the URLs of organic search results from the last search. 154 | 155 | Returns: 156 | List of URLs 157 | """ 158 | return self.organic_urls 159 | -------------------------------------------------------------------------------- /src/tavily_client.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | 6 | Client for Tavily Search API. 7 | 8 | curl -X POST https://api.tavily.com/search \\ 9 | -H 'Content-Type: application/json' \\ 10 | -H 'Authorization: Bearer tvly-dev-xxx' \\ 11 | -d '{ 12 | "query": "good" 13 | }' 14 | 15 | Example response: 16 | { 17 | "query": "good", 18 | "follow_up_questions": null, 19 | "answer": null, 20 | "images": [], 21 | "results": [ 22 | { 23 | "title": "GOOD Definition & Meaning - Merriam-Webster", 24 | "url": "https://www.merriam-webster.com/dictionary/good", 25 | "content": "good a good many of us good : something that is good...", 26 | "score": 0.7283819, 27 | "raw_content": null 28 | }, 29 | ... 30 | ], 31 | "response_time": 1.38 32 | } 33 | """ 34 | import httpx 35 | import aiohttp 36 | from typing import Dict, Any, Optional 37 | from loguru import logger 38 | 39 | from .config import get_config 40 | 41 | 42 | class TavilyClient: 43 | """Client for Tavily Search API.""" 44 | 45 | def __init__(self): 46 | config = get_config() 47 | self.api_key = config.get("tavily", {}).get("api_key") 48 | self.base_url = config.get("tavily", {}).get("base_url", "https://api.tavily.com/search") 49 | self.client = httpx.Client(timeout=30.0) 50 | 51 | def search_sync(self, query: str, options: Dict[str, Any] = None) -> Dict[str, Any]: 52 | """ 53 | Perform a search using Tavily API. 54 | 55 | Args: 56 | query: Search query 57 | options: Additional options for the search 58 | 59 | Returns: 60 | Dict containing search results 61 | """ 62 | if not self.api_key: 63 | raise ValueError("Tavily API key not configured") 64 | 65 | if options is None: 66 | options = {} 67 | 68 | # Default payload for Tavily API 69 | payload = { 70 | "query": query 71 | } 72 | 73 | # Add optional parameters if provided 74 | if "search_depth" in options: 75 | payload["search_depth"] = options.get("search_depth") 76 | if "include_domains" in options and options["include_domains"]: 77 | payload["include_domains"] = options.get("include_domains") 78 | if "exclude_domains" in options and options["exclude_domains"]: 79 | payload["exclude_domains"] = options.get("exclude_domains") 80 | if "max_results" in options: 81 | payload["max_results"] = options.get("max_results") 82 | 83 | # Use Authorization Bearer format as per the official API example 84 | headers = { 85 | "Content-Type": "application/json", 86 | "Authorization": f"Bearer {self.api_key}" 87 | } 88 | 89 | try: 90 | logger.debug(f"Searching with Tavily API: {query}") 91 | response = self.client.post( 92 | self.base_url, 93 | headers=headers, 94 | json=payload 95 | ) 96 | response.raise_for_status() 97 | result = response.json() 98 | 99 | # Transform the result to match the expected format 100 | transformed_result = self._transform_result(result, query) 101 | return transformed_result 102 | 103 | except Exception as e: 104 | logger.error(f"Error searching with Tavily API: {str(e)}") 105 | raise 106 | 107 | async def search(self, query: str, options: Dict[str, Any] = None) -> Dict[str, Any]: 108 | """ 109 | Perform an async search using Tavily API. 110 | 111 | Args: 112 | query: Search query 113 | options: Additional options for the search 114 | 115 | Returns: 116 | Dict containing search results 117 | """ 118 | if not self.api_key: 119 | raise ValueError("Tavily API key not configured") 120 | 121 | if options is None: 122 | options = {} 123 | 124 | # Default payload for Tavily API 125 | payload = { 126 | "query": query 127 | } 128 | 129 | # Add optional parameters if provided 130 | if "search_depth" in options: 131 | payload["search_depth"] = options.get("search_depth") 132 | if "include_domains" in options and options["include_domains"]: 133 | payload["include_domains"] = options.get("include_domains") 134 | if "exclude_domains" in options and options["exclude_domains"]: 135 | payload["exclude_domains"] = options.get("exclude_domains") 136 | if "max_results" in options: 137 | payload["max_results"] = options.get("max_results") 138 | 139 | # Use Authorization Bearer format as per the official API example 140 | headers = { 141 | "Content-Type": "application/json", 142 | "Authorization": f"Bearer {self.api_key}" 143 | } 144 | 145 | try: 146 | logger.debug(f"Searching with Tavily API: {query}") 147 | async with aiohttp.ClientSession() as session: 148 | async with session.post( 149 | self.base_url, 150 | headers=headers, 151 | json=payload 152 | ) as response: 153 | response.raise_for_status() 154 | result = await response.json() 155 | 156 | # Transform the result to match the expected format 157 | transformed_result = self._transform_result(result, query) 158 | return transformed_result 159 | 160 | except Exception as e: 161 | logger.error(f"Error searching with Tavily API: {str(e)}") 162 | raise 163 | 164 | def _transform_result(self, tavily_result: Dict[str, Any], query: str) -> Dict[str, Any]: 165 | """ 166 | Transform Tavily API result to match the expected format. 167 | 168 | Args: 169 | tavily_result: Raw result from Tavily API 170 | query: Original search query 171 | 172 | Returns: 173 | Transformed result in compatible format 174 | """ 175 | transformed_data = [] 176 | 177 | # Process results from Tavily API 178 | if "results" in tavily_result and isinstance(tavily_result["results"], list): 179 | for item in tavily_result["results"]: 180 | content = "" 181 | 182 | # Extract content from the result 183 | if "content" in item and item["content"]: 184 | content += item["content"] + "\n\n" 185 | 186 | # Add score information if available 187 | if "score" in item: 188 | content += f"Relevance Score: {item['score']}\n\n" 189 | 190 | transformed_item = { 191 | "url": item.get("url", ""), 192 | "title": item.get("title", ""), 193 | "content": content.strip(), 194 | "source": "tavily" 195 | } 196 | transformed_data.append(transformed_item) 197 | 198 | # If answer is provided by Tavily, add it as a special result 199 | if "answer" in tavily_result and tavily_result["answer"]: 200 | transformed_item = { 201 | "url": "", 202 | "title": "Tavily Direct Answer", 203 | "content": tavily_result["answer"], 204 | "source": "tavily_answer" 205 | } 206 | transformed_data.append(transformed_item) 207 | 208 | # If follow-up questions are provided, add them as a special result 209 | if "follow_up_questions" in tavily_result and tavily_result["follow_up_questions"]: 210 | follow_up_content = "Suggested follow-up questions:\n\n" 211 | for question in tavily_result["follow_up_questions"]: 212 | follow_up_content += f"- {question}\n" 213 | 214 | transformed_item = { 215 | "url": "", 216 | "title": "Suggested Follow-up Questions", 217 | "content": follow_up_content, 218 | "source": "tavily_follow_up" 219 | } 220 | transformed_data.append(transformed_item) 221 | 222 | # If no results were found or processing failed 223 | if not transformed_data: 224 | transformed_item = { 225 | "url": "", 226 | "title": f"Tavily Search Result for: {query}", 227 | "content": "No results found or could not process the search response.", 228 | "source": "tavily" 229 | } 230 | transformed_data.append(transformed_item) 231 | 232 | # Add response time if available 233 | if "response_time" in tavily_result: 234 | logger.debug(f"Tavily search completed in {tavily_result['response_time']} seconds") 235 | 236 | return { 237 | "query": query, 238 | "data": transformed_data 239 | } -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Tests package initialization -------------------------------------------------------------------------------- /tests/test_providers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | from src.providers import get_model, trim_prompt 4 | 5 | 6 | class TestProviders(unittest.TestCase): 7 | """测试 providers 模块功能""" 8 | 9 | @patch('src.providers.get_config') 10 | @patch('src.providers.openai_client') 11 | def test_get_model(self, mock_openai_client, mock_get_config): 12 | """测试获取模型配置""" 13 | # 模拟配置 14 | mock_get_config.return_value = { 15 | "openai": { 16 | "model": "test-model" 17 | } 18 | } 19 | 20 | # 模拟 OpenAI 客户端 21 | mock_openai_client.__bool__.return_value = True 22 | 23 | # 获取模型配置 24 | model_config = get_model() 25 | 26 | # 验证结果 27 | self.assertEqual(model_config["client"], mock_openai_client) 28 | self.assertEqual(model_config["model"], "test-model") 29 | 30 | def test_trim_prompt_short(self): 31 | """测试短提示不需要裁剪""" 32 | prompt = "This is a short prompt" 33 | result = trim_prompt(prompt, context_size=1000) 34 | self.assertEqual(result, prompt) 35 | 36 | @patch('src.providers.encoder.encode') 37 | def test_trim_prompt_long(self, mock_encode): 38 | """测试长提示需要裁剪""" 39 | # 模拟编码器返回超长 token 数 40 | mock_encode.return_value = [0] * 2000 41 | 42 | prompt = "This is a very long prompt" * 100 43 | result = trim_prompt(prompt, context_size=1000) 44 | 45 | # 验证结果被裁剪 46 | self.assertLess(len(result), len(prompt)) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tests/test_serper_client.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | from src.serper_client import SerperClient 4 | 5 | 6 | class TestSerperClient(unittest.TestCase): 7 | """测试 Serper 搜索客户端""" 8 | 9 | @patch('src.serper_client.get_config') 10 | def setUp(self, mock_get_config): 11 | """测试前准备""" 12 | # 模拟配置 13 | mock_get_config.return_value = { 14 | "serper": { 15 | "api_key": "test_api_key", 16 | "base_url": "https://test.serper.dev/search" 17 | } 18 | } 19 | 20 | # 创建客户端实例 21 | self.client = SerperClient() 22 | 23 | def test_init(self): 24 | """测试初始化""" 25 | self.assertEqual(self.client.api_key, "test_api_key") 26 | self.assertEqual(self.client.base_url, "https://test.serper.dev/search") 27 | 28 | @patch('httpx.Client.post') 29 | def test_search(self, mock_post): 30 | """测试搜索功能""" 31 | # 模拟响应 32 | mock_response = MagicMock() 33 | mock_response.status_code = 200 34 | mock_response.json.return_value = { 35 | "organic": [ 36 | { 37 | "title": "Test Result", 38 | "link": "https://example.com", 39 | "snippet": "This is a test result" 40 | } 41 | ] 42 | } 43 | mock_post.return_value = mock_response 44 | 45 | # 执行搜索 46 | result = self.client.search("test query") 47 | 48 | # 验证请求 49 | mock_post.assert_called_once() 50 | args, kwargs = mock_post.call_args 51 | self.assertEqual(kwargs["headers"]["X-API-KEY"], "test_api_key") 52 | self.assertEqual(kwargs["json"]["q"], "test query") 53 | 54 | # 验证结果转换 55 | self.assertEqual(result["query"], "test query") 56 | self.assertEqual(len(result["data"]), 1) 57 | self.assertEqual(result["data"][0]["title"], "Test Result") 58 | self.assertEqual(result["data"][0]["url"], "https://example.com") 59 | self.assertEqual(result["data"][0]["content"], "This is a test result") 60 | 61 | @patch('httpx.Client.post') 62 | def test_search_with_options(self, mock_post): 63 | """测试带选项的搜索""" 64 | # 模拟响应 65 | mock_response = MagicMock() 66 | mock_response.status_code = 200 67 | mock_response.json.return_value = {"organic": []} 68 | mock_post.return_value = mock_response 69 | 70 | # 执行搜索 71 | self.client.search("test query", {"gl": "us", "num": 5}) 72 | 73 | # 验证请求选项 74 | args, kwargs = mock_post.call_args 75 | self.assertEqual(kwargs["json"]["gl"], "us") 76 | self.assertEqual(kwargs["json"]["num"], 5) 77 | 78 | 79 | if __name__ == '__main__': 80 | unittest.main() 81 | --------------------------------------------------------------------------------