├── .gitignore ├── .streamlit └── config.toml ├── LICENSE ├── README.md ├── README_zh.md ├── SECURITY.md ├── app.py ├── config.py ├── docs ├── Code_of_Conduct.md ├── HowToDownloadModels.md ├── HowToUsePythonVirtualEnv.md └── images │ ├── KB_File.png │ ├── KB_Manage.png │ ├── KB_Web.png │ ├── Model_LLM.png │ ├── Model_Reranker.png │ ├── Query.png │ ├── Settings_Advanced.png │ └── ThinkRAG_Architecture.png ├── frontend ├── Document_QA.py ├── KB_File.py ├── KB_Manage.py ├── KB_Web.py ├── Model_Embed.py ├── Model_LLM.py ├── Model_Rerank.py ├── Setting_Advanced.py ├── Storage.py ├── images │ └── ThinkRAG_Logo.png └── state.py ├── requirements.txt └── server ├── engine.py ├── index.py ├── ingestion.py ├── models ├── embedding.py ├── llm_api.py ├── ollama.py └── reranker.py ├── prompt.py ├── readers ├── beautiful_soup_web.py └── jina_web.py ├── retriever.py ├── splitters ├── __init__.py ├── chinese_recursive_text_splitter.py ├── chinese_text_splitter.py └── zh_title_enhance.py ├── stores ├── chat_store.py ├── config_store.py ├── doc_store.py ├── index_store.py ├── ingestion_cache.py ├── strage_context.py └── vector_store.py ├── text_splitter.py └── utils ├── file.py └── hf_mirror.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Local models 2 | localmodels/ 3 | 4 | # Uploaded files 5 | data/ 6 | 7 | # Local Storage 8 | storage/ 9 | .chroma/ 10 | .lancedb/ 11 | # Test files 12 | test_*.* 13 | 14 | # Python venv 15 | .venv/ 16 | bin/ 17 | include/ 18 | lib/ 19 | pyvenv.cfg 20 | etc/ 21 | 22 | # The following are ignored by default. 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | *$py.class 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | cover/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | db.sqlite3-journal 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | .pybuilder/ 99 | target/ 100 | 101 | # Jupyter Notebook 102 | .ipynb_checkpoints 103 | 104 | # IPython 105 | profile_default/ 106 | ipython_config.py 107 | 108 | # pyenv 109 | # For a library or package, you might want to ignore these files since the code is 110 | # intended to run in multiple environments; otherwise, check them in: 111 | # .python-version 112 | 113 | # pipenv 114 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 115 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 116 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 117 | # install all needed dependencies. 118 | #Pipfile.lock 119 | 120 | # poetry 121 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 122 | # This is especially recommended for binary packages to ensure reproducibility, and is more 123 | # commonly ignored for libraries. 124 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 125 | #poetry.lock 126 | 127 | # pdm 128 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 129 | #pdm.lock 130 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 131 | # in version control. 132 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 133 | .pdm.toml 134 | .pdm-python 135 | .pdm-build/ 136 | 137 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 138 | __pypackages__/ 139 | 140 | # Celery stuff 141 | celerybeat-schedule 142 | celerybeat.pid 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | # Pyre type checker 172 | .pyre/ 173 | 174 | # pytype static type analyzer 175 | .pytype/ 176 | 177 | # Cython debug symbols 178 | cython_debug/ 179 | 180 | # .DS_Store files 181 | **/.DS_Store 182 | 183 | # PyCharm 184 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 185 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 186 | # and can be added to the global gitignore or merged into this file. For a more nuclear 187 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 188 | #.idea/ 189 | -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [client] 2 | toolbarMode = "minimal" 3 | showSidebarNavigation = false 4 | 5 | [theme] 6 | primaryColor = "#F63366" 7 | backgroundColor = "white" 8 | font = "sans serif" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 David Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | English | 3 | 简体中文 4 |

