├── .gitattributes ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── config.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── ROADMAP.md ├── assets ├── db │ └── .gitkeep └── models │ └── .gitkeep ├── cli.py ├── docs └── README_zh.md ├── rapid_rag ├── __init__.py ├── config.yaml ├── encoder │ ├── __init__.py │ ├── erniebot.py │ └── sentence_transformer.py ├── file_loader │ ├── __init__.py │ ├── image_loader.py │ ├── main.py │ ├── office_loader.py │ ├── pdf_loader.py │ └── txt_loader.py ├── llm │ ├── __init__.py │ ├── baichuan_7b.py │ ├── chatglm2_6b.py │ ├── ernie_bot_turbo.py │ ├── internlm_7b.py │ ├── llama2.py │ ├── ollama.py │ ├── openai.py │ └── qwen7b_chat.py ├── text_splitter │ ├── __init__.py │ └── chinese_text_splitter.py ├── utils │ ├── __init__.py │ ├── logger.py │ └── utils.py └── vector_utils │ ├── __init__.py │ └── sqlite_version.py ├── requirements.txt ├── tests ├── demo_store_embedding.py ├── test_bge.py ├── test_chatglm2_6b.py ├── test_file_loader.py ├── test_files │ ├── office │ │ ├── excel_with_image.xlsx │ │ ├── ppt_example.pptx │ │ └── word_example.docx │ ├── test.jpg │ ├── test.md │ ├── test.txt │ ├── word_example.pdf │ └── 长安三万里.pdf ├── test_llama2_7b_chat.py ├── test_m3e.py ├── test_office_loader.py ├── test_qwen.py ├── test_search.py └── test_sql_insert.py └── webui.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Set the default behavior, in case people don't have core.autocrlf set. 2 | * text=auto 3 | 4 | # Explicitly declare text files you want to always be normalized and converted 5 | # to native line endings on checkout. 6 | *.c text 7 | *.h text 8 | *.py text 9 | *.md text 10 | *.js text 11 | *.cpp text 12 | 13 | # Declare files that will always have CRLF line endings on checkout. 14 | *.sln text eol=crlf 15 | 16 | # Denote all files that are truly binary and should not be modified. 17 | *.png binary 18 | *.jpg binary 19 | *.pdf binary -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "🐛 Bug Report" 2 | description: Create a report to help us improve Lotus Docs 3 | body: 4 | - type: markdown 5 | attributes: 6 | value: | 7 | Thanks for taking the time to fill out this bug report! 8 | 9 | Please note that this tracker is only for bugs. Do not use the issue tracker for help or feature requests. 10 | 11 | [Our docs](https://rapidai.github.io/RapidRAG/) are a great place for most answers, but if you can't find your answer there, you can ask in [community discussion forum](https://github.com/RapidAI/RapidRAG/discussions/categories/q-a). 12 | 13 | Have a feature request? Please search the ideas [on our forum](https://github.com/RapidAI/RapidRAG/discussions/categories/feature-requests) to make sure that the feature has not yet been requested. If you cannot find what you had in mind, please [submit your feature request here](https://github.com/colinwilson/lotusdocs/discussions/new?category=feature-requests). 14 | 15 | Want to show off your Lotus Docs themed website? Post a link, screenshot (optional), and details in [our Show & tell forum](https://github.com/RapidAI/RapidRAG/discussions/categories/show-and-tell). 16 | 17 | **Thanks!** 18 | - type: checkboxes 19 | attributes: 20 | label: Past Issues Searched 21 | options: 22 | - label: >- 23 | I have searched open and closed issues to make sure that the bug has 24 | not yet been reported 25 | required: true 26 | - type: checkboxes 27 | attributes: 28 | label: Issue is a Bug Report 29 | options: 30 | - label: >- 31 | This is a bug report and not a feature request, nor asking for support 32 | required: true 33 | - type: textarea 34 | id: bug-description 35 | attributes: 36 | label: Describe the bug 37 | description: A clear and concise description of what the bug is 38 | placeholder: Tell us what happened! 39 | validations: 40 | required: true 41 | - type: textarea 42 | id: bug-expectation 43 | attributes: 44 | label: Expected behavior 45 | description: A clear and concise description of what you expected to happen 46 | placeholder: Tell us what you expected 47 | validations: 48 | required: true 49 | - type: textarea 50 | id: bug-screenshots 51 | attributes: 52 | label: Screenshots 53 | description: 'If applicable, add screenshots to help explain your problem' 54 | placeholder: Insert screenshots here 55 | - type: textarea 56 | attributes: 57 | label: Environment 58 | description: | 59 | examples: 60 | - **OS**: MacOS 61 | - **Browser**: Firefox 62 | - **Browser Version**: 115 63 | value: | 64 | - OS: 65 | - Browser: 66 | - Browser Version: 67 | render: markdown -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: ❓ Questions 4 | url: https://github.com/RapidAI/RapidRAG/discussions/categories/q-a 5 | about: Please use the community forum for help and questions regarding RapidRAG Docs 6 | - name: 💡 Feature requests and ideas 7 | url: https://github.com/RapidAI/RapidRAG/discussions/new?category=feature-requests 8 | about: Please vote for and post new feature ideas in the community forum 9 | - name: 📖 Documentation 10 | url: https://rapidai.github.io/RapidRAG/ 11 | about: A great place to find instructions and answers on how to run your custom RapidRAG. 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.db 2 | assets/models/m3e-small 3 | assets/raw_upload_files 4 | log/ 5 | 6 | # Created by .ignore support plugin (hsz.mobi) 7 | ### Python template 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | .pytest_cache 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | # *.manifest 42 | # *.spec 43 | *.res 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | #idea 141 | .vs 142 | .vscode 143 | .idea 144 | /images 145 | 146 | #models 147 | *.onnx 148 | 149 | *.ttf 150 | *.ttc 151 | 152 | long1.jpg 153 | 154 | *.bin 155 | *.mapping 156 | *.xml 157 | 158 | *.pdiparams 159 | *.pdiparams.info 160 | *.pdmodel 161 | 162 | .DS_Store 163 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitee.com/SWHL/autoflake 3 | rev: v2.1.1 4 | hooks: 5 | - id: autoflake 6 | args: 7 | [ 8 | "--recursive", 9 | "--in-place", 10 | "--remove-all-unused-imports", 11 | "--remove-unused-variable", 12 | "--ignore-init-module-imports", 13 | ] 14 | files: \.py$ 15 | - repo: https://gitee.com/SWHL/black 16 | rev: 23.1.0 17 | hooks: 18 | - id: black 19 | files: \.py$ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

🧐 Rapid RAG

4 |
5 | 6 | 7 | 8 | SemVer2.0 9 | 10 | GitHub 11 | 12 | [简体中文](./docs/README_zh.md) | English 13 |
14 | 15 | ### 📣 We're looking for front-end development engineers interested in Knowledge QA with LLM, who can help us achieve front-end and back-end separation with our current implementation 16 | 17 | ### Introduction 18 | 19 | - Questions & Answers based on local knowledge base + LLM. 20 | - Reason: 21 | - The idea of this project comes from [Langchain-Chatchat](https://github.com/chatchat-space/Langchain-Chatchat). 22 | - I have used this project before, but it is not very flexible and deployment is not very friendly. 23 | - Learn from the ideas in [How to build a knowledge question answering system with a large language model](https://mp.weixin.qq.com/s/movaNCWjJGBaes6KxhpYpg), and try to use this as a practice. 24 | - Advantage: 25 | - The whole project is modularized and does not depend on the `lanchain` library, each part can be easily replaced, and the code is simple and easy to understand. 26 | - In addition to the large language model interface that needs to be deployed separately, other parts can use CPU. 27 | - Support documents in common formats, including `txt, md, pdf, docx, pptx, excel` etc. Of course, other types of documents can also be customized and supported. 28 | 29 | ### Demo 30 | 31 | ⚠️ If you have Baidu Account, you can visit the [online demo](https://aistudio.baidu.com/projectdetail/6675380?contributionType=1) based on ERNIE Bot. 32 | 33 |
34 | 35 |
36 | 37 | ### Documentation 38 | 39 | Full documentation can be found on [docs](https://rapidai.github.io/RapidRAG/docs/), in Chinese. 40 | 41 | ### TODO 42 | 43 | - [ ] Support keyword + vector hybrid search. 44 | - [ ] Vue.js based UI . 45 | 46 | ### Code Contributors 47 | 48 |

49 | 50 | 51 | 52 |

53 | 54 | ### Contributing 55 | 56 | - Pull requests are welcome. For major changes, please open an issue first 57 | to discuss what you would like to change. 58 | - Please make sure to update tests as appropriate. 59 | 60 | ### [Sponsor](https://swhl.github.io/RapidVideOCR/docs/sponsor/) 61 | 62 | If you want to sponsor the project, you can directly click the **Buy me a coffee** image, please write a note (e.g. your github account name) to facilitate adding to the sponsorship list below. 63 | 64 |
65 | 66 |
67 | 68 | ### License 69 | 70 | [Apache 2.0](https://choosealicense.com/licenses/apache-2.0/) 71 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | 3 | ### Standard Evaluation Process 4 | 5 | Before proceeding with feature development and strategy optimization, we need a standard evaluation process to ensure all the features and strategies we introduce are effective. 6 | 7 | Create testsets using any dataset with advanced models and Ragas, then validate solution effectiveness using basic models. 8 | 9 | ### Feature Development and Strategy Optimization 10 | 11 | 1. BM25 Keyword Search 12 | 2. Hybrid Search (BM25 + Vector) 13 | 3. GraphRAG 14 | 4. ReRanking 15 | 5. Query Rewriting 16 | 6. Small-to-big 17 | 7. ... 18 | -------------------------------------------------------------------------------- /assets/db/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/assets/db/.gitkeep -------------------------------------------------------------------------------- /assets/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/assets/models/.gitkeep -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import uuid 5 | from pathlib import Path 6 | 7 | from rapid_rag.encoder import EncodeText 8 | from rapid_rag.file_loader import FileLoader 9 | from rapid_rag.llm import ERNIEBot 10 | from rapid_rag.utils import make_prompt, read_yaml 11 | from rapid_rag.vector_utils import DBUtils 12 | 13 | config = read_yaml("knowledge_qa_llm/config.yaml") 14 | 15 | extract = FileLoader() 16 | 17 | # 解析文档 18 | file_path = "tests/test_files/office/word_example.docx" 19 | text = extract(file_path) 20 | sentences = text.get(Path(file_path).name) 21 | 22 | # 提取特征 23 | model_path = config.get("Encoder")["m3e-small"] 24 | embedding_model = EncodeText(**model_path) 25 | embeddings = embedding_model(sentences) 26 | 27 | # 插入数据到数据库中 28 | db_tools = DBUtils(config.get("vector_db_path")) 29 | uid = str(uuid.uuid1()) 30 | db_tools.insert(file_path, embeddings, sentences, uid=uid) 31 | 32 | params = config.get("LLM_API")["ERNIEBot"] 33 | llm_engine = ERNIEBot(**params) 34 | 35 | print("欢迎使用 🧐 Knowledge QA LLM,输入“stop”终止程序 ") 36 | while True: 37 | query = input("\n😀 用户: ") 38 | if query.strip() == "stop": 39 | break 40 | 41 | embedding = embedding_model(query) 42 | 43 | search_res, search_elapse = db_tools.search_local(embedding_query=embedding) 44 | 45 | context = "\n".join(sum(search_res.values(), [])) 46 | print(f"上下文:\n{context}\n") 47 | 48 | prompt = make_prompt(query, context, custom_prompt=config.get("DEFAULT_PROMPT")) 49 | response = llm_engine(prompt, history=None) 50 | print(f"🤖 LLM:\n {response}") 51 | -------------------------------------------------------------------------------- /docs/README_zh.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

🧐 Rapid RAG

4 |
5 | 6 | 7 | 8 | SemVer2.0 9 | 10 | GitHub 11 | 12 | 简体中文 | [English](../README.md) 13 |
14 | 15 | ### 简介 16 | 17 | 基于本地知识库+LLM的问答系统。该项目的思路是由[langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)启发而来。 18 | 19 | - 缘由: 20 | - 之前使用过这个项目,感觉不是太灵活,部署不太友好。 21 | - 借鉴[如何用大语言模型构建一个知识问答系统](https://mp.weixin.qq.com/s/movaNCWjJGBaes6KxhpYpg)中思路,尝试以此作为实践。 22 | - 优势: 23 | - 整个项目为模块化配置,不依赖`lanchain`库,各部分可轻易替换,代码简单易懂。 24 | - 除需要单独部署大模型接口外,其他部分用CPU即可。 25 | - 支持常见格式文档,包括txt、md、pdf, docx, pptx, excel等等。当然,也可自定义支持其他类型文档。 26 | 27 | ### [Demo](https://aistudio.baidu.com/projectdetail/6675380?contributionType=1) 28 | 29 |
30 | 31 |
32 | 33 | ### 文档 34 | 35 | 完整文档请移步:[docs](https://rapidai.github.io/RapidRAG/docs). 36 | 37 | ### TODO 38 | 39 | - [ ] Support keyword + vector hybrid search. 40 | - [ ] Vue.js based UI . 41 | 42 | ### 贡献者 43 | 44 |

45 | 46 | 47 | 48 |

49 | 50 | ### 贡献指南 51 | 52 | 我们感谢所有的贡献者为改进和提升 RapidOCR 所作出的努力。 53 | 54 | - 欢迎提交请求。对于重大更改,请先打开issue讨论您想要改变的内容。 55 | - 请确保适当更新测试。 56 | 57 | ### [赞助](https://rapidai.github.io/RapidRAG/docs/sponsor/) 58 | 59 | 如果您想要赞助该项目,可直接点击当前页最上面的Sponsor按钮,请写好备注(**您的Github账号名称**),方便添加到赞助列表中。 60 | 61 | ### 开源许可证 62 | 63 | 该项目采用[Apache 2.0](https://choosealicense.com/licenses/apache-2.0/)开源许可证。 64 | -------------------------------------------------------------------------------- /rapid_rag/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | -------------------------------------------------------------------------------- /rapid_rag/config.yaml: -------------------------------------------------------------------------------- 1 | title: 🧐 Knowledge QA LLM 2 | version: 0.0.10 3 | 4 | LLM_API: 5 | ERNIEBot: 6 | api_type: aistudio 7 | access_token: your_token 8 | Qwen7B_Chat: 9 | api_url: your_api 10 | ChatGLM2_6B: 11 | api_url: your_api 12 | BaiChuan7B: 13 | api_url: your_api 14 | InternLM_7B: 15 | api_url: your_api 16 | 17 | DEFAULT_PROMPT: 问题是:$query,从下面文章里,找出能回答以上问题的答案。如果文中没有答案,回答“没找到答案”。 文章:$context\n 18 | 19 | upload_dir: assets/raw_upload_files 20 | vector_db_path: assets/db/DefaultVector.db 21 | 22 | encoder_batch_size: 16 23 | Encoder: 24 | ERNIEBot: 25 | api_type: aistudio 26 | access_token: your_token 27 | m3e-small: 28 | model_path: assets/models/m3e-small 29 | 30 | # text splitter 31 | SENTENCE_SIZE: 200 32 | 33 | top_k: 5 34 | 35 | Parameter: 36 | max_length: 37 | min_value: 0 38 | max_value: 4096 39 | default: 1024 40 | step: 1 41 | tip: 生成结果时的最大token数 42 | top_p: 43 | min_value: 0.0 44 | max_value: 1.0 45 | default: 0.7 46 | step: 0.01 47 | tip: 用于控制模型生成文本时,选择下一个单词的概率分布的范围。 48 | temperature: 49 | min_value: 0.01 50 | max_value: 1.0 51 | default: 0.01 52 | step: 0.01 53 | tip: 用于调整模型生成文本时的创造性程度,较高的temperature将使模型更有可能生成新颖、独特的文本,而较低的温度则更有可能生成常见或常规的文本 54 | -------------------------------------------------------------------------------- /rapid_rag/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .sentence_transformer import EncodeText 5 | from .erniebot import ErnieEncodeText 6 | -------------------------------------------------------------------------------- /rapid_rag/encoder/erniebot.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import random 5 | import time 6 | from typing import List 7 | 8 | import erniebot 9 | import numpy as np 10 | 11 | 12 | class ErnieEncodeText: 13 | def __init__(self, api_type: str, access_token: str): 14 | erniebot.api_type = api_type 15 | erniebot.access_token = access_token 16 | 17 | def __call__(self, sentences: List[str]): 18 | if not isinstance(sentences, List): 19 | sentences = [sentences] 20 | 21 | time.sleep(random.randint(3, 10)) 22 | response = erniebot.Embedding.create( 23 | model="ernie-text-embedding", 24 | input=sentences, 25 | ) 26 | datas = response.get("data", None) 27 | if not datas: 28 | return None 29 | 30 | embeddings = np.array([v["embedding"] for v in datas]) 31 | return embeddings 32 | -------------------------------------------------------------------------------- /rapid_rag/encoder/sentence_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from typing import List, Optional 5 | 6 | from sentence_transformers import SentenceTransformer 7 | 8 | 9 | class EncodeText: 10 | def __init__(self, model_path: Optional[str] = None) -> None: 11 | if model_path is None: 12 | raise EncodeTextError("model_path is None.") 13 | self.model = SentenceTransformer(model_path) 14 | 15 | def __call__(self, sentences: List[str]): 16 | if not isinstance(sentences, List): 17 | sentences = [sentences] 18 | return self.model.encode(sentences) 19 | 20 | 21 | class EncodeTextError(Exception): 22 | pass 23 | -------------------------------------------------------------------------------- /rapid_rag/file_loader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .main import FileLoader 5 | -------------------------------------------------------------------------------- /rapid_rag/file_loader/image_loader.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | from typing import List, Union 6 | 7 | from rapidocr_onnxruntime import RapidOCR 8 | 9 | from ..text_splitter.chinese_text_splitter import ChineseTextSplitter 10 | 11 | 12 | class ImageLoader: 13 | def __init__( 14 | self, 15 | ): 16 | self.ocr = RapidOCR() 17 | self.splitter = ChineseTextSplitter() 18 | 19 | def __call__(self, img_path: Union[str, Path]) -> List[str]: 20 | ocr_results, _ = self.ocr(img_path) 21 | _, rec_res, _ = list(zip(*ocr_results)) 22 | split_contents = [self.splitter.split_text(v) for v in rec_res] 23 | return sum(split_contents, []) 24 | -------------------------------------------------------------------------------- /rapid_rag/file_loader/main.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | from typing import Dict, List, Union 6 | 7 | import filetype 8 | 9 | from ..utils import logger 10 | from .image_loader import ImageLoader 11 | from .office_loader import OfficeLoader 12 | from .pdf_loader import PDFLoader 13 | from .txt_loader import TXTLoader 14 | 15 | INPUT_TYPE = Union[str, Path] 16 | 17 | 18 | class FileLoader: 19 | def __init__(self) -> None: 20 | self.file_map = { 21 | "office": ["docx", "doc", "ppt", "pptx", "xlsx", "xlx"], 22 | "image": ["jpg", "png", "bmp", "tif", "jpeg"], 23 | "txt": ["txt", "md"], 24 | "pdf": ["pdf"], 25 | } 26 | 27 | self.img_loader = ImageLoader() 28 | self.office_loader = OfficeLoader() 29 | self.pdf_loader = PDFLoader() 30 | self.txt_loader = TXTLoader() 31 | 32 | def __call__(self, file_path: INPUT_TYPE) -> Dict[str, List[str]]: 33 | all_content = {} 34 | 35 | file_list = self.get_file_list(file_path) 36 | for file_path in file_list: 37 | file_name = file_path.name 38 | 39 | if file_path.suffix[1:] in self.file_map["txt"]: 40 | content = self.txt_loader(file_path) 41 | all_content[file_name] = content 42 | continue 43 | 44 | file_type = self.which_type(file_path) 45 | if file_type in self.file_map["office"]: 46 | content = self.office_loader(file_path) 47 | elif file_type in self.file_map["pdf"]: 48 | content = self.pdf_loader(file_path) 49 | elif file_type in self.file_map["image"]: 50 | content = self.img_loader(file_path) 51 | else: 52 | logger.warning("%s does not support.", file_path) 53 | continue 54 | 55 | all_content[file_name] = content 56 | return all_content 57 | 58 | def get_file_list(self, file_path: INPUT_TYPE): 59 | if not isinstance(file_path, Path): 60 | file_path = Path(file_path) 61 | 62 | if file_path.is_dir(): 63 | return file_path.rglob("*.*") 64 | return [file_path] 65 | 66 | @staticmethod 67 | def which_type(content: Union[bytes, str, Path]) -> str: 68 | kind = filetype.guess(content) 69 | if kind is None: 70 | raise TypeError(f"The type of {content} does not support.") 71 | 72 | return kind.extension 73 | 74 | def sorted_by_suffix(self, file_list: List[str]) -> Dict[str, str]: 75 | sorted_res = {k: [] for k in self.file_map} 76 | 77 | for file_path in file_list: 78 | if file_path.suffix[1:] in self.file_map["txt"]: 79 | sorted_res["txt"].append(file_path) 80 | continue 81 | 82 | file_type = self.which_type(file_path) 83 | if file_type in self.file_map["office"]: 84 | sorted_res["office"].append(file_path) 85 | continue 86 | 87 | if file_type in self.file_map["pdf"]: 88 | sorted_res["pdf"].append(file_path) 89 | continue 90 | 91 | if file_type in self.file_map["image"]: 92 | sorted_res["image"].append(file_path) 93 | continue 94 | 95 | return sorted_res 96 | -------------------------------------------------------------------------------- /rapid_rag/file_loader/office_loader.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | from extract_office_content import ExtractOfficeContent 8 | 9 | from ..text_splitter.chinese_text_splitter import ChineseTextSplitter 10 | 11 | 12 | class OfficeLoader: 13 | def __init__(self) -> None: 14 | self.extracter = ExtractOfficeContent() 15 | self.splitter = ChineseTextSplitter() 16 | 17 | def __call__(self, office_path: Union[str, Path]) -> str: 18 | contents = self.extracter(office_path) 19 | split_contents = [self.splitter.split_text(v) for v in contents] 20 | return sum(split_contents, []) 21 | -------------------------------------------------------------------------------- /rapid_rag/file_loader/pdf_loader.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | from typing import List, Union 6 | 7 | from rapidocr_pdf import PDFExtracter 8 | 9 | from ..text_splitter.chinese_text_splitter import ChineseTextSplitter 10 | 11 | 12 | class PDFLoader: 13 | def __init__( 14 | self, 15 | ): 16 | self.extracter = PDFExtracter() 17 | self.splitter = ChineseTextSplitter(pdf=True) 18 | 19 | def __call__(self, pdf_path: Union[str, Path]) -> List[str]: 20 | contents = self.extracter(pdf_path) 21 | split_contents = [self.splitter.split_text(v[1]) for v in contents] 22 | return sum(split_contents, []) 23 | -------------------------------------------------------------------------------- /rapid_rag/file_loader/txt_loader.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | from typing import List, Union 6 | 7 | from ..text_splitter.chinese_text_splitter import ChineseTextSplitter 8 | from ..utils.utils import read_txt 9 | 10 | 11 | class TXTLoader: 12 | def __init__(self) -> None: 13 | self.splitter = ChineseTextSplitter() 14 | 15 | def __call__(self, txt_path: Union[str, Path]) -> List[str]: 16 | contents = read_txt(txt_path) 17 | split_contents = [self.splitter.split_text(v) for v in contents] 18 | return sum(split_contents, []) 19 | -------------------------------------------------------------------------------- /rapid_rag/llm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .baichuan_7b import BaiChuan7B 5 | from .chatglm2_6b import ChatGLM2_6B 6 | from .ernie_bot_turbo import ERNIEBot 7 | from .internlm_7b import InternLM_7B 8 | from .qwen7b_chat import Qwen7B_Chat 9 | from .openai import OpenAI 10 | from .ollama import Ollama 11 | 12 | __all__ = [ 13 | "BaiChuan7B", 14 | "ChatGLM2_6B", 15 | "ERNIEBot", 16 | "Qwen7B_Chat", 17 | "InternLM_7B", 18 | "OpenAI", 19 | "Ollama", 20 | ] 21 | -------------------------------------------------------------------------------- /rapid_rag/llm/baichuan_7b.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import json 5 | from typing import List, Optional 6 | 7 | import requests 8 | 9 | 10 | class BaiChuan7B: 11 | def __init__(self, api_url: str = None): 12 | self.api_url = api_url 13 | 14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): 15 | if not history: 16 | history = [] 17 | 18 | data = {"input_text": prompt} 19 | if kwargs: 20 | temperature = kwargs.get("temperature", 0.1) 21 | top_p = kwargs.get("top_p", 0.7) 22 | max_length = kwargs.get("max_length", 4096) 23 | 24 | data.update( 25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length} 26 | ) 27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60) 28 | try: 29 | rdata = req.json() 30 | if rdata["status"] == 200: 31 | return rdata["response"] 32 | return "网络出错" 33 | except Exception as e: 34 | return f"网络出错:{e}" 35 | 36 | 37 | if __name__ == "__main__": 38 | prompt = "你是谁?" 39 | history = [] 40 | t = BaiChuan7B() 41 | 42 | res = t(prompt, history) 43 | print(res) 44 | -------------------------------------------------------------------------------- /rapid_rag/llm/chatglm2_6b.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import json 5 | from typing import List, Optional 6 | 7 | import requests 8 | 9 | 10 | class ChatGLM2_6B: 11 | def __init__(self, api_url: str = None): 12 | self.api_url = api_url 13 | 14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): 15 | if not history: 16 | history = [] 17 | 18 | data = {"prompt": prompt, "history": history} 19 | if kwargs: 20 | temperature = kwargs.get("temperature", 0.1) 21 | top_p = kwargs.get("top_p", 0.7) 22 | max_length = kwargs.get("max_length", 4096) 23 | 24 | data.update( 25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length} 26 | ) 27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60) 28 | try: 29 | rdata = req.json() 30 | if rdata["status"] == 200: 31 | return rdata["response"] 32 | return "网络出错" 33 | except Exception as e: 34 | return f"网络出错:{e}" 35 | 36 | 37 | if __name__ == "__main__": 38 | prompt = "你是谁?" 39 | history = [] 40 | t = ChatGLM2_6B() 41 | 42 | res = t(prompt, history) 43 | print(res) 44 | -------------------------------------------------------------------------------- /rapid_rag/llm/ernie_bot_turbo.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from typing import List, Optional 5 | 6 | import erniebot 7 | 8 | 9 | class ERNIEBot: 10 | def __init__(self, api_type: str = None, access_token: str = None): 11 | self.api_type = api_type 12 | self.access_token = access_token 13 | 14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): 15 | if not history: 16 | history = [] 17 | 18 | response = erniebot.ChatCompletion.create( 19 | _config_={ 20 | "api_type": self.api_type, 21 | "access_token": self.access_token, 22 | }, 23 | model="ernie-bot", 24 | messages=[ 25 | { 26 | "role": "user", 27 | "content": prompt, 28 | } 29 | ], 30 | ) 31 | result = response.get("result", None) 32 | return result 33 | -------------------------------------------------------------------------------- /rapid_rag/llm/internlm_7b.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import json 5 | from typing import List, Optional 6 | 7 | import requests 8 | 9 | 10 | class InternLM_7B: 11 | def __init__(self, api_url: str = None): 12 | self.api_url = api_url 13 | 14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): 15 | if not history: 16 | history = [] 17 | 18 | data = {"prompt": prompt, "history": history} 19 | if kwargs: 20 | temperature = kwargs.get("temperature", 0.1) 21 | top_p = kwargs.get("top_p", 0.7) 22 | max_length = kwargs.get("max_length", 4096) 23 | 24 | data.update( 25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length} 26 | ) 27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60) 28 | try: 29 | rdata = req.json() 30 | if rdata["status"] == 200: 31 | return rdata["response"] 32 | return "Network error" 33 | except Exception as e: 34 | return f"Network error:{e}" 35 | -------------------------------------------------------------------------------- /rapid_rag/llm/llama2.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import json 5 | from typing import List, Optional 6 | 7 | import requests 8 | 9 | 10 | class Llama2_7BChat: 11 | def __init__(self, api_url: str = None): 12 | self.api_url = api_url 13 | 14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): 15 | if not history: 16 | history = [] 17 | 18 | data = {"prompt": prompt} 19 | if kwargs: 20 | temperature = kwargs.get("temperature", 0.1) 21 | top_p = kwargs.get("top_p", 0.7) 22 | max_length = kwargs.get("max_length", 4096) 23 | 24 | data.update( 25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length} 26 | ) 27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60) 28 | try: 29 | rdata = req.json() 30 | if rdata["status"] == 200: 31 | return rdata["response"] 32 | return "网络出错" 33 | except Exception as e: 34 | return f"网络出错:{e}" 35 | 36 | 37 | if __name__ == "__main__": 38 | prompt = "你是谁?" 39 | history = [] 40 | t = BaiChuan7B() 41 | 42 | res = t(prompt, history) 43 | print(res) 44 | -------------------------------------------------------------------------------- /rapid_rag/llm/ollama.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Leo Peng 3 | # @Contact: leo@promptcn.com 4 | from typing import List, Optional 5 | 6 | import ollama 7 | 8 | 9 | class Ollama: 10 | def __init__(self, host: str = "http://localhost:11434", model: str = None): 11 | self.host = host 12 | self.model = model 13 | self.client = ollama.Client(host=self.host) 14 | 15 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): 16 | if not history: 17 | history = [] 18 | 19 | response = self.client.chat( 20 | messages=[ 21 | { 22 | "role": "user", 23 | "content": prompt, 24 | } 25 | ], 26 | model=self.model, 27 | ) 28 | result = response["message"]["content"] 29 | return result 30 | -------------------------------------------------------------------------------- /rapid_rag/llm/openai.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Leo Peng 3 | # @Contact: leo@promptcn.com 4 | from typing import List, Optional 5 | 6 | import openai 7 | 8 | 9 | class OpenAI: 10 | def __init__( 11 | self, base_url: str = None, api_key: str = None, model: str = "gpt-4o" 12 | ): 13 | self.base_url = base_url 14 | self.api_key = api_key 15 | self.model = model 16 | self.client = openai.OpenAI(base_url=self.base_url, api_key=self.api_key) 17 | 18 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): 19 | if not history: 20 | history = [] 21 | 22 | response = self.client.chat.completions.create( 23 | messages=[ 24 | { 25 | "role": "user", 26 | "content": prompt, 27 | } 28 | ], 29 | model=self.model, 30 | ) 31 | result = response.choices[0].message.content 32 | return result 33 | -------------------------------------------------------------------------------- /rapid_rag/llm/qwen7b_chat.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import json 5 | from typing import List, Optional 6 | 7 | import requests 8 | 9 | 10 | class Qwen7B_Chat: 11 | def __init__(self, api_url: str = None): 12 | self.api_url = api_url 13 | 14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs): 15 | if not history: 16 | history = [] 17 | 18 | data = {"prompt": prompt, "history": history} 19 | if kwargs: 20 | temperature = kwargs.get("temperature", 0.1) 21 | top_p = kwargs.get("top_p", 0.7) 22 | max_length = kwargs.get("max_length", 4096) 23 | 24 | data.update( 25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length} 26 | ) 27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60) 28 | try: 29 | rdata = req.json() 30 | if rdata["status"] == 200: 31 | return rdata["response"] 32 | return "网络出错" 33 | except Exception as e: 34 | return f"网络出错:{e}" 35 | 36 | 37 | if __name__ == "__main__": 38 | prompt = "你是谁?" 39 | history = [] 40 | t = Qwen7B() 41 | 42 | res = t(prompt, history) 43 | print(res) 44 | -------------------------------------------------------------------------------- /rapid_rag/text_splitter/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | -------------------------------------------------------------------------------- /rapid_rag/text_splitter/chinese_text_splitter.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | # Modified from https://github.com/chatchat-space/langchain-ChatGLM/blob/master/configs/model_config.py 5 | import re 6 | from pathlib import Path 7 | from typing import List 8 | 9 | from ..utils.utils import read_yaml 10 | 11 | # knowledge_qa_llm 12 | root_dir = Path(__file__).resolve().parent.parent 13 | config_path = root_dir / "config.yaml" 14 | config = read_yaml(config_path) 15 | 16 | 17 | class ChineseTextSplitter: 18 | def __init__( 19 | self, 20 | pdf: bool = False, 21 | sentence_size: int = config.get("SENTENCE_SIZE"), 22 | ): 23 | self.pdf = pdf 24 | self.sentence_size = sentence_size 25 | 26 | def split_text1(self, text: str) -> List[str]: 27 | if self.pdf: 28 | text = re.sub(r"\n{3,}", "\n", text) 29 | text = re.sub("\s", " ", text) 30 | text = text.replace("\n\n", "") 31 | sent_sep_pattern = re.compile( 32 | '([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))' 33 | ) # del :; 34 | sent_list = [] 35 | for ele in sent_sep_pattern.split(text): 36 | ele = ele.strip() 37 | if sent_sep_pattern.match(ele) and sent_list: 38 | sent_list[-1] += ele 39 | elif ele: 40 | sent_list.append(ele) 41 | return sent_list 42 | 43 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 44 | if self.pdf: 45 | text = re.sub(r"\n{3,}", r"\n", text) 46 | text = re.sub("\s", " ", text) 47 | text = re.sub("\n\n", "", text) 48 | 49 | text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) # 单字符断句符 50 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 51 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 52 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text) 53 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 54 | text = text.rstrip() # 段尾如果有多余的\n就去掉它 55 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 56 | ls = [i for i in text.split("\n") if i] 57 | for ele in ls: 58 | if len(ele) > self.sentence_size: 59 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele) 60 | ele1_ls = ele1.split("\n") 61 | for ele_ele1 in ele1_ls: 62 | if len(ele_ele1) > self.sentence_size: 63 | ele_ele2 = re.sub( 64 | r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1 65 | ) 66 | ele2_ls = ele_ele2.split("\n") 67 | for ele_ele2 in ele2_ls: 68 | if len(ele_ele2) > self.sentence_size: 69 | ele_ele3 = re.sub( 70 | '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2 71 | ) 72 | ele2_id = ele2_ls.index(ele_ele2) 73 | ele2_ls = ( 74 | ele2_ls[:ele2_id] 75 | + [i for i in ele_ele3.split("\n") if i] 76 | + ele2_ls[ele2_id + 1 :] 77 | ) 78 | ele_id = ele1_ls.index(ele_ele1) 79 | ele1_ls = ( 80 | ele1_ls[:ele_id] 81 | + [i for i in ele2_ls if i] 82 | + ele1_ls[ele_id + 1 :] 83 | ) 84 | 85 | id = ls.index(ele) 86 | ls = ls[:id] + [i.strip() for i in ele1_ls if i] + ls[id + 1 :] 87 | return ls 88 | -------------------------------------------------------------------------------- /rapid_rag/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .logger import logger 5 | from .utils import get_timestamp, make_prompt, mkdir, read_yaml 6 | -------------------------------------------------------------------------------- /rapid_rag/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import functools 5 | import sys 6 | from pathlib import Path 7 | 8 | from loguru import logger 9 | 10 | 11 | @functools.lru_cache() 12 | def get_logger(save_dir: str = "."): 13 | loguru_format = ( 14 | "{time:YYYY-MM-DD HH:mm:ss} | " 15 | "{level: <8} | " 16 | "{name}:{line} - {message}" 17 | ) 18 | 19 | logger.remove() 20 | logger.add( 21 | sys.stderr, 22 | format=loguru_format, 23 | level="INFO", 24 | enqueue=True, 25 | ) 26 | save_file = Path(save_dir) / "{time:YYYY-MM-DD-HH-mm-ss}.log" 27 | logger.add(save_file, rotation=None, retention="5 days") 28 | return logger 29 | 30 | 31 | log_dir = Path(__file__).resolve().parent.parent.parent / "log" 32 | logger = get_logger(str(log_dir)) 33 | -------------------------------------------------------------------------------- /rapid_rag/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from datetime import datetime 5 | from pathlib import Path 6 | from string import Template 7 | from typing import List, Union 8 | 9 | import yaml 10 | 11 | 12 | def make_prompt(query: str, context: str = None, custom_prompt: str = None) -> str: 13 | if context is None: 14 | return query 15 | 16 | if "$query" not in custom_prompt or "$context" not in custom_prompt: 17 | raise ValueError("prompt中必须含有$query和$context两个值") 18 | 19 | msg_template = Template(custom_prompt) 20 | message = msg_template.substitute(query=query, context=context) 21 | return message 22 | 23 | 24 | def read_yaml(yaml_path: Union[str, Path]): 25 | with open(str(yaml_path), "rb") as f: 26 | data = yaml.load(f, Loader=yaml.Loader) 27 | return data 28 | 29 | 30 | def mkdir(dir_path): 31 | Path(dir_path).mkdir(parents=True, exist_ok=True) 32 | 33 | 34 | def get_timestamp(): 35 | return datetime.strftime(datetime.now(), "%Y-%m-%d") 36 | 37 | 38 | def read_txt(txt_path: Union[Path, str]) -> List[str]: 39 | if not isinstance(txt_path, str): 40 | txt_path = str(txt_path) 41 | 42 | with open(txt_path, "r", encoding="utf-8") as f: 43 | data = list(map(lambda x: x.rstrip("\n"), f)) 44 | return data 45 | -------------------------------------------------------------------------------- /rapid_rag/vector_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .sqlite_version import DBUtils 5 | -------------------------------------------------------------------------------- /rapid_rag/vector_utils/sqlite_version.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import io 5 | import sqlite3 6 | import time 7 | from typing import Dict, List, Optional 8 | 9 | import faiss 10 | import numpy as np 11 | 12 | from ..utils.logger import logger 13 | 14 | 15 | def adapt_array(arr): 16 | out = io.BytesIO() 17 | np.save(out, arr) 18 | out.seek(0) 19 | return sqlite3.Binary(out.read()) 20 | 21 | 22 | def convert_array(text): 23 | out = io.BytesIO(text) 24 | out.seek(0) 25 | return np.load(out, allow_pickle=True) 26 | 27 | 28 | sqlite3.register_adapter(np.ndarray, adapt_array) 29 | sqlite3.register_converter("array", convert_array) 30 | 31 | 32 | class DBUtils: 33 | def __init__( 34 | self, 35 | db_path: str, 36 | ) -> None: 37 | self.db_path = db_path 38 | 39 | self.table_name = "embedding_texts" 40 | self.vector_nums = 0 41 | 42 | self.max_prompt_length = 4096 43 | 44 | self.connect_db() 45 | 46 | def connect_db( 47 | self, 48 | ): 49 | con = sqlite3.connect(self.db_path, detect_types=sqlite3.PARSE_DECLTYPES) 50 | cur = con.cursor() 51 | cur.execute( 52 | f"create table if not exists {self.table_name} (id integer primary key autoincrement, file_name TEXT, embeddings array UNIQUE, texts TEXT, uids TEXT)" 53 | ) 54 | return cur, con 55 | 56 | def load_vectors(self, uid: Optional[str] = None): 57 | cur, _ = self.connect_db() 58 | 59 | search_sql = f"select file_name, embeddings, texts from {self.table_name}" 60 | if uid: 61 | search_sql = f'select file_name, embeddings, texts from {self.table_name} where uids="{uid}"' 62 | 63 | cur.execute(search_sql) 64 | all_vectors = cur.fetchall() 65 | 66 | self.file_names = np.array([v[0] for v in all_vectors]) 67 | all_embeddings = np.array([v[1] for v in all_vectors]) 68 | self.all_texts = np.array([v[2] for v in all_vectors]) 69 | 70 | self.search_index = faiss.IndexFlatL2(all_embeddings.shape[1]) 71 | self.search_index.add(all_embeddings) 72 | self.vector_nums = len(all_vectors) 73 | 74 | def count_vectors( 75 | self, 76 | ): 77 | cur, _ = self.connect_db() 78 | 79 | cur.execute(f"select file_name from {self.table_name}") 80 | all_vectors = cur.fetchall() 81 | return len(all_vectors) 82 | 83 | def search_local( 84 | self, 85 | embedding_query: np.ndarray, 86 | top_k: int = 5, 87 | uid: Optional[str] = None, 88 | ) -> Optional[Dict[str, List[str]]]: 89 | s = time.perf_counter() 90 | 91 | cur_vector_nums = self.count_vectors() 92 | if cur_vector_nums == 0: 93 | return None, 0 94 | 95 | if cur_vector_nums != self.vector_nums: 96 | self.load_vectors(uid) 97 | 98 | # cur_vector_nums 小于 top_k 时,返回 cur_vector_nums 个结果 99 | _, I = self.search_index.search(embedding_query, min(top_k, cur_vector_nums)) 100 | top_index = I.squeeze().tolist() 101 | 102 | # 处理只有一个结果的情况 103 | if isinstance(top_index, int): 104 | top_index = [top_index] 105 | 106 | search_contents = self.all_texts[top_index] 107 | file_names = [self.file_names[idx] for idx in top_index] 108 | dup_file_names = list(set(file_names)) 109 | dup_file_names.sort(key=file_names.index) 110 | 111 | search_res = {v: [] for v in dup_file_names} 112 | for file_name, content in zip(file_names, search_contents): 113 | search_res[file_name].append(content) 114 | 115 | elapse = time.perf_counter() - s 116 | return search_res, elapse 117 | 118 | def insert( 119 | self, file_name: str, embeddings: np.ndarray, texts: List[str], uid: str 120 | ): 121 | cur, con = self.connect_db() 122 | 123 | file_names = [file_name] * len(embeddings) 124 | uids = [uid] * len(embeddings) 125 | 126 | t1 = time.perf_counter() 127 | insert_sql = f"insert or ignore into {self.table_name} (file_name, embeddings, texts, uids) values (?, ?, ?, ?)" 128 | cur.executemany(insert_sql, list(zip(file_names, embeddings, texts, uids))) 129 | elapse = time.perf_counter() - t1 130 | logger.info( 131 | f"Insert {len(embeddings)} data, total is {len(embeddings)}, cost: {elapse:4f}s" 132 | ) 133 | con.commit() 134 | 135 | def get_files(self, uid: Optional[str] = None): 136 | cur, _ = self.connect_db() 137 | 138 | if not uid: 139 | return None 140 | 141 | search_sql = ( 142 | f'select distinct file_name from {self.table_name} where uids="{uid}"' 143 | ) 144 | cur.execute(search_sql) 145 | search_res = cur.fetchall() 146 | search_res = [v[0] for v in search_res] 147 | return search_res 148 | 149 | def clear_db( 150 | self, 151 | ): 152 | cur, con = self.connect_db() 153 | 154 | run_sql = f"delete from {self.table_name}" 155 | cur.execute(run_sql) 156 | 157 | con.commit() 158 | self.connect_db() 159 | 160 | def __enter__(self): 161 | return self 162 | 163 | def __exit__(self, *a): 164 | self.cur.close() 165 | self.con.close() 166 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.21.6 2 | streamlit>=1.25.0 3 | transformers>=4.27.0.dev0,<4.47.0 4 | faiss-cpu 5 | filetype 6 | extract-office-content>=0.0.6 7 | sentence_transformers 8 | rapidocr_onnxruntime 9 | rapidocr_pdf>=0.0.5 10 | loguru 11 | erniebot 12 | openai>=1.58.1 13 | ollama>=0.4.5 14 | ragas>=0.2.9 15 | -------------------------------------------------------------------------------- /tests/demo_store_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from extract_office_content import ExtractWord 5 | 6 | from vector_utils import DBUtils, EncodeText 7 | 8 | # 读取文档 9 | word_extract = ExtractWord() 10 | 11 | file_path = "tests/test_files/office/word_example.docx" 12 | text = word_extract(file_path) 13 | sentences = [v.strip() for v in text if v.strip()] 14 | 15 | # 提取特征 16 | model = EncodeText() 17 | embeddings = model(sentences) 18 | 19 | db_path = "db/Vector.db" 20 | db_tools = DBUtils(db_path) 21 | 22 | db_tools.insert(file_path, embeddings, sentences) 23 | 24 | print("ok") 25 | -------------------------------------------------------------------------------- /tests/test_bge.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from sentence_transformers import SentenceTransformer 5 | 6 | queries = ["手机开不了机怎么办?"] 7 | passages = ["样例段落-1", "样例段落-2"] 8 | instruction = "为这个句子生成表示以用于检索相关文章:" 9 | model = SentenceTransformer("assets/models/bge-small-zh") 10 | q_embeddings = model.encode( 11 | [instruction + q for q in queries], normalize_embeddings=True 12 | ) 13 | p_embeddings = model.encode(passages, normalize_embeddings=True) 14 | scores = q_embeddings @ p_embeddings.T 15 | 16 | print(scores) 17 | -------------------------------------------------------------------------------- /tests/test_chatglm2_6b.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import sys 5 | from pathlib import Path 6 | 7 | cur_dir = Path(__file__).resolve().parent 8 | root_dir = cur_dir.parent 9 | sys.path.append(str(root_dir)) 10 | 11 | from rapid_rag.llm import ChatGLM2_6B 12 | from rapid_rag.utils import read_yaml 13 | 14 | config_path = root_dir / "knowledge_qa_llm" / "config.yaml" 15 | config = read_yaml(config_path) 16 | 17 | llm_model = ChatGLM2_6B(config.get("LLM_API")["ChatGLM2_6B"]) 18 | 19 | 20 | def test_normal_input(): 21 | prompt = "你是谁?" 22 | history = [] 23 | 24 | res = llm_model(prompt, history) 25 | 26 | assert ( 27 | res 28 | == "我是一个名为 ChatGLM2-6B 的人工智能助手,是基于清华大学 KEG 实验室和智谱 AI 公司于 2023 年共同训练的语言模型开发的。我的任务是针对用户的问题和要求提供适当的答复和支持。" 29 | ) 30 | -------------------------------------------------------------------------------- /tests/test_file_loader.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from rapid_rag.file_loader.main import FileLoader 5 | 6 | loader = FileLoader() 7 | 8 | file_dir = "tests/test_files" 9 | 10 | res = loader(file_dir) 11 | print("ok") 12 | -------------------------------------------------------------------------------- /tests/test_files/office/excel_with_image.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/office/excel_with_image.xlsx -------------------------------------------------------------------------------- /tests/test_files/office/ppt_example.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/office/ppt_example.pptx -------------------------------------------------------------------------------- /tests/test_files/office/word_example.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/office/word_example.docx -------------------------------------------------------------------------------- /tests/test_files/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/test.jpg -------------------------------------------------------------------------------- /tests/test_files/test.md: -------------------------------------------------------------------------------- 1 | 我与父亲不相见已二年余了,我最不能忘记的是他的背影。 2 | 3 | 那年冬天,祖母死了,父亲的差使也交卸了,正是祸不单行的日子。我从北京到徐州,打算跟着父亲奔丧回家。到徐州见着父亲,看见满院狼藉的东西,又想起祖母,不禁簌簌地流下眼泪。父亲说:“事已如此,不必难过,好在天无绝人之路! 4 | -------------------------------------------------------------------------------- /tests/test_files/test.txt: -------------------------------------------------------------------------------- 1 | 我与父亲不相见已二年余了,我最不能忘记的是他的背影。 2 | 3 | 那年冬天,祖母死了,父亲的差使也交卸了,正是祸不单行的日子。我从北京到徐州,打算跟着父亲奔丧回家。到徐州见着父亲,看见满院狼藉的东西,又想起祖母,不禁簌簌地流下眼泪。父亲说:“事已如此,不必难过,好在天无绝人之路! 4 | -------------------------------------------------------------------------------- /tests/test_files/word_example.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/word_example.pdf -------------------------------------------------------------------------------- /tests/test_files/长安三万里.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/长安三万里.pdf -------------------------------------------------------------------------------- /tests/test_llama2_7b_chat.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from rapid_rag.llm.llama2 import Llama2_7BChat 5 | 6 | api = "" 7 | llm = Llama2_7BChat(api_url=api) 8 | 9 | 10 | prompt = "你是谁?" 11 | 12 | response = llm(prompt) 13 | print(response) 14 | -------------------------------------------------------------------------------- /tests/test_m3e.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import sys 5 | from pathlib import Path 6 | 7 | cur_dir = Path(__file__).resolve().parent 8 | root_dir = cur_dir.parent 9 | sys.path.append(str(root_dir)) 10 | 11 | from rapid_rag.utils import read_yaml 12 | from rapid_rag.vector_utils import EncodeText 13 | 14 | config_path = root_dir / "config.yaml" 15 | config = read_yaml(config_path) 16 | model = EncodeText(config["encoder_model_path"]) 17 | 18 | 19 | def test_normal_input(): 20 | sentences = [ 21 | "* Moka 此文本嵌入模型由 MokaAI 训练并开源,训练脚本使用 uniem", 22 | "* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练", 23 | "* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算,异质文本检索等功能,未来还会支持代码检索,ALL in one", 24 | ] 25 | 26 | embeddings = model(sentences) 27 | assert embeddings.shape == (3, 512) 28 | -------------------------------------------------------------------------------- /tests/test_office_loader.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import sys 5 | from pathlib import Path 6 | 7 | cur_dir = Path(__file__).resolve().parent 8 | root_dir = cur_dir.parent 9 | sys.path.append(str(root_dir)) 10 | 11 | import pytest 12 | 13 | from rapid_rag.file_loader.office_loader import ExtractOfficeLoader 14 | 15 | extracter_office = ExtractOfficeLoader() 16 | 17 | 18 | test_file_dir = cur_dir / "test_files" / "office" 19 | 20 | 21 | @pytest.mark.parametrize( 22 | "file_path, gt1, gt2", 23 | [ 24 | ("word_example.docx", 221, "我与父亲不"), 25 | ("ppt_example.pptx", 350, "| 0 "), 26 | ("excel_with_image.xlsx", 361, "| "), 27 | ], 28 | ) 29 | def test_extract(file_path, gt1, gt2): 30 | file_path = test_file_dir / file_path 31 | extract_res = extracter_office([file_path]) 32 | 33 | assert len(extract_res[0][1][0]) == gt1 34 | assert extract_res[0][1][0][:5] == gt2 35 | -------------------------------------------------------------------------------- /tests/test_qwen.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from rapid_rag.llm.qwen7b_chat import Qwen7B_Chat 5 | 6 | api = "" 7 | llm = Qwen7B_Chat(api_url=api) 8 | 9 | 10 | prompt = "杭州有哪些景点?" 11 | 12 | response = llm(prompt, history=None) 13 | print(response) 14 | -------------------------------------------------------------------------------- /tests/test_search.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | 6 | cur_dir = Path(__file__).resolve().parent 7 | 8 | from rapid_rag.utils import read_yaml 9 | from rapid_rag.vector_utils import DBUtils, EncodeText 10 | 11 | config_path = Path("knowledge_qa_llm") / "config.yaml" 12 | config = read_yaml(config_path) 13 | 14 | model = EncodeText(config["encoder_model_path"]) 15 | db = DBUtils(config["vector_db_path"]) 16 | 17 | query = "蔡徐坤" 18 | embedding = model(query) 19 | search_res = db.search_local(embedding_query=embedding, top_k=3) 20 | 21 | print(search_res) 22 | print("ok") 23 | -------------------------------------------------------------------------------- /tests/test_sql_insert.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from rapid_rag.file_loader import FileLoader 5 | from rapid_rag.utils import read_yaml 6 | from rapid_rag.vector_utils import DBUtils, EncodeText 7 | 8 | config = read_yaml("knowledge_qa_llm/config.yaml") 9 | 10 | extract = FileLoader() 11 | 12 | # 解析文档 13 | file_path = "长安三万里.pdf" 14 | text = extract(file_path) 15 | sentences = text[file_path][0] 16 | 17 | # 提取特征 18 | embedding_model = EncodeText(config.get("encoder_model_path")) 19 | embeddings = embedding_model(sentences) 20 | 21 | # 插入数据到数据库中 22 | db_tools = DBUtils(config.get("vector_db_path")) 23 | db_tools.insert(file_path, embeddings, sentences) 24 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import importlib 5 | import shutil 6 | import time 7 | import uuid 8 | from pathlib import Path 9 | from typing import Dict 10 | 11 | import numpy as np 12 | import streamlit as st 13 | 14 | from rapid_rag.encoder import EncodeText, ErnieEncodeText 15 | from rapid_rag.file_loader import FileLoader 16 | from rapid_rag.utils import get_timestamp, logger, make_prompt, mkdir, read_yaml 17 | from rapid_rag.vector_utils import DBUtils 18 | 19 | config = read_yaml("knowledge_qa_llm/config.yaml") 20 | 21 | st.set_page_config( 22 | page_title=config.get("title"), 23 | page_icon=":robot:", 24 | ) 25 | 26 | 27 | def init_ui_parameters(): 28 | st.session_state["params"] = {} 29 | param = config.get("Parameter") 30 | 31 | st.sidebar.markdown("### 🛶 参数设置") 32 | 33 | param_max_length = param.get("max_length") 34 | max_length = st.sidebar.slider( 35 | "max_length", 36 | min_value=param_max_length.get("min_value"), 37 | max_value=param_max_length.get("max_value"), 38 | value=param_max_length.get("default"), 39 | step=param_max_length.get("step"), 40 | help=param_max_length.get("tip"), 41 | ) 42 | st.session_state["params"]["max_length"] = max_length 43 | 44 | param_top = param.get("top_p") 45 | top_p = st.sidebar.slider( 46 | "top_p", 47 | min_value=param_top.get("min_value"), 48 | max_value=param_top.get("max_value"), 49 | value=param_top.get("default"), 50 | step=param_top.get("step"), 51 | help=param_top.get("tip"), 52 | ) 53 | st.session_state["params"]["top_p"] = top_p 54 | 55 | param_temp = param.get("temperature") 56 | temperature = st.sidebar.slider( 57 | "temperature", 58 | min_value=param_temp.get("min_value"), 59 | max_value=param_temp.get("max_value"), 60 | value=param_temp.get("default"), 61 | step=param_temp.get("stemp"), 62 | help=param_temp.get("tip"), 63 | ) 64 | st.session_state["params"]["temperature"] = temperature 65 | 66 | 67 | def init_ui_db(): 68 | st.sidebar.markdown("### 🧻 知识库") 69 | uploaded_files = st.sidebar.file_uploader( 70 | "default", 71 | accept_multiple_files=True, 72 | label_visibility="hidden", 73 | help="支持多个文件的选取", 74 | ) 75 | 76 | upload_dir = config.get("upload_dir") 77 | btn_upload = st.sidebar.button("上传文档并加载") 78 | if btn_upload: 79 | time_stamp = get_timestamp() 80 | doc_dir = Path(upload_dir) / time_stamp 81 | 82 | tips("正在上传文件到平台中...", icon="⏳") 83 | for file_data in uploaded_files: 84 | bytes_data = file_data.getvalue() 85 | 86 | mkdir(doc_dir) 87 | save_path = doc_dir / file_data.name 88 | with open(save_path, "wb") as f: 89 | f.write(bytes_data) 90 | tips("上传完毕!") 91 | 92 | with st.spinner(f"正在从{doc_dir}提取内容...."): 93 | all_doc_contents = file_loader(doc_dir) 94 | 95 | pro_text = "提取语义向量..." 96 | batch_size = config.get("encoder_batch_size", 16) 97 | uid = str(uuid.uuid1()) 98 | st.session_state["connect_id"] = uid 99 | for file_path, one_doc_contents in all_doc_contents.items(): 100 | my_bar = st.sidebar.progress(0, text=pro_text) 101 | content_nums = len(one_doc_contents) 102 | all_embeddings = [] 103 | for i in range(0, content_nums, batch_size): 104 | start_idx = i 105 | end_idx = start_idx + batch_size 106 | end_idx = content_nums if end_idx > content_nums else end_idx 107 | 108 | cur_contents = one_doc_contents[start_idx:end_idx] 109 | if not cur_contents: 110 | continue 111 | 112 | embeddings = embedding_extract(cur_contents) 113 | if embeddings is None or embeddings.size == 0: 114 | continue 115 | 116 | all_embeddings.append(embeddings) 117 | my_bar.progress( 118 | end_idx / content_nums, 119 | f"Extract {file_path} datas: [{end_idx}/{content_nums}]", 120 | ) 121 | my_bar.empty() 122 | 123 | if all_embeddings: 124 | all_embeddings = np.vstack(all_embeddings) 125 | db_tools.insert(file_path, all_embeddings, one_doc_contents, uid) 126 | else: 127 | tips(f"从{file_path}提取向量为空。") 128 | 129 | shutil.rmtree(doc_dir.resolve()) 130 | tips("现在可以提问问题了哈!") 131 | 132 | clear_db_btn = st.sidebar.button("清空知识库") 133 | if clear_db_btn: 134 | db_tools.clear_db() 135 | tips("知识库已经被清空!") 136 | 137 | if "connect_id" in st.session_state: 138 | had_files = db_tools.get_files(uid=st.session_state.connect_id) 139 | else: 140 | had_files = db_tools.get_files() 141 | 142 | st.session_state.had_file_nums = len(had_files) if had_files else 0 143 | if had_files: 144 | st.sidebar.markdown("已有文档:") 145 | st.sidebar.markdown("\n".join([f" - {v}" for v in had_files])) 146 | 147 | 148 | @st.cache_resource 149 | def init_encoder(encoder_name: str, **kwargs): 150 | if "ERNIEBot" in encoder_name: 151 | return ErnieEncodeText(**kwargs) 152 | return EncodeText(**kwargs) 153 | 154 | 155 | def predict( 156 | text, 157 | search_res, 158 | model, 159 | custom_prompt=None, 160 | ): 161 | for file, content in search_res.items(): 162 | content = "\n".join(content) 163 | one_context = f"**从《{file}》** 检索到相关内容: \n{content}" 164 | bot_print(one_context, avatar="📄") 165 | 166 | logger.info(f"Context:\n{one_context}\n") 167 | 168 | context = "\n".join(sum(search_res.values(), [])) 169 | response, elapse = get_model_response(text, context, custom_prompt, model) 170 | 171 | print_res = f"**推理耗时:{elapse:.5f}s**" 172 | bot_print(print_res, avatar="📄") 173 | bot_print(response) 174 | 175 | 176 | def predict_only_model(text, model): 177 | params_dict = st.session_state["params"] 178 | response = model(text, history=None, **params_dict) 179 | bot_print(response) 180 | 181 | 182 | def bot_print(content, avatar: str = "🤖"): 183 | with st.chat_message("assistant", avatar=avatar): 184 | message_placeholder = st.empty() 185 | full_response = "" 186 | for chunk in content.split(): 187 | full_response += chunk + " " 188 | time.sleep(0.05) 189 | message_placeholder.markdown(full_response + "▌") 190 | message_placeholder.markdown(full_response) 191 | 192 | 193 | def get_model_response(text, context, custom_prompt, model): 194 | params_dict = st.session_state["params"] 195 | 196 | s_model = time.perf_counter() 197 | prompt_msg = make_prompt(text, context, custom_prompt) 198 | logger.info(f"Final prompt: \n{prompt_msg}\n") 199 | 200 | response = model(prompt_msg, history=None, **params_dict) 201 | elapse = time.perf_counter() - s_model 202 | 203 | logger.info(f"Reponse of LLM: \n{response}\n") 204 | if not response: 205 | response = "抱歉,我并不能正确回答该问题。" 206 | return response, elapse 207 | 208 | 209 | def tips(txt: str, wait_time: int = 2, icon: str = "🎉"): 210 | st.toast(txt, icon=icon) 211 | time.sleep(wait_time) 212 | 213 | 214 | if __name__ == "__main__": 215 | title = config.get("title") 216 | version = config.get("version", "0.0.1") 217 | st.markdown( 218 | f"

{title} v{version}


", 219 | unsafe_allow_html=True, 220 | ) 221 | 222 | init_ui_parameters() 223 | 224 | file_loader = FileLoader() 225 | 226 | db_path = config.get("vector_db_path") 227 | db_tools = DBUtils(db_path) 228 | 229 | llm_module = importlib.import_module("knowledge_qa_llm.llm") 230 | llm_params: Dict[str, Dict] = config.get("LLM_API") 231 | 232 | menu_col1, menu_col2, menu_col3 = st.columns([1, 1, 1]) 233 | select_model = menu_col1.selectbox("🎨LLM:", llm_params.keys()) 234 | if "ERNIEBot" in select_model: 235 | with st.expander("LLM ErnieBot", expanded=True): 236 | opt_col1, opt_col2 = st.columns([1, 1]) 237 | api_type = opt_col1.selectbox( 238 | "API Type(必选)", 239 | options=["aistudio", "qianfan", "yinian"], 240 | help="提供对话能力的后端平台", 241 | ) 242 | access_token = opt_col2.text_input( 243 | "Access Token(必填)  [如何获得?](https://github.com/PaddlePaddle/ERNIE-Bot-SDK/blob/develop/docs/authentication.md)", 244 | "", 245 | help="用于访问后端平台的access token(参考使用说明获取),如果设置了AK、SK则无需设置此参数", 246 | ) 247 | llm_params[select_model]["api_type"] = api_type 248 | 249 | if access_token: 250 | llm_params[select_model]["access_token"] = access_token 251 | 252 | MODEL_OPTIONS = { 253 | name: getattr(llm_module, name)(**params) for name, params in llm_params.items() 254 | } 255 | 256 | encoder_params = config.get("Encoder") 257 | select_encoder = menu_col2.selectbox("🧬提取向量模型:", encoder_params.keys()) 258 | if "ERNIEBot" in select_encoder: 259 | with st.expander("提取语义向量 ErnieBot", expanded=True): 260 | opt_col1, opt_col2 = st.columns([1, 1]) 261 | extract_api_type = opt_col1.selectbox( 262 | "API Type(必选)", 263 | options=["aistudio", "qianfan", "yinian"], 264 | help="提供对话能力的后端平台", 265 | key="Extract_type", 266 | ) 267 | encoder_params[select_encoder]["api_type"] = extract_api_type 268 | 269 | extract_access_token = opt_col2.text_input( 270 | "Access Token(必填)  [如何获得?](https://github.com/PaddlePaddle/ERNIE-Bot-SDK/blob/develop/docs/authentication.md)", 271 | "", 272 | help="用于访问后端平台的access token(参考使用说明获取),如果设置了AK、SK则无需设置此参数", 273 | key="Extract_token", 274 | ) 275 | if extract_access_token: 276 | encoder_params[select_encoder]["access_token"] = extract_access_token 277 | 278 | embedding_extract = init_encoder(select_encoder, **encoder_params[select_encoder]) 279 | 280 | TOP_OPTIONS = [5, 10, 15] 281 | search_top = menu_col3.selectbox("🔍搜索 Top_K:", TOP_OPTIONS) 282 | 283 | init_ui_db() 284 | 285 | with st.expander("💡Prompt", expanded=False): 286 | text_area = st.empty() 287 | input_prompt = text_area.text_area( 288 | label="Input", 289 | max_chars=500, 290 | height=200, 291 | label_visibility="hidden", 292 | value=config.get("DEFAULT_PROMPT"), 293 | key="input_prompt", 294 | ) 295 | 296 | input_txt = st.chat_input("问点啥吧!") 297 | if input_txt: 298 | with st.chat_message("user", avatar="😀"): 299 | st.markdown(input_txt) 300 | 301 | llm = MODEL_OPTIONS[select_model] 302 | 303 | if not input_prompt: 304 | input_prompt = config.get("DEFAULT_PROMPT") 305 | 306 | query_embedding = embedding_extract(input_txt) 307 | with st.spinner("正在搜索相关文档..."): 308 | uid = st.session_state.get("connect_id", None) 309 | search_res, search_elapse = db_tools.search_local( 310 | query_embedding, top_k=search_top, uid=uid 311 | ) 312 | 313 | if search_res is None: 314 | bot_print("从知识库中抽取结果为空,直接采用LLM的本身能力回答。", avatar="📄") 315 | predict_only_model(input_txt, llm) 316 | else: 317 | logger.info(f"使用 {type(llm).__name__}") 318 | 319 | res_cxt = f"**Top{search_top}\n(得分从高到低,耗时:{search_elapse:.5f}s):** \n" 320 | bot_print(res_cxt, avatar="📄") 321 | 322 | predict( 323 | input_txt, 324 | search_res, 325 | llm, 326 | input_prompt, 327 | ) 328 | --------------------------------------------------------------------------------