├── .gitattributes
├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ └── config.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── ROADMAP.md
├── assets
├── db
│ └── .gitkeep
└── models
│ └── .gitkeep
├── cli.py
├── docs
└── README_zh.md
├── rapid_rag
├── __init__.py
├── config.yaml
├── encoder
│ ├── __init__.py
│ ├── erniebot.py
│ └── sentence_transformer.py
├── file_loader
│ ├── __init__.py
│ ├── image_loader.py
│ ├── main.py
│ ├── office_loader.py
│ ├── pdf_loader.py
│ └── txt_loader.py
├── llm
│ ├── __init__.py
│ ├── baichuan_7b.py
│ ├── chatglm2_6b.py
│ ├── ernie_bot_turbo.py
│ ├── internlm_7b.py
│ ├── llama2.py
│ ├── ollama.py
│ ├── openai.py
│ └── qwen7b_chat.py
├── text_splitter
│ ├── __init__.py
│ └── chinese_text_splitter.py
├── utils
│ ├── __init__.py
│ ├── logger.py
│ └── utils.py
└── vector_utils
│ ├── __init__.py
│ └── sqlite_version.py
├── requirements.txt
├── tests
├── demo_store_embedding.py
├── test_bge.py
├── test_chatglm2_6b.py
├── test_file_loader.py
├── test_files
│ ├── office
│ │ ├── excel_with_image.xlsx
│ │ ├── ppt_example.pptx
│ │ └── word_example.docx
│ ├── test.jpg
│ ├── test.md
│ ├── test.txt
│ ├── word_example.pdf
│ └── 长安三万里.pdf
├── test_llama2_7b_chat.py
├── test_m3e.py
├── test_office_loader.py
├── test_qwen.py
├── test_search.py
└── test_sql_insert.py
└── webui.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Set the default behavior, in case people don't have core.autocrlf set.
2 | * text=auto
3 |
4 | # Explicitly declare text files you want to always be normalized and converted
5 | # to native line endings on checkout.
6 | *.c text
7 | *.h text
8 | *.py text
9 | *.md text
10 | *.js text
11 | *.cpp text
12 |
13 | # Declare files that will always have CRLF line endings on checkout.
14 | *.sln text eol=crlf
15 |
16 | # Denote all files that are truly binary and should not be modified.
17 | *.png binary
18 | *.jpg binary
19 | *.pdf binary
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: "🐛 Bug Report"
2 | description: Create a report to help us improve Lotus Docs
3 | body:
4 | - type: markdown
5 | attributes:
6 | value: |
7 | Thanks for taking the time to fill out this bug report!
8 |
9 | Please note that this tracker is only for bugs. Do not use the issue tracker for help or feature requests.
10 |
11 | [Our docs](https://rapidai.github.io/RapidRAG/) are a great place for most answers, but if you can't find your answer there, you can ask in [community discussion forum](https://github.com/RapidAI/RapidRAG/discussions/categories/q-a).
12 |
13 | Have a feature request? Please search the ideas [on our forum](https://github.com/RapidAI/RapidRAG/discussions/categories/feature-requests) to make sure that the feature has not yet been requested. If you cannot find what you had in mind, please [submit your feature request here](https://github.com/colinwilson/lotusdocs/discussions/new?category=feature-requests).
14 |
15 | Want to show off your Lotus Docs themed website? Post a link, screenshot (optional), and details in [our Show & tell forum](https://github.com/RapidAI/RapidRAG/discussions/categories/show-and-tell).
16 |
17 | **Thanks!**
18 | - type: checkboxes
19 | attributes:
20 | label: Past Issues Searched
21 | options:
22 | - label: >-
23 | I have searched open and closed issues to make sure that the bug has
24 | not yet been reported
25 | required: true
26 | - type: checkboxes
27 | attributes:
28 | label: Issue is a Bug Report
29 | options:
30 | - label: >-
31 | This is a bug report and not a feature request, nor asking for support
32 | required: true
33 | - type: textarea
34 | id: bug-description
35 | attributes:
36 | label: Describe the bug
37 | description: A clear and concise description of what the bug is
38 | placeholder: Tell us what happened!
39 | validations:
40 | required: true
41 | - type: textarea
42 | id: bug-expectation
43 | attributes:
44 | label: Expected behavior
45 | description: A clear and concise description of what you expected to happen
46 | placeholder: Tell us what you expected
47 | validations:
48 | required: true
49 | - type: textarea
50 | id: bug-screenshots
51 | attributes:
52 | label: Screenshots
53 | description: 'If applicable, add screenshots to help explain your problem'
54 | placeholder: Insert screenshots here
55 | - type: textarea
56 | attributes:
57 | label: Environment
58 | description: |
59 | examples:
60 | - **OS**: MacOS
61 | - **Browser**: Firefox
62 | - **Browser Version**: 115
63 | value: |
64 | - OS:
65 | - Browser:
66 | - Browser Version:
67 | render: markdown
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: ❓ Questions
4 | url: https://github.com/RapidAI/RapidRAG/discussions/categories/q-a
5 | about: Please use the community forum for help and questions regarding RapidRAG Docs
6 | - name: 💡 Feature requests and ideas
7 | url: https://github.com/RapidAI/RapidRAG/discussions/new?category=feature-requests
8 | about: Please vote for and post new feature ideas in the community forum
9 | - name: 📖 Documentation
10 | url: https://rapidai.github.io/RapidRAG/
11 | about: A great place to find instructions and answers on how to run your custom RapidRAG.
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.db
2 | assets/models/m3e-small
3 | assets/raw_upload_files
4 | log/
5 |
6 | # Created by .ignore support plugin (hsz.mobi)
7 | ### Python template
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 | .pytest_cache
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | pip-wheel-metadata/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | # *.manifest
42 | # *.spec
43 | *.res
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
140 | #idea
141 | .vs
142 | .vscode
143 | .idea
144 | /images
145 |
146 | #models
147 | *.onnx
148 |
149 | *.ttf
150 | *.ttc
151 |
152 | long1.jpg
153 |
154 | *.bin
155 | *.mapping
156 | *.xml
157 |
158 | *.pdiparams
159 | *.pdiparams.info
160 | *.pdmodel
161 |
162 | .DS_Store
163 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://gitee.com/SWHL/autoflake
3 | rev: v2.1.1
4 | hooks:
5 | - id: autoflake
6 | args:
7 | [
8 | "--recursive",
9 | "--in-place",
10 | "--remove-all-unused-imports",
11 | "--remove-unused-variable",
12 | "--ignore-init-module-imports",
13 | ]
14 | files: \.py$
15 | - repo: https://gitee.com/SWHL/black
16 | rev: 23.1.0
17 | hooks:
18 | - id: black
19 | files: \.py$
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
🧐 Rapid RAG
4 |
5 |

6 |

7 |

8 |

9 |

10 |

11 |
12 | [简体中文](./docs/README_zh.md) | English
13 |
14 |
15 | ### 📣 We're looking for front-end development engineers interested in Knowledge QA with LLM, who can help us achieve front-end and back-end separation with our current implementation
16 |
17 | ### Introduction
18 |
19 | - Questions & Answers based on local knowledge base + LLM.
20 | - Reason:
21 | - The idea of this project comes from [Langchain-Chatchat](https://github.com/chatchat-space/Langchain-Chatchat).
22 | - I have used this project before, but it is not very flexible and deployment is not very friendly.
23 | - Learn from the ideas in [How to build a knowledge question answering system with a large language model](https://mp.weixin.qq.com/s/movaNCWjJGBaes6KxhpYpg), and try to use this as a practice.
24 | - Advantage:
25 | - The whole project is modularized and does not depend on the `lanchain` library, each part can be easily replaced, and the code is simple and easy to understand.
26 | - In addition to the large language model interface that needs to be deployed separately, other parts can use CPU.
27 | - Support documents in common formats, including `txt, md, pdf, docx, pptx, excel` etc. Of course, other types of documents can also be customized and supported.
28 |
29 | ### Demo
30 |
31 | ⚠️ If you have Baidu Account, you can visit the [online demo](https://aistudio.baidu.com/projectdetail/6675380?contributionType=1) based on ERNIE Bot.
32 |
33 |
34 |

35 |
36 |
37 | ### Documentation
38 |
39 | Full documentation can be found on [docs](https://rapidai.github.io/RapidRAG/docs/), in Chinese.
40 |
41 | ### TODO
42 |
43 | - [ ] Support keyword + vector hybrid search.
44 | - [ ] Vue.js based UI .
45 |
46 | ### Code Contributors
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | ### Contributing
55 |
56 | - Pull requests are welcome. For major changes, please open an issue first
57 | to discuss what you would like to change.
58 | - Please make sure to update tests as appropriate.
59 |
60 | ### [Sponsor](https://swhl.github.io/RapidVideOCR/docs/sponsor/)
61 |
62 | If you want to sponsor the project, you can directly click the **Buy me a coffee** image, please write a note (e.g. your github account name) to facilitate adding to the sponsorship list below.
63 |
64 |
65 |

66 |
67 |
68 | ### License
69 |
70 | [Apache 2.0](https://choosealicense.com/licenses/apache-2.0/)
71 |
--------------------------------------------------------------------------------
/ROADMAP.md:
--------------------------------------------------------------------------------
1 | # Roadmap
2 |
3 | ### Standard Evaluation Process
4 |
5 | Before proceeding with feature development and strategy optimization, we need a standard evaluation process to ensure all the features and strategies we introduce are effective.
6 |
7 | Create testsets using any dataset with advanced models and Ragas, then validate solution effectiveness using basic models.
8 |
9 | ### Feature Development and Strategy Optimization
10 |
11 | 1. BM25 Keyword Search
12 | 2. Hybrid Search (BM25 + Vector)
13 | 3. GraphRAG
14 | 4. ReRanking
15 | 5. Query Rewriting
16 | 6. Small-to-big
17 | 7. ...
18 |
--------------------------------------------------------------------------------
/assets/db/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/assets/db/.gitkeep
--------------------------------------------------------------------------------
/assets/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/assets/models/.gitkeep
--------------------------------------------------------------------------------
/cli.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import uuid
5 | from pathlib import Path
6 |
7 | from rapid_rag.encoder import EncodeText
8 | from rapid_rag.file_loader import FileLoader
9 | from rapid_rag.llm import ERNIEBot
10 | from rapid_rag.utils import make_prompt, read_yaml
11 | from rapid_rag.vector_utils import DBUtils
12 |
13 | config = read_yaml("knowledge_qa_llm/config.yaml")
14 |
15 | extract = FileLoader()
16 |
17 | # 解析文档
18 | file_path = "tests/test_files/office/word_example.docx"
19 | text = extract(file_path)
20 | sentences = text.get(Path(file_path).name)
21 |
22 | # 提取特征
23 | model_path = config.get("Encoder")["m3e-small"]
24 | embedding_model = EncodeText(**model_path)
25 | embeddings = embedding_model(sentences)
26 |
27 | # 插入数据到数据库中
28 | db_tools = DBUtils(config.get("vector_db_path"))
29 | uid = str(uuid.uuid1())
30 | db_tools.insert(file_path, embeddings, sentences, uid=uid)
31 |
32 | params = config.get("LLM_API")["ERNIEBot"]
33 | llm_engine = ERNIEBot(**params)
34 |
35 | print("欢迎使用 🧐 Knowledge QA LLM,输入“stop”终止程序 ")
36 | while True:
37 | query = input("\n😀 用户: ")
38 | if query.strip() == "stop":
39 | break
40 |
41 | embedding = embedding_model(query)
42 |
43 | search_res, search_elapse = db_tools.search_local(embedding_query=embedding)
44 |
45 | context = "\n".join(sum(search_res.values(), []))
46 | print(f"上下文:\n{context}\n")
47 |
48 | prompt = make_prompt(query, context, custom_prompt=config.get("DEFAULT_PROMPT"))
49 | response = llm_engine(prompt, history=None)
50 | print(f"🤖 LLM:\n {response}")
51 |
--------------------------------------------------------------------------------
/docs/README_zh.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
🧐 Rapid RAG
4 |
5 |

6 |

7 |

8 |

9 |

10 |

11 |
12 | 简体中文 | [English](../README.md)
13 |
14 |
15 | ### 简介
16 |
17 | 基于本地知识库+LLM的问答系统。该项目的思路是由[langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)启发而来。
18 |
19 | - 缘由:
20 | - 之前使用过这个项目,感觉不是太灵活,部署不太友好。
21 | - 借鉴[如何用大语言模型构建一个知识问答系统](https://mp.weixin.qq.com/s/movaNCWjJGBaes6KxhpYpg)中思路,尝试以此作为实践。
22 | - 优势:
23 | - 整个项目为模块化配置,不依赖`lanchain`库,各部分可轻易替换,代码简单易懂。
24 | - 除需要单独部署大模型接口外,其他部分用CPU即可。
25 | - 支持常见格式文档,包括txt、md、pdf, docx, pptx, excel等等。当然,也可自定义支持其他类型文档。
26 |
27 | ### [Demo](https://aistudio.baidu.com/projectdetail/6675380?contributionType=1)
28 |
29 |
30 |

31 |
32 |
33 | ### 文档
34 |
35 | 完整文档请移步:[docs](https://rapidai.github.io/RapidRAG/docs).
36 |
37 | ### TODO
38 |
39 | - [ ] Support keyword + vector hybrid search.
40 | - [ ] Vue.js based UI .
41 |
42 | ### 贡献者
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 | ### 贡献指南
51 |
52 | 我们感谢所有的贡献者为改进和提升 RapidOCR 所作出的努力。
53 |
54 | - 欢迎提交请求。对于重大更改,请先打开issue讨论您想要改变的内容。
55 | - 请确保适当更新测试。
56 |
57 | ### [赞助](https://rapidai.github.io/RapidRAG/docs/sponsor/)
58 |
59 | 如果您想要赞助该项目,可直接点击当前页最上面的Sponsor按钮,请写好备注(**您的Github账号名称**),方便添加到赞助列表中。
60 |
61 | ### 开源许可证
62 |
63 | 该项目采用[Apache 2.0](https://choosealicense.com/licenses/apache-2.0/)开源许可证。
64 |
--------------------------------------------------------------------------------
/rapid_rag/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 |
--------------------------------------------------------------------------------
/rapid_rag/config.yaml:
--------------------------------------------------------------------------------
1 | title: 🧐 Knowledge QA LLM
2 | version: 0.0.10
3 |
4 | LLM_API:
5 | ERNIEBot:
6 | api_type: aistudio
7 | access_token: your_token
8 | Qwen7B_Chat:
9 | api_url: your_api
10 | ChatGLM2_6B:
11 | api_url: your_api
12 | BaiChuan7B:
13 | api_url: your_api
14 | InternLM_7B:
15 | api_url: your_api
16 |
17 | DEFAULT_PROMPT: 问题是:$query,从下面文章里,找出能回答以上问题的答案。如果文中没有答案,回答“没找到答案”。 文章:$context\n
18 |
19 | upload_dir: assets/raw_upload_files
20 | vector_db_path: assets/db/DefaultVector.db
21 |
22 | encoder_batch_size: 16
23 | Encoder:
24 | ERNIEBot:
25 | api_type: aistudio
26 | access_token: your_token
27 | m3e-small:
28 | model_path: assets/models/m3e-small
29 |
30 | # text splitter
31 | SENTENCE_SIZE: 200
32 |
33 | top_k: 5
34 |
35 | Parameter:
36 | max_length:
37 | min_value: 0
38 | max_value: 4096
39 | default: 1024
40 | step: 1
41 | tip: 生成结果时的最大token数
42 | top_p:
43 | min_value: 0.0
44 | max_value: 1.0
45 | default: 0.7
46 | step: 0.01
47 | tip: 用于控制模型生成文本时,选择下一个单词的概率分布的范围。
48 | temperature:
49 | min_value: 0.01
50 | max_value: 1.0
51 | default: 0.01
52 | step: 0.01
53 | tip: 用于调整模型生成文本时的创造性程度,较高的temperature将使模型更有可能生成新颖、独特的文本,而较低的温度则更有可能生成常见或常规的文本
54 |
--------------------------------------------------------------------------------
/rapid_rag/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .sentence_transformer import EncodeText
5 | from .erniebot import ErnieEncodeText
6 |
--------------------------------------------------------------------------------
/rapid_rag/encoder/erniebot.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import random
5 | import time
6 | from typing import List
7 |
8 | import erniebot
9 | import numpy as np
10 |
11 |
12 | class ErnieEncodeText:
13 | def __init__(self, api_type: str, access_token: str):
14 | erniebot.api_type = api_type
15 | erniebot.access_token = access_token
16 |
17 | def __call__(self, sentences: List[str]):
18 | if not isinstance(sentences, List):
19 | sentences = [sentences]
20 |
21 | time.sleep(random.randint(3, 10))
22 | response = erniebot.Embedding.create(
23 | model="ernie-text-embedding",
24 | input=sentences,
25 | )
26 | datas = response.get("data", None)
27 | if not datas:
28 | return None
29 |
30 | embeddings = np.array([v["embedding"] for v in datas])
31 | return embeddings
32 |
--------------------------------------------------------------------------------
/rapid_rag/encoder/sentence_transformer.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from typing import List, Optional
5 |
6 | from sentence_transformers import SentenceTransformer
7 |
8 |
9 | class EncodeText:
10 | def __init__(self, model_path: Optional[str] = None) -> None:
11 | if model_path is None:
12 | raise EncodeTextError("model_path is None.")
13 | self.model = SentenceTransformer(model_path)
14 |
15 | def __call__(self, sentences: List[str]):
16 | if not isinstance(sentences, List):
17 | sentences = [sentences]
18 | return self.model.encode(sentences)
19 |
20 |
21 | class EncodeTextError(Exception):
22 | pass
23 |
--------------------------------------------------------------------------------
/rapid_rag/file_loader/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .main import FileLoader
5 |
--------------------------------------------------------------------------------
/rapid_rag/file_loader/image_loader.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 | from typing import List, Union
6 |
7 | from rapidocr_onnxruntime import RapidOCR
8 |
9 | from ..text_splitter.chinese_text_splitter import ChineseTextSplitter
10 |
11 |
12 | class ImageLoader:
13 | def __init__(
14 | self,
15 | ):
16 | self.ocr = RapidOCR()
17 | self.splitter = ChineseTextSplitter()
18 |
19 | def __call__(self, img_path: Union[str, Path]) -> List[str]:
20 | ocr_results, _ = self.ocr(img_path)
21 | _, rec_res, _ = list(zip(*ocr_results))
22 | split_contents = [self.splitter.split_text(v) for v in rec_res]
23 | return sum(split_contents, [])
24 |
--------------------------------------------------------------------------------
/rapid_rag/file_loader/main.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 | from typing import Dict, List, Union
6 |
7 | import filetype
8 |
9 | from ..utils import logger
10 | from .image_loader import ImageLoader
11 | from .office_loader import OfficeLoader
12 | from .pdf_loader import PDFLoader
13 | from .txt_loader import TXTLoader
14 |
15 | INPUT_TYPE = Union[str, Path]
16 |
17 |
18 | class FileLoader:
19 | def __init__(self) -> None:
20 | self.file_map = {
21 | "office": ["docx", "doc", "ppt", "pptx", "xlsx", "xlx"],
22 | "image": ["jpg", "png", "bmp", "tif", "jpeg"],
23 | "txt": ["txt", "md"],
24 | "pdf": ["pdf"],
25 | }
26 |
27 | self.img_loader = ImageLoader()
28 | self.office_loader = OfficeLoader()
29 | self.pdf_loader = PDFLoader()
30 | self.txt_loader = TXTLoader()
31 |
32 | def __call__(self, file_path: INPUT_TYPE) -> Dict[str, List[str]]:
33 | all_content = {}
34 |
35 | file_list = self.get_file_list(file_path)
36 | for file_path in file_list:
37 | file_name = file_path.name
38 |
39 | if file_path.suffix[1:] in self.file_map["txt"]:
40 | content = self.txt_loader(file_path)
41 | all_content[file_name] = content
42 | continue
43 |
44 | file_type = self.which_type(file_path)
45 | if file_type in self.file_map["office"]:
46 | content = self.office_loader(file_path)
47 | elif file_type in self.file_map["pdf"]:
48 | content = self.pdf_loader(file_path)
49 | elif file_type in self.file_map["image"]:
50 | content = self.img_loader(file_path)
51 | else:
52 | logger.warning("%s does not support.", file_path)
53 | continue
54 |
55 | all_content[file_name] = content
56 | return all_content
57 |
58 | def get_file_list(self, file_path: INPUT_TYPE):
59 | if not isinstance(file_path, Path):
60 | file_path = Path(file_path)
61 |
62 | if file_path.is_dir():
63 | return file_path.rglob("*.*")
64 | return [file_path]
65 |
66 | @staticmethod
67 | def which_type(content: Union[bytes, str, Path]) -> str:
68 | kind = filetype.guess(content)
69 | if kind is None:
70 | raise TypeError(f"The type of {content} does not support.")
71 |
72 | return kind.extension
73 |
74 | def sorted_by_suffix(self, file_list: List[str]) -> Dict[str, str]:
75 | sorted_res = {k: [] for k in self.file_map}
76 |
77 | for file_path in file_list:
78 | if file_path.suffix[1:] in self.file_map["txt"]:
79 | sorted_res["txt"].append(file_path)
80 | continue
81 |
82 | file_type = self.which_type(file_path)
83 | if file_type in self.file_map["office"]:
84 | sorted_res["office"].append(file_path)
85 | continue
86 |
87 | if file_type in self.file_map["pdf"]:
88 | sorted_res["pdf"].append(file_path)
89 | continue
90 |
91 | if file_type in self.file_map["image"]:
92 | sorted_res["image"].append(file_path)
93 | continue
94 |
95 | return sorted_res
96 |
--------------------------------------------------------------------------------
/rapid_rag/file_loader/office_loader.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 | from typing import Union
6 |
7 | from extract_office_content import ExtractOfficeContent
8 |
9 | from ..text_splitter.chinese_text_splitter import ChineseTextSplitter
10 |
11 |
12 | class OfficeLoader:
13 | def __init__(self) -> None:
14 | self.extracter = ExtractOfficeContent()
15 | self.splitter = ChineseTextSplitter()
16 |
17 | def __call__(self, office_path: Union[str, Path]) -> str:
18 | contents = self.extracter(office_path)
19 | split_contents = [self.splitter.split_text(v) for v in contents]
20 | return sum(split_contents, [])
21 |
--------------------------------------------------------------------------------
/rapid_rag/file_loader/pdf_loader.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 | from typing import List, Union
6 |
7 | from rapidocr_pdf import PDFExtracter
8 |
9 | from ..text_splitter.chinese_text_splitter import ChineseTextSplitter
10 |
11 |
12 | class PDFLoader:
13 | def __init__(
14 | self,
15 | ):
16 | self.extracter = PDFExtracter()
17 | self.splitter = ChineseTextSplitter(pdf=True)
18 |
19 | def __call__(self, pdf_path: Union[str, Path]) -> List[str]:
20 | contents = self.extracter(pdf_path)
21 | split_contents = [self.splitter.split_text(v[1]) for v in contents]
22 | return sum(split_contents, [])
23 |
--------------------------------------------------------------------------------
/rapid_rag/file_loader/txt_loader.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 | from typing import List, Union
6 |
7 | from ..text_splitter.chinese_text_splitter import ChineseTextSplitter
8 | from ..utils.utils import read_txt
9 |
10 |
11 | class TXTLoader:
12 | def __init__(self) -> None:
13 | self.splitter = ChineseTextSplitter()
14 |
15 | def __call__(self, txt_path: Union[str, Path]) -> List[str]:
16 | contents = read_txt(txt_path)
17 | split_contents = [self.splitter.split_text(v) for v in contents]
18 | return sum(split_contents, [])
19 |
--------------------------------------------------------------------------------
/rapid_rag/llm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .baichuan_7b import BaiChuan7B
5 | from .chatglm2_6b import ChatGLM2_6B
6 | from .ernie_bot_turbo import ERNIEBot
7 | from .internlm_7b import InternLM_7B
8 | from .qwen7b_chat import Qwen7B_Chat
9 | from .openai import OpenAI
10 | from .ollama import Ollama
11 |
12 | __all__ = [
13 | "BaiChuan7B",
14 | "ChatGLM2_6B",
15 | "ERNIEBot",
16 | "Qwen7B_Chat",
17 | "InternLM_7B",
18 | "OpenAI",
19 | "Ollama",
20 | ]
21 |
--------------------------------------------------------------------------------
/rapid_rag/llm/baichuan_7b.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import json
5 | from typing import List, Optional
6 |
7 | import requests
8 |
9 |
10 | class BaiChuan7B:
11 | def __init__(self, api_url: str = None):
12 | self.api_url = api_url
13 |
14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs):
15 | if not history:
16 | history = []
17 |
18 | data = {"input_text": prompt}
19 | if kwargs:
20 | temperature = kwargs.get("temperature", 0.1)
21 | top_p = kwargs.get("top_p", 0.7)
22 | max_length = kwargs.get("max_length", 4096)
23 |
24 | data.update(
25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length}
26 | )
27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60)
28 | try:
29 | rdata = req.json()
30 | if rdata["status"] == 200:
31 | return rdata["response"]
32 | return "网络出错"
33 | except Exception as e:
34 | return f"网络出错:{e}"
35 |
36 |
37 | if __name__ == "__main__":
38 | prompt = "你是谁?"
39 | history = []
40 | t = BaiChuan7B()
41 |
42 | res = t(prompt, history)
43 | print(res)
44 |
--------------------------------------------------------------------------------
/rapid_rag/llm/chatglm2_6b.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import json
5 | from typing import List, Optional
6 |
7 | import requests
8 |
9 |
10 | class ChatGLM2_6B:
11 | def __init__(self, api_url: str = None):
12 | self.api_url = api_url
13 |
14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs):
15 | if not history:
16 | history = []
17 |
18 | data = {"prompt": prompt, "history": history}
19 | if kwargs:
20 | temperature = kwargs.get("temperature", 0.1)
21 | top_p = kwargs.get("top_p", 0.7)
22 | max_length = kwargs.get("max_length", 4096)
23 |
24 | data.update(
25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length}
26 | )
27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60)
28 | try:
29 | rdata = req.json()
30 | if rdata["status"] == 200:
31 | return rdata["response"]
32 | return "网络出错"
33 | except Exception as e:
34 | return f"网络出错:{e}"
35 |
36 |
37 | if __name__ == "__main__":
38 | prompt = "你是谁?"
39 | history = []
40 | t = ChatGLM2_6B()
41 |
42 | res = t(prompt, history)
43 | print(res)
44 |
--------------------------------------------------------------------------------
/rapid_rag/llm/ernie_bot_turbo.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from typing import List, Optional
5 |
6 | import erniebot
7 |
8 |
9 | class ERNIEBot:
10 | def __init__(self, api_type: str = None, access_token: str = None):
11 | self.api_type = api_type
12 | self.access_token = access_token
13 |
14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs):
15 | if not history:
16 | history = []
17 |
18 | response = erniebot.ChatCompletion.create(
19 | _config_={
20 | "api_type": self.api_type,
21 | "access_token": self.access_token,
22 | },
23 | model="ernie-bot",
24 | messages=[
25 | {
26 | "role": "user",
27 | "content": prompt,
28 | }
29 | ],
30 | )
31 | result = response.get("result", None)
32 | return result
33 |
--------------------------------------------------------------------------------
/rapid_rag/llm/internlm_7b.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import json
5 | from typing import List, Optional
6 |
7 | import requests
8 |
9 |
10 | class InternLM_7B:
11 | def __init__(self, api_url: str = None):
12 | self.api_url = api_url
13 |
14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs):
15 | if not history:
16 | history = []
17 |
18 | data = {"prompt": prompt, "history": history}
19 | if kwargs:
20 | temperature = kwargs.get("temperature", 0.1)
21 | top_p = kwargs.get("top_p", 0.7)
22 | max_length = kwargs.get("max_length", 4096)
23 |
24 | data.update(
25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length}
26 | )
27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60)
28 | try:
29 | rdata = req.json()
30 | if rdata["status"] == 200:
31 | return rdata["response"]
32 | return "Network error"
33 | except Exception as e:
34 | return f"Network error:{e}"
35 |
--------------------------------------------------------------------------------
/rapid_rag/llm/llama2.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import json
5 | from typing import List, Optional
6 |
7 | import requests
8 |
9 |
10 | class Llama2_7BChat:
11 | def __init__(self, api_url: str = None):
12 | self.api_url = api_url
13 |
14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs):
15 | if not history:
16 | history = []
17 |
18 | data = {"prompt": prompt}
19 | if kwargs:
20 | temperature = kwargs.get("temperature", 0.1)
21 | top_p = kwargs.get("top_p", 0.7)
22 | max_length = kwargs.get("max_length", 4096)
23 |
24 | data.update(
25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length}
26 | )
27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60)
28 | try:
29 | rdata = req.json()
30 | if rdata["status"] == 200:
31 | return rdata["response"]
32 | return "网络出错"
33 | except Exception as e:
34 | return f"网络出错:{e}"
35 |
36 |
37 | if __name__ == "__main__":
38 | prompt = "你是谁?"
39 | history = []
40 | t = BaiChuan7B()
41 |
42 | res = t(prompt, history)
43 | print(res)
44 |
--------------------------------------------------------------------------------
/rapid_rag/llm/ollama.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: Leo Peng
3 | # @Contact: leo@promptcn.com
4 | from typing import List, Optional
5 |
6 | import ollama
7 |
8 |
9 | class Ollama:
10 | def __init__(self, host: str = "http://localhost:11434", model: str = None):
11 | self.host = host
12 | self.model = model
13 | self.client = ollama.Client(host=self.host)
14 |
15 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs):
16 | if not history:
17 | history = []
18 |
19 | response = self.client.chat(
20 | messages=[
21 | {
22 | "role": "user",
23 | "content": prompt,
24 | }
25 | ],
26 | model=self.model,
27 | )
28 | result = response["message"]["content"]
29 | return result
30 |
--------------------------------------------------------------------------------
/rapid_rag/llm/openai.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: Leo Peng
3 | # @Contact: leo@promptcn.com
4 | from typing import List, Optional
5 |
6 | import openai
7 |
8 |
9 | class OpenAI:
10 | def __init__(
11 | self, base_url: str = None, api_key: str = None, model: str = "gpt-4o"
12 | ):
13 | self.base_url = base_url
14 | self.api_key = api_key
15 | self.model = model
16 | self.client = openai.OpenAI(base_url=self.base_url, api_key=self.api_key)
17 |
18 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs):
19 | if not history:
20 | history = []
21 |
22 | response = self.client.chat.completions.create(
23 | messages=[
24 | {
25 | "role": "user",
26 | "content": prompt,
27 | }
28 | ],
29 | model=self.model,
30 | )
31 | result = response.choices[0].message.content
32 | return result
33 |
--------------------------------------------------------------------------------
/rapid_rag/llm/qwen7b_chat.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import json
5 | from typing import List, Optional
6 |
7 | import requests
8 |
9 |
10 | class Qwen7B_Chat:
11 | def __init__(self, api_url: str = None):
12 | self.api_url = api_url
13 |
14 | def __call__(self, prompt: str, history: Optional[List] = None, **kwargs):
15 | if not history:
16 | history = []
17 |
18 | data = {"prompt": prompt, "history": history}
19 | if kwargs:
20 | temperature = kwargs.get("temperature", 0.1)
21 | top_p = kwargs.get("top_p", 0.7)
22 | max_length = kwargs.get("max_length", 4096)
23 |
24 | data.update(
25 | {"temperature": temperature, "top_p": top_p, "max_length": max_length}
26 | )
27 | req = requests.post(self.api_url, data=json.dumps(data), timeout=60)
28 | try:
29 | rdata = req.json()
30 | if rdata["status"] == 200:
31 | return rdata["response"]
32 | return "网络出错"
33 | except Exception as e:
34 | return f"网络出错:{e}"
35 |
36 |
37 | if __name__ == "__main__":
38 | prompt = "你是谁?"
39 | history = []
40 | t = Qwen7B()
41 |
42 | res = t(prompt, history)
43 | print(res)
44 |
--------------------------------------------------------------------------------
/rapid_rag/text_splitter/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 |
--------------------------------------------------------------------------------
/rapid_rag/text_splitter/chinese_text_splitter.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | # Modified from https://github.com/chatchat-space/langchain-ChatGLM/blob/master/configs/model_config.py
5 | import re
6 | from pathlib import Path
7 | from typing import List
8 |
9 | from ..utils.utils import read_yaml
10 |
11 | # knowledge_qa_llm
12 | root_dir = Path(__file__).resolve().parent.parent
13 | config_path = root_dir / "config.yaml"
14 | config = read_yaml(config_path)
15 |
16 |
17 | class ChineseTextSplitter:
18 | def __init__(
19 | self,
20 | pdf: bool = False,
21 | sentence_size: int = config.get("SENTENCE_SIZE"),
22 | ):
23 | self.pdf = pdf
24 | self.sentence_size = sentence_size
25 |
26 | def split_text1(self, text: str) -> List[str]:
27 | if self.pdf:
28 | text = re.sub(r"\n{3,}", "\n", text)
29 | text = re.sub("\s", " ", text)
30 | text = text.replace("\n\n", "")
31 | sent_sep_pattern = re.compile(
32 | '([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))'
33 | ) # del :;
34 | sent_list = []
35 | for ele in sent_sep_pattern.split(text):
36 | ele = ele.strip()
37 | if sent_sep_pattern.match(ele) and sent_list:
38 | sent_list[-1] += ele
39 | elif ele:
40 | sent_list.append(ele)
41 | return sent_list
42 |
43 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
44 | if self.pdf:
45 | text = re.sub(r"\n{3,}", r"\n", text)
46 | text = re.sub("\s", " ", text)
47 | text = re.sub("\n\n", "", text)
48 |
49 | text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) # 单字符断句符
50 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
51 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
52 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text)
53 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
54 | text = text.rstrip() # 段尾如果有多余的\n就去掉它
55 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
56 | ls = [i for i in text.split("\n") if i]
57 | for ele in ls:
58 | if len(ele) > self.sentence_size:
59 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele)
60 | ele1_ls = ele1.split("\n")
61 | for ele_ele1 in ele1_ls:
62 | if len(ele_ele1) > self.sentence_size:
63 | ele_ele2 = re.sub(
64 | r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1
65 | )
66 | ele2_ls = ele_ele2.split("\n")
67 | for ele_ele2 in ele2_ls:
68 | if len(ele_ele2) > self.sentence_size:
69 | ele_ele3 = re.sub(
70 | '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2
71 | )
72 | ele2_id = ele2_ls.index(ele_ele2)
73 | ele2_ls = (
74 | ele2_ls[:ele2_id]
75 | + [i for i in ele_ele3.split("\n") if i]
76 | + ele2_ls[ele2_id + 1 :]
77 | )
78 | ele_id = ele1_ls.index(ele_ele1)
79 | ele1_ls = (
80 | ele1_ls[:ele_id]
81 | + [i for i in ele2_ls if i]
82 | + ele1_ls[ele_id + 1 :]
83 | )
84 |
85 | id = ls.index(ele)
86 | ls = ls[:id] + [i.strip() for i in ele1_ls if i] + ls[id + 1 :]
87 | return ls
88 |
--------------------------------------------------------------------------------
/rapid_rag/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .logger import logger
5 | from .utils import get_timestamp, make_prompt, mkdir, read_yaml
6 |
--------------------------------------------------------------------------------
/rapid_rag/utils/logger.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import functools
5 | import sys
6 | from pathlib import Path
7 |
8 | from loguru import logger
9 |
10 |
11 | @functools.lru_cache()
12 | def get_logger(save_dir: str = "."):
13 | loguru_format = (
14 | "{time:YYYY-MM-DD HH:mm:ss} | "
15 | "{level: <8} | "
16 | "{name}:{line} - {message}"
17 | )
18 |
19 | logger.remove()
20 | logger.add(
21 | sys.stderr,
22 | format=loguru_format,
23 | level="INFO",
24 | enqueue=True,
25 | )
26 | save_file = Path(save_dir) / "{time:YYYY-MM-DD-HH-mm-ss}.log"
27 | logger.add(save_file, rotation=None, retention="5 days")
28 | return logger
29 |
30 |
31 | log_dir = Path(__file__).resolve().parent.parent.parent / "log"
32 | logger = get_logger(str(log_dir))
33 |
--------------------------------------------------------------------------------
/rapid_rag/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from datetime import datetime
5 | from pathlib import Path
6 | from string import Template
7 | from typing import List, Union
8 |
9 | import yaml
10 |
11 |
12 | def make_prompt(query: str, context: str = None, custom_prompt: str = None) -> str:
13 | if context is None:
14 | return query
15 |
16 | if "$query" not in custom_prompt or "$context" not in custom_prompt:
17 | raise ValueError("prompt中必须含有$query和$context两个值")
18 |
19 | msg_template = Template(custom_prompt)
20 | message = msg_template.substitute(query=query, context=context)
21 | return message
22 |
23 |
24 | def read_yaml(yaml_path: Union[str, Path]):
25 | with open(str(yaml_path), "rb") as f:
26 | data = yaml.load(f, Loader=yaml.Loader)
27 | return data
28 |
29 |
30 | def mkdir(dir_path):
31 | Path(dir_path).mkdir(parents=True, exist_ok=True)
32 |
33 |
34 | def get_timestamp():
35 | return datetime.strftime(datetime.now(), "%Y-%m-%d")
36 |
37 |
38 | def read_txt(txt_path: Union[Path, str]) -> List[str]:
39 | if not isinstance(txt_path, str):
40 | txt_path = str(txt_path)
41 |
42 | with open(txt_path, "r", encoding="utf-8") as f:
43 | data = list(map(lambda x: x.rstrip("\n"), f))
44 | return data
45 |
--------------------------------------------------------------------------------
/rapid_rag/vector_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .sqlite_version import DBUtils
5 |
--------------------------------------------------------------------------------
/rapid_rag/vector_utils/sqlite_version.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import io
5 | import sqlite3
6 | import time
7 | from typing import Dict, List, Optional
8 |
9 | import faiss
10 | import numpy as np
11 |
12 | from ..utils.logger import logger
13 |
14 |
15 | def adapt_array(arr):
16 | out = io.BytesIO()
17 | np.save(out, arr)
18 | out.seek(0)
19 | return sqlite3.Binary(out.read())
20 |
21 |
22 | def convert_array(text):
23 | out = io.BytesIO(text)
24 | out.seek(0)
25 | return np.load(out, allow_pickle=True)
26 |
27 |
28 | sqlite3.register_adapter(np.ndarray, adapt_array)
29 | sqlite3.register_converter("array", convert_array)
30 |
31 |
32 | class DBUtils:
33 | def __init__(
34 | self,
35 | db_path: str,
36 | ) -> None:
37 | self.db_path = db_path
38 |
39 | self.table_name = "embedding_texts"
40 | self.vector_nums = 0
41 |
42 | self.max_prompt_length = 4096
43 |
44 | self.connect_db()
45 |
46 | def connect_db(
47 | self,
48 | ):
49 | con = sqlite3.connect(self.db_path, detect_types=sqlite3.PARSE_DECLTYPES)
50 | cur = con.cursor()
51 | cur.execute(
52 | f"create table if not exists {self.table_name} (id integer primary key autoincrement, file_name TEXT, embeddings array UNIQUE, texts TEXT, uids TEXT)"
53 | )
54 | return cur, con
55 |
56 | def load_vectors(self, uid: Optional[str] = None):
57 | cur, _ = self.connect_db()
58 |
59 | search_sql = f"select file_name, embeddings, texts from {self.table_name}"
60 | if uid:
61 | search_sql = f'select file_name, embeddings, texts from {self.table_name} where uids="{uid}"'
62 |
63 | cur.execute(search_sql)
64 | all_vectors = cur.fetchall()
65 |
66 | self.file_names = np.array([v[0] for v in all_vectors])
67 | all_embeddings = np.array([v[1] for v in all_vectors])
68 | self.all_texts = np.array([v[2] for v in all_vectors])
69 |
70 | self.search_index = faiss.IndexFlatL2(all_embeddings.shape[1])
71 | self.search_index.add(all_embeddings)
72 | self.vector_nums = len(all_vectors)
73 |
74 | def count_vectors(
75 | self,
76 | ):
77 | cur, _ = self.connect_db()
78 |
79 | cur.execute(f"select file_name from {self.table_name}")
80 | all_vectors = cur.fetchall()
81 | return len(all_vectors)
82 |
83 | def search_local(
84 | self,
85 | embedding_query: np.ndarray,
86 | top_k: int = 5,
87 | uid: Optional[str] = None,
88 | ) -> Optional[Dict[str, List[str]]]:
89 | s = time.perf_counter()
90 |
91 | cur_vector_nums = self.count_vectors()
92 | if cur_vector_nums == 0:
93 | return None, 0
94 |
95 | if cur_vector_nums != self.vector_nums:
96 | self.load_vectors(uid)
97 |
98 | # cur_vector_nums 小于 top_k 时,返回 cur_vector_nums 个结果
99 | _, I = self.search_index.search(embedding_query, min(top_k, cur_vector_nums))
100 | top_index = I.squeeze().tolist()
101 |
102 | # 处理只有一个结果的情况
103 | if isinstance(top_index, int):
104 | top_index = [top_index]
105 |
106 | search_contents = self.all_texts[top_index]
107 | file_names = [self.file_names[idx] for idx in top_index]
108 | dup_file_names = list(set(file_names))
109 | dup_file_names.sort(key=file_names.index)
110 |
111 | search_res = {v: [] for v in dup_file_names}
112 | for file_name, content in zip(file_names, search_contents):
113 | search_res[file_name].append(content)
114 |
115 | elapse = time.perf_counter() - s
116 | return search_res, elapse
117 |
118 | def insert(
119 | self, file_name: str, embeddings: np.ndarray, texts: List[str], uid: str
120 | ):
121 | cur, con = self.connect_db()
122 |
123 | file_names = [file_name] * len(embeddings)
124 | uids = [uid] * len(embeddings)
125 |
126 | t1 = time.perf_counter()
127 | insert_sql = f"insert or ignore into {self.table_name} (file_name, embeddings, texts, uids) values (?, ?, ?, ?)"
128 | cur.executemany(insert_sql, list(zip(file_names, embeddings, texts, uids)))
129 | elapse = time.perf_counter() - t1
130 | logger.info(
131 | f"Insert {len(embeddings)} data, total is {len(embeddings)}, cost: {elapse:4f}s"
132 | )
133 | con.commit()
134 |
135 | def get_files(self, uid: Optional[str] = None):
136 | cur, _ = self.connect_db()
137 |
138 | if not uid:
139 | return None
140 |
141 | search_sql = (
142 | f'select distinct file_name from {self.table_name} where uids="{uid}"'
143 | )
144 | cur.execute(search_sql)
145 | search_res = cur.fetchall()
146 | search_res = [v[0] for v in search_res]
147 | return search_res
148 |
149 | def clear_db(
150 | self,
151 | ):
152 | cur, con = self.connect_db()
153 |
154 | run_sql = f"delete from {self.table_name}"
155 | cur.execute(run_sql)
156 |
157 | con.commit()
158 | self.connect_db()
159 |
160 | def __enter__(self):
161 | return self
162 |
163 | def __exit__(self, *a):
164 | self.cur.close()
165 | self.con.close()
166 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.21.6
2 | streamlit>=1.25.0
3 | transformers>=4.27.0.dev0,<4.47.0
4 | faiss-cpu
5 | filetype
6 | extract-office-content>=0.0.6
7 | sentence_transformers
8 | rapidocr_onnxruntime
9 | rapidocr_pdf>=0.0.5
10 | loguru
11 | erniebot
12 | openai>=1.58.1
13 | ollama>=0.4.5
14 | ragas>=0.2.9
15 |
--------------------------------------------------------------------------------
/tests/demo_store_embedding.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from extract_office_content import ExtractWord
5 |
6 | from vector_utils import DBUtils, EncodeText
7 |
8 | # 读取文档
9 | word_extract = ExtractWord()
10 |
11 | file_path = "tests/test_files/office/word_example.docx"
12 | text = word_extract(file_path)
13 | sentences = [v.strip() for v in text if v.strip()]
14 |
15 | # 提取特征
16 | model = EncodeText()
17 | embeddings = model(sentences)
18 |
19 | db_path = "db/Vector.db"
20 | db_tools = DBUtils(db_path)
21 |
22 | db_tools.insert(file_path, embeddings, sentences)
23 |
24 | print("ok")
25 |
--------------------------------------------------------------------------------
/tests/test_bge.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from sentence_transformers import SentenceTransformer
5 |
6 | queries = ["手机开不了机怎么办?"]
7 | passages = ["样例段落-1", "样例段落-2"]
8 | instruction = "为这个句子生成表示以用于检索相关文章:"
9 | model = SentenceTransformer("assets/models/bge-small-zh")
10 | q_embeddings = model.encode(
11 | [instruction + q for q in queries], normalize_embeddings=True
12 | )
13 | p_embeddings = model.encode(passages, normalize_embeddings=True)
14 | scores = q_embeddings @ p_embeddings.T
15 |
16 | print(scores)
17 |
--------------------------------------------------------------------------------
/tests/test_chatglm2_6b.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import sys
5 | from pathlib import Path
6 |
7 | cur_dir = Path(__file__).resolve().parent
8 | root_dir = cur_dir.parent
9 | sys.path.append(str(root_dir))
10 |
11 | from rapid_rag.llm import ChatGLM2_6B
12 | from rapid_rag.utils import read_yaml
13 |
14 | config_path = root_dir / "knowledge_qa_llm" / "config.yaml"
15 | config = read_yaml(config_path)
16 |
17 | llm_model = ChatGLM2_6B(config.get("LLM_API")["ChatGLM2_6B"])
18 |
19 |
20 | def test_normal_input():
21 | prompt = "你是谁?"
22 | history = []
23 |
24 | res = llm_model(prompt, history)
25 |
26 | assert (
27 | res
28 | == "我是一个名为 ChatGLM2-6B 的人工智能助手,是基于清华大学 KEG 实验室和智谱 AI 公司于 2023 年共同训练的语言模型开发的。我的任务是针对用户的问题和要求提供适当的答复和支持。"
29 | )
30 |
--------------------------------------------------------------------------------
/tests/test_file_loader.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from rapid_rag.file_loader.main import FileLoader
5 |
6 | loader = FileLoader()
7 |
8 | file_dir = "tests/test_files"
9 |
10 | res = loader(file_dir)
11 | print("ok")
12 |
--------------------------------------------------------------------------------
/tests/test_files/office/excel_with_image.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/office/excel_with_image.xlsx
--------------------------------------------------------------------------------
/tests/test_files/office/ppt_example.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/office/ppt_example.pptx
--------------------------------------------------------------------------------
/tests/test_files/office/word_example.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/office/word_example.docx
--------------------------------------------------------------------------------
/tests/test_files/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/test.jpg
--------------------------------------------------------------------------------
/tests/test_files/test.md:
--------------------------------------------------------------------------------
1 | 我与父亲不相见已二年余了,我最不能忘记的是他的背影。
2 |
3 | 那年冬天,祖母死了,父亲的差使也交卸了,正是祸不单行的日子。我从北京到徐州,打算跟着父亲奔丧回家。到徐州见着父亲,看见满院狼藉的东西,又想起祖母,不禁簌簌地流下眼泪。父亲说:“事已如此,不必难过,好在天无绝人之路!
4 |
--------------------------------------------------------------------------------
/tests/test_files/test.txt:
--------------------------------------------------------------------------------
1 | 我与父亲不相见已二年余了,我最不能忘记的是他的背影。
2 |
3 | 那年冬天,祖母死了,父亲的差使也交卸了,正是祸不单行的日子。我从北京到徐州,打算跟着父亲奔丧回家。到徐州见着父亲,看见满院狼藉的东西,又想起祖母,不禁簌簌地流下眼泪。父亲说:“事已如此,不必难过,好在天无绝人之路!
4 |
--------------------------------------------------------------------------------
/tests/test_files/word_example.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/word_example.pdf
--------------------------------------------------------------------------------
/tests/test_files/长安三万里.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidRAG/47647f80e68a469e6055e362e5a6c6abd8161f80/tests/test_files/长安三万里.pdf
--------------------------------------------------------------------------------
/tests/test_llama2_7b_chat.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from rapid_rag.llm.llama2 import Llama2_7BChat
5 |
6 | api = ""
7 | llm = Llama2_7BChat(api_url=api)
8 |
9 |
10 | prompt = "你是谁?"
11 |
12 | response = llm(prompt)
13 | print(response)
14 |
--------------------------------------------------------------------------------
/tests/test_m3e.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import sys
5 | from pathlib import Path
6 |
7 | cur_dir = Path(__file__).resolve().parent
8 | root_dir = cur_dir.parent
9 | sys.path.append(str(root_dir))
10 |
11 | from rapid_rag.utils import read_yaml
12 | from rapid_rag.vector_utils import EncodeText
13 |
14 | config_path = root_dir / "config.yaml"
15 | config = read_yaml(config_path)
16 | model = EncodeText(config["encoder_model_path"])
17 |
18 |
19 | def test_normal_input():
20 | sentences = [
21 | "* Moka 此文本嵌入模型由 MokaAI 训练并开源,训练脚本使用 uniem",
22 | "* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练",
23 | "* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算,异质文本检索等功能,未来还会支持代码检索,ALL in one",
24 | ]
25 |
26 | embeddings = model(sentences)
27 | assert embeddings.shape == (3, 512)
28 |
--------------------------------------------------------------------------------
/tests/test_office_loader.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import sys
5 | from pathlib import Path
6 |
7 | cur_dir = Path(__file__).resolve().parent
8 | root_dir = cur_dir.parent
9 | sys.path.append(str(root_dir))
10 |
11 | import pytest
12 |
13 | from rapid_rag.file_loader.office_loader import ExtractOfficeLoader
14 |
15 | extracter_office = ExtractOfficeLoader()
16 |
17 |
18 | test_file_dir = cur_dir / "test_files" / "office"
19 |
20 |
21 | @pytest.mark.parametrize(
22 | "file_path, gt1, gt2",
23 | [
24 | ("word_example.docx", 221, "我与父亲不"),
25 | ("ppt_example.pptx", 350, "| 0 "),
26 | ("excel_with_image.xlsx", 361, "| "),
27 | ],
28 | )
29 | def test_extract(file_path, gt1, gt2):
30 | file_path = test_file_dir / file_path
31 | extract_res = extracter_office([file_path])
32 |
33 | assert len(extract_res[0][1][0]) == gt1
34 | assert extract_res[0][1][0][:5] == gt2
35 |
--------------------------------------------------------------------------------
/tests/test_qwen.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from rapid_rag.llm.qwen7b_chat import Qwen7B_Chat
5 |
6 | api = ""
7 | llm = Qwen7B_Chat(api_url=api)
8 |
9 |
10 | prompt = "杭州有哪些景点?"
11 |
12 | response = llm(prompt, history=None)
13 | print(response)
14 |
--------------------------------------------------------------------------------
/tests/test_search.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 |
6 | cur_dir = Path(__file__).resolve().parent
7 |
8 | from rapid_rag.utils import read_yaml
9 | from rapid_rag.vector_utils import DBUtils, EncodeText
10 |
11 | config_path = Path("knowledge_qa_llm") / "config.yaml"
12 | config = read_yaml(config_path)
13 |
14 | model = EncodeText(config["encoder_model_path"])
15 | db = DBUtils(config["vector_db_path"])
16 |
17 | query = "蔡徐坤"
18 | embedding = model(query)
19 | search_res = db.search_local(embedding_query=embedding, top_k=3)
20 |
21 | print(search_res)
22 | print("ok")
23 |
--------------------------------------------------------------------------------
/tests/test_sql_insert.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from rapid_rag.file_loader import FileLoader
5 | from rapid_rag.utils import read_yaml
6 | from rapid_rag.vector_utils import DBUtils, EncodeText
7 |
8 | config = read_yaml("knowledge_qa_llm/config.yaml")
9 |
10 | extract = FileLoader()
11 |
12 | # 解析文档
13 | file_path = "长安三万里.pdf"
14 | text = extract(file_path)
15 | sentences = text[file_path][0]
16 |
17 | # 提取特征
18 | embedding_model = EncodeText(config.get("encoder_model_path"))
19 | embeddings = embedding_model(sentences)
20 |
21 | # 插入数据到数据库中
22 | db_tools = DBUtils(config.get("vector_db_path"))
23 | db_tools.insert(file_path, embeddings, sentences)
24 |
--------------------------------------------------------------------------------
/webui.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import importlib
5 | import shutil
6 | import time
7 | import uuid
8 | from pathlib import Path
9 | from typing import Dict
10 |
11 | import numpy as np
12 | import streamlit as st
13 |
14 | from rapid_rag.encoder import EncodeText, ErnieEncodeText
15 | from rapid_rag.file_loader import FileLoader
16 | from rapid_rag.utils import get_timestamp, logger, make_prompt, mkdir, read_yaml
17 | from rapid_rag.vector_utils import DBUtils
18 |
19 | config = read_yaml("knowledge_qa_llm/config.yaml")
20 |
21 | st.set_page_config(
22 | page_title=config.get("title"),
23 | page_icon=":robot:",
24 | )
25 |
26 |
27 | def init_ui_parameters():
28 | st.session_state["params"] = {}
29 | param = config.get("Parameter")
30 |
31 | st.sidebar.markdown("### 🛶 参数设置")
32 |
33 | param_max_length = param.get("max_length")
34 | max_length = st.sidebar.slider(
35 | "max_length",
36 | min_value=param_max_length.get("min_value"),
37 | max_value=param_max_length.get("max_value"),
38 | value=param_max_length.get("default"),
39 | step=param_max_length.get("step"),
40 | help=param_max_length.get("tip"),
41 | )
42 | st.session_state["params"]["max_length"] = max_length
43 |
44 | param_top = param.get("top_p")
45 | top_p = st.sidebar.slider(
46 | "top_p",
47 | min_value=param_top.get("min_value"),
48 | max_value=param_top.get("max_value"),
49 | value=param_top.get("default"),
50 | step=param_top.get("step"),
51 | help=param_top.get("tip"),
52 | )
53 | st.session_state["params"]["top_p"] = top_p
54 |
55 | param_temp = param.get("temperature")
56 | temperature = st.sidebar.slider(
57 | "temperature",
58 | min_value=param_temp.get("min_value"),
59 | max_value=param_temp.get("max_value"),
60 | value=param_temp.get("default"),
61 | step=param_temp.get("stemp"),
62 | help=param_temp.get("tip"),
63 | )
64 | st.session_state["params"]["temperature"] = temperature
65 |
66 |
67 | def init_ui_db():
68 | st.sidebar.markdown("### 🧻 知识库")
69 | uploaded_files = st.sidebar.file_uploader(
70 | "default",
71 | accept_multiple_files=True,
72 | label_visibility="hidden",
73 | help="支持多个文件的选取",
74 | )
75 |
76 | upload_dir = config.get("upload_dir")
77 | btn_upload = st.sidebar.button("上传文档并加载")
78 | if btn_upload:
79 | time_stamp = get_timestamp()
80 | doc_dir = Path(upload_dir) / time_stamp
81 |
82 | tips("正在上传文件到平台中...", icon="⏳")
83 | for file_data in uploaded_files:
84 | bytes_data = file_data.getvalue()
85 |
86 | mkdir(doc_dir)
87 | save_path = doc_dir / file_data.name
88 | with open(save_path, "wb") as f:
89 | f.write(bytes_data)
90 | tips("上传完毕!")
91 |
92 | with st.spinner(f"正在从{doc_dir}提取内容...."):
93 | all_doc_contents = file_loader(doc_dir)
94 |
95 | pro_text = "提取语义向量..."
96 | batch_size = config.get("encoder_batch_size", 16)
97 | uid = str(uuid.uuid1())
98 | st.session_state["connect_id"] = uid
99 | for file_path, one_doc_contents in all_doc_contents.items():
100 | my_bar = st.sidebar.progress(0, text=pro_text)
101 | content_nums = len(one_doc_contents)
102 | all_embeddings = []
103 | for i in range(0, content_nums, batch_size):
104 | start_idx = i
105 | end_idx = start_idx + batch_size
106 | end_idx = content_nums if end_idx > content_nums else end_idx
107 |
108 | cur_contents = one_doc_contents[start_idx:end_idx]
109 | if not cur_contents:
110 | continue
111 |
112 | embeddings = embedding_extract(cur_contents)
113 | if embeddings is None or embeddings.size == 0:
114 | continue
115 |
116 | all_embeddings.append(embeddings)
117 | my_bar.progress(
118 | end_idx / content_nums,
119 | f"Extract {file_path} datas: [{end_idx}/{content_nums}]",
120 | )
121 | my_bar.empty()
122 |
123 | if all_embeddings:
124 | all_embeddings = np.vstack(all_embeddings)
125 | db_tools.insert(file_path, all_embeddings, one_doc_contents, uid)
126 | else:
127 | tips(f"从{file_path}提取向量为空。")
128 |
129 | shutil.rmtree(doc_dir.resolve())
130 | tips("现在可以提问问题了哈!")
131 |
132 | clear_db_btn = st.sidebar.button("清空知识库")
133 | if clear_db_btn:
134 | db_tools.clear_db()
135 | tips("知识库已经被清空!")
136 |
137 | if "connect_id" in st.session_state:
138 | had_files = db_tools.get_files(uid=st.session_state.connect_id)
139 | else:
140 | had_files = db_tools.get_files()
141 |
142 | st.session_state.had_file_nums = len(had_files) if had_files else 0
143 | if had_files:
144 | st.sidebar.markdown("已有文档:")
145 | st.sidebar.markdown("\n".join([f" - {v}" for v in had_files]))
146 |
147 |
148 | @st.cache_resource
149 | def init_encoder(encoder_name: str, **kwargs):
150 | if "ERNIEBot" in encoder_name:
151 | return ErnieEncodeText(**kwargs)
152 | return EncodeText(**kwargs)
153 |
154 |
155 | def predict(
156 | text,
157 | search_res,
158 | model,
159 | custom_prompt=None,
160 | ):
161 | for file, content in search_res.items():
162 | content = "\n".join(content)
163 | one_context = f"**从《{file}》** 检索到相关内容: \n{content}"
164 | bot_print(one_context, avatar="📄")
165 |
166 | logger.info(f"Context:\n{one_context}\n")
167 |
168 | context = "\n".join(sum(search_res.values(), []))
169 | response, elapse = get_model_response(text, context, custom_prompt, model)
170 |
171 | print_res = f"**推理耗时:{elapse:.5f}s**"
172 | bot_print(print_res, avatar="📄")
173 | bot_print(response)
174 |
175 |
176 | def predict_only_model(text, model):
177 | params_dict = st.session_state["params"]
178 | response = model(text, history=None, **params_dict)
179 | bot_print(response)
180 |
181 |
182 | def bot_print(content, avatar: str = "🤖"):
183 | with st.chat_message("assistant", avatar=avatar):
184 | message_placeholder = st.empty()
185 | full_response = ""
186 | for chunk in content.split():
187 | full_response += chunk + " "
188 | time.sleep(0.05)
189 | message_placeholder.markdown(full_response + "▌")
190 | message_placeholder.markdown(full_response)
191 |
192 |
193 | def get_model_response(text, context, custom_prompt, model):
194 | params_dict = st.session_state["params"]
195 |
196 | s_model = time.perf_counter()
197 | prompt_msg = make_prompt(text, context, custom_prompt)
198 | logger.info(f"Final prompt: \n{prompt_msg}\n")
199 |
200 | response = model(prompt_msg, history=None, **params_dict)
201 | elapse = time.perf_counter() - s_model
202 |
203 | logger.info(f"Reponse of LLM: \n{response}\n")
204 | if not response:
205 | response = "抱歉,我并不能正确回答该问题。"
206 | return response, elapse
207 |
208 |
209 | def tips(txt: str, wait_time: int = 2, icon: str = "🎉"):
210 | st.toast(txt, icon=icon)
211 | time.sleep(wait_time)
212 |
213 |
214 | if __name__ == "__main__":
215 | title = config.get("title")
216 | version = config.get("version", "0.0.1")
217 | st.markdown(
218 | f"{title} v{version}
",
219 | unsafe_allow_html=True,
220 | )
221 |
222 | init_ui_parameters()
223 |
224 | file_loader = FileLoader()
225 |
226 | db_path = config.get("vector_db_path")
227 | db_tools = DBUtils(db_path)
228 |
229 | llm_module = importlib.import_module("knowledge_qa_llm.llm")
230 | llm_params: Dict[str, Dict] = config.get("LLM_API")
231 |
232 | menu_col1, menu_col2, menu_col3 = st.columns([1, 1, 1])
233 | select_model = menu_col1.selectbox("🎨LLM:", llm_params.keys())
234 | if "ERNIEBot" in select_model:
235 | with st.expander("LLM ErnieBot", expanded=True):
236 | opt_col1, opt_col2 = st.columns([1, 1])
237 | api_type = opt_col1.selectbox(
238 | "API Type(必选)",
239 | options=["aistudio", "qianfan", "yinian"],
240 | help="提供对话能力的后端平台",
241 | )
242 | access_token = opt_col2.text_input(
243 | "Access Token(必填) [如何获得?](https://github.com/PaddlePaddle/ERNIE-Bot-SDK/blob/develop/docs/authentication.md)",
244 | "",
245 | help="用于访问后端平台的access token(参考使用说明获取),如果设置了AK、SK则无需设置此参数",
246 | )
247 | llm_params[select_model]["api_type"] = api_type
248 |
249 | if access_token:
250 | llm_params[select_model]["access_token"] = access_token
251 |
252 | MODEL_OPTIONS = {
253 | name: getattr(llm_module, name)(**params) for name, params in llm_params.items()
254 | }
255 |
256 | encoder_params = config.get("Encoder")
257 | select_encoder = menu_col2.selectbox("🧬提取向量模型:", encoder_params.keys())
258 | if "ERNIEBot" in select_encoder:
259 | with st.expander("提取语义向量 ErnieBot", expanded=True):
260 | opt_col1, opt_col2 = st.columns([1, 1])
261 | extract_api_type = opt_col1.selectbox(
262 | "API Type(必选)",
263 | options=["aistudio", "qianfan", "yinian"],
264 | help="提供对话能力的后端平台",
265 | key="Extract_type",
266 | )
267 | encoder_params[select_encoder]["api_type"] = extract_api_type
268 |
269 | extract_access_token = opt_col2.text_input(
270 | "Access Token(必填) [如何获得?](https://github.com/PaddlePaddle/ERNIE-Bot-SDK/blob/develop/docs/authentication.md)",
271 | "",
272 | help="用于访问后端平台的access token(参考使用说明获取),如果设置了AK、SK则无需设置此参数",
273 | key="Extract_token",
274 | )
275 | if extract_access_token:
276 | encoder_params[select_encoder]["access_token"] = extract_access_token
277 |
278 | embedding_extract = init_encoder(select_encoder, **encoder_params[select_encoder])
279 |
280 | TOP_OPTIONS = [5, 10, 15]
281 | search_top = menu_col3.selectbox("🔍搜索 Top_K:", TOP_OPTIONS)
282 |
283 | init_ui_db()
284 |
285 | with st.expander("💡Prompt", expanded=False):
286 | text_area = st.empty()
287 | input_prompt = text_area.text_area(
288 | label="Input",
289 | max_chars=500,
290 | height=200,
291 | label_visibility="hidden",
292 | value=config.get("DEFAULT_PROMPT"),
293 | key="input_prompt",
294 | )
295 |
296 | input_txt = st.chat_input("问点啥吧!")
297 | if input_txt:
298 | with st.chat_message("user", avatar="😀"):
299 | st.markdown(input_txt)
300 |
301 | llm = MODEL_OPTIONS[select_model]
302 |
303 | if not input_prompt:
304 | input_prompt = config.get("DEFAULT_PROMPT")
305 |
306 | query_embedding = embedding_extract(input_txt)
307 | with st.spinner("正在搜索相关文档..."):
308 | uid = st.session_state.get("connect_id", None)
309 | search_res, search_elapse = db_tools.search_local(
310 | query_embedding, top_k=search_top, uid=uid
311 | )
312 |
313 | if search_res is None:
314 | bot_print("从知识库中抽取结果为空,直接采用LLM的本身能力回答。", avatar="📄")
315 | predict_only_model(input_txt, llm)
316 | else:
317 | logger.info(f"使用 {type(llm).__name__}")
318 |
319 | res_cxt = f"**Top{search_top}\n(得分从高到低,耗时:{search_elapse:.5f}s):** \n"
320 | bot_print(res_cxt, avatar="📄")
321 |
322 | predict(
323 | input_txt,
324 | search_res,
325 | llm,
326 | input_prompt,
327 | )
328 |
--------------------------------------------------------------------------------