5 | 6 |
7 | 8 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](./LICENSE) [![support: Ollama](https://img.shields.io/badge/Support-Ollama-green.svg)](https://ollama.com/) [![support: LlamaIndex](https://img.shields.io/badge/Support-LlamaIndex-purple.svg)](https://www.llamaindex.ai/) 9 | 10 |
11 | 12 | ### Table of Contents 13 | 14 | - 🤔 [Overview](#What-is-ThinkRAG) 15 | - ✨ [Features](#Key-Features) 16 | - 🧸 [Model Support](#Support-Models) 17 | - 🛫 [Quick Start](#quick-start) 18 | - 📖 [User Guide](#Instructions) 19 | - 🔬 [Architecture](#Architecture) 20 | - 📜 [Roadmap](#Roadmap) 21 | - 📄 [License](#License) 22 | 23 |
24 | 25 | # ThinkRAG 26 | 27 | ThinkRAG is a LLM RAG system that can be easily deployed on a laptop to implement Q&A with local knowledge base. 28 | 29 | This system is built on LlamaIndex and Streamlit, and has been optimized for Chinese users in various fields such as model selection and text processing. 30 | 31 |
32 | 33 | # Key Features 34 | 35 | ThinkRAG is a LLM application developed for professionals, researchers, students, and other knowledge workers, which can be used directly on a laptop with all knowledge and data stored locally on the computer. 36 | 37 | ThinkRAG has the following features: 38 | - Complete application of the LlamaIndex framework 39 | - Development mode supports local file storage without the need to install any databases 40 | - No GPU support is required to run on a laptop 41 | - Supports locally deployed models and offline use 42 | 43 | Specifically, ThinkRAG has also made a lot of customizations and optimizations for Chinese users: 44 | - Uses Spacy text splitter for better handling of Chinese characters 45 | - Employs Chinese title enhancement features 46 | - Uses Chinese prompt templates for Q&A and refinement processes 47 | - Default support for China LLM service provider such as DeepSeek, Moonshot and ZhiPu 48 | - Uses bilingual embedding models, such as bge-large-zh-v1.5 from BAAI 49 | 50 |
51 | 52 | # Model Support 53 | 54 | ThinkRAG can use all models supported by the LlamaIndex data framework. For model list information, please refer to [relevant documentation](https://docs.llamaindex.ai/en/stable/module_guides/models/llms/modules/). 55 | 56 | ThinkRAG is committed to creating an application system that is directly usable, useful, and easy to use. 57 | 58 | Therefore, we have made careful selections and trade-offs in various models, components, and technologies. 59 | 60 | Firstly, using large models, ThinkRAG supports the OpenAI API and all compatible LLM APIs, including LLM service providers in China, such as: 61 | 62 | - DeepSeek 63 | - Moonshot 64 | - ZhiPu 65 | - ... 66 | 67 | If you want to deploy LLMs locally, ThinkRAG chooses Ollama, which is easy to use. We can download models to run locally through Ollama. 68 | 69 | Currently, Ollama supports the local deployment of almost all large language models, including Llama, Gemma, GLM, Mistral, Phi, Llava, etc. For details, please visit the [Ollama official website](https://ollama.com/). 70 | 71 | The system also uses embedding models and reranking models, which can support most models from Hugging Face. Currently, ThinkRAG mainly selects the BGE series models from BAAI. Chinese users can visit the [mirror website](https://hf-mirror.com/BAAI) to learn about and download them. 72 | 73 | ## Known Issues 74 | 75 | Currently, there are issues from Windows users that have not been reproduced or resolved. Please use ThinkRAG on Linux or MacOS systems. 76 | 77 | Due to incompatibility between llama_index and the latest ollama 0.4, please install ollama 0.3.3, which is reflected in the requirements.txt file. 78 | 79 |
80 | 81 | # Quick Start 82 | 83 | ## Step 1 Download and Installation 84 | 85 | After downloading the code from Github, use pip to install the required components. 86 | ```zsh 87 | pip3 install -r requirements.txt 88 | ``` 89 | If you want to run the system offline, please first download Ollama from the official website. Then, use the Ollama command to download LLMs such as DeepSeek, Qwen, and Gemma. 90 | 91 | Then, download the embedding model (BAAI/bge-large-zh-v1.5) and reranking model (BAAI/bge-reranker-base) from Hugging Face to the `localmodels` directory. 92 | 93 | For specific steps, please refer to the document in the `docs` directory: HowToDownloadModels.md 94 | 95 | ## Step 2 System Configuration 96 | 97 | For better performance, it is recommended to use commercial LLM APIs. 98 | 99 | First, obtain the API key from the LLM service provider and configure the following environment variables. 100 | 101 | ```zsh 102 | OPENAI_API_KEY = "" 103 | DEEPSEEK_API_KEY = "" 104 | MOONSHOT_API_KEY = "" 105 | ZHIPU_API_KEY = "" 106 | ``` 107 | 108 | You can skip this step and configure the API keys through the application interface after the system is running. 109 | 110 | If you choose to use one or more LLM APIs, please delete the unused service providers in the config.py configuration file. 111 | 112 | Of course, you can also add other service providers compatible with the OpenAI API in the configuration file. 113 | 114 | ThinkRAG runs in development mode by default. In this mode, the system uses local file storage, and you do not need to install any databases. 115 | 116 | If you want to switch to production mode, you can configure the environment variables as follows. 117 | 118 | ```zsh 119 | THINKRAG_ENV = production 120 | ``` 121 | 122 | In production mode, the system uses vector databases Chroma or LanceDB, and key-value databases Redis. 123 | 124 | If you do not have Redis installed, it is recommended to install it through Docker or use an existing Redis instance. Please configure the parameters of the Redis instance in the config.py file. 125 | 126 | ## Step 3 Running the System 127 | 128 | Now, you are ready to run ThinkRAG. 129 | 130 | Please run the following command in the directory containing the app.py file. 131 | 132 | ```zsh 133 | streamlit run app.py 134 | ``` 135 | 136 | The system will run and automatically open the following URL in the browser to display the application interface. 137 | 138 | http://localhost:8501/ 139 | 140 | The first run may take a moment. If you have not downloaded the embedding model from Hugging Face in advance, the system will automatically download the model, which will take a longer time. 141 | 142 |
143 | 144 | # User Guide 145 | 146 | ## 1. System Configuration 147 | 148 | ThinkRAG supports configuration and selection of large models in the user interface, including: the Base URL and API key of the large model LLM API, and the specific model to be used, such as ZhiPu's glm-4. 149 | 150 |
151 | file_uploads 152 | 153 |
154 | 155 | The system will automatically detect whether the API and key are available, and if available, display the current selected large model instance in green text at the bottom. 156 | 157 | Similarly, the system can automatically obtain models downloaded by Ollama, and users can select the required models on the user interface. 158 | 159 |
160 | file_uploads 161 | 162 |
163 | 164 | If you have already downloaded the embedding model and reranking model to the local localmodels directory, you can switch and select the model to be used on the user interface, and set the parameters of the reranking model, such as Top N. 165 | 166 |
167 | file_uploads 168 | 169 |
170 | 171 | In the left navigation bar, click on Advanced Settings (Settings-Advanced), and you can also set the following parameters: 172 | - Top K 173 | - Temperature 174 | - System Prompt 175 | - Response Mode 176 | 177 | By using different parameters, we can compare the output results of large models and find the most effective parameter combination. 178 | 179 | ## 2. Knowledge Base Management 180 | 181 | ThinkRAG supports uploading various types of files such as PDF, DOCX, PPTX, and also supports uploading web page URLs. 182 | 183 |
184 | file_uploads 185 | 186 |
187 | 188 | Click the `Browse files` button, select the files on your computer, and then click the Load button to load, at which point all loaded files will be listed. 189 | 190 | Then, click the `Save` button, and the system will process the files, including text splitting and embedding, and save them to the knowledge base. 191 | 192 |
193 | file_uploads 194 | 195 |
196 | 197 | Similarly, you can enter or paste the web page URL to obtain web page information, process it, and save it to the knowledge base. 198 | 199 | The system supports knowledge base management. 200 | 201 |
202 | file_uploads 203 | 204 |
205 | 206 | As shown in the figure above, ThinkRAG can list all documents in the knowledge base in pages. 207 | 208 | Select the document to be deleted, and the Delete selected documents button will appear, click the button to delete the document from the knowledge base. 209 | 210 | ## 3. Query 211 | 212 | In the left navigation bar, click on `Query`, and the Q&A page will appear. 213 | 214 | After entering the question, the system will search the knowledge base and provide an answer. In this process, the system will use hybrid retrieval and reranking technologies to obtain accurate content from the knowledge base. 215 | 216 | For example, we have uploaded a Word document about business process management to the knowledge base. 217 | 218 | Now enter the question: "What are the three characteristics of the process?" 219 | 220 |
221 | file_uploads 222 | 223 |
224 | 225 | As shown in the figure, the system took 2.49 seconds to provide an accurate answer: The process has the characteristics of goal orientation, repetition, and processuality. At the same time, the system also provided 2 reference documents retrieved from the knowledge base. 226 | 227 | It can be seen that ThinkRAG fully and effectively implements the function of large model retrieval and enhanced generation based on the local knowledge base. 228 | 229 |
230 | 231 | # Architecture 232 | 233 | ThinkRAG is developed using the LlamaIndex data framework, with Streamlit for the front end. The development mode and production mode of the system use different technical components, as shown in the table below: 234 | 235 | | |Development Mode|Production Mode| 236 | |:----|:----|:----| 237 | |RAG Framework|LlamaIndex|LlamaIndex| 238 | |Frontend Framework|Streamlit|Streamlit| 239 | |Embedding Model|BAAI/bge-small-zh-v1.5|BAAI/bge-large-zh-v1.5| 240 | |Reranking Model|BAAI/bge-reranker-base|BAAI/bge-reranker-large| 241 | |Text Splitter|SentenceSplitter|SpacyTextSplitter| 242 | |Conversation Storage|SimpleChatStore|Redis| 243 | |Document Storage|SimpleDocumentStore|Redis| 244 | |Index Storage|SimpleIndexStore|Redis| 245 | |Vector Storage|SimpleVectorStore|LanceDB| 246 | 247 | These technical components are designed according to the Frontend, Framework, LLM, Tools, Storage, and Infrastructure, which are six parts of the architecture. 248 | 249 | As shown in the figure below: 250 | 251 |
252 | file_uploads 253 | 254 |
255 | 256 |
257 | 258 | # Roadmap 259 | 260 | ThinkRAG will continue to optimize core functions and continuously improve the efficiency and accuracy of retrieval, including: 261 | 262 | - Optimize the processing of documents and web pages, support multimodal knowledge bases and multimodal retrieval 263 | - Build a knowledge graph, enhance retrieval through the knowledge graph, and reason based on the graph 264 | - Process complex scenarios through intelligent agents, especially accurately calling other tools and data to complete tasks 265 | 266 | At the same time, we will further improve the application architecture and enhance the user experience, mainly including: 267 | - Design: A user interface with a sense of design and excellent user experience 268 | - Frontend: Build a desktop client application based on Electron, React, Vite, etc., to provide users with an extremely simple way to download, install, and run 269 | - Backend: Provide interfaces through FastAPI, and improve overall performance and scalability through technologies such as message queues 270 | 271 | Welcome to join the ThinkRAG open source project, and together create AI products that users love! 272 | 273 |
274 | 275 | # License 276 | 277 | ThinkRAG uses the [MIT License](LICENSE). -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 |

2 | English | 3 | 简体中文 4 |

5 | 6 |
7 | 8 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](./LICENSE) [![support: Ollama](https://img.shields.io/badge/Support-Ollama-green.svg)](https://ollama.com/) [![support: LlamaIndex](https://img.shields.io/badge/Support-LlamaIndex-purple.svg)](https://www.llamaindex.ai/) 9 | 10 |
11 | 12 | ### 目录 13 | 14 | - 🤔 [项目简介](#What-is-ThinkRAG) 15 | - ✨ [主要特性](#Key-Features) 16 | - 🧸 [模型支持](#Support-Models) 17 | - 🛫 [快速开始](#quick-start) 18 | - 📖 [使用指南](#Instructions) 19 | - 🔬 [技术架构](#Architecture) 20 | - 📜 [开发计划](#Roadmap) 21 | - 📄 [许可协议](#License) 22 | 23 |
24 | 25 | # ThinkRAG 26 | 27 | ThinkRAG 大模型检索增强生成系统,可以轻松部署在笔记本电脑上,实现本地知识库智能问答。 28 | 29 | 该系统基于 LlamaIndex 和 Streamlit 构建,针对国内用户在模型选择、文本处理等诸多领域进行了优化。 30 | 31 |
32 | 33 | # 主要特性 34 | 35 | ThinkRAG 是为专业人士、科研人员、学生等知识工作者开发的大模型应用系统,可在笔记本电脑上直接使用,且知识库数据都保存在电脑本地。 36 | 37 | ThinkRAG 具备以下特点: 38 | - LlamaIndex 框架的完整应用 39 | - 开发模式支持本地文件存储,无需安装任何数据库 40 | - 无需 GPU 支持,即可在笔记本电脑上运行 41 | - 支持本地部署的模型和离线使用 42 | 43 | 特别地,ThinkRAG 还为国内用户做了大量定制和优化: 44 | - 使用 Spacy 文本分割器,更好地处理中文字符 45 | - 采用中文标题增强功能 46 | - 使用中文提示词模板进行问答和细化过程 47 | - 默认支持国内大模型厂商,如DeepSeek,Moonshot和Zhipu等 48 | - 使用双语嵌入模型,如 BAAI的bge-large-zh-v1.5 49 | 50 |
51 | 52 | # 模型支持 53 | 54 | ThinkRAG 可使用 LlamaIndex 数据框架支持的所有模型。关于模型列表信息,请参考[相关文档](https://docs.llamaindex.ai/en/stable/module_guides/models/llms/modules/)。 55 | 56 | ThinkRAG致力于打造一个直接能用、有用、易用的应用系统。 57 | 58 | 因此,在各种模型、组件与技术上,我们做了精心的选择与取舍。 59 | 60 | 首先,使用大模型,ThinkRAG支持OpenAI API 以及所有兼容的 LLM API,包括国内主流大模型厂商,例如: 61 | 62 | - 深度求索(DeepSeek) 63 | - 月之暗面(Moonshot) 64 | - 智谱(Zhipu) 65 | - …… 66 | 67 | 如果要本地化部署大模型,ThinkRAG 选用了简单易用的 Ollama。我们可以从通过 Ollama 将大模型下载到本地运行。 68 | 69 | 目前 Ollama 支持几乎所有主流大模型本地化部署,包括 DeepSeek、Llama、Gemma、GLM 、Mistral、Phi、Llava等。具体可访问以下 [Ollama 官网](https://ollama.com/)了解。 70 | 71 | 系统也使用了嵌入模型和重排模型,可支持来自 Hugging Face 的大多数模型。目前,ThinkRAG主要选用了BAAI的BGE系列模型。国内用户可访问[镜像网址](https://hf-mirror.com/BAAI)了解和下载。 72 | 73 | ## 已知问题 74 | 75 | 目前有Windows用户报告有问题,尚未复现和解决,请选择Linux或MacOS系统上使用ThinkRAG。 76 | 77 | 由于LlamaIndex与最新的ollama 0.4未完成兼容,请安装使用ollama 0.3.3,指定版本已在requirements.txt中体现。 78 | 79 |
80 | 81 | # 快速开始 82 | 83 | ## Step 1 下载与安装 84 | 85 | 从Github下载代码后,用pip安装所需组件。 86 | ```zsh 87 | pip3 install -r requirements.txt 88 | ``` 89 | 若要离线运行系统,请首先从官网下载 Ollama。然后,使用 Ollama 命令下载如DeepSeek、 QWen 和 Gemma 等大模型。 90 | 91 | 同步,从Hugging Face将嵌入模型(BAAI/bge-large-zh-v1.5)和重排模型(BAAI/bge-reranker-base)下载到 localmodels 目录中。 92 | 93 | 具体步骤,可参考 docs 目录下的文档:HowToDownloadModels.md 94 | 95 | ## Step 2 系统配置 96 | 97 | 为了获得更好的性能,推荐使用千亿级参数的商用大模型 LLM API。 98 | 99 | 首先,从 LLM 服务商获取 API 密钥,配置如下环境变量。 100 | 101 | ```zsh 102 | OPENAI_API_KEY = "" 103 | DEEPSEEK_API_KEY = "" 104 | MOONSHOT_API_KEY = "" 105 | ZHIPU_API_KEY = "" 106 | ``` 107 | 108 | 你可以跳过这一步,在系统运行后,再通过应用界面配置 API 密钥。 109 | 110 | 如果选择使用其中一个或多个 LLM API,请在 config.py 配置文件中删除不再使用的服务商。 111 | 112 | 当然,你也可以在配置文件中,添加兼容 OpenAI API 的其他服务商。 113 | 114 | ThinkRAG 默认以开发模式运行。在此模式下,系统使用本地文件存储,你不需要安装任何数据库。 115 | 116 | 若要切换到生产模式,你可以按照以下方式配置环境变量。 117 | 118 | ```zsh 119 | THINKRAG_ENV = production 120 | ``` 121 | 122 | 在生产模式下,系统可使用向量数据库 Chroma 或 LanceDB,以及键值数据库 Redis。 123 | 124 | 如果你没有安装 Redis,建议通过 Docker 安装,或使用已有的 Redis 实例。请在 config.py 文件里,配置 Redis 实例的参数信息。 125 | 126 | ## Step 3 运行系统 127 | 128 | 现在,你已经准备好运行 ThinkRAG。 129 | 130 | 请在包含 app.py 文件的目录中运行以下命令。 131 | 132 | ```zsh 133 | streamlit run app.py 134 | ``` 135 | 136 | 系统将运行,并在浏览器上自动打开以下网址,展示应用界面。 137 | 138 | http://localhost:8501/ 139 | 140 | 第一次运行可能会需要等待片刻。如果没有提前下载 Hugging Face 上的嵌入模型,系统还会自动下载模型,将需要等待更长时间。 141 | 142 |
143 | 144 | # 使用指南 145 | 146 | ## 1.系统配置 147 | 148 | ThinkRAG 支持在用户界面,对大模型进行配置与选择,包括:大模型 LLM API 的 Base URL 和 API 密钥,并可以选择使用的具体模型,例如:智谱的 glm-4。 149 | 150 |
151 | file_uploads 152 | 153 |
154 | 155 | 系统将自动检测 API 和密钥是否可用,若可用则在底部用绿色文字,显示当前选择的大模型实例。 156 | 157 | 同样,系统可以自动获取 Ollama 下载的模型,用户可以在用户界面上选择所需的模型。 158 | 159 |
160 | file_uploads 161 | 162 |
163 | 164 | 若你已经将嵌入模型和重排模型下载到本地 localmodels 目录下。在用户界面上,可以切换选择使用的模型,并设置重排模型的参数,比如 Top N。 165 | 166 |
167 | file_uploads 168 | 169 |
170 | 171 | 在左侧导航栏,点击高级设置(Settings-Advanced),你还可以对下列参数进行设置: 172 | - Top K 173 | - Temperature 174 | - System Prompt 175 | - Response Mode 176 | 177 | 通过使用不同参数,我们可以对比大模型输出结果,找到最有效的参数组合。 178 | 179 | ## 2.管理知识库 180 | 181 | ThinkRAG 支持上传 PDF、DOCX、PPTX 等各类文件,也支持上传网页 URL。 182 | 183 |
184 | file_uploads 185 | 186 |
187 | 188 | 点击 Browse files 按钮,选择电脑上的文件,然后点击 Load 按钮加载,此时会列出所有加载的文件。 189 | 190 | 然后,点击 Save 按钮,系统会对文件进行处理,包括文本分割和嵌入,保存到知识库中。 191 | 192 |
193 | file_uploads 194 | 195 |
196 | 197 | 同样,你可以输入或粘贴网页 URL,获取网页信息,处理后保存到知识库中。 198 | 199 | 系统支持对知识库进行管理。 200 | 201 |
202 | file_uploads 203 | 204 |
205 | 206 | 如上图所示,ThinkRAG 可以分页列出,知识库中所有的文档。 207 | 208 | 选择要删除的文档,将出现 Delete selected documents 按钮,点击该按钮可以将文档从知识库中删除。 209 | 210 | ## 3.智能问答 211 | 212 | 在左侧导航栏,点击 Query,将会出现智能问答页面。 213 | 214 | 输入问题后,系统会对知识库进行检索,并给出回答。在这个过程当中,系统将采用混合检索和重排等技术,从知识库获取准确的内容。 215 | 216 | 例如,我们已经在知识库中上传了一个 Word 文档:“大卫说流程.docx“。 217 | 218 | 现在输入问题:”流程有哪三个特征?” 219 | 220 |
221 | file_uploads 222 | 223 |
224 | 225 | 如图所示,系统用时2.49秒,给出了准确的回答:流程具备目标性、重复性与过程性。同时,系统还给出了从知识库检索到的2个相关文档。 226 | 227 | 可以看到,ThinkRAG 完整和有效地实现了,基于本地知识库的大模型检索增强生成的功能。 228 | 229 |
230 | 231 | # 技术架构 232 | 233 | ThinkRAG 采用 LlamaIndex 数据框架开发,前端使用Streamlit。系统的开发模式和生产模式,分别选用了不同的技术组件,如下表所示: 234 | 235 | | |开发模式|生产模式| 236 | |:----|:----|:----| 237 | |RAG框架|LlamaIndex|LlamaIndex| 238 | |前端框架|Streamlit|Streamlit| 239 | |嵌入模型|BAAI/bge-small-zh-v1.5|BAAI/bge-large-zh-v1.5| 240 | |重排模型|BAAI/bge-reranker-base|BAAI/bge-reranker-large| 241 | |文本分割器|SentenceSplitter|SpacyTextSplitter| 242 | |对话存储|SimpleChatStore|Redis| 243 | |文档存储|SimpleDocumentStore|Redis| 244 | |索引存储|SimpleIndexStore|Redis| 245 | |向量存储|SimpleVectorStore|LanceDB| 246 | 247 | 这些技术组件,按照前端、框架、大模型、工具、存储、基础设施,这六个部分进行架构设计。 248 | 249 | 如下图所示: 250 | 251 |
252 | file_uploads 253 | 254 |
255 | 256 |
257 | 258 | # 开发计划 259 | 260 | ThinkRAG 将继续优化核心功能,持续提升检索的效率和准确性,主要包括: 261 | 262 | - 优化对文档和网页的处理,支持多模态知识库和多模态检索 263 | - 构建知识图谱,通过知识图谱增强检索,并基于图进行推理 264 | - 通过智能体处理复杂场景,尤其是准确调用其他工具和数据,完成任务 265 | 266 | 同时,我们还将进一步完善应用架构、提升用户体验,主要包括: 267 | - 设计:有设计感和极佳用户体验的用户界面 268 | - 前端:基于Electron、React、Vite等技术,构建桌面客户端应用,为用户提供极致简洁的下载、安装和运行方式 269 | - 后端:通过FastAPI提供接口,以及消息队列等技术提升整体性能和可扩展性 270 | 271 | 欢迎你加入 ThinkRAG 开源项目,一起打造用户喜爱的 AI 产品! 272 | 273 |
274 | 275 | # 许可协议 276 | 277 | ThinkRAG 使用 [MIT 协议](LICENSE). -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | Thanks for helping make GitHub safe for everyone. 2 | 3 | ## Security 4 | 5 | GitHub takes the security of our software products and services seriously, including all of the open source code repositories managed through our GitHub organizations, such as [GitHub](https://github.com/GitHub). 6 | 7 | Even though [open source repositories are outside of the scope of our bug bounty program](https://bounty.github.com/index.html#scope) and therefore not eligible for bounty rewards, we will ensure that your finding gets passed along to the appropriate maintainers for remediation. 8 | 9 | ## Reporting Security Issues 10 | 11 | If you believe you have found a security vulnerability in any GitHub-owned repository, please report it to us through coordinated disclosure. 12 | 13 | **Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.** 14 | 15 | Instead, please send an email to opensource-security@github.com. 16 | 17 | Please include as much of the information listed below as you can to help us better understand and resolve the issue: 18 | 19 | * The type of issue (e.g., buffer overflow, SQL injection, or cross-site scripting) 20 | * Full paths of source file(s) related to the manifestation of the issue 21 | * The location of the affected source code (tag/branch/commit or direct URL) 22 | * Any special configuration required to reproduce the issue 23 | * Step-by-step instructions to reproduce the issue 24 | * Proof-of-concept or exploit code (if possible) 25 | * Impact of the issue, including how an attacker might exploit the issue 26 | 27 | This information will help us triage your report more quickly. 28 | 29 | ## Policy 30 | 31 | See [GitHub's Safe Harbor Policy](https://docs.github.com/en/site-policy/security-policies/github-bug-bounty-program-legal-safe-harbor#1-safe-harbor-terms) -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Log configuration 2 | import logging 3 | import sys 4 | 5 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 6 | logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) 7 | 8 | # Configure the Streamlit Web Application 9 | import streamlit as st 10 | from frontend.state import init_state 11 | 12 | if __name__ == '__main__': 13 | 14 | st.set_page_config( 15 | page_title="ThinkRAG - LLM RAG system runs on laptop", 16 | page_icon="🧊", 17 | layout="wide", 18 | initial_sidebar_state="auto", 19 | menu_items=None, 20 | ) 21 | 22 | st.logo("frontend/images/ThinkRAG_Logo.png") 23 | 24 | init_state() 25 | 26 | pages = { 27 | "Application" : [ 28 | st.Page("frontend/Document_QA.py", title="Query"), 29 | ], 30 | "Knowledge Base" : [ 31 | st.Page("frontend/KB_File.py", title="File"), 32 | st.Page("frontend/KB_Web.py", title="Web"), 33 | st.Page("frontend/KB_Manage.py", title="Manage"), 34 | ], 35 | "Model & Tool" : [ 36 | st.Page("frontend/Model_LLM.py", title="LLM"), 37 | st.Page("frontend/Model_Embed.py", title="Embed"), 38 | st.Page("frontend/Model_Rerank.py", title="Rerank"), 39 | st.Page("frontend/Storage.py", title="Storage"), 40 | ], 41 | "Settings" : [ 42 | st.Page("frontend/Setting_Advanced.py", title="Advanced"), 43 | ], 44 | } 45 | 46 | pg = st.navigation(pages, position="sidebar") 47 | 48 | pg.run() -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | STORAGE_DIR = "storage" # directory to cache the generated index 4 | DATA_DIR = "data" # directory containing the documents to index 5 | MODEL_DIR = "localmodels" # directory containing the model files, use None if use remote model 6 | CONFIG_STORE_FILE = "config_store.json" # local storage for configurations 7 | 8 | # The device that used for running the model. 9 | # Set it to 'auto' will automatically detect (with warnings), or it can be manually set to one of 'cuda', 'mps', 'cpu', or 'xpu'. 10 | LLM_DEVICE = "auto" 11 | EMBEDDING_DEVICE = "auto" 12 | 13 | # LLM Settings 14 | 15 | HISTORY_LEN = 3 16 | 17 | MAX_TOKENS = 2048 18 | 19 | TEMPERATURE = 0.1 20 | 21 | TOP_K = 5 22 | 23 | SYSTEM_PROMPT = "You are an AI assistant that helps users to find accurate information. You can answer questions, provide explanations, and generate text based on the input. Please answer the user's question exactly in the same language as the question or follow user's instructions. For example, if user's question is in Chinese, please generate answer in Chinese as well. If you don't know the answer, please reply the user that you don't know. If you need more information, you can ask the user for clarification. Please be professional to the user." 24 | 25 | RESPONSE_MODE = [ # Configure the response mode of the query engine 26 | "compact", 27 | "refine", 28 | "tree_summarize", 29 | "simple_summarize", 30 | "accumulate", 31 | "compact_accumulate", 32 | ] 33 | DEFAULT_RESPONSE_MODE = "simple_summarize" 34 | 35 | OLLAMA_API_URL = "http://localhost:11434" 36 | 37 | # Models' API configuration,set the KEY in environment variables 38 | ZHIPU_API_KEY = os.getenv("ZHIPU_API_KEY", "") 39 | MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY", "") 40 | DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "") 41 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") 42 | 43 | LLM_API_LIST = { 44 | # Ollama API 45 | "Ollama": { 46 | "api_base": OLLAMA_API_URL, 47 | "models": [], 48 | "provider": "Ollama", 49 | }, 50 | # OpenAI API 51 | "OpenAI": { 52 | "api_key": OPENAI_API_KEY, 53 | "api_base": "https://api.openai.com/v1/", 54 | "models": ["gpt-4", "gpt-3.5", "gpt-4o"], 55 | "provider": "OpenAI", 56 | }, 57 | # DeepSeek API 58 | "DeepSeek": { 59 | "api_key": DEEPSEEK_API_KEY, 60 | "api_base": "https://api.deepseek.com/v1/", 61 | "models": ["deepseek-chat","deepseek-reasoner"], 62 | "provider": "DeepSeek", 63 | }, 64 | # Moonshot API 65 | "Moonshot": { 66 | "api_key": MOONSHOT_API_KEY, 67 | "api_base": "https://api.moonshot.cn/v1/", 68 | "models": ["moonshot-v1-8k","moonshot-v1-32k","moonshot-v1-128k"], 69 | "provider": "Moonshot", 70 | }, 71 | # ZhiPu API 72 | "Zhipu": { 73 | "api_key": ZHIPU_API_KEY, 74 | "api_base": "https://open.bigmodel.cn/api/paas/v4/", 75 | "models": ["glm-4-plus", "glm-4-0520", "glm-4", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flashx", "glm-4-flash", "glm-4v-plus", "glm-4v"], 76 | "provider": "Zhipu", 77 | }, 78 | } 79 | 80 | # Text splitter configuration 81 | 82 | DEFAULT_CHUNK_SIZE = 2048 83 | DEFAULT_CHUNK_OVERLAP = 512 84 | ZH_TITLE_ENHANCE = False # Chinese title enhance 85 | 86 | # Storage configuration 87 | 88 | MONGO_URI = "mongodb://localhost:27017" 89 | REDIS_URI = "redis://localhost:6379" 90 | REDIS_HOST = "localhost" 91 | REDIS_PORT = 6379 92 | ES_URI = "http://localhost:9200" 93 | 94 | # Default vector database type, including "es" and "chroma" 95 | DEFAULT_VS_TYPE = "es" 96 | 97 | # Chat store type,including "simple" and "redis" 98 | DEFAULT_CHAT_STORE = "redis" 99 | CHAT_STORE_FILE_NAME = "chat_store.json" 100 | CHAT_STORE_KEY = "user1" 101 | 102 | # Use HuggingFace model,Configure domestic mirror 103 | HF_ENDPOINT = "https://hf-mirror.com" # Default to be "https://huggingface.co" 104 | 105 | # Configure Embedding model 106 | DEFAULT_EMBEDDING_MODEL = "bge-small-zh-v1.5" 107 | EMBEDDING_MODEL_PATH = { 108 | "bge-small-zh-v1.5": "BAAI/bge-small-zh-v1.5", 109 | "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", 110 | } 111 | 112 | # Configure Reranker model 113 | DEFAULT_RERANKER_MODEL = "bge-reranker-base" 114 | RERANKER_MODEL_PATH = { 115 | "bge-reranker-base": "BAAI/bge-reranker-base", 116 | "bge-reranker-large": "BAAI/bge-reranker-large", 117 | } 118 | 119 | # Use reranker model or not 120 | USE_RERANKER = False 121 | RERANKER_MODEL_TOP_N = 2 122 | RERANKER_MAX_LENGTH = 1024 123 | 124 | # Evironment variable, default to be "development", set to "production" for production environment 125 | THINKRAG_ENV = os.getenv("THINKRAG_ENV", "development") 126 | DEV_MODE = THINKRAG_ENV == "development" 127 | 128 | # For creating IndexManager 129 | DEFAULT_INDEX_NAME = "knowledge_base" -------------------------------------------------------------------------------- /docs/Code_of_Conduct.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. 6 | 7 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. 8 | 9 | ## Our Standards 10 | 11 | Examples of behavior that contributes to a positive environment for our community include: 12 | 13 | * Demonstrating empathy and kindness toward other people 14 | * Being respectful of differing opinions, viewpoints, and experiences 15 | * Giving and gracefully accepting constructive feedback 16 | * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience 17 | * Focusing on what is best not just for us as individuals, but for the overall community 18 | 19 | Examples of unacceptable behavior include: 20 | 21 | * The use of sexualized language or imagery, and sexual attention or advances of any kind 22 | * Trolling, insulting or derogatory comments, and personal or political attacks 23 | * Public or private harassment 24 | * Publishing others' private information, such as a physical or email address, without their explicit permission 25 | * Other conduct which could reasonably be considered inappropriate in a professional setting 26 | 27 | ## Enforcement Responsibilities 28 | 29 | Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. 32 | 33 | ## Scope 34 | 35 | This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official email address, posting via an official social media account, or acting as an appointed representative at an online or offline event. 36 | 37 | ## Enforcement 38 | 39 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at _*[INSERT CONTACT METHOD]*_. All complaints will be reviewed and investigated promptly and fairly. 40 | 41 | All community leaders are obligated to respect the privacy and security of the reporter of any incident. 42 | 43 | ## Enforcement Guidelines 44 | 45 | Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: 46 | 47 | ### 1. Correction 48 | 49 | **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. 50 | 51 | **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. 52 | 53 | ### 2. Warning 54 | 55 | **Community Impact**: A violation through a single incident or series of 56 | actions. 57 | 58 | **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. 59 | 60 | ### 3. Temporary Ban 61 | 62 | **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. 63 | 64 | **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. 65 | 66 | ### 4. Permanent Ban 67 | 68 | **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. 69 | 70 | **Consequence**: A permanent ban from any sort of public interaction within the community. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 75 | 76 | Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 77 | 78 | For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. 79 | 80 | [homepage]: https://www.contributor-covenant.org 81 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 82 | [Mozilla CoC]: https://github.com/mozilla/diversity 83 | [FAQ]: https://www.contributor-covenant.org/faq 84 | [translations]: https://www.contributor-covenant.org/translations -------------------------------------------------------------------------------- /docs/HowToDownloadModels.md: -------------------------------------------------------------------------------- 1 | # Download embedding models from HuggingFace 2 | ### 1. Install or upgrade the huggingface_hub package 3 | ```zsh 4 | >>>pip install -U huggingface_hub 5 | ``` 6 | ### 2. Use HF-Mirror to help downloading required models 7 | ```zsh 8 | >>>export HF_ENDPOINT=https://hf-mirror.com 9 | ``` 10 | ### 3. Create and change the current working directory to ~/ThinkRAG/localmodels 11 | ```zsh 12 | >>>mkdir localmodels && cd localmodels 13 | ``` 14 | ### 4. Create and change the current working directory to ~/ThinkRAG/localmodels/BAAI 15 | ```zsh 16 | >>>mkdir BAAI && cd BAAI 17 | ``` 18 | ### 5. Download required models 19 | ```zsh 20 | >>>huggingface-cli download --resume-download BAAI/bge-small-zh-v1.5 --local-dir bge-small-zh-v1.5 21 | ``` 22 | ```zsh 23 | >>>huggingface-cli download --resume-download BAAI/bge-reranker-base --local-dir bge-reranker-base 24 | ``` -------------------------------------------------------------------------------- /docs/HowToUsePythonVirtualEnv.md: -------------------------------------------------------------------------------- 1 | # Create an environment with venv and installing Streamlit and other packages with pip. 2 | # Refer to: https://docs.streamlit.io/get-started/installation/command-line 3 | ### 1. Open a terminal and navigate to your project folder 4 | ```zsh 5 | cd ThinkRAG 6 | ``` 7 | ### 2. In your terminal, type: 8 | ```zsh 9 | python -m venv .venv 10 | ``` 11 | A folder named ".venv" will appear in your project. This directory is where your virtual environment and its dependencies are installed. 12 | ### 3. Activate your environment with one of the following commands 13 | ```zsh 14 | # Windows command prompt 15 | .venv\Scripts\activate.bat 16 | 17 | # Windows PowerShell 18 | .venv\Scripts\Activate.ps1 19 | 20 | # macOS and Linux 21 | source .venv/bin/activate 22 | ``` 23 | Once activated, you will see your environment name in parentheses before your prompt. "(.venv)" 24 | ### 4. In the terminal with your environment activated, install all required packages: 25 | ```zsh 26 | pip3 install -r requirements.txt 27 | ``` 28 | ### 5. Run your Streamlit app 29 | ```zsh 30 | python3 -m streamlit run app.py 31 | ``` 32 | ### 6. When you're done using this environment, return to your normal shell by typing: 33 | ```zsh 34 | deactivate 35 | ``` -------------------------------------------------------------------------------- /docs/images/KB_File.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/docs/images/KB_File.png -------------------------------------------------------------------------------- /docs/images/KB_Manage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/docs/images/KB_Manage.png -------------------------------------------------------------------------------- /docs/images/KB_Web.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/docs/images/KB_Web.png -------------------------------------------------------------------------------- /docs/images/Model_LLM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/docs/images/Model_LLM.png -------------------------------------------------------------------------------- /docs/images/Model_Reranker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/docs/images/Model_Reranker.png -------------------------------------------------------------------------------- /docs/images/Query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/docs/images/Query.png -------------------------------------------------------------------------------- /docs/images/Settings_Advanced.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/docs/images/Settings_Advanced.png -------------------------------------------------------------------------------- /docs/images/ThinkRAG_Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/docs/images/ThinkRAG_Architecture.png -------------------------------------------------------------------------------- /frontend/Document_QA.py: -------------------------------------------------------------------------------- 1 | # Document-based Q&A 2 | import time 3 | import re 4 | import streamlit as st 5 | import pandas as pd 6 | from server.stores.chat_store import CHAT_MEMORY 7 | from llama_index.core.llms import ChatMessage, MessageRole 8 | from server.engine import create_query_engine 9 | from server.stores.config_store import CONFIG_STORE 10 | 11 | def perform_query(prompt): 12 | if not st.session_state.query_engine: 13 | print("Index is not initialized yet") 14 | if (not prompt) or prompt.strip() == "": 15 | print("Query text is required") 16 | try: 17 | query_response = st.session_state.query_engine.query(prompt) 18 | return query_response 19 | except Exception as e: 20 | # print(f"An error occurred while processing the query: {e}") 21 | print(f"An error occurred while processing the query: {type(e).__name__}: {e}") 22 | 23 | # https://github.com/halilergul1/QA-app 24 | def simple_format_response_and_sources(response): 25 | primary_response = getattr(response, 'response', '') 26 | output = {"response": primary_response} 27 | sources = [] 28 | if hasattr(response, 'source_nodes'): 29 | for node in response.source_nodes: 30 | node_data = getattr(node, 'node', None) 31 | if node_data: 32 | metadata = getattr(node_data, 'metadata', {}) 33 | text = getattr(node_data, 'text', '') 34 | text = re.sub(r'\n\n|\n|\u2028', lambda m: {'\n\n': '\u2028', '\n': ' ', '\u2028': '\n\n'}[m.group()], text) 35 | source_info = { 36 | "file": metadata.get('file_name', 'N/A'), 37 | "page": metadata.get('page_label', 'N/A'), 38 | "text": text 39 | } 40 | sources.append(source_info) 41 | output['sources'] = sources 42 | return output 43 | 44 | def chatbox(): 45 | 46 | # Load Q&A history 47 | messages = CHAT_MEMORY.get() 48 | if len(messages) == 0: 49 | # Initialize Q&A record 50 | CHAT_MEMORY.put(ChatMessage(role=MessageRole.ASSISTANT, content="Feel free to ask about anything in the knowledge base")) 51 | messages = CHAT_MEMORY.get() 52 | 53 | # Show Q&A records 54 | for message in messages: 55 | with st.chat_message(message.role): 56 | st.write(message.content) 57 | 58 | if prompt := st.chat_input("Input your question"): # Prompt the user to input the question then add it to the message history 59 | with st.chat_message(MessageRole.USER): 60 | st.write(prompt) 61 | CHAT_MEMORY.put(ChatMessage(role=MessageRole.USER, content=prompt)) 62 | with st.chat_message(MessageRole.ASSISTANT): 63 | with st.spinner("Thinking..."): 64 | start_time = time.time() 65 | response = perform_query(prompt) 66 | end_time = time.time() 67 | query_time = round(end_time - start_time, 2) 68 | if response is None: 69 | st.write("Couldn't come up with an answer.") 70 | else: 71 | response_text = st.write_stream(response.response_gen) 72 | st.write(f"Took {query_time} second(s)") 73 | details_title = f"Found {len(response.source_nodes)} document(s)" 74 | with st.expander( 75 | details_title, 76 | expanded=False, 77 | ): 78 | source_nodes = [] 79 | for item in response.source_nodes: 80 | node = item.node 81 | score = item.score 82 | title = node.metadata.get('file_name', None) 83 | if title is None: 84 | title = node.metadata.get('title', 'N/A') # if the document is a webpage, use the title 85 | continue 86 | page_label = node.metadata.get('page_label', 'N/A') 87 | text = node.text 88 | short_text = text[:50] + "..." if len(text) > 50 else text 89 | source_nodes.append({"Title": title, "Page": page_label, "Text": short_text, "Score": f"{score:.2f}"}) 90 | df = pd.DataFrame(source_nodes) 91 | st.table(df) 92 | # store the answer in the chat history 93 | CHAT_MEMORY.put(ChatMessage(role=MessageRole.ASSISTANT, content=response_text)) 94 | def main(): 95 | st.header("Query") 96 | if st.session_state.llm is not None: 97 | current_llm_info = CONFIG_STORE.get(key="current_llm_info") 98 | current_llm_settings = CONFIG_STORE.get(key="current_llm_settings") 99 | st.caption("LLM `" + current_llm_info["service_provider"] + "` `" + current_llm_info["model"] + 100 | "` Response mode `" + current_llm_settings["response_mode"] + 101 | "` Top K `" + str(current_llm_settings["top_k"]) + 102 | "` Temperature `" + str(current_llm_settings["temperature"]) + 103 | "` Reranking `" + str(current_llm_settings["use_reranker"]) + 104 | "` Top N `" + str(current_llm_settings["top_n"]) + 105 | "` Reranker `" + current_llm_settings["reranker_model"] + "`" 106 | ) 107 | if st.session_state.index_manager is not None: 108 | if st.session_state.index_manager.check_index_exists(): 109 | st.session_state.index_manager.load_index() 110 | st.session_state.query_engine = create_query_engine( 111 | index=st.session_state.index_manager.index, 112 | use_reranker=current_llm_settings["use_reranker"], 113 | response_mode=current_llm_settings["response_mode"], 114 | top_k=current_llm_settings["top_k"], 115 | top_n=current_llm_settings["top_n"], 116 | reranker=current_llm_settings["reranker_model"]) 117 | print("Index loaded and query engine created") 118 | chatbox() 119 | else: 120 | print("Index does not exist yet") 121 | st.warning("Your knowledge base is empty. Please upload some documents into it first.") 122 | else: 123 | print("IndexManager is not initialized yet.") 124 | st.warning("Please upload documents into your knowledge base first.") 125 | else: 126 | st.warning("Please configure LLM first.") 127 | 128 | main() -------------------------------------------------------------------------------- /frontend/KB_File.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pandas as pd 3 | import streamlit as st 4 | from server.utils.file import save_uploaded_file, get_save_dir 5 | 6 | def handle_file(): 7 | 8 | st.header("Load Files") 9 | st.caption("Load Files like PDF, DOCX, TXT, etc. to create a knowledge base index.") 10 | 11 | with st.form("my-form", clear_on_submit=True): 12 | st.session_state.selected_files = st.file_uploader("Upload files: ", accept_multiple_files=True, label_visibility="hidden") 13 | submitted = st.form_submit_button( 14 | "Load", 15 | help="Click here to load it after you select a file.", 16 | ) 17 | if len(st.session_state.selected_files) > 0 and submitted: 18 | print("Starting to upload files...") 19 | print(st.session_state.selected_files) 20 | for selected_file in st.session_state.selected_files: 21 | with st.spinner(f"Uploading {selected_file.name}..."): 22 | save_dir = get_save_dir() 23 | save_uploaded_file(selected_file, save_dir) 24 | st.session_state.uploaded_files.append({"name": selected_file.name, "type": selected_file.type, "size": selected_file.size}) 25 | st.toast('✔️ Upload successful', icon='🎉') 26 | 27 | if len(st.session_state.uploaded_files) > 0: 28 | with st.expander( 29 | "The following files are uploaded successfully.", 30 | expanded=True, 31 | ): 32 | df = pd.DataFrame(st.session_state.uploaded_files) 33 | st.dataframe( 34 | df, 35 | column_config={ 36 | "name": "File name", 37 | "size": st.column_config.NumberColumn( 38 | "size", format="%d byte", 39 | ), 40 | "type": "type", 41 | }, 42 | hide_index=True, 43 | ) 44 | 45 | with st.expander( 46 | "Text Splitter Settings", 47 | expanded=True, 48 | ): 49 | cols = st.columns(2) 50 | chunk_size = cols[0].number_input("Maximum length of a single text block: ", 1, 4096, st.session_state.chunk_size) 51 | chunk_overlap = cols[1].number_input("Adjacent text overlap length: ", 0, st.session_state.chunk_size, st.session_state.chunk_overlap) 52 | 53 | if st.button( 54 | "Save", 55 | disabled=len(st.session_state.uploaded_files) == 0, 56 | help="After uploading files, click here to generate the index and save it to the knowledge base.", 57 | ): 58 | print("Generating index...") 59 | with st.spinner(text="Loading documents and building the index, may take a minute or two"): 60 | st.session_state.index_manager.load_files(st.session_state.uploaded_files, chunk_size, chunk_overlap) 61 | st.toast('✔️ Knowledge base index generation complete', icon='🎉') 62 | st.session_state.uploaded_files = [] 63 | time.sleep(4) 64 | st.rerun() 65 | 66 | handle_file() 67 | 68 | 69 | -------------------------------------------------------------------------------- /frontend/KB_Manage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from math import ceil 4 | import pandas as pd 5 | import streamlit as st 6 | 7 | def get_unique_files_info(ref_doc_info): 8 | docs = [] 9 | seen_paths = set() 10 | 11 | for ref_doc_id, ref_doc in ref_doc_info.items(): 12 | 13 | metadata = ref_doc.metadata 14 | file_path = metadata.get('file_path', None) 15 | 16 | if file_path is None: 17 | title = metadata.get('title', None) 18 | url = metadata.get('url_source', None) 19 | docs.append({ 20 | 'id': ref_doc_id, 21 | 'name': title, 22 | 'type': "url", 23 | 'path': url, 24 | 'date': metadata['creation_date'] 25 | }) 26 | 27 | if file_path and file_path not in seen_paths: 28 | base_name, extension = os.path.splitext(metadata['file_name']) 29 | # Remove the leading dot from the extension 30 | extension = extension.lstrip('.') 31 | 32 | file_info = { 33 | 'id': ref_doc_id, 34 | 'name': base_name, 35 | 'type': extension, 36 | 'path': file_path, 37 | #'file_size': metadata['file_size'], 38 | 'date': metadata['creation_date'] 39 | } 40 | docs.append(file_info) 41 | seen_paths.add(file_path) 42 | 43 | return docs 44 | 45 | 46 | def handle_knowledgebase(): 47 | st.header("Manage Knowledge Base") 48 | st.caption("Manage documents and web urls in your knowledge base.") 49 | 50 | from server.stores.strage_context import STORAGE_CONTEXT 51 | doc_store = STORAGE_CONTEXT.docstore 52 | if len(doc_store.docs) > 0: 53 | ref_doc_info = doc_store.get_all_ref_doc_info() 54 | unique_files= get_unique_files_info(ref_doc_info) 55 | st.write("You have total", len(unique_files), "documents.") 56 | df = pd.DataFrame(unique_files) 57 | 58 | # Pagination settings 59 | 60 | page_size = 5 61 | total_pages = ceil(len(df)/page_size) 62 | 63 | if "curr_page" not in st.session_state.keys(): 64 | st.session_state.curr_page = 1 65 | 66 | curr_page = min(st.session_state['curr_page'], total_pages) 67 | 68 | # Displaying pagination buttons 69 | if total_pages > 1: 70 | prev, next, _, col3 = st.columns([1,1,6,2]) 71 | 72 | if next.button("Next"): 73 | curr_page = min(curr_page + 1, total_pages) 74 | st.session_state['curr_page'] = curr_page 75 | 76 | if prev.button("Prev"): 77 | curr_page = max(curr_page - 1, 1) 78 | st.session_state['curr_page'] = curr_page 79 | 80 | with col3: 81 | st.write("Page: ", curr_page, "/", total_pages) 82 | 83 | start_index = (curr_page - 1) * page_size 84 | end_index = curr_page * page_size 85 | df_paginated = df.iloc[start_index:end_index] 86 | 87 | # Displaying the paginated dataframe 88 | docs = st.dataframe( 89 | df_paginated, 90 | width=2000, 91 | column_config={ 92 | "id": None, #hidden 93 | "name": "name", 94 | "type": "type", 95 | "path": None, 96 | "date": "Creation date", 97 | #"file_size": st.column_config.NumberColumn( 98 | #"size", format="%d byte", 99 | #), 100 | }, 101 | hide_index=True, 102 | on_select="rerun", 103 | selection_mode="multi-row", 104 | ) 105 | 106 | selected_docs = docs.selection.rows 107 | if len(selected_docs) > 0: 108 | delete_button = st.button("Delete selected documents", key="delete_docs") 109 | if delete_button: 110 | print("Deleting documents...") 111 | with st.spinner(text="Deleting documents and related index. It may take several minutes."): 112 | for item in selected_docs: 113 | path = df_paginated.iloc[item]['path'] 114 | for ref_doc_id, ref_doc in ref_doc_info.items(): # a file may have multiple documents 115 | metadata = ref_doc.metadata 116 | file_path = metadata.get('file_path', None) 117 | if file_path: 118 | if file_path == path: 119 | st.session_state.index_manager.delete_ref_doc(ref_doc_id) 120 | elif metadata.get('url_source', None) == path: 121 | st.session_state.index_manager.delete_ref_doc(ref_doc_id) 122 | st.toast('✔️ The selected documents are deleted.', icon='🎉') 123 | time.sleep(4) 124 | st.rerun() 125 | 126 | st.write("Selected documents:") 127 | for item in selected_docs: 128 | st.write(f"- {df_paginated.iloc[item]['name']}") 129 | 130 | else: 131 | st.write("Knowledge base is empty") 132 | 133 | handle_knowledgebase() 134 | 135 | 136 | -------------------------------------------------------------------------------- /frontend/KB_Web.py: -------------------------------------------------------------------------------- 1 | import time 2 | import streamlit as st 3 | 4 | def handle_website(): 5 | st.header("Load Web Pages") 6 | st.caption("Enter a list of URLs to extract text and metadata from web pages.") 7 | 8 | with st.form("website-form", clear_on_submit=True): 9 | 10 | col1, col2 = st.columns([1, 0.2]) 11 | with col1: 12 | new_website = st.text_input("Please enter the web page address", label_visibility="collapsed") 13 | with col2: 14 | add_button = st.form_submit_button("Load") 15 | if add_button and new_website != "": 16 | st.session_state["websites"].append(new_website) 17 | 18 | if st.session_state["websites"] != []: 19 | st.markdown(f"

Website(s)

", unsafe_allow_html=True) 20 | for site in st.session_state["websites"]: 21 | st.caption(f"- {site}") 22 | st.write("") 23 | 24 | with st.expander( 25 | "Text processing parameter configuration", 26 | expanded=True, 27 | ): 28 | cols = st.columns(2) 29 | chunk_size = cols[0].number_input("Maximum length of a single text block: ", 1, 4096, st.session_state.chunk_size, key="web_chunk_size") 30 | chunk_overlap = cols[1].number_input("Adjacent text overlap length: ", 0, st.session_state.chunk_size, st.session_state.chunk_overlap, key="web_chunk_overlap") 31 | 32 | process_button = st.button("Save", 33 | key="process_website", 34 | disabled=len(st.session_state["websites"]) == 0) 35 | if process_button: 36 | print("Generating index...") 37 | with st.spinner(text="Loading documents and building the index, may take a minute or two"): 38 | st.session_state.index_manager.load_websites(st.session_state["websites"], chunk_size, chunk_overlap) 39 | st.toast('✔️ Knowledge base index generation complete', icon='🎉') 40 | st.session_state.websites = [] 41 | time.sleep(4) 42 | st.rerun() 43 | 44 | handle_website() 45 | 46 | 47 | -------------------------------------------------------------------------------- /frontend/Model_Embed.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from config import EMBEDDING_MODEL_PATH 3 | from server.stores.config_store import CONFIG_STORE 4 | from server.stores.strage_context import STORAGE_CONTEXT 5 | from server.models.embedding import create_embedding_model 6 | 7 | st.header("Embedding Model") 8 | st.caption("Configure embedding models", 9 | help="Embeddings are numerical representations of data, useful for tasks like document clustering and similarity detection when processing files, as they encode semantic meaning for efficient manipulation and retrieval.", 10 | ) 11 | 12 | def change_embedding_model(): 13 | st.session_state["current_llm_settings"]["embedding_model"] = st.session_state["selected_embedding_model"] 14 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 15 | create_embedding_model(st.session_state["current_llm_settings"]["embedding_model"]) 16 | 17 | doc_store = STORAGE_CONTEXT.docstore 18 | if len(doc_store.docs) > 0: 19 | disabled = True 20 | else: 21 | disabled = False 22 | embedding_settings = st.container(border=True) 23 | with embedding_settings: 24 | embedding_model_list = list(EMBEDDING_MODEL_PATH.keys()) 25 | embedding_model = st.selectbox( 26 | "Embedding models", 27 | embedding_model_list, 28 | key="selected_embedding_model", 29 | index=embedding_model_list.index(st.session_state["current_llm_settings"]["embedding_model"]), 30 | disabled=disabled, 31 | on_change=change_embedding_model, 32 | ) 33 | if disabled: 34 | st.info("You cannot change embedding model once you add documents in the knowledge base.") 35 | st.caption("ThinkRAG supports most reranking models from `Hugging Face`. You may specify the models you want to use in the `config.py` file.") 36 | st.caption("It is recommended to download the models to the `localmodels` directory, in case you need run the system without an Internet connection. Plase refer to the instructions in `docs` directory.") -------------------------------------------------------------------------------- /frontend/Model_LLM.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from config import LLM_API_LIST 3 | import server.models.ollama as ollama 4 | from server.stores.config_store import CONFIG_STORE 5 | from server.models.llm_api import check_openai_llm 6 | from frontend.state import init_llm_sp, init_ollama_endpoint, init_api_base, init_api_model, init_api_key, create_llm_instance 7 | 8 | st.header("Large Language Model") 9 | st.caption("Support local models from Ollama and OpenAI compatible LLM APIs.", 10 | help="Large language models (LLMs) are powerful models that can generate human-like text based on the input they receive. LLMs can be used for a wide range of natural language processing tasks, including text generation, question answering, and summarization.", 11 | ) 12 | 13 | init_llm_sp() 14 | 15 | sp = st.session_state.llm_service_provider_selected 16 | llm = LLM_API_LIST[sp] 17 | 18 | init_ollama_endpoint() 19 | init_api_base(sp) 20 | init_api_model(sp) 21 | init_api_key(sp) 22 | 23 | def save_current_llm_info(): 24 | sp = st.session_state.llm_service_provider_selected 25 | if sp == "Ollama": 26 | if st.session_state.ollama_model_selected is not None: 27 | CONFIG_STORE.put(key="current_llm_info", val={ 28 | "service_provider": sp, 29 | "model": st.session_state.ollama_model_selected, 30 | }) 31 | else: 32 | api_key = sp + "_api_key" 33 | model_key = sp + "_model_selected" 34 | base_key = sp + "_api_base" 35 | if st.session_state[model_key] is not None and st.session_state[api_key] is not None and st.session_state[base_key] is not None: 36 | CONFIG_STORE.put(key="current_llm_info", val={ 37 | "service_provider": sp, 38 | "model": st.session_state[model_key], 39 | "api_base": st.session_state[base_key], 40 | "api_key": st.session_state[api_key], 41 | "api_key_valid": st.session_state[api_key + "_valid"], 42 | }) 43 | else: 44 | st.warning("Please fill in all the required fields") 45 | 46 | def update_llm_service_provider(): 47 | selected_option = st.session_state["llm_service_provider"] 48 | st.session_state.llm_service_provider_selected = selected_option 49 | CONFIG_STORE.put(key="llm_service_provider_selected", val={"llm_service_provider_selected": selected_option}) 50 | if selected_option != "Ollama": 51 | init_api_base(selected_option) 52 | init_api_model(selected_option) 53 | init_api_key(selected_option) 54 | save_current_llm_info() 55 | 56 | def init_llm_options(): 57 | llm_options = list(LLM_API_LIST.keys()) 58 | col1, _, col2 = st.columns([5, 4, 1], vertical_alignment="bottom") 59 | with col1: 60 | option = st.selectbox( 61 | "Please select one of the options.", 62 | llm_options, 63 | index=llm_options.index(st.session_state.llm_service_provider_selected), 64 | key="llm_service_provider", 65 | on_change=update_llm_service_provider, 66 | ) 67 | 68 | if option is not None and option != st.session_state.llm_service_provider_selected: 69 | CONFIG_STORE.put(key="llm_service_provider_selected", val={ 70 | "llm_service_provider_selected": option, 71 | }) 72 | 73 | current_llm_info = CONFIG_STORE.get(key="current_llm_info") 74 | 75 | if current_llm_info is None: 76 | save_current_llm_info() 77 | 78 | init_llm_options() 79 | 80 | option = st.session_state.llm_service_provider_selected 81 | 82 | def change_ollama_endpoint(): 83 | st.session_state.ollama_api_url = st.session_state.ollama_endpoint 84 | if ollama.is_alive(): 85 | name = option + "_api_url" # e.g. "Ollama_api_url" 86 | CONFIG_STORE.put(key=name, val={ 87 | name: st.session_state.ollama_api_url, 88 | }) 89 | save_current_llm_info() 90 | else: 91 | st.warning("Failed to connect to Ollama") 92 | 93 | def change_ollama_model(): 94 | st.session_state.ollama_model_selected = st.session_state.ollama_model_name 95 | name = option + "_model_selected" # e.g. "Ollama_model_selected" 96 | CONFIG_STORE.put(key=name, val={ 97 | name: st.session_state.ollama_model_selected, 98 | }) 99 | save_current_llm_info() 100 | 101 | def change_llm_api_base(): 102 | name = option + "_api_base" # e.g. "OpenAI_api_base" 103 | st.session_state[name] = st.session_state.llm_api_endpoint 104 | CONFIG_STORE.put(key=name, val={ 105 | name: st.session_state.llm_api_endpoint, 106 | }) 107 | save_current_llm_info() 108 | 109 | def change_llm_api_key(): 110 | name = option + "_api_key" # e.g. "OpenAI_api_key" 111 | st.session_state[name] = st.session_state.llm_api_key 112 | CONFIG_STORE.put(key=name, val={ 113 | name: st.session_state.llm_api_key, 114 | }) 115 | print("Checking API key...") 116 | print(st.session_state.llm_api_key) 117 | is_valid = check_openai_llm(st.session_state.llm_api_model, st.session_state.llm_api_endpoint, st.session_state.llm_api_key) 118 | st.session_state[name + "_valid"] = is_valid 119 | CONFIG_STORE.put(key=name + "_valid", val={ # e.g. "OpenAI_api_key_valid" 120 | name + "_valid": is_valid, 121 | }) 122 | save_current_llm_info() 123 | if is_valid: 124 | print("API key is valid") 125 | else: 126 | print("API key is invalid") 127 | 128 | def change_llm_api_model(): 129 | name = option + "_model_selected" # e.g. "OpenAI_model_selected" 130 | st.session_state[name] = st.session_state.llm_api_model 131 | CONFIG_STORE.put(key=name, val={ 132 | name: st.session_state.llm_api_model, 133 | }) 134 | save_current_llm_info() 135 | 136 | def llm_configuration_page(): 137 | llm_api_settings = st.container(border=True) 138 | with llm_api_settings: 139 | if option == "Ollama": 140 | st.subheader("Configure for Ollama") 141 | st.text_input( 142 | "Ollama Endpoint", 143 | key="ollama_endpoint", 144 | value=st.session_state.ollama_api_url, 145 | on_change=change_ollama_endpoint, 146 | ) 147 | if ollama.is_alive(): 148 | ollama.get_model_list() 149 | st.write("🟢 Ollama is running") 150 | st.selectbox('Local LLM', st.session_state.ollama_models, 151 | index=st.session_state.ollama_models.index(st.session_state.ollama_model_selected), 152 | help='Select locally deployed LLM from Ollama', 153 | on_change=change_ollama_model, 154 | key='ollama_model_name', # session_state key 155 | ) 156 | else: 157 | st.write("🔴 Ollama is not running") 158 | 159 | st.button( 160 | "Refresh models", 161 | on_click=ollama.get_model_list, 162 | help="Refresh the list of available models from the Ollama API.", 163 | ) 164 | 165 | else: # OpenAI, Zhipu, Moonshot, Deepseek 166 | st.subheader(f"Configure for {llm['provider']}") 167 | st.text_input( 168 | "Base URL", 169 | key="llm_api_endpoint", 170 | value=st.session_state[option + "_api_base"], 171 | on_change=change_llm_api_base, 172 | ) 173 | st.text_input( 174 | "API key", 175 | key="llm_api_key", 176 | value=st.session_state[option + "_api_key"], 177 | type="password", 178 | on_change=change_llm_api_key, 179 | ) 180 | st.selectbox('Choose LLM API', llm['models'], 181 | help='Choose LLMs API service', 182 | on_change=change_llm_api_model, 183 | key='llm_api_model', 184 | index=llm['models'].index(st.session_state[option + "_model_selected"]), 185 | ) 186 | 187 | def show_llm_instance(): 188 | create_llm_instance() 189 | if st.session_state.llm is not None: 190 | current_llm_info = CONFIG_STORE.get(key="current_llm_info") 191 | st.success("Current LLM instance: " + current_llm_info["service_provider"] + " / " + current_llm_info["model"]) 192 | else: 193 | st.warning("No LLM instance available") 194 | 195 | llm_configuration_page() 196 | 197 | show_llm_instance() 198 | 199 | st.caption("ThinkRAG supports `OpenAI` and all compatible LLM API like `DeepSeek`, `Moonshot` or `Zhipu`. You may specify the LLMs you want to use in the `config.py` file.") 200 | st.caption("It is recommended to use `Ollama` if you need run the system without an Internet connection. Plase refer to the Ollama docs to download and use Ollama models.") -------------------------------------------------------------------------------- /frontend/Model_Rerank.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from config import RERANKER_MODEL_PATH 3 | from server.stores.config_store import CONFIG_STORE 4 | 5 | st.header("Reranking Model") 6 | st.caption("Configure reranking models", 7 | help="Reranking is the process of reordering a list of items based on a set of criteria. In the context of search engines, reranking is used to improve the relevance of search results by taking into account additional information about the items being ranked.", 8 | ) 9 | 10 | def change_use_reranker(): 11 | st.session_state["current_llm_settings"]["use_reranker"] = st.session_state["use_reranker"] 12 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 13 | 14 | def change_top_n(): 15 | st.session_state["current_llm_settings"]["top_n"] = st.session_state["top_n"] 16 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 17 | 18 | def change_reranker_model(): 19 | st.session_state["current_llm_settings"]["reranker_model"] = st.session_state["selected_reranker_model"] 20 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 21 | 22 | reranking_settings = st.container(border=True) 23 | with reranking_settings: 24 | st.toggle("Use reranker", 25 | key="use_reranker", 26 | value= st.session_state["current_llm_settings"]["use_reranker"], 27 | on_change=change_use_reranker, 28 | ) 29 | if st.session_state["current_llm_settings"]["use_reranker"] == True: 30 | st.number_input( 31 | "Top N", 32 | min_value=1, 33 | max_value=st.session_state["current_llm_settings"]["top_k"], 34 | help="The number of most similar documents to retrieve in response to a query.", 35 | value=st.session_state["current_llm_settings"]["top_n"], 36 | key="top_n", 37 | on_change=change_top_n, 38 | ) 39 | 40 | reranker_model_list = list(RERANKER_MODEL_PATH.keys()) 41 | reranker_model = st.selectbox( 42 | "Reranking models", 43 | reranker_model_list, 44 | key="selected_reranker_model", 45 | index=reranker_model_list.index(st.session_state["current_llm_settings"]["reranker_model"]), 46 | on_change=change_reranker_model, 47 | ) 48 | 49 | st.caption("ThinkRAG supports most reranking models from `Hugging Face`. You may specify the models you want to use in the `config.py` file.") 50 | st.caption("It is recommended to download the models to the `localmodels` directory, in case you need run the system without an Internet connection. Plase refer to the instructions in `docs` directory.") -------------------------------------------------------------------------------- /frontend/Setting_Advanced.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from server.stores.config_store import CONFIG_STORE 3 | from frontend.state import create_llm_instance 4 | from config import RESPONSE_MODE 5 | 6 | st.header("Advanced settings") 7 | advanced_settings = st.container(border=True) 8 | 9 | def change_top_k(): 10 | st.session_state["current_llm_settings"]["top_k"] = st.session_state["top_k"] 11 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 12 | create_llm_instance() 13 | 14 | def change_temperature(): 15 | st.session_state["current_llm_settings"]["temperature"] = st.session_state["temperature"] 16 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 17 | create_llm_instance() 18 | 19 | def change_system_prompt(): 20 | st.session_state["current_llm_settings"]["system_prompt"] = st.session_state["system_prompt"] 21 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 22 | create_llm_instance() 23 | 24 | def change_response_mode(): 25 | st.session_state["current_llm_settings"]["response_mode"] = st.session_state["response_mode"] 26 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 27 | create_llm_instance() 28 | 29 | with advanced_settings: 30 | col_1, _, col_2 = st.columns([4, 2, 4]) 31 | with col_1: 32 | st.number_input( 33 | "Top K", 34 | min_value=1, 35 | max_value=100, 36 | help="The number of most similar documents to retrieve in response to a query.", 37 | value=st.session_state["current_llm_settings"]["top_k"], 38 | key="top_k", 39 | on_change=change_top_k, 40 | ) 41 | with col_2: 42 | st.select_slider( 43 | "Temperature", 44 | options=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], 45 | help="The temperature to use when generating responses. Higher temperatures result in more random responses.", 46 | value=st.session_state["current_llm_settings"]["temperature"], 47 | key="temperature", 48 | on_change=change_temperature, 49 | ) 50 | st.text_area( 51 | "System Prompt", 52 | help="The prompt to use when generating responses. The system prompt is used to provide context to the model.", 53 | value=st.session_state["current_llm_settings"]["system_prompt"], 54 | key="system_prompt", 55 | height=240, 56 | on_change=change_system_prompt, 57 | ) 58 | st.selectbox( 59 | "Response Mode", 60 | options=RESPONSE_MODE, 61 | help="Sets the Llama Index Query Engine response mode used when creating the Query Engine. Default: `compact`.", 62 | key="response_mode", 63 | index=RESPONSE_MODE.index(st.session_state["current_llm_settings"]["response_mode"]), # simple_summarize by default 64 | on_change=change_response_mode, 65 | ) 66 | 67 | # For debug purpost only 68 | def show_session_state(): 69 | st.write("") 70 | with st.expander("List of current application parameters"): 71 | state = dict(sorted(st.session_state.items())) 72 | st.write(state) 73 | 74 | # show_session_state() -------------------------------------------------------------------------------- /frontend/Storage.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from config import THINKRAG_ENV 3 | 4 | st.header("Storage") 5 | st.caption("All your data is stored in local file system or the database you configured.", 6 | help="You may change the storage settings in the config.py file.", 7 | ) 8 | 9 | embedding_settings = st.container(border=True) 10 | with embedding_settings: 11 | st.info("You are running ThinkRAG in " + THINKRAG_ENV + " mode.") 12 | st.dataframe(data={ 13 | "Storage Type": ["Vector Store","Doc Store","Index Store","Chat Store","Config Store"], 14 | "Development": ["Simple Vector Store","Simple Document Store","Simple Index Store","Simple Chat Store (in memory)","Simple KV Store"], 15 | "Production": ["Chroma","Redis","Redis","Redis","Simple KV Store"], 16 | #"Enterprise": ["Elasticsearch","MongoDB","MongoDB","Redis","Simple KV Store"], 17 | },hide_index=True) 18 | 19 | st.caption("You may change the storage settings in the config.py file.") 20 | st.caption("`Development Mode` uses local storage which means you need not install any extra tools. All the data is stored as local files in the 'storage' directory where you run ThinkRAG.") 21 | st.caption("`Production Mode`: is recommended to use for production on your laptop. You need a redis instance, either running locally or using a cloud service.") 22 | st.caption("If you want to deploy ThinkRAG on a server and handle large volume of data, please contact the author of ThinkRAG (wzdavid@gmail.com)") -------------------------------------------------------------------------------- /frontend/images/ThinkRAG_Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/frontend/images/ThinkRAG_Logo.png -------------------------------------------------------------------------------- /frontend/state.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import config as config 3 | from server.models import ollama 4 | from server.models.llm_api import create_openai_llm, check_openai_llm 5 | from server.models.ollama import create_ollama_llm 6 | from server.models.embedding import create_embedding_model 7 | from server.index import IndexManager 8 | from server.stores.config_store import CONFIG_STORE 9 | 10 | def find_api_by_model(model_name): 11 | for api_name, api_info in config.LLM_API_LIST.items(): 12 | if model_name in api_info['models']: 13 | return api_info 14 | 15 | # Initialize st.session_state 16 | def init_keys(): 17 | 18 | # Initialize LLM 19 | if "llm" not in st.session_state.keys(): 20 | st.session_state.llm = None 21 | 22 | # Initialize index 23 | if "index_manager" not in st.session_state.keys(): 24 | st.session_state.index_manager = IndexManager(config.DEFAULT_INDEX_NAME) 25 | 26 | # Initialize model selection 27 | if "ollama_api_url" not in st.session_state.keys(): 28 | st.session_state.ollama_api_url = config.OLLAMA_API_URL 29 | 30 | if "ollama_models" not in st.session_state.keys(): 31 | ollama.get_model_list() 32 | if (st.session_state.ollama_models is not None and len(st.session_state.ollama_models) > 0): 33 | st.session_state.ollama_model_selected = st.session_state.ollama_models[0] 34 | create_ollama_llm(st.session_state.ollama_model_selected) 35 | if "ollama_model_selected" not in st.session_state.keys(): 36 | st.session_state.ollama_model_selected = None 37 | if "llm_api_list" not in st.session_state.keys(): 38 | st.session_state.llm_api_list = [model for api in config.LLM_API_LIST.values() for model in api['models']] 39 | if "llm_api_selected" not in st.session_state.keys(): 40 | st.session_state.llm_api_selected = st.session_state.llm_api_list[0] 41 | if st.session_state.ollama_model_selected is None: 42 | api_object = find_api_by_model(st.session_state.llm_api_selected) 43 | create_openai_llm(st.session_state.llm_api_selected, api_object['api_base'], api_object['api_key']) 44 | 45 | # Initialize query engine 46 | if "query_engine" not in st.session_state.keys(): 47 | st.session_state.query_engine = None 48 | 49 | if "system_prompt" not in st.session_state.keys(): 50 | st.session_state.system_prompt = "Chat with me!" 51 | 52 | if "response_mode" not in st.session_state.keys(): 53 | response_mode_result = CONFIG_STORE.get(key="response_mode") 54 | if response_mode_result is not None: 55 | st.session_state.response_mode = response_mode_result["response_mode"] 56 | else: 57 | st.session_state.response_mode = config.DEFAULT_RESPONSE_MODE 58 | 59 | if "ollama_endpoint" not in st.session_state.keys(): 60 | st.session_state.ollama_endpoint = "http://localhost:11434" 61 | 62 | if "chunk_size" not in st.session_state.keys(): 63 | st.session_state.chunk_size = config.DEFAULT_CHUNK_SIZE 64 | 65 | if "chunk_overlap" not in st.session_state.keys(): 66 | st.session_state.chunk_overlap = config.DEFAULT_CHUNK_OVERLAP 67 | 68 | if "zh_title_enhance" not in st.session_state.keys(): 69 | st.session_state.zh_title_enhance = config.ZH_TITLE_ENHANCE 70 | 71 | if "max_tokens" not in st.session_state.keys(): 72 | st.session_state.max_tokens = 100 73 | 74 | if "top_p" not in st.session_state.keys(): 75 | st.session_state.top_p = 1.0 76 | 77 | # contents related to the knowledge base 78 | if "websites" not in st.session_state: 79 | st.session_state["websites"] = [] 80 | 81 | if 'uploaded_files' not in st.session_state: 82 | st.session_state.uploaded_files = [] 83 | if 'selected_files' not in st.session_state: 84 | st.session_state.selected_files = None 85 | 86 | # Initialize user data 87 | # TODO: supposed to be loaded from database 88 | st.session_state.user_id = "user_1" 89 | st.session_state.kb_id = "kb_1" 90 | st.session_state.kb_name = "My knowledge base" 91 | 92 | def init_llm_sp(): 93 | 94 | llm_options = list(config.LLM_API_LIST.keys()) 95 | 96 | # LLM service provider selection 97 | if "llm_service_provider_selected" not in st.session_state: 98 | sp = CONFIG_STORE.get(key="llm_service_provider_selected") 99 | if sp: 100 | st.session_state.llm_service_provider_selected = sp["llm_service_provider_selected"] 101 | else: 102 | st.session_state.llm_service_provider_selected = llm_options[0] 103 | 104 | def init_ollama_endpoint(): 105 | # Initialize Ollama endpoint 106 | if "ollama_api_url" not in st.session_state.keys(): 107 | ollama_api_url = CONFIG_STORE.get(key="Ollama_api_url") 108 | if ollama_api_url: 109 | st.session_state.ollama_api_url = ollama_api_url["Ollama_api_url"] 110 | else: 111 | st.session_state.ollama_api_url = config.LLM_API_LIST["Ollama"]["api_base"] 112 | 113 | # Initialize llm api model 114 | def init_api_model(sp): 115 | if sp != "Ollama": 116 | model_key = sp + "_model_selected" 117 | if model_key not in st.session_state.keys(): 118 | model_result = CONFIG_STORE.get(key=model_key) 119 | if model_result: 120 | st.session_state[model_key] = model_result[model_key] 121 | else: 122 | st.session_state[model_key] = config.LLM_API_LIST[sp]["models"][0] 123 | 124 | 125 | # Initialize llm api base 126 | def init_api_base(sp): 127 | if sp != "Ollama": 128 | api_base = sp + "_api_base" 129 | if api_base not in st.session_state.keys(): 130 | api_key_result = CONFIG_STORE.get(key=api_base) 131 | if api_key_result is not None: 132 | st.session_state[api_base] = api_key_result[api_base] 133 | else: 134 | st.session_state[api_base] = config.LLM_API_LIST[sp]["api_base"] 135 | 136 | # Initialize llm api key 137 | def init_api_key(sp): 138 | if sp != "Ollama": 139 | api_key = sp + "_api_key" 140 | if api_key not in st.session_state.keys(): 141 | api_key_result = CONFIG_STORE.get(key=api_key) 142 | if api_key_result is not None: 143 | st.session_state[api_key] = api_key_result[api_key] 144 | else: 145 | st.session_state[api_key] = config.LLM_API_LIST[sp]["api_key"] 146 | 147 | valid_key = api_key + "_valid" 148 | if valid_key not in st.session_state.keys(): 149 | valid_result = CONFIG_STORE.get(key=valid_key) 150 | if valid_result is None and st.session_state[api_key] is not None: 151 | is_valid = check_openai_llm(st.session_state[sp + "_model_selected"], config.LLM_API_LIST[sp]["api_base"], st.session_state[api_key]) 152 | CONFIG_STORE.put(key=valid_key, val={valid_key: is_valid}) 153 | st.session_state[valid_key] = is_valid 154 | else: 155 | st.session_state[valid_key] = valid_result[valid_key] 156 | 157 | # Initialize LLM settings, like temperature, system prompt, etc. 158 | def init_llm_settings(): 159 | if "current_llm_settings" not in st.session_state.keys(): 160 | current_llm_settings = CONFIG_STORE.get(key="current_llm_settings") 161 | if current_llm_settings: 162 | st.session_state.current_llm_settings = current_llm_settings 163 | else: 164 | st.session_state.current_llm_settings = { 165 | "temperature": config.TEMPERATURE, 166 | "system_prompt": config.SYSTEM_PROMPT, 167 | "top_k": config.TOP_K, 168 | "response_mode": config.DEFAULT_RESPONSE_MODE, 169 | "use_reranker": config.USE_RERANKER, 170 | "top_n": config.RERANKER_MODEL_TOP_N, 171 | "embedding_model": config.DEFAULT_EMBEDDING_MODEL, 172 | "reranker_model": config.DEFAULT_RERANKER_MODEL, 173 | } 174 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state.current_llm_settings) 175 | 176 | 177 | # Create LLM instance if there is related information 178 | def create_llm_instance(): 179 | current_llm_info = CONFIG_STORE.get(key="current_llm_info") 180 | if current_llm_info is not None: 181 | print("Current LLM info: ", current_llm_info) 182 | if current_llm_info["service_provider"] == "Ollama": 183 | if ollama.is_alive(): 184 | model_name = current_llm_info["model"] 185 | st.session_state.llm = ollama.create_ollama_llm( 186 | model=model_name, 187 | temperature=st.session_state.current_llm_settings["temperature"], 188 | system_prompt=st.session_state.current_llm_settings["system_prompt"], 189 | ) 190 | else: 191 | model_name = current_llm_info["model"] 192 | api_base = current_llm_info["api_base"] 193 | api_key = current_llm_info["api_key"] 194 | api_key_valid = current_llm_info["api_key_valid"] 195 | if api_key_valid: 196 | print("API key is valid when creating LLM instance") 197 | st.session_state.llm = create_openai_llm( 198 | model_name=model_name, 199 | api_base=api_base, 200 | api_key=api_key, 201 | temperature=st.session_state.current_llm_settings["temperature"], 202 | system_prompt=st.session_state.current_llm_settings["system_prompt"], 203 | ) 204 | else: 205 | print("API key is invalid when creating LLM instance") 206 | st.session_state.llm = None 207 | else: 208 | print("No current LLM infomation") 209 | st.session_state.llm = None 210 | 211 | def init_state(): 212 | init_keys() 213 | init_llm_sp() 214 | init_llm_settings() 215 | init_ollama_endpoint() 216 | sp = st.session_state.llm_service_provider_selected 217 | init_api_model(sp) 218 | init_api_key(sp) 219 | create_embedding_model(st.session_state["current_llm_settings"]["embedding_model"]) 220 | create_llm_instance() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | llama_index==0.11.19 2 | streamlit==1.39.0 3 | langchain==0.3.4 4 | langchain-community==0.3.3 5 | langchain_openai==0.2.3 6 | ollama==0.3.3 7 | llama-index-embeddings-huggingface==0.3.1 8 | llama-index-embeddings-langchain==0.2.1 9 | llama-index-llms-langchain==0.4.2 10 | llama-index-readers-web==0.2.4 11 | llama-index-retrievers-bm25==0.4.0 12 | llama-index-storage-kvstore-redis==0.2.0 13 | llama-index-vector-stores-chroma==0.2.1 14 | llama-index-vector-stores-elasticsearch==0.3.3 15 | llama-index-vector-stores-lancedb==0.2.4 16 | llama-index-llms-ollama==0.3.4 17 | llama-index-storage-chat_store-redis==0.3.2 18 | llama_index.storage.docstore.redis==0.2.0 19 | llama_index.storage.index_store.redis==0.3.0 20 | docx2txt==0.8 -------------------------------------------------------------------------------- /server/engine.py: -------------------------------------------------------------------------------- 1 | # Create and manage query/chat engine 2 | import config as config 3 | from server.models.reranker import create_reranker_model 4 | from server.prompt import text_qa_template, refine_template 5 | from server.retriever import SimpleFusionRetriever 6 | from llama_index.core.query_engine import RetrieverQueryEngine 7 | 8 | # Create a query engine 9 | def create_query_engine(index, 10 | top_k=config.TOP_K, 11 | response_mode=config.RESPONSE_MODE, 12 | use_reranker=config.USE_RERANKER, 13 | top_n=config.RERANKER_MODEL_TOP_N, 14 | reranker=config.DEFAULT_RERANKER_MODEL): 15 | # Customized query engine with hybrid search and reranker 16 | node_postprocessors = [create_reranker_model(model_name=reranker, top_n=top_n)] if use_reranker else [] 17 | retriever = SimpleFusionRetriever(vector_index=index, top_k=top_k) 18 | 19 | query_engine = RetrieverQueryEngine.from_args( 20 | retriever=retriever, 21 | text_qa_template=text_qa_template, 22 | refine_template=refine_template, 23 | node_postprocessors=node_postprocessors, 24 | response_mode=response_mode, # https://docs.llamaindex.ai/en/stable/api_reference/response_synthesizers/ 25 | verbose=True, 26 | streaming=True, 27 | ) 28 | 29 | return query_engine 30 | -------------------------------------------------------------------------------- /server/index.py: -------------------------------------------------------------------------------- 1 | # Index management - create, load and insert 2 | import os 3 | from llama_index.core import Settings, StorageContext, VectorStoreIndex 4 | from llama_index.core import load_index_from_storage, load_indices_from_storage 5 | from llama_index.core import VectorStoreIndex, SimpleDirectoryReader 6 | from server.utils.file import get_save_dir 7 | from server.stores.strage_context import STORAGE_CONTEXT 8 | from server.ingestion import AdvancedIngestionPipeline 9 | from config import DEV_MODE 10 | 11 | class IndexManager: 12 | def __init__(self, index_name): 13 | self.index_name: str = index_name 14 | self.storage_context: StorageContext = STORAGE_CONTEXT 15 | self.index_id: str = None 16 | self.index: VectorStoreIndex = None 17 | 18 | def check_index_exists(self): 19 | indices = load_indices_from_storage(self.storage_context) 20 | print(f"Loaded {len(indices)} indices") 21 | if len(indices) > 0: 22 | self.index = indices[0] 23 | self.index_id = indices[0].index_id 24 | return True 25 | else: 26 | return False 27 | 28 | def init_index(self, nodes): 29 | self.index = VectorStoreIndex(nodes, 30 | storage_context=self.storage_context, 31 | store_nodes_override=True) # note: no nodes in doc store if using vector database, set store_nodes_override=True to add nodes to doc store 32 | self.index_id = self.index.index_id 33 | if DEV_MODE: 34 | self.storage_context.persist() 35 | print(f"Created index {self.index.index_id}") 36 | return self.index 37 | 38 | def load_index(self): # TODO: load index based on index_id 39 | self.index = load_index_from_storage(self.storage_context) 40 | if not DEV_MODE: 41 | self.index._store_nodes_override = True 42 | print(f"Loaded index {self.index.index_id}") 43 | return self.index 44 | 45 | def insert_nodes(self, nodes): 46 | if self.index is not None: 47 | self.index.insert_nodes(nodes=nodes) 48 | if DEV_MODE: 49 | self.storage_context.persist() 50 | print(f"Inserted {len(nodes)} nodes into index {self.index.index_id}") 51 | else: 52 | self.init_index(nodes=nodes) 53 | return self.index 54 | 55 | # Build index based on documents under 'data' folder 56 | def load_dir(self, input_dir, chunk_size, chunk_overlap): 57 | Settings.chunk_size = chunk_size 58 | Settings.chunk_overlap = chunk_overlap 59 | documents = SimpleDirectoryReader(input_dir=input_dir, recursive=True).load_data() 60 | if len(documents) > 0: 61 | pipeline = AdvancedIngestionPipeline() 62 | nodes = pipeline.run(documents=documents) 63 | index = self.insert_nodes(nodes) 64 | return nodes 65 | else: 66 | print("No documents found") 67 | return [] 68 | 69 | # get file's directory and create index 70 | def load_files(self, uploaded_files, chunk_size, chunk_overlap): 71 | Settings.chunk_size = chunk_size 72 | Settings.chunk_overlap = chunk_overlap 73 | save_dir = get_save_dir() 74 | files = [os.path.join(save_dir, file["name"]) for file in uploaded_files] 75 | print(files) 76 | documents = SimpleDirectoryReader(input_files=files).load_data() 77 | if len(documents) > 0: 78 | pipeline = AdvancedIngestionPipeline() 79 | nodes = pipeline.run(documents=documents) 80 | index = self.insert_nodes(nodes) 81 | return nodes 82 | else: 83 | print("No documents found") 84 | return [] 85 | 86 | # Get URL and create index 87 | # https://docs.llamaindex.ai/en/stable/examples/data_connectors/WebPageDemo/ 88 | def load_websites(self, websites, chunk_size, chunk_overlap): 89 | Settings.chunk_size = chunk_size 90 | Settings.chunk_overlap = chunk_overlap 91 | 92 | from server.readers.beautiful_soup_web import BeautifulSoupWebReader 93 | documents = BeautifulSoupWebReader().load_data(websites) 94 | if len(documents) > 0: 95 | pipeline = AdvancedIngestionPipeline() 96 | nodes = pipeline.run(documents=documents) 97 | index = self.insert_nodes(nodes) 98 | return nodes 99 | else: 100 | print("No documents found") 101 | return [] 102 | 103 | # Delete a document and all related nodes 104 | def delete_ref_doc(self, ref_doc_id): 105 | self.index.delete_ref_doc(ref_doc_id=ref_doc_id, delete_from_docstore=True) 106 | self.storage_context.persist() 107 | print("Deleted document", ref_doc_id) -------------------------------------------------------------------------------- /server/ingestion.py: -------------------------------------------------------------------------------- 1 | # Import pipeline IngestionPipeline 2 | # https://docs.llamaindex.ai/en/stable/api_reference/ingestion/ 3 | # https://docs.llamaindex.ai/en/stable/examples/ingestion/advanced_ingestion_pipeline/ 4 | 5 | from llama_index.core import Settings 6 | from llama_index.core.ingestion import IngestionPipeline, DocstoreStrategy 7 | from server.splitters import ChineseTitleExtractor 8 | from server.stores.strage_context import STORAGE_CONTEXT 9 | from server.stores.ingestion_cache import INGESTION_CACHE 10 | 11 | class AdvancedIngestionPipeline(IngestionPipeline): 12 | def __init__( 13 | self, 14 | ): 15 | # Initialize the embedding model, text splitter 16 | embed_model = Settings.embed_model 17 | text_splitter = Settings.text_splitter 18 | 19 | # Call the super class's __init__ method with the necessary arguments 20 | super().__init__( 21 | transformations=[ 22 | text_splitter, 23 | embed_model, 24 | ChineseTitleExtractor(), # modified Chinese title enhance: zh_title_enhance 25 | ], 26 | docstore=STORAGE_CONTEXT.docstore, 27 | vector_store=STORAGE_CONTEXT.vector_store, 28 | cache=INGESTION_CACHE, 29 | docstore_strategy=DocstoreStrategy.UPSERTS, # UPSERTS: Update or insert 30 | ) 31 | 32 | # If you need to override the run method or add new methods, you can do so here 33 | def run(self, documents): 34 | print(f"Load {len(documents)} Documents") 35 | nodes = super().run(documents=documents) 36 | print(f"Ingested {len(nodes)} Nodes") 37 | return nodes -------------------------------------------------------------------------------- /server/models/embedding.py: -------------------------------------------------------------------------------- 1 | # Create embedding models 2 | import os 3 | from llama_index.core import Settings 4 | from llama_index.embeddings.huggingface import HuggingFaceEmbedding 5 | from config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_MODEL_PATH, MODEL_DIR 6 | from server.utils.hf_mirror import use_hf_mirror 7 | 8 | def create_embedding_model(model_name = DEFAULT_EMBEDDING_MODEL) -> HuggingFaceEmbedding: 9 | try: 10 | use_hf_mirror() 11 | model_path = EMBEDDING_MODEL_PATH[model_name] 12 | if MODEL_DIR is not None: 13 | path = f"./{MODEL_DIR}/{model_path}" 14 | if os.path.exists(path): # Use local models if the path exists 15 | model_path = path 16 | embed_model = HuggingFaceEmbedding(model_name=model_path) 17 | Settings.embed_model = embed_model 18 | print(f"created embed model: {model_path}") 19 | except Exception as e: 20 | print(f"An error occurred while creating the embedding model: {type(e).__name__}: {e}") 21 | Settings.embed_model = None 22 | 23 | return Settings.embed_model -------------------------------------------------------------------------------- /server/models/llm_api.py: -------------------------------------------------------------------------------- 1 | # Create LLM with API compatible with OpenAI 2 | from llama_index.core import Settings 3 | from langchain_openai import ChatOpenAI 4 | from llama_index.llms.langchain import LangChainLLM 5 | 6 | def create_openai_llm(model_name:str, api_base:str, api_key:str, temperature:float = 0.5, system_prompt:str = None) -> ChatOpenAI: 7 | try: 8 | llm = LangChainLLM( 9 | llm=ChatOpenAI( 10 | openai_api_base=api_base, 11 | openai_api_key=api_key, 12 | model_name=model_name, 13 | temperature=temperature, 14 | ), 15 | system_prompt=system_prompt, 16 | ) 17 | Settings.llm = llm 18 | return llm 19 | except Exception as e: 20 | print(f"An error occurred while creating the OpenAI compatibale model: {type(e).__name__}: {e}") 21 | return None 22 | 23 | def check_openai_llm(model_name, api_base, api_key) -> bool: 24 | # Make a simple API call to verify the key 25 | try: 26 | llm = ChatOpenAI( 27 | openai_api_base=api_base, 28 | openai_api_key=api_key, 29 | model_name=model_name, 30 | timeout=5, 31 | max_retries=1 32 | ) 33 | response = llm.invoke("Hello, World!") 34 | print(response) 35 | if response: 36 | return True 37 | else: 38 | return False 39 | except Exception as e: 40 | print(f"An error occurred while verifying the LLM API: {type(e).__name__}: {e}") 41 | return False 42 | -------------------------------------------------------------------------------- /server/models/ollama.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import streamlit as st 3 | from ollama import Client 4 | from llama_index.core import Settings 5 | from llama_index.llms.ollama import Ollama 6 | 7 | def is_alive(): 8 | try: 9 | response = requests.get(st.session_state.ollama_api_url) 10 | return response.status_code == 200 11 | except requests.exceptions.RequestException: 12 | print("Failed to connect to Ollama") 13 | return False 14 | 15 | def get_model_list(): 16 | st.session_state.ollama_models = [] 17 | if is_alive(): 18 | client = Client(host=st.session_state.ollama_api_url) 19 | response = client.list() 20 | models = response["models"] 21 | # Initialize the list of model names 22 | for model in models: 23 | st.session_state.ollama_models.append(model["name"]) 24 | return response["models"] 25 | else: 26 | print("Ollama is not alive") 27 | return None 28 | 29 | # Create Ollama LLM 30 | def create_ollama_llm(model:str, temperature:float = 0.5, system_prompt:str = None) -> Ollama: 31 | try: 32 | llm = Ollama( 33 | model=model, 34 | base_url=st.session_state.ollama_api_url, 35 | request_timeout=600, 36 | temperature=temperature, 37 | system_prompt=system_prompt, 38 | ) 39 | print(f"created ollama model for query: {model}") 40 | Settings.llm = llm 41 | return llm 42 | except Exception as e: 43 | print(f"An error occurred while creating Ollama LLM: {e}") 44 | return None 45 | -------------------------------------------------------------------------------- /server/models/reranker.py: -------------------------------------------------------------------------------- 1 | # Create Rerank model 2 | # https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/SentenceTransformerRerank/ 3 | import os 4 | from llama_index.core.postprocessor import SentenceTransformerRerank 5 | from config import DEFAULT_RERANKER_MODEL, RERANKER_MODEL_TOP_N, RERANKER_MODEL_PATH, MODEL_DIR 6 | from server.utils.hf_mirror import use_hf_mirror 7 | 8 | def create_reranker_model(model_name = DEFAULT_RERANKER_MODEL, top_n = RERANKER_MODEL_TOP_N) -> SentenceTransformerRerank: 9 | try: 10 | use_hf_mirror() 11 | model_path = RERANKER_MODEL_PATH[model_name] 12 | if MODEL_DIR is not None: 13 | path = f"./{MODEL_DIR}/{model_path}" 14 | if os.path.exists(path): # Use local models if the path exists 15 | model_path = path 16 | rerank_model = SentenceTransformerRerank(model=model_path, top_n=top_n) 17 | print(f"created rerank model: {model_name}") 18 | return rerank_model 19 | except Exception as e: 20 | return None -------------------------------------------------------------------------------- /server/prompt.py: -------------------------------------------------------------------------------- 1 | # https://docs.llamaindex.ai/en/stable/examples/customization/prompts/completion_prompts/ 2 | # https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/prompts/default_prompts.py 3 | # https://docs.llamaindex.ai/en/stable/module_guides/models/prompts/usage_pattern/ 4 | 5 | from llama_index.core import PromptTemplate 6 | 7 | text_qa_template_str = ( 8 | "以下为上下文信息\n" 9 | "---------------------\n" 10 | "{context_str}\n" 11 | "---------------------\n" 12 | "请根据上下文信息回答我的问题或回复我的指令。前面的上下文信息可能有用,也可能没用,你需要从我给出的上下文信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。我的问题或指令是什么语种,你就用什么语种回复。\n" 13 | "问题:{query_str}\n" 14 | "你的回复: " 15 | ) 16 | 17 | 18 | text_qa_template = PromptTemplate(text_qa_template_str) 19 | 20 | refine_template_str = ( 21 | "这是原本的问题: {query_str}\n" 22 | "我们已经提供了回答: {existing_answer}\n" 23 | "现在我们有机会改进这个回答 " 24 | "使用以下更多上下文(仅当需要用时)\n" 25 | "------------\n" 26 | "{context_msg}\n" 27 | "------------\n" 28 | "根据新的上下文, 请改进原来的回答。" 29 | "如果新的上下文没有用, 直接返回原本的回答。\n" 30 | "改进的回答: " 31 | ) 32 | refine_template = PromptTemplate(refine_template_str) 33 | -------------------------------------------------------------------------------- /server/readers/beautiful_soup_web.py: -------------------------------------------------------------------------------- 1 | """Beautiful Soup Web scraper.""" 2 | 3 | import logging 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | from urllib.parse import urljoin 6 | from datetime import datetime 7 | 8 | from llama_index.core.bridge.pydantic import PrivateAttr 9 | from llama_index.core.readers.base import BasePydanticReader 10 | from llama_index.core.schema import Document 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def _mpweixin_reader(soup: Any, **kwargs) -> Tuple[str, Dict[str, Any]]: 16 | """Extract text from Substack blog post.""" 17 | meta_tag_title = soup.find('meta', attrs={'property': 'og:title'}) 18 | title = meta_tag_title['content'] 19 | extra_info = { 20 | "title": title, 21 | #"Author": soup.select_one("span #js_author_name").getText(), 22 | } 23 | text = soup.select_one("div #page-content").getText() 24 | return text, extra_info 25 | 26 | 27 | DEFAULT_WEBSITE_EXTRACTOR: Dict[ 28 | str, Callable[[Any, str], Tuple[str, Dict[str, Any]]] 29 | ] = { 30 | "mp.weixin.qq.com": _mpweixin_reader, 31 | } 32 | 33 | 34 | class BeautifulSoupWebReader(BasePydanticReader): 35 | """BeautifulSoup web page reader. 36 | 37 | Reads pages from the web. 38 | Requires the `bs4` and `urllib` packages. 39 | 40 | Args: 41 | website_extractor (Optional[Dict[str, Callable]]): A mapping of website 42 | hostname (e.g. google.com) to a function that specifies how to 43 | extract text from the BeautifulSoup obj. See DEFAULT_WEBSITE_EXTRACTOR. 44 | """ 45 | 46 | is_remote: bool = True 47 | _website_extractor: Dict[str, Callable] = PrivateAttr() 48 | 49 | def __init__(self, website_extractor: Optional[Dict[str, Callable]] = None) -> None: 50 | super().__init__() 51 | self._website_extractor = website_extractor or DEFAULT_WEBSITE_EXTRACTOR 52 | 53 | @classmethod 54 | def class_name(cls) -> str: 55 | """Get the name identifier of the class.""" 56 | return "BeautifulSoupWebReader" 57 | 58 | def load_data( 59 | self, 60 | urls: List[str], 61 | custom_hostname: Optional[str] = None, 62 | include_url_in_text: Optional[bool] = True, 63 | ) -> List[Document]: 64 | """Load data from the urls. 65 | 66 | Args: 67 | urls (List[str]): List of URLs to scrape. 68 | custom_hostname (Optional[str]): Force a certain hostname in the case 69 | a website is displayed under custom URLs (e.g. Substack blogs) 70 | include_url_in_text (Optional[bool]): Include the reference url in the text of the document 71 | 72 | Returns: 73 | List[Document]: List of documents. 74 | 75 | """ 76 | from urllib.parse import urlparse 77 | 78 | import requests 79 | from bs4 import BeautifulSoup 80 | 81 | documents = [] 82 | for url in urls: 83 | try: 84 | page = requests.get(url) 85 | hostname = custom_hostname or urlparse(url).hostname or "" 86 | 87 | soup = BeautifulSoup(page.content, "html.parser") 88 | 89 | data = "" 90 | extra_info = { 91 | "title": soup.select_one("title"), 92 | "url_source": url, 93 | "creation_date": datetime.now().date().isoformat(), # Convert datetime to ISO format string 94 | } 95 | if hostname in self._website_extractor: 96 | data, metadata = self._website_extractor[hostname]( 97 | soup=soup, url=url, include_url_in_text=include_url_in_text 98 | ) 99 | extra_info.update(metadata) 100 | 101 | else: 102 | data = soup.getText() 103 | 104 | documents.append(Document(text=data, id_=url, extra_info=extra_info)) 105 | except Exception: 106 | print(f"Could not scrape {url}") 107 | raise ValueError(f"One of the inputs is not a valid url: {url}") 108 | 109 | return documents 110 | -------------------------------------------------------------------------------- /server/readers/jina_web.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict, Callable 2 | from datetime import datetime 3 | 4 | import requests, re 5 | from llama_index.core.readers.base import BasePydanticReader 6 | from llama_index.core.schema import Document 7 | 8 | 9 | class JinaWebReader(BasePydanticReader): 10 | """Jina web page reader. 11 | 12 | Reads pages from the web. 13 | 14 | """ 15 | 16 | def __init__(self) -> None: 17 | """Initialize with parameters.""" 18 | 19 | def load_data(self, urls: List[str]) -> List[Document]: 20 | """Load data from the input directory. 21 | 22 | Args: 23 | urls (List[str]): List of URLs to scrape. 24 | 25 | Returns: 26 | List[Document]: List of documents. 27 | 28 | """ 29 | if not isinstance(urls, list): 30 | raise ValueError("urls must be a list of strings.") 31 | 32 | documents = [] 33 | for url in urls: 34 | new_url = "https://r.jina.ai/" + url 35 | response = requests.get(new_url) 36 | text = response.text 37 | 38 | # Extract Title 39 | title_match = re.search(r"Title:\s*(.*)", text) 40 | title = title_match.group(1) if title_match else None 41 | 42 | # Extract URL Source 43 | url_match = re.search(r"URL Source:\s*(.*)", text) 44 | url_source = url_match.group(1) if url_match else None 45 | 46 | # Extract Markdown Content 47 | markdown_match = re.search(r"Markdown Content:\s*(.*)", text, re.DOTALL) 48 | markdown_content = markdown_match.group(1).strip() if markdown_match else None 49 | 50 | # Compose metadata 51 | metadata: Dict = { 52 | "title": title, 53 | "url_source": url_source, 54 | "creation_date": datetime.now().date().isoformat(), # Convert datetime to ISO format string 55 | } 56 | 57 | documents.append(Document(text=markdown_content, id_=url, metadata=metadata or {})) 58 | 59 | return documents 60 | -------------------------------------------------------------------------------- /server/retriever.py: -------------------------------------------------------------------------------- 1 | # Retriever method 2 | 3 | from llama_index.core.retrievers import BaseRetriever 4 | from llama_index.core.retrievers import VectorIndexRetriever 5 | from llama_index.retrievers.bm25 import BM25Retriever 6 | 7 | # A simple BM25 retrieval method, customized for document storage and tokenization 8 | 9 | # BM25Retriever's default tokenizer does not support Chinese 10 | # Reference:https://github.com/run-llama/llama_index/issues/13866 11 | 12 | import jieba 13 | from typing import List 14 | def chinese_tokenizer(text: str) -> List[str]: 15 | return list(jieba.cut(text)) 16 | 17 | class SimpleBM25Retriever(BM25Retriever): 18 | @classmethod 19 | def from_defaults(cls, index, similarity_top_k, **kwargs) -> "BM25Retriever": 20 | docstore = index.docstore 21 | return BM25Retriever.from_defaults( 22 | docstore=docstore, similarity_top_k=similarity_top_k, verbose=True, 23 | tokenizer=chinese_tokenizer, **kwargs 24 | ) 25 | 26 | # A simple hybrid retriever method 27 | # Reference:https://docs.llamaindex.ai/en/stable/examples/retrievers/bm25_retriever/ 28 | 29 | class SimpleHybridRetriever(BaseRetriever): 30 | def __init__(self, vector_index, top_k=2): 31 | self.top_k = top_k 32 | 33 | # Build vector retriever from vector index 34 | self.vector_retriever = VectorIndexRetriever( 35 | index=vector_index, similarity_top_k=top_k, verbose=True, 36 | ) 37 | 38 | # Build BM25 retriever from document storage 39 | self.bm25_retriever = SimpleBM25Retriever.from_defaults( 40 | index=vector_index, similarity_top_k=top_k, 41 | ) 42 | 43 | super().__init__() 44 | 45 | def _retrieve(self, query, **kwargs): 46 | bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs) 47 | 48 | # the score is related to the query and may exceed 1, thus normalization is required 49 | # calculate min and max value 50 | min_score = min(item.score for item in bm25_nodes) 51 | max_score = max(item.score for item in bm25_nodes) 52 | 53 | # normalize score 54 | normalized_data = [(item.score - min_score) / (max_score - min_score) for item in bm25_nodes] 55 | 56 | # Assign normalized score back to the original object 57 | for item, normalized_score in zip(bm25_nodes, normalized_data): 58 | item.score = normalized_score 59 | 60 | vector_nodes = self.vector_retriever.retrieve(query, **kwargs) 61 | 62 | # Merge two retrieval results, remove duplicates, and return only the Top_K results 63 | all_nodes = [] 64 | node_ids = set() 65 | count = 0 66 | for n in vector_nodes + bm25_nodes: 67 | if n.node.node_id not in node_ids: 68 | all_nodes.append(n) 69 | node_ids.add(n.node.node_id) 70 | count += 1 71 | if count >= self.top_k: 72 | break 73 | for node in all_nodes: 74 | print(f"Hybrid Retrieved Node: {node.node_id} - Score: {node.score:.2f} - {node.text[:10]}...\n-----") 75 | return all_nodes 76 | 77 | # Fusion retriever method 78 | # Reference: https://docs.llamaindex.ai/en/stable/examples/retrievers/relative_score_dist_fusion/ 79 | # https://medium.com/plain-simple-software/distribution-based-score-fusion-dbsf-a-new-approach-to-vector-search-ranking-f87c37488b18 80 | # https://docs.llamaindex.ai/en/stable/examples/low_level/fusion_retriever/?h=retrieverqueryengine 81 | from llama_index.core.retrievers import QueryFusionRetriever 82 | from enum import Enum 83 | 84 | # Three different modes, from LlamaIndex's source code 85 | class FUSION_MODES(str, Enum): 86 | RECIPROCAL_RANK = "reciprocal_rerank" # apply reciprocal rank fusion 87 | RELATIVE_SCORE = "relative_score" # apply relative score fusion 88 | DIST_BASED_SCORE = "dist_based_score" # apply distance-based score fusion 89 | SIMPLE = "simple" # simple re-ordering of results based on original scores 90 | 91 | class SimpleFusionRetriever(QueryFusionRetriever): 92 | def __init__(self, vector_index, top_k=2, mode=FUSION_MODES.DIST_BASED_SCORE): 93 | self.top_k = top_k 94 | self.mode = mode 95 | 96 | # Build vector retriever from vector index 97 | self.vector_retriever = VectorIndexRetriever( 98 | index=vector_index, similarity_top_k=top_k, verbose=True, 99 | ) 100 | 101 | # Build BM25 retriever from document storage 102 | self.bm25_retriever = SimpleBM25Retriever.from_defaults( 103 | index=vector_index, similarity_top_k=top_k, 104 | ) 105 | 106 | super().__init__( 107 | [self.vector_retriever, self.bm25_retriever], 108 | retriever_weights=[0.6, 0.4], 109 | similarity_top_k=top_k, 110 | num_queries=1, # set this to 1 to disable query generation 111 | mode=mode, 112 | use_async=True, 113 | verbose=True, 114 | ) -------------------------------------------------------------------------------- /server/splitters/__init__.py: -------------------------------------------------------------------------------- 1 | from .chinese_text_splitter import ChineseTextSplitter 2 | from .zh_title_enhance import ChineseTitleExtractor 3 | from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter -------------------------------------------------------------------------------- /server/splitters/chinese_recursive_text_splitter.py: -------------------------------------------------------------------------------- 1 | # Chinese recursive text splitter 2 | # Source:LangchainChatChat, QAnything 3 | 4 | import re 5 | from typing import List, Optional, Any 6 | from langchain.text_splitter import RecursiveCharacterTextSplitter 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def _split_text_with_regex_from_end( 13 | text: str, separator: str, keep_separator: bool 14 | ) -> List[str]: 15 | # Now that we have the separator, split the text 16 | if separator: 17 | if keep_separator: 18 | # The parentheses in the pattern keep the delimiters in the result. 19 | _splits = re.split(f"({separator})", text) 20 | splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] 21 | if len(_splits) % 2 == 1: 22 | splits += _splits[-1:] 23 | # splits = [_splits[0]] + splits 24 | else: 25 | splits = re.split(separator, text) 26 | else: 27 | splits = list(text) 28 | return [s for s in splits if s != ""] 29 | 30 | 31 | class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): 32 | def __init__( 33 | self, 34 | separators: Optional[List[str]] = None, 35 | keep_separator: bool = True, 36 | is_separator_regex: bool = True, 37 | **kwargs: Any, 38 | ) -> None: 39 | """Create a new TextSplitter.""" 40 | super().__init__(keep_separator=keep_separator, **kwargs) 41 | self._separators = separators or [ 42 | "\n\n", 43 | "\n", 44 | "。|!|?", 45 | "\.\s|\!\s|\?\s", 46 | ";|;\s", 47 | ",|,\s" 48 | ] 49 | self._is_separator_regex = is_separator_regex 50 | 51 | def _split_text(self, text: str, separators: List[str]) -> List[str]: 52 | """Split incoming text and return chunks.""" 53 | final_chunks = [] 54 | # Get appropriate separator to use 55 | separator = separators[-1] 56 | new_separators = [] 57 | for i, _s in enumerate(separators): 58 | _separator = _s if self._is_separator_regex else re.escape(_s) 59 | if _s == "": 60 | separator = _s 61 | break 62 | if re.search(_separator, text): 63 | separator = _s 64 | new_separators = separators[i + 1:] 65 | break 66 | 67 | _separator = separator if self._is_separator_regex else re.escape(separator) 68 | splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) 69 | 70 | # Now go merging things, recursively splitting longer texts. 71 | _good_splits = [] 72 | _separator = "" if self._keep_separator else separator 73 | for s in splits: 74 | if self._length_function(s) < self._chunk_size: 75 | _good_splits.append(s) 76 | else: 77 | if _good_splits: 78 | merged_text = self._merge_splits(_good_splits, _separator) 79 | final_chunks.extend(merged_text) 80 | _good_splits = [] 81 | if not new_separators: 82 | final_chunks.append(s) 83 | else: 84 | other_info = self._split_text(s, new_separators) 85 | final_chunks.extend(other_info) 86 | if _good_splits: 87 | merged_text = self._merge_splits(_good_splits, _separator) 88 | final_chunks.extend(merged_text) 89 | return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""] 90 | 91 | 92 | if __name__ == "__main__": 93 | text_splitter = ChineseRecursiveTextSplitter( 94 | keep_separator=True, 95 | is_separator_regex=True, 96 | chunk_size=50, 97 | chunk_overlap=0 98 | ) 99 | ls = [ 100 | """中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""", 101 | ] 102 | # text = """""" 103 | for inum, text in enumerate(ls): 104 | print(inum) 105 | chunks = text_splitter.split_text(text) 106 | for chunk in chunks: 107 | print(chunk) 108 | -------------------------------------------------------------------------------- /server/splitters/chinese_text_splitter.py: -------------------------------------------------------------------------------- 1 | # Chinese text splitter 2 | # Source:LangchainChatChat, QAnything 3 | 4 | from langchain.text_splitter import CharacterTextSplitter 5 | import re 6 | from typing import List 7 | 8 | class ChineseTextSplitter(CharacterTextSplitter): 9 | def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): 10 | super().__init__(**kwargs) 11 | self.pdf = pdf 12 | self.sentence_size = sentence_size 13 | 14 | def split_text1(self, text: str) -> List[str]: 15 | if self.pdf: 16 | text = re.sub(r"\n{3,}", "\n", text) 17 | text = re.sub('\s', ' ', text) 18 | text = text.replace("\n\n", "") 19 | sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; 20 | sent_list = [] 21 | for ele in sent_sep_pattern.split(text): 22 | if sent_sep_pattern.match(ele) and sent_list: 23 | sent_list[-1] += ele 24 | elif ele: 25 | sent_list.append(ele) 26 | return sent_list 27 | 28 | def split_text(self, text: str) -> List[str]: ## Need further logical optimization here 29 | if self.pdf: 30 | text = re.sub(r"\n{3,}", r"\n", text) 31 | text = re.sub('\s', " ", text) 32 | text = re.sub("\n\n", "", text) 33 | 34 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # Single-character delimiter 35 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # English ellipsis 36 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # Chinese ellipsis 37 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) 38 | # If there is an ending punctuation before the double quotes, then the double quotes are considered to be the end of the sentence. 39 | # Place the sentence delimiter \n after the double quotes, and be aware that the double quotes in the previous sentences are preserved. 40 | text = text.rstrip() # Remove the extra \n at the end of the paragraph(if any) 41 | # Semicolons was not considered in this case, along with dashes and English double quotes. If needed, all we need is some simple adjustment. 42 | ls = [i for i in text.split("\n") if i] 43 | for ele in ls: 44 | if len(ele) > self.sentence_size: 45 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) 46 | ele1_ls = ele1.split("\n") 47 | for ele_ele1 in ele1_ls: 48 | if len(ele_ele1) > self.sentence_size: 49 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) 50 | ele2_ls = ele_ele2.split("\n") 51 | for ele_ele2 in ele2_ls: 52 | if len(ele_ele2) > self.sentence_size: 53 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) 54 | ele2_id = ele2_ls.index(ele_ele2) 55 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ 56 | ele2_id + 1:] 57 | ele_id = ele1_ls.index(ele_ele1) 58 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] 59 | 60 | id = ls.index(ele) 61 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] 62 | return ls 63 | -------------------------------------------------------------------------------- /server/splitters/zh_title_enhance.py: -------------------------------------------------------------------------------- 1 | # Chinese title enhance 2 | # Source:LangchainChatChat, QAnything 3 | 4 | from llama_index.core.schema import BaseNode # modified based on Document in Langchain 5 | from typing import List 6 | import re 7 | 8 | 9 | def under_non_alpha_ratio(text: str, threshold: float = 0.5): 10 | """Checks if the proportion of non-alpha characters in the text snippet exceeds a given 11 | threshold. This helps prevent text like "-----------BREAK---------" from being tagged 12 | as a title or narrative text. The ratio does not count spaces. 13 | 14 | Parameters 15 | ---------- 16 | text 17 | The input string to test 18 | threshold 19 | If the proportion of non-alpha characters exceeds this threshold, the function 20 | returns False 21 | """ 22 | if len(text) == 0: 23 | return False 24 | 25 | alpha_count = len([char for char in text if char.strip() and char.isalpha()]) 26 | total_count = len([char for char in text if char.strip()]) 27 | try: 28 | ratio = alpha_count / total_count 29 | return ratio < threshold 30 | except: 31 | return False 32 | 33 | 34 | def is_possible_title( 35 | text: str, 36 | title_max_word_length: int = 20, 37 | non_alpha_threshold: float = 0.5, 38 | ) -> bool: 39 | """Checks to see if the text passes all of the checks for a valid title. 40 | 41 | Parameters 42 | ---------- 43 | text 44 | The input text to check 45 | title_max_word_length 46 | The maximum number of words a title can contain 47 | non_alpha_threshold 48 | The minimum number of alpha characters the text needs to be considered a title 49 | """ 50 | 51 | # If the text length is zero, it is not a title 52 | if len(text) == 0: 53 | print("Not a title. Text is empty.") 54 | return False 55 | 56 | # If the text has punctuation, it is not a title 57 | ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z" 58 | ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN) 59 | if ENDS_IN_PUNCT_RE.search(text) is not None: 60 | return False 61 | 62 | # The text length must not exceed the set value, which is set to be 20 by default. 63 | # NOTE(robinson) - splitting on spaces here instead of word tokenizing because it 64 | # is less expensive and actual tokenization doesn't add much value for the length check 65 | if len(text) > title_max_word_length: 66 | return False 67 | 68 | # The ratio of numbers in the text should not be too high, otherwise it is not a title. 69 | if under_non_alpha_ratio(text, threshold=non_alpha_threshold): 70 | return False 71 | 72 | # NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles 73 | if text.endswith((",", ".", ",", "。")): 74 | return False 75 | 76 | if text.isnumeric(): 77 | print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore 78 | return False 79 | 80 | # "The initial characters should contain numbers, typically within the first 5 characters by default." 81 | if len(text) < 5: 82 | text_5 = text 83 | else: 84 | text_5 = text[:5] 85 | alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5)))) 86 | if not alpha_in_text_5: 87 | return False 88 | 89 | return True 90 | 91 | 92 | def zh_title_enhance(docs: List[BaseNode]) -> List[BaseNode]: # modified based on Document in Langchain 93 | title = None 94 | if len(docs) > 0: 95 | for doc in docs: 96 | if is_possible_title(doc.text): # modified based on doc.page_content in Langchain 97 | doc.metadata['category'] = 'cn_Title' 98 | title = doc.text 99 | elif title: 100 | doc.text = f"下文与({title})有关。{doc.text}" 101 | return docs 102 | else: 103 | print("文件不存在") 104 | 105 | # The following is an encapsulation based on LlamaIndex 106 | 107 | import re 108 | from llama_index.core.schema import TransformComponent 109 | 110 | class ChineseTitleExtractor(TransformComponent): 111 | def __call__(self, nodes, **kwargs): 112 | nodes = zh_title_enhance(nodes) 113 | return nodes -------------------------------------------------------------------------------- /server/stores/chat_store.py: -------------------------------------------------------------------------------- 1 | # Chat Store 2 | 3 | from config import DEV_MODE, REDIS_URI, CHAT_STORE_KEY 4 | 5 | def create_chat_memory(): 6 | 7 | if DEV_MODE: 8 | # Development environment: SimpleChatStore 9 | # https://docs.llamaindex.ai/en/stable/module_guides/storing/chat_stores/ 10 | from llama_index.core.storage.chat_store import SimpleChatStore 11 | from llama_index.core.memory import ChatMemoryBuffer 12 | 13 | simple_chat_store = SimpleChatStore() 14 | 15 | simple_chat_memory = ChatMemoryBuffer.from_defaults( 16 | token_limit=3000, 17 | chat_store=simple_chat_store, 18 | chat_store_key=CHAT_STORE_KEY, 19 | ) 20 | return simple_chat_memory 21 | else: 22 | # Production environment: Redis 23 | # https://docs.llamaindex.ai/en/stable/examples/vector_stores/RedisIndexDemo/ 24 | 25 | # Start redis locally: 26 | # docker run --name redis-vecdb -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest 27 | 28 | from llama_index.core.memory import ChatMemoryBuffer 29 | from llama_index.storage.chat_store.redis import RedisChatStore 30 | 31 | redis_chat_store = RedisChatStore(redis_url=REDIS_URI, ttl=3600) 32 | 33 | redis_chat_memory = ChatMemoryBuffer.from_defaults( 34 | token_limit=3000, 35 | chat_store=redis_chat_store, 36 | chat_store_key=CHAT_STORE_KEY, 37 | ) 38 | return redis_chat_memory 39 | 40 | CHAT_MEMORY = create_chat_memory() -------------------------------------------------------------------------------- /server/stores/config_store.py: -------------------------------------------------------------------------------- 1 | # Config Store 2 | # Save configuration in local kv store or database 3 | 4 | import os 5 | from typing import Optional, Dict 6 | from llama_index.core.storage.kvstore import SimpleKVStore 7 | from config import STORAGE_DIR, CONFIG_STORE_FILE 8 | 9 | DATA_TYPE = Dict[str, Dict[str, dict]] 10 | 11 | PERSISIT_PATH = "./" + STORAGE_DIR + "/" + CONFIG_STORE_FILE 12 | 13 | class LocalKVStore(SimpleKVStore): 14 | #Simple Key-Value store with local persistent. 15 | 16 | def __init__( 17 | self, 18 | data: Optional[DATA_TYPE] = None, 19 | ) -> None: 20 | """Init a SimpleKVStore.""" 21 | super().__init__(data) 22 | 23 | def put(self, key: str, val: dict) -> None: 24 | """Put a key-value pair into the store.""" 25 | super().put(key=key, val=val) 26 | super().persist(persist_path=self.persist_path) 27 | 28 | def delete(self, key: str) -> bool: 29 | """Delete a value from the store.""" 30 | try: 31 | super().delete(key) 32 | super().persist(persist_path=self.persist_path) 33 | return True 34 | except KeyError: 35 | return False 36 | 37 | @classmethod 38 | def from_persist_path( 39 | cls, persist_path: str = PERSISIT_PATH 40 | ) -> "LocalKVStore": 41 | """Load a SimpleKVStore from a persist path and filesystem.""" 42 | cls.persist_path = persist_path 43 | if (os.path.exists(persist_path)): 44 | return super().from_persist_path(persist_path=persist_path) 45 | else: 46 | return cls({}) 47 | 48 | CONFIG_STORE = LocalKVStore.from_persist_path() -------------------------------------------------------------------------------- /server/stores/doc_store.py: -------------------------------------------------------------------------------- 1 | # Document Store 2 | # https://docs.llamaindex.ai/en/stable/examples/docstore/MongoDocstoreDemo/ 3 | # https://docs.llamaindex.ai/en/stable/examples/docstore/RedisDocstoreIndexStoreDemo/ 4 | import config 5 | 6 | if config.THINKRAG_ENV == "production": 7 | from llama_index.storage.docstore.redis import RedisDocumentStore 8 | DOC_STORE = RedisDocumentStore.from_host_and_port( 9 | host=config.REDIS_HOST, port=config.REDIS_PORT, namespace="think" 10 | ) 11 | elif config.THINKRAG_ENV == "development": 12 | from llama_index.core.storage.docstore import SimpleDocumentStore 13 | DOC_STORE = SimpleDocumentStore() -------------------------------------------------------------------------------- /server/stores/index_store.py: -------------------------------------------------------------------------------- 1 | # Index Store 2 | import config 3 | 4 | if config.THINKRAG_ENV == "production": 5 | from llama_index.storage.index_store.redis import RedisIndexStore 6 | INDEX_STORE = RedisIndexStore.from_host_and_port( 7 | host=config.REDIS_HOST, port=config.REDIS_PORT, namespace="think" 8 | ) 9 | elif config.THINKRAG_ENV == "development": 10 | from llama_index.core.storage.index_store import SimpleIndexStore 11 | INDEX_STORE = SimpleIndexStore() -------------------------------------------------------------------------------- /server/stores/ingestion_cache.py: -------------------------------------------------------------------------------- 1 | from llama_index.core.ingestion import IngestionCache 2 | from llama_index.storage.kvstore.redis import RedisKVStore as RedisCache 3 | from config import REDIS_URI, DEV_MODE 4 | 5 | redis_cache=IngestionCache( 6 | cache=RedisCache(redis_uri=REDIS_URI), 7 | collection="redis_pipeline_cache", 8 | ) 9 | 10 | INGESTION_CACHE = redis_cache if not DEV_MODE else None -------------------------------------------------------------------------------- /server/stores/strage_context.py: -------------------------------------------------------------------------------- 1 | # Store context 2 | # https://docs.llamaindex.ai/en/stable/module_guides/storing/customization/ 3 | 4 | from llama_index.core import StorageContext 5 | from config import THINKRAG_ENV 6 | from server.stores.doc_store import DOC_STORE 7 | from server.stores.vector_store import VECTOR_STORE 8 | from server.stores.index_store import INDEX_STORE 9 | 10 | def create_storage_context(): 11 | if THINKRAG_ENV == "development": 12 | # Development environment 13 | import os 14 | from config import STORAGE_DIR 15 | persist_dir = "./" + STORAGE_DIR 16 | if os.path.exists(STORAGE_DIR + "/docstore.json"): 17 | dev_storage_context = StorageContext.from_defaults( 18 | persist_dir=persist_dir # Load from the persist directory 19 | ) 20 | print(f"Loaded storage context from {persist_dir}") 21 | return dev_storage_context 22 | else: 23 | dev_storage_context = StorageContext.from_defaults() # Created new storage context, need persistence 24 | print(f"Created new storage context") 25 | return dev_storage_context 26 | elif THINKRAG_ENV == "production": 27 | pro_storage_context = StorageContext.from_defaults( 28 | docstore=DOC_STORE, 29 | index_store=INDEX_STORE, 30 | vector_store=VECTOR_STORE, 31 | ) 32 | return pro_storage_context 33 | 34 | STORAGE_CONTEXT = create_storage_context() -------------------------------------------------------------------------------- /server/stores/vector_store.py: -------------------------------------------------------------------------------- 1 | # Vector database 2 | 3 | # https://docs.llamaindex.ai/en/stable/examples/vector_stores/ChromaIndexDemo/ 4 | # https://docs.llamaindex.ai/en/stable/module_guides/storing/customization/ 5 | 6 | import config 7 | 8 | def create_vector_store(type=config.DEFAULT_VS_TYPE): 9 | if type == "chroma": 10 | # Vector database Chroma 11 | 12 | # Install Chroma vector database 13 | """ pip install chromadb """ 14 | 15 | import chromadb 16 | from llama_index.vector_stores.chroma import ChromaVectorStore 17 | 18 | db = chromadb.PersistentClient(path=".chroma") 19 | chroma_collection = db.get_or_create_collection("think") 20 | chroma_vector_store = ChromaVectorStore(chroma_collection=chroma_collection) 21 | return chroma_vector_store 22 | elif type == "es": 23 | # Todo: use Metadata Filters 24 | 25 | # Vector database ES 26 | # https://docs.llamaindex.ai/en/stable/examples/vector_stores/ElasticsearchIndexDemo/ 27 | 28 | # Run ES locally 29 | """ docker run -p 9200:9200 \ 30 | -e "discovery.type=single-node" \ 31 | -e "xpack.security.enabled=false" \ 32 | -e "xpack.license.self_generated.type=trial" \ 33 | docker.elastic.co/elasticsearch/elasticsearch:8.13.2 """ 34 | 35 | from llama_index.vector_stores.elasticsearch import ElasticsearchStore 36 | from llama_index.vector_stores.elasticsearch import AsyncDenseVectorStrategy 37 | 38 | es_vector_store = ElasticsearchStore( 39 | es_url="http://localhost:9200", 40 | index_name="think", 41 | retrieval_strategy=AsyncDenseVectorStrategy(hybrid=False), 42 | ) 43 | return es_vector_store 44 | elif type == "lancedb": 45 | # Vector database LanceDB 46 | # https://docs.llamaindex.ai/en/stable/examples/vector_stores/LanceDBIndexDemo/ 47 | # https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/ 48 | from llama_index.vector_stores.lancedb import LanceDBVectorStore 49 | from lancedb.rerankers import LinearCombinationReranker 50 | reranker = LinearCombinationReranker(weight=0.9) 51 | 52 | lance_vector_store = LanceDBVectorStore( 53 | uri=".lancedb", mode="overwrite", query_type="vector", reranker=reranker 54 | ) 55 | return lance_vector_store 56 | elif type == "simple": 57 | from llama_index.core.vector_stores import SimpleVectorStore 58 | return SimpleVectorStore() 59 | else: 60 | raise ValueError(f"Invalid vector store type: {type}") 61 | 62 | if config.THINKRAG_ENV == "production": 63 | VECTOR_STORE = create_vector_store(type="chroma") 64 | else: 65 | VECTOR_STORE = create_vector_store(type="simple") -------------------------------------------------------------------------------- /server/text_splitter.py: -------------------------------------------------------------------------------- 1 | # Text splitter 2 | 3 | from config import DEV_MODE 4 | from llama_index.core import Settings 5 | 6 | def create_text_splitter(chunk_size=2048, chunk_overlap=512): 7 | if DEV_MODE: 8 | # Development environment 9 | # SentenceSplitter 10 | from llama_index.core.node_parser import SentenceSplitter 11 | 12 | sentence_splitter = SentenceSplitter( 13 | chunk_size=chunk_size, 14 | chunk_overlap=chunk_overlap, 15 | ) 16 | 17 | return sentence_splitter 18 | 19 | else: 20 | # Production environment 21 | # SpacyTextSplitter 22 | # https://zhuanlan.zhihu.com/p/638827267 23 | # pip install spacy 24 | # spacy download zh_core_web_sm 25 | from langchain.text_splitter import SpacyTextSplitter 26 | from llama_index.core.node_parser import LangchainNodeParser 27 | 28 | spacy_text_splitter = LangchainNodeParser(SpacyTextSplitter( 29 | pipeline="zh_core_web_sm", 30 | chunk_size=chunk_size, 31 | chunk_overlap=chunk_overlap, 32 | )) 33 | 34 | return spacy_text_splitter 35 | 36 | Settings.text_splitter = create_text_splitter() -------------------------------------------------------------------------------- /server/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import DATA_DIR 3 | 4 | def get_save_dir(): 5 | save_dir = os.getcwd() + "/" + DATA_DIR 6 | return save_dir 7 | 8 | def save_uploaded_file(uploaded_file: bytes, save_dir: str): 9 | try: 10 | if not os.path.exists(save_dir): 11 | os.makedirs(save_dir) 12 | path = os.path.join(save_dir, uploaded_file.name) 13 | with open(path, "wb") as f: 14 | f.write(uploaded_file.getbuffer()) 15 | print(f"已保存 {path}") 16 | except Exception as e: 17 | print(f"Error saving upload to disk: {e}") -------------------------------------------------------------------------------- /server/utils/hf_mirror.py: -------------------------------------------------------------------------------- 1 | # Setting up a HuggingFace mirror 2 | 3 | def use_hf_mirror(): 4 | import os 5 | from config import HF_ENDPOINT 6 | os.environ['HF_ENDPOINT'] = HF_ENDPOINT 7 | print(f"Use HF mirror: {os.environ['HF_ENDPOINT']}") 8 | return os.environ['HF_ENDPOINT'] 9 | --------------------------------------------------------------------------------