├── .gitignore ├── .streamlit └── config.toml ├── LICENSE ├── README.md ├── README_zh.md ├── SECURITY.md ├── app.py ├── config.py ├── docs ├── Code_of_Conduct.md ├── HowToDownloadModels.md ├── HowToUsePythonVirtualEnv.md └── images │ ├── KB_File.png │ ├── KB_Manage.png │ ├── KB_Web.png │ ├── Model_LLM.png │ ├── Model_Reranker.png │ ├── Query.png │ ├── Settings_Advanced.png │ └── ThinkRAG_Architecture.png ├── frontend ├── Document_QA.py ├── KB_File.py ├── KB_Manage.py ├── KB_Web.py ├── Model_Embed.py ├── Model_LLM.py ├── Model_Rerank.py ├── Setting_Advanced.py ├── Storage.py ├── images │ └── ThinkRAG_Logo.png └── state.py ├── requirements.txt └── server ├── engine.py ├── index.py ├── ingestion.py ├── models ├── embedding.py ├── llm_api.py ├── ollama.py └── reranker.py ├── prompt.py ├── readers ├── beautiful_soup_web.py └── jina_web.py ├── retriever.py ├── splitters ├── __init__.py ├── chinese_recursive_text_splitter.py ├── chinese_text_splitter.py └── zh_title_enhance.py ├── stores ├── chat_store.py ├── config_store.py ├── doc_store.py ├── index_store.py ├── ingestion_cache.py ├── strage_context.py └── vector_store.py ├── text_splitter.py └── utils ├── file.py └── hf_mirror.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Local models 2 | localmodels/ 3 | 4 | # Uploaded files 5 | data/ 6 | 7 | # Local Storage 8 | storage/ 9 | .chroma/ 10 | .lancedb/ 11 | # Test files 12 | test_*.* 13 | 14 | # Python venv 15 | .venv/ 16 | bin/ 17 | include/ 18 | lib/ 19 | pyvenv.cfg 20 | etc/ 21 | 22 | # The following are ignored by default. 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | *$py.class 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | cover/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | db.sqlite3-journal 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | .pybuilder/ 99 | target/ 100 | 101 | # Jupyter Notebook 102 | .ipynb_checkpoints 103 | 104 | # IPython 105 | profile_default/ 106 | ipython_config.py 107 | 108 | # pyenv 109 | # For a library or package, you might want to ignore these files since the code is 110 | # intended to run in multiple environments; otherwise, check them in: 111 | # .python-version 112 | 113 | # pipenv 114 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 115 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 116 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 117 | # install all needed dependencies. 118 | #Pipfile.lock 119 | 120 | # poetry 121 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 122 | # This is especially recommended for binary packages to ensure reproducibility, and is more 123 | # commonly ignored for libraries. 124 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 125 | #poetry.lock 126 | 127 | # pdm 128 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 129 | #pdm.lock 130 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 131 | # in version control. 132 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 133 | .pdm.toml 134 | .pdm-python 135 | .pdm-build/ 136 | 137 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 138 | __pypackages__/ 139 | 140 | # Celery stuff 141 | celerybeat-schedule 142 | celerybeat.pid 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | # Pyre type checker 172 | .pyre/ 173 | 174 | # pytype static type analyzer 175 | .pytype/ 176 | 177 | # Cython debug symbols 178 | cython_debug/ 179 | 180 | # .DS_Store files 181 | **/.DS_Store 182 | 183 | # PyCharm 184 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 185 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 186 | # and can be added to the global gitignore or merged into this file. For a more nuclear 187 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 188 | #.idea/ 189 | -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [client] 2 | toolbarMode = "minimal" 3 | showSidebarNavigation = false 4 | 5 | [theme] 6 | primaryColor = "#F63366" 7 | backgroundColor = "white" 8 | font = "sans serif" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 David Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
5 | 6 |Website(s)
", unsafe_allow_html=True) 20 | for site in st.session_state["websites"]: 21 | st.caption(f"- {site}") 22 | st.write("") 23 | 24 | with st.expander( 25 | "Text processing parameter configuration", 26 | expanded=True, 27 | ): 28 | cols = st.columns(2) 29 | chunk_size = cols[0].number_input("Maximum length of a single text block: ", 1, 4096, st.session_state.chunk_size, key="web_chunk_size") 30 | chunk_overlap = cols[1].number_input("Adjacent text overlap length: ", 0, st.session_state.chunk_size, st.session_state.chunk_overlap, key="web_chunk_overlap") 31 | 32 | process_button = st.button("Save", 33 | key="process_website", 34 | disabled=len(st.session_state["websites"]) == 0) 35 | if process_button: 36 | print("Generating index...") 37 | with st.spinner(text="Loading documents and building the index, may take a minute or two"): 38 | st.session_state.index_manager.load_websites(st.session_state["websites"], chunk_size, chunk_overlap) 39 | st.toast('✔️ Knowledge base index generation complete', icon='🎉') 40 | st.session_state.websites = [] 41 | time.sleep(4) 42 | st.rerun() 43 | 44 | handle_website() 45 | 46 | 47 | -------------------------------------------------------------------------------- /frontend/Model_Embed.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from config import EMBEDDING_MODEL_PATH 3 | from server.stores.config_store import CONFIG_STORE 4 | from server.stores.strage_context import STORAGE_CONTEXT 5 | from server.models.embedding import create_embedding_model 6 | 7 | st.header("Embedding Model") 8 | st.caption("Configure embedding models", 9 | help="Embeddings are numerical representations of data, useful for tasks like document clustering and similarity detection when processing files, as they encode semantic meaning for efficient manipulation and retrieval.", 10 | ) 11 | 12 | def change_embedding_model(): 13 | st.session_state["current_llm_settings"]["embedding_model"] = st.session_state["selected_embedding_model"] 14 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 15 | create_embedding_model(st.session_state["current_llm_settings"]["embedding_model"]) 16 | 17 | doc_store = STORAGE_CONTEXT.docstore 18 | if len(doc_store.docs) > 0: 19 | disabled = True 20 | else: 21 | disabled = False 22 | embedding_settings = st.container(border=True) 23 | with embedding_settings: 24 | embedding_model_list = list(EMBEDDING_MODEL_PATH.keys()) 25 | embedding_model = st.selectbox( 26 | "Embedding models", 27 | embedding_model_list, 28 | key="selected_embedding_model", 29 | index=embedding_model_list.index(st.session_state["current_llm_settings"]["embedding_model"]), 30 | disabled=disabled, 31 | on_change=change_embedding_model, 32 | ) 33 | if disabled: 34 | st.info("You cannot change embedding model once you add documents in the knowledge base.") 35 | st.caption("ThinkRAG supports most reranking models from `Hugging Face`. You may specify the models you want to use in the `config.py` file.") 36 | st.caption("It is recommended to download the models to the `localmodels` directory, in case you need run the system without an Internet connection. Plase refer to the instructions in `docs` directory.") -------------------------------------------------------------------------------- /frontend/Model_LLM.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from config import LLM_API_LIST 3 | import server.models.ollama as ollama 4 | from server.stores.config_store import CONFIG_STORE 5 | from server.models.llm_api import check_openai_llm 6 | from frontend.state import init_llm_sp, init_ollama_endpoint, init_api_base, init_api_model, init_api_key, create_llm_instance 7 | 8 | st.header("Large Language Model") 9 | st.caption("Support local models from Ollama and OpenAI compatible LLM APIs.", 10 | help="Large language models (LLMs) are powerful models that can generate human-like text based on the input they receive. LLMs can be used for a wide range of natural language processing tasks, including text generation, question answering, and summarization.", 11 | ) 12 | 13 | init_llm_sp() 14 | 15 | sp = st.session_state.llm_service_provider_selected 16 | llm = LLM_API_LIST[sp] 17 | 18 | init_ollama_endpoint() 19 | init_api_base(sp) 20 | init_api_model(sp) 21 | init_api_key(sp) 22 | 23 | def save_current_llm_info(): 24 | sp = st.session_state.llm_service_provider_selected 25 | if sp == "Ollama": 26 | if st.session_state.ollama_model_selected is not None: 27 | CONFIG_STORE.put(key="current_llm_info", val={ 28 | "service_provider": sp, 29 | "model": st.session_state.ollama_model_selected, 30 | }) 31 | else: 32 | api_key = sp + "_api_key" 33 | model_key = sp + "_model_selected" 34 | base_key = sp + "_api_base" 35 | if st.session_state[model_key] is not None and st.session_state[api_key] is not None and st.session_state[base_key] is not None: 36 | CONFIG_STORE.put(key="current_llm_info", val={ 37 | "service_provider": sp, 38 | "model": st.session_state[model_key], 39 | "api_base": st.session_state[base_key], 40 | "api_key": st.session_state[api_key], 41 | "api_key_valid": st.session_state[api_key + "_valid"], 42 | }) 43 | else: 44 | st.warning("Please fill in all the required fields") 45 | 46 | def update_llm_service_provider(): 47 | selected_option = st.session_state["llm_service_provider"] 48 | st.session_state.llm_service_provider_selected = selected_option 49 | CONFIG_STORE.put(key="llm_service_provider_selected", val={"llm_service_provider_selected": selected_option}) 50 | if selected_option != "Ollama": 51 | init_api_base(selected_option) 52 | init_api_model(selected_option) 53 | init_api_key(selected_option) 54 | save_current_llm_info() 55 | 56 | def init_llm_options(): 57 | llm_options = list(LLM_API_LIST.keys()) 58 | col1, _, col2 = st.columns([5, 4, 1], vertical_alignment="bottom") 59 | with col1: 60 | option = st.selectbox( 61 | "Please select one of the options.", 62 | llm_options, 63 | index=llm_options.index(st.session_state.llm_service_provider_selected), 64 | key="llm_service_provider", 65 | on_change=update_llm_service_provider, 66 | ) 67 | 68 | if option is not None and option != st.session_state.llm_service_provider_selected: 69 | CONFIG_STORE.put(key="llm_service_provider_selected", val={ 70 | "llm_service_provider_selected": option, 71 | }) 72 | 73 | current_llm_info = CONFIG_STORE.get(key="current_llm_info") 74 | 75 | if current_llm_info is None: 76 | save_current_llm_info() 77 | 78 | init_llm_options() 79 | 80 | option = st.session_state.llm_service_provider_selected 81 | 82 | def change_ollama_endpoint(): 83 | st.session_state.ollama_api_url = st.session_state.ollama_endpoint 84 | if ollama.is_alive(): 85 | name = option + "_api_url" # e.g. "Ollama_api_url" 86 | CONFIG_STORE.put(key=name, val={ 87 | name: st.session_state.ollama_api_url, 88 | }) 89 | save_current_llm_info() 90 | else: 91 | st.warning("Failed to connect to Ollama") 92 | 93 | def change_ollama_model(): 94 | st.session_state.ollama_model_selected = st.session_state.ollama_model_name 95 | name = option + "_model_selected" # e.g. "Ollama_model_selected" 96 | CONFIG_STORE.put(key=name, val={ 97 | name: st.session_state.ollama_model_selected, 98 | }) 99 | save_current_llm_info() 100 | 101 | def change_llm_api_base(): 102 | name = option + "_api_base" # e.g. "OpenAI_api_base" 103 | st.session_state[name] = st.session_state.llm_api_endpoint 104 | CONFIG_STORE.put(key=name, val={ 105 | name: st.session_state.llm_api_endpoint, 106 | }) 107 | save_current_llm_info() 108 | 109 | def change_llm_api_key(): 110 | name = option + "_api_key" # e.g. "OpenAI_api_key" 111 | st.session_state[name] = st.session_state.llm_api_key 112 | CONFIG_STORE.put(key=name, val={ 113 | name: st.session_state.llm_api_key, 114 | }) 115 | print("Checking API key...") 116 | print(st.session_state.llm_api_key) 117 | is_valid = check_openai_llm(st.session_state.llm_api_model, st.session_state.llm_api_endpoint, st.session_state.llm_api_key) 118 | st.session_state[name + "_valid"] = is_valid 119 | CONFIG_STORE.put(key=name + "_valid", val={ # e.g. "OpenAI_api_key_valid" 120 | name + "_valid": is_valid, 121 | }) 122 | save_current_llm_info() 123 | if is_valid: 124 | print("API key is valid") 125 | else: 126 | print("API key is invalid") 127 | 128 | def change_llm_api_model(): 129 | name = option + "_model_selected" # e.g. "OpenAI_model_selected" 130 | st.session_state[name] = st.session_state.llm_api_model 131 | CONFIG_STORE.put(key=name, val={ 132 | name: st.session_state.llm_api_model, 133 | }) 134 | save_current_llm_info() 135 | 136 | def llm_configuration_page(): 137 | llm_api_settings = st.container(border=True) 138 | with llm_api_settings: 139 | if option == "Ollama": 140 | st.subheader("Configure for Ollama") 141 | st.text_input( 142 | "Ollama Endpoint", 143 | key="ollama_endpoint", 144 | value=st.session_state.ollama_api_url, 145 | on_change=change_ollama_endpoint, 146 | ) 147 | if ollama.is_alive(): 148 | ollama.get_model_list() 149 | st.write("🟢 Ollama is running") 150 | st.selectbox('Local LLM', st.session_state.ollama_models, 151 | index=st.session_state.ollama_models.index(st.session_state.ollama_model_selected), 152 | help='Select locally deployed LLM from Ollama', 153 | on_change=change_ollama_model, 154 | key='ollama_model_name', # session_state key 155 | ) 156 | else: 157 | st.write("🔴 Ollama is not running") 158 | 159 | st.button( 160 | "Refresh models", 161 | on_click=ollama.get_model_list, 162 | help="Refresh the list of available models from the Ollama API.", 163 | ) 164 | 165 | else: # OpenAI, Zhipu, Moonshot, Deepseek 166 | st.subheader(f"Configure for {llm['provider']}") 167 | st.text_input( 168 | "Base URL", 169 | key="llm_api_endpoint", 170 | value=st.session_state[option + "_api_base"], 171 | on_change=change_llm_api_base, 172 | ) 173 | st.text_input( 174 | "API key", 175 | key="llm_api_key", 176 | value=st.session_state[option + "_api_key"], 177 | type="password", 178 | on_change=change_llm_api_key, 179 | ) 180 | st.selectbox('Choose LLM API', llm['models'], 181 | help='Choose LLMs API service', 182 | on_change=change_llm_api_model, 183 | key='llm_api_model', 184 | index=llm['models'].index(st.session_state[option + "_model_selected"]), 185 | ) 186 | 187 | def show_llm_instance(): 188 | create_llm_instance() 189 | if st.session_state.llm is not None: 190 | current_llm_info = CONFIG_STORE.get(key="current_llm_info") 191 | st.success("Current LLM instance: " + current_llm_info["service_provider"] + " / " + current_llm_info["model"]) 192 | else: 193 | st.warning("No LLM instance available") 194 | 195 | llm_configuration_page() 196 | 197 | show_llm_instance() 198 | 199 | st.caption("ThinkRAG supports `OpenAI` and all compatible LLM API like `DeepSeek`, `Moonshot` or `Zhipu`. You may specify the LLMs you want to use in the `config.py` file.") 200 | st.caption("It is recommended to use `Ollama` if you need run the system without an Internet connection. Plase refer to the Ollama docs to download and use Ollama models.") -------------------------------------------------------------------------------- /frontend/Model_Rerank.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from config import RERANKER_MODEL_PATH 3 | from server.stores.config_store import CONFIG_STORE 4 | 5 | st.header("Reranking Model") 6 | st.caption("Configure reranking models", 7 | help="Reranking is the process of reordering a list of items based on a set of criteria. In the context of search engines, reranking is used to improve the relevance of search results by taking into account additional information about the items being ranked.", 8 | ) 9 | 10 | def change_use_reranker(): 11 | st.session_state["current_llm_settings"]["use_reranker"] = st.session_state["use_reranker"] 12 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 13 | 14 | def change_top_n(): 15 | st.session_state["current_llm_settings"]["top_n"] = st.session_state["top_n"] 16 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 17 | 18 | def change_reranker_model(): 19 | st.session_state["current_llm_settings"]["reranker_model"] = st.session_state["selected_reranker_model"] 20 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 21 | 22 | reranking_settings = st.container(border=True) 23 | with reranking_settings: 24 | st.toggle("Use reranker", 25 | key="use_reranker", 26 | value= st.session_state["current_llm_settings"]["use_reranker"], 27 | on_change=change_use_reranker, 28 | ) 29 | if st.session_state["current_llm_settings"]["use_reranker"] == True: 30 | st.number_input( 31 | "Top N", 32 | min_value=1, 33 | max_value=st.session_state["current_llm_settings"]["top_k"], 34 | help="The number of most similar documents to retrieve in response to a query.", 35 | value=st.session_state["current_llm_settings"]["top_n"], 36 | key="top_n", 37 | on_change=change_top_n, 38 | ) 39 | 40 | reranker_model_list = list(RERANKER_MODEL_PATH.keys()) 41 | reranker_model = st.selectbox( 42 | "Reranking models", 43 | reranker_model_list, 44 | key="selected_reranker_model", 45 | index=reranker_model_list.index(st.session_state["current_llm_settings"]["reranker_model"]), 46 | on_change=change_reranker_model, 47 | ) 48 | 49 | st.caption("ThinkRAG supports most reranking models from `Hugging Face`. You may specify the models you want to use in the `config.py` file.") 50 | st.caption("It is recommended to download the models to the `localmodels` directory, in case you need run the system without an Internet connection. Plase refer to the instructions in `docs` directory.") -------------------------------------------------------------------------------- /frontend/Setting_Advanced.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from server.stores.config_store import CONFIG_STORE 3 | from frontend.state import create_llm_instance 4 | from config import RESPONSE_MODE 5 | 6 | st.header("Advanced settings") 7 | advanced_settings = st.container(border=True) 8 | 9 | def change_top_k(): 10 | st.session_state["current_llm_settings"]["top_k"] = st.session_state["top_k"] 11 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 12 | create_llm_instance() 13 | 14 | def change_temperature(): 15 | st.session_state["current_llm_settings"]["temperature"] = st.session_state["temperature"] 16 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 17 | create_llm_instance() 18 | 19 | def change_system_prompt(): 20 | st.session_state["current_llm_settings"]["system_prompt"] = st.session_state["system_prompt"] 21 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 22 | create_llm_instance() 23 | 24 | def change_response_mode(): 25 | st.session_state["current_llm_settings"]["response_mode"] = st.session_state["response_mode"] 26 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state["current_llm_settings"]) 27 | create_llm_instance() 28 | 29 | with advanced_settings: 30 | col_1, _, col_2 = st.columns([4, 2, 4]) 31 | with col_1: 32 | st.number_input( 33 | "Top K", 34 | min_value=1, 35 | max_value=100, 36 | help="The number of most similar documents to retrieve in response to a query.", 37 | value=st.session_state["current_llm_settings"]["top_k"], 38 | key="top_k", 39 | on_change=change_top_k, 40 | ) 41 | with col_2: 42 | st.select_slider( 43 | "Temperature", 44 | options=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], 45 | help="The temperature to use when generating responses. Higher temperatures result in more random responses.", 46 | value=st.session_state["current_llm_settings"]["temperature"], 47 | key="temperature", 48 | on_change=change_temperature, 49 | ) 50 | st.text_area( 51 | "System Prompt", 52 | help="The prompt to use when generating responses. The system prompt is used to provide context to the model.", 53 | value=st.session_state["current_llm_settings"]["system_prompt"], 54 | key="system_prompt", 55 | height=240, 56 | on_change=change_system_prompt, 57 | ) 58 | st.selectbox( 59 | "Response Mode", 60 | options=RESPONSE_MODE, 61 | help="Sets the Llama Index Query Engine response mode used when creating the Query Engine. Default: `compact`.", 62 | key="response_mode", 63 | index=RESPONSE_MODE.index(st.session_state["current_llm_settings"]["response_mode"]), # simple_summarize by default 64 | on_change=change_response_mode, 65 | ) 66 | 67 | # For debug purpost only 68 | def show_session_state(): 69 | st.write("") 70 | with st.expander("List of current application parameters"): 71 | state = dict(sorted(st.session_state.items())) 72 | st.write(state) 73 | 74 | # show_session_state() -------------------------------------------------------------------------------- /frontend/Storage.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from config import THINKRAG_ENV 3 | 4 | st.header("Storage") 5 | st.caption("All your data is stored in local file system or the database you configured.", 6 | help="You may change the storage settings in the config.py file.", 7 | ) 8 | 9 | embedding_settings = st.container(border=True) 10 | with embedding_settings: 11 | st.info("You are running ThinkRAG in " + THINKRAG_ENV + " mode.") 12 | st.dataframe(data={ 13 | "Storage Type": ["Vector Store","Doc Store","Index Store","Chat Store","Config Store"], 14 | "Development": ["Simple Vector Store","Simple Document Store","Simple Index Store","Simple Chat Store (in memory)","Simple KV Store"], 15 | "Production": ["Chroma","Redis","Redis","Redis","Simple KV Store"], 16 | #"Enterprise": ["Elasticsearch","MongoDB","MongoDB","Redis","Simple KV Store"], 17 | },hide_index=True) 18 | 19 | st.caption("You may change the storage settings in the config.py file.") 20 | st.caption("`Development Mode` uses local storage which means you need not install any extra tools. All the data is stored as local files in the 'storage' directory where you run ThinkRAG.") 21 | st.caption("`Production Mode`: is recommended to use for production on your laptop. You need a redis instance, either running locally or using a cloud service.") 22 | st.caption("If you want to deploy ThinkRAG on a server and handle large volume of data, please contact the author of ThinkRAG (wzdavid@gmail.com)") -------------------------------------------------------------------------------- /frontend/images/ThinkRAG_Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzdavid/ThinkRAG/989d6b3d4ea163e75f2dafacf44eccfee7af221c/frontend/images/ThinkRAG_Logo.png -------------------------------------------------------------------------------- /frontend/state.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import config as config 3 | from server.models import ollama 4 | from server.models.llm_api import create_openai_llm, check_openai_llm 5 | from server.models.ollama import create_ollama_llm 6 | from server.models.embedding import create_embedding_model 7 | from server.index import IndexManager 8 | from server.stores.config_store import CONFIG_STORE 9 | 10 | def find_api_by_model(model_name): 11 | for api_name, api_info in config.LLM_API_LIST.items(): 12 | if model_name in api_info['models']: 13 | return api_info 14 | 15 | # Initialize st.session_state 16 | def init_keys(): 17 | 18 | # Initialize LLM 19 | if "llm" not in st.session_state.keys(): 20 | st.session_state.llm = None 21 | 22 | # Initialize index 23 | if "index_manager" not in st.session_state.keys(): 24 | st.session_state.index_manager = IndexManager(config.DEFAULT_INDEX_NAME) 25 | 26 | # Initialize model selection 27 | if "ollama_api_url" not in st.session_state.keys(): 28 | st.session_state.ollama_api_url = config.OLLAMA_API_URL 29 | 30 | if "ollama_models" not in st.session_state.keys(): 31 | ollama.get_model_list() 32 | if (st.session_state.ollama_models is not None and len(st.session_state.ollama_models) > 0): 33 | st.session_state.ollama_model_selected = st.session_state.ollama_models[0] 34 | create_ollama_llm(st.session_state.ollama_model_selected) 35 | if "ollama_model_selected" not in st.session_state.keys(): 36 | st.session_state.ollama_model_selected = None 37 | if "llm_api_list" not in st.session_state.keys(): 38 | st.session_state.llm_api_list = [model for api in config.LLM_API_LIST.values() for model in api['models']] 39 | if "llm_api_selected" not in st.session_state.keys(): 40 | st.session_state.llm_api_selected = st.session_state.llm_api_list[0] 41 | if st.session_state.ollama_model_selected is None: 42 | api_object = find_api_by_model(st.session_state.llm_api_selected) 43 | create_openai_llm(st.session_state.llm_api_selected, api_object['api_base'], api_object['api_key']) 44 | 45 | # Initialize query engine 46 | if "query_engine" not in st.session_state.keys(): 47 | st.session_state.query_engine = None 48 | 49 | if "system_prompt" not in st.session_state.keys(): 50 | st.session_state.system_prompt = "Chat with me!" 51 | 52 | if "response_mode" not in st.session_state.keys(): 53 | response_mode_result = CONFIG_STORE.get(key="response_mode") 54 | if response_mode_result is not None: 55 | st.session_state.response_mode = response_mode_result["response_mode"] 56 | else: 57 | st.session_state.response_mode = config.DEFAULT_RESPONSE_MODE 58 | 59 | if "ollama_endpoint" not in st.session_state.keys(): 60 | st.session_state.ollama_endpoint = "http://localhost:11434" 61 | 62 | if "chunk_size" not in st.session_state.keys(): 63 | st.session_state.chunk_size = config.DEFAULT_CHUNK_SIZE 64 | 65 | if "chunk_overlap" not in st.session_state.keys(): 66 | st.session_state.chunk_overlap = config.DEFAULT_CHUNK_OVERLAP 67 | 68 | if "zh_title_enhance" not in st.session_state.keys(): 69 | st.session_state.zh_title_enhance = config.ZH_TITLE_ENHANCE 70 | 71 | if "max_tokens" not in st.session_state.keys(): 72 | st.session_state.max_tokens = 100 73 | 74 | if "top_p" not in st.session_state.keys(): 75 | st.session_state.top_p = 1.0 76 | 77 | # contents related to the knowledge base 78 | if "websites" not in st.session_state: 79 | st.session_state["websites"] = [] 80 | 81 | if 'uploaded_files' not in st.session_state: 82 | st.session_state.uploaded_files = [] 83 | if 'selected_files' not in st.session_state: 84 | st.session_state.selected_files = None 85 | 86 | # Initialize user data 87 | # TODO: supposed to be loaded from database 88 | st.session_state.user_id = "user_1" 89 | st.session_state.kb_id = "kb_1" 90 | st.session_state.kb_name = "My knowledge base" 91 | 92 | def init_llm_sp(): 93 | 94 | llm_options = list(config.LLM_API_LIST.keys()) 95 | 96 | # LLM service provider selection 97 | if "llm_service_provider_selected" not in st.session_state: 98 | sp = CONFIG_STORE.get(key="llm_service_provider_selected") 99 | if sp: 100 | st.session_state.llm_service_provider_selected = sp["llm_service_provider_selected"] 101 | else: 102 | st.session_state.llm_service_provider_selected = llm_options[0] 103 | 104 | def init_ollama_endpoint(): 105 | # Initialize Ollama endpoint 106 | if "ollama_api_url" not in st.session_state.keys(): 107 | ollama_api_url = CONFIG_STORE.get(key="Ollama_api_url") 108 | if ollama_api_url: 109 | st.session_state.ollama_api_url = ollama_api_url["Ollama_api_url"] 110 | else: 111 | st.session_state.ollama_api_url = config.LLM_API_LIST["Ollama"]["api_base"] 112 | 113 | # Initialize llm api model 114 | def init_api_model(sp): 115 | if sp != "Ollama": 116 | model_key = sp + "_model_selected" 117 | if model_key not in st.session_state.keys(): 118 | model_result = CONFIG_STORE.get(key=model_key) 119 | if model_result: 120 | st.session_state[model_key] = model_result[model_key] 121 | else: 122 | st.session_state[model_key] = config.LLM_API_LIST[sp]["models"][0] 123 | 124 | 125 | # Initialize llm api base 126 | def init_api_base(sp): 127 | if sp != "Ollama": 128 | api_base = sp + "_api_base" 129 | if api_base not in st.session_state.keys(): 130 | api_key_result = CONFIG_STORE.get(key=api_base) 131 | if api_key_result is not None: 132 | st.session_state[api_base] = api_key_result[api_base] 133 | else: 134 | st.session_state[api_base] = config.LLM_API_LIST[sp]["api_base"] 135 | 136 | # Initialize llm api key 137 | def init_api_key(sp): 138 | if sp != "Ollama": 139 | api_key = sp + "_api_key" 140 | if api_key not in st.session_state.keys(): 141 | api_key_result = CONFIG_STORE.get(key=api_key) 142 | if api_key_result is not None: 143 | st.session_state[api_key] = api_key_result[api_key] 144 | else: 145 | st.session_state[api_key] = config.LLM_API_LIST[sp]["api_key"] 146 | 147 | valid_key = api_key + "_valid" 148 | if valid_key not in st.session_state.keys(): 149 | valid_result = CONFIG_STORE.get(key=valid_key) 150 | if valid_result is None and st.session_state[api_key] is not None: 151 | is_valid = check_openai_llm(st.session_state[sp + "_model_selected"], config.LLM_API_LIST[sp]["api_base"], st.session_state[api_key]) 152 | CONFIG_STORE.put(key=valid_key, val={valid_key: is_valid}) 153 | st.session_state[valid_key] = is_valid 154 | else: 155 | st.session_state[valid_key] = valid_result[valid_key] 156 | 157 | # Initialize LLM settings, like temperature, system prompt, etc. 158 | def init_llm_settings(): 159 | if "current_llm_settings" not in st.session_state.keys(): 160 | current_llm_settings = CONFIG_STORE.get(key="current_llm_settings") 161 | if current_llm_settings: 162 | st.session_state.current_llm_settings = current_llm_settings 163 | else: 164 | st.session_state.current_llm_settings = { 165 | "temperature": config.TEMPERATURE, 166 | "system_prompt": config.SYSTEM_PROMPT, 167 | "top_k": config.TOP_K, 168 | "response_mode": config.DEFAULT_RESPONSE_MODE, 169 | "use_reranker": config.USE_RERANKER, 170 | "top_n": config.RERANKER_MODEL_TOP_N, 171 | "embedding_model": config.DEFAULT_EMBEDDING_MODEL, 172 | "reranker_model": config.DEFAULT_RERANKER_MODEL, 173 | } 174 | CONFIG_STORE.put(key="current_llm_settings", val=st.session_state.current_llm_settings) 175 | 176 | 177 | # Create LLM instance if there is related information 178 | def create_llm_instance(): 179 | current_llm_info = CONFIG_STORE.get(key="current_llm_info") 180 | if current_llm_info is not None: 181 | print("Current LLM info: ", current_llm_info) 182 | if current_llm_info["service_provider"] == "Ollama": 183 | if ollama.is_alive(): 184 | model_name = current_llm_info["model"] 185 | st.session_state.llm = ollama.create_ollama_llm( 186 | model=model_name, 187 | temperature=st.session_state.current_llm_settings["temperature"], 188 | system_prompt=st.session_state.current_llm_settings["system_prompt"], 189 | ) 190 | else: 191 | model_name = current_llm_info["model"] 192 | api_base = current_llm_info["api_base"] 193 | api_key = current_llm_info["api_key"] 194 | api_key_valid = current_llm_info["api_key_valid"] 195 | if api_key_valid: 196 | print("API key is valid when creating LLM instance") 197 | st.session_state.llm = create_openai_llm( 198 | model_name=model_name, 199 | api_base=api_base, 200 | api_key=api_key, 201 | temperature=st.session_state.current_llm_settings["temperature"], 202 | system_prompt=st.session_state.current_llm_settings["system_prompt"], 203 | ) 204 | else: 205 | print("API key is invalid when creating LLM instance") 206 | st.session_state.llm = None 207 | else: 208 | print("No current LLM infomation") 209 | st.session_state.llm = None 210 | 211 | def init_state(): 212 | init_keys() 213 | init_llm_sp() 214 | init_llm_settings() 215 | init_ollama_endpoint() 216 | sp = st.session_state.llm_service_provider_selected 217 | init_api_model(sp) 218 | init_api_key(sp) 219 | create_embedding_model(st.session_state["current_llm_settings"]["embedding_model"]) 220 | create_llm_instance() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | llama_index==0.11.19 2 | streamlit==1.39.0 3 | langchain==0.3.4 4 | langchain-community==0.3.3 5 | langchain_openai==0.2.3 6 | ollama==0.3.3 7 | llama-index-embeddings-huggingface==0.3.1 8 | llama-index-embeddings-langchain==0.2.1 9 | llama-index-llms-langchain==0.4.2 10 | llama-index-readers-web==0.2.4 11 | llama-index-retrievers-bm25==0.4.0 12 | llama-index-storage-kvstore-redis==0.2.0 13 | llama-index-vector-stores-chroma==0.2.1 14 | llama-index-vector-stores-elasticsearch==0.3.3 15 | llama-index-vector-stores-lancedb==0.2.4 16 | llama-index-llms-ollama==0.3.4 17 | llama-index-storage-chat_store-redis==0.3.2 18 | llama_index.storage.docstore.redis==0.2.0 19 | llama_index.storage.index_store.redis==0.3.0 20 | docx2txt==0.8 -------------------------------------------------------------------------------- /server/engine.py: -------------------------------------------------------------------------------- 1 | # Create and manage query/chat engine 2 | import config as config 3 | from server.models.reranker import create_reranker_model 4 | from server.prompt import text_qa_template, refine_template 5 | from server.retriever import SimpleFusionRetriever 6 | from llama_index.core.query_engine import RetrieverQueryEngine 7 | 8 | # Create a query engine 9 | def create_query_engine(index, 10 | top_k=config.TOP_K, 11 | response_mode=config.RESPONSE_MODE, 12 | use_reranker=config.USE_RERANKER, 13 | top_n=config.RERANKER_MODEL_TOP_N, 14 | reranker=config.DEFAULT_RERANKER_MODEL): 15 | # Customized query engine with hybrid search and reranker 16 | node_postprocessors = [create_reranker_model(model_name=reranker, top_n=top_n)] if use_reranker else [] 17 | retriever = SimpleFusionRetriever(vector_index=index, top_k=top_k) 18 | 19 | query_engine = RetrieverQueryEngine.from_args( 20 | retriever=retriever, 21 | text_qa_template=text_qa_template, 22 | refine_template=refine_template, 23 | node_postprocessors=node_postprocessors, 24 | response_mode=response_mode, # https://docs.llamaindex.ai/en/stable/api_reference/response_synthesizers/ 25 | verbose=True, 26 | streaming=True, 27 | ) 28 | 29 | return query_engine 30 | -------------------------------------------------------------------------------- /server/index.py: -------------------------------------------------------------------------------- 1 | # Index management - create, load and insert 2 | import os 3 | from llama_index.core import Settings, StorageContext, VectorStoreIndex 4 | from llama_index.core import load_index_from_storage, load_indices_from_storage 5 | from llama_index.core import VectorStoreIndex, SimpleDirectoryReader 6 | from server.utils.file import get_save_dir 7 | from server.stores.strage_context import STORAGE_CONTEXT 8 | from server.ingestion import AdvancedIngestionPipeline 9 | from config import DEV_MODE 10 | 11 | class IndexManager: 12 | def __init__(self, index_name): 13 | self.index_name: str = index_name 14 | self.storage_context: StorageContext = STORAGE_CONTEXT 15 | self.index_id: str = None 16 | self.index: VectorStoreIndex = None 17 | 18 | def check_index_exists(self): 19 | indices = load_indices_from_storage(self.storage_context) 20 | print(f"Loaded {len(indices)} indices") 21 | if len(indices) > 0: 22 | self.index = indices[0] 23 | self.index_id = indices[0].index_id 24 | return True 25 | else: 26 | return False 27 | 28 | def init_index(self, nodes): 29 | self.index = VectorStoreIndex(nodes, 30 | storage_context=self.storage_context, 31 | store_nodes_override=True) # note: no nodes in doc store if using vector database, set store_nodes_override=True to add nodes to doc store 32 | self.index_id = self.index.index_id 33 | if DEV_MODE: 34 | self.storage_context.persist() 35 | print(f"Created index {self.index.index_id}") 36 | return self.index 37 | 38 | def load_index(self): # TODO: load index based on index_id 39 | self.index = load_index_from_storage(self.storage_context) 40 | if not DEV_MODE: 41 | self.index._store_nodes_override = True 42 | print(f"Loaded index {self.index.index_id}") 43 | return self.index 44 | 45 | def insert_nodes(self, nodes): 46 | if self.index is not None: 47 | self.index.insert_nodes(nodes=nodes) 48 | if DEV_MODE: 49 | self.storage_context.persist() 50 | print(f"Inserted {len(nodes)} nodes into index {self.index.index_id}") 51 | else: 52 | self.init_index(nodes=nodes) 53 | return self.index 54 | 55 | # Build index based on documents under 'data' folder 56 | def load_dir(self, input_dir, chunk_size, chunk_overlap): 57 | Settings.chunk_size = chunk_size 58 | Settings.chunk_overlap = chunk_overlap 59 | documents = SimpleDirectoryReader(input_dir=input_dir, recursive=True).load_data() 60 | if len(documents) > 0: 61 | pipeline = AdvancedIngestionPipeline() 62 | nodes = pipeline.run(documents=documents) 63 | index = self.insert_nodes(nodes) 64 | return nodes 65 | else: 66 | print("No documents found") 67 | return [] 68 | 69 | # get file's directory and create index 70 | def load_files(self, uploaded_files, chunk_size, chunk_overlap): 71 | Settings.chunk_size = chunk_size 72 | Settings.chunk_overlap = chunk_overlap 73 | save_dir = get_save_dir() 74 | files = [os.path.join(save_dir, file["name"]) for file in uploaded_files] 75 | print(files) 76 | documents = SimpleDirectoryReader(input_files=files).load_data() 77 | if len(documents) > 0: 78 | pipeline = AdvancedIngestionPipeline() 79 | nodes = pipeline.run(documents=documents) 80 | index = self.insert_nodes(nodes) 81 | return nodes 82 | else: 83 | print("No documents found") 84 | return [] 85 | 86 | # Get URL and create index 87 | # https://docs.llamaindex.ai/en/stable/examples/data_connectors/WebPageDemo/ 88 | def load_websites(self, websites, chunk_size, chunk_overlap): 89 | Settings.chunk_size = chunk_size 90 | Settings.chunk_overlap = chunk_overlap 91 | 92 | from server.readers.beautiful_soup_web import BeautifulSoupWebReader 93 | documents = BeautifulSoupWebReader().load_data(websites) 94 | if len(documents) > 0: 95 | pipeline = AdvancedIngestionPipeline() 96 | nodes = pipeline.run(documents=documents) 97 | index = self.insert_nodes(nodes) 98 | return nodes 99 | else: 100 | print("No documents found") 101 | return [] 102 | 103 | # Delete a document and all related nodes 104 | def delete_ref_doc(self, ref_doc_id): 105 | self.index.delete_ref_doc(ref_doc_id=ref_doc_id, delete_from_docstore=True) 106 | self.storage_context.persist() 107 | print("Deleted document", ref_doc_id) -------------------------------------------------------------------------------- /server/ingestion.py: -------------------------------------------------------------------------------- 1 | # Import pipeline IngestionPipeline 2 | # https://docs.llamaindex.ai/en/stable/api_reference/ingestion/ 3 | # https://docs.llamaindex.ai/en/stable/examples/ingestion/advanced_ingestion_pipeline/ 4 | 5 | from llama_index.core import Settings 6 | from llama_index.core.ingestion import IngestionPipeline, DocstoreStrategy 7 | from server.splitters import ChineseTitleExtractor 8 | from server.stores.strage_context import STORAGE_CONTEXT 9 | from server.stores.ingestion_cache import INGESTION_CACHE 10 | 11 | class AdvancedIngestionPipeline(IngestionPipeline): 12 | def __init__( 13 | self, 14 | ): 15 | # Initialize the embedding model, text splitter 16 | embed_model = Settings.embed_model 17 | text_splitter = Settings.text_splitter 18 | 19 | # Call the super class's __init__ method with the necessary arguments 20 | super().__init__( 21 | transformations=[ 22 | text_splitter, 23 | embed_model, 24 | ChineseTitleExtractor(), # modified Chinese title enhance: zh_title_enhance 25 | ], 26 | docstore=STORAGE_CONTEXT.docstore, 27 | vector_store=STORAGE_CONTEXT.vector_store, 28 | cache=INGESTION_CACHE, 29 | docstore_strategy=DocstoreStrategy.UPSERTS, # UPSERTS: Update or insert 30 | ) 31 | 32 | # If you need to override the run method or add new methods, you can do so here 33 | def run(self, documents): 34 | print(f"Load {len(documents)} Documents") 35 | nodes = super().run(documents=documents) 36 | print(f"Ingested {len(nodes)} Nodes") 37 | return nodes -------------------------------------------------------------------------------- /server/models/embedding.py: -------------------------------------------------------------------------------- 1 | # Create embedding models 2 | import os 3 | from llama_index.core import Settings 4 | from llama_index.embeddings.huggingface import HuggingFaceEmbedding 5 | from config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_MODEL_PATH, MODEL_DIR 6 | from server.utils.hf_mirror import use_hf_mirror 7 | 8 | def create_embedding_model(model_name = DEFAULT_EMBEDDING_MODEL) -> HuggingFaceEmbedding: 9 | try: 10 | use_hf_mirror() 11 | model_path = EMBEDDING_MODEL_PATH[model_name] 12 | if MODEL_DIR is not None: 13 | path = f"./{MODEL_DIR}/{model_path}" 14 | if os.path.exists(path): # Use local models if the path exists 15 | model_path = path 16 | embed_model = HuggingFaceEmbedding(model_name=model_path) 17 | Settings.embed_model = embed_model 18 | print(f"created embed model: {model_path}") 19 | except Exception as e: 20 | print(f"An error occurred while creating the embedding model: {type(e).__name__}: {e}") 21 | Settings.embed_model = None 22 | 23 | return Settings.embed_model -------------------------------------------------------------------------------- /server/models/llm_api.py: -------------------------------------------------------------------------------- 1 | # Create LLM with API compatible with OpenAI 2 | from llama_index.core import Settings 3 | from langchain_openai import ChatOpenAI 4 | from llama_index.llms.langchain import LangChainLLM 5 | 6 | def create_openai_llm(model_name:str, api_base:str, api_key:str, temperature:float = 0.5, system_prompt:str = None) -> ChatOpenAI: 7 | try: 8 | llm = LangChainLLM( 9 | llm=ChatOpenAI( 10 | openai_api_base=api_base, 11 | openai_api_key=api_key, 12 | model_name=model_name, 13 | temperature=temperature, 14 | ), 15 | system_prompt=system_prompt, 16 | ) 17 | Settings.llm = llm 18 | return llm 19 | except Exception as e: 20 | print(f"An error occurred while creating the OpenAI compatibale model: {type(e).__name__}: {e}") 21 | return None 22 | 23 | def check_openai_llm(model_name, api_base, api_key) -> bool: 24 | # Make a simple API call to verify the key 25 | try: 26 | llm = ChatOpenAI( 27 | openai_api_base=api_base, 28 | openai_api_key=api_key, 29 | model_name=model_name, 30 | timeout=5, 31 | max_retries=1 32 | ) 33 | response = llm.invoke("Hello, World!") 34 | print(response) 35 | if response: 36 | return True 37 | else: 38 | return False 39 | except Exception as e: 40 | print(f"An error occurred while verifying the LLM API: {type(e).__name__}: {e}") 41 | return False 42 | -------------------------------------------------------------------------------- /server/models/ollama.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import streamlit as st 3 | from ollama import Client 4 | from llama_index.core import Settings 5 | from llama_index.llms.ollama import Ollama 6 | 7 | def is_alive(): 8 | try: 9 | response = requests.get(st.session_state.ollama_api_url) 10 | return response.status_code == 200 11 | except requests.exceptions.RequestException: 12 | print("Failed to connect to Ollama") 13 | return False 14 | 15 | def get_model_list(): 16 | st.session_state.ollama_models = [] 17 | if is_alive(): 18 | client = Client(host=st.session_state.ollama_api_url) 19 | response = client.list() 20 | models = response["models"] 21 | # Initialize the list of model names 22 | for model in models: 23 | st.session_state.ollama_models.append(model["name"]) 24 | return response["models"] 25 | else: 26 | print("Ollama is not alive") 27 | return None 28 | 29 | # Create Ollama LLM 30 | def create_ollama_llm(model:str, temperature:float = 0.5, system_prompt:str = None) -> Ollama: 31 | try: 32 | llm = Ollama( 33 | model=model, 34 | base_url=st.session_state.ollama_api_url, 35 | request_timeout=600, 36 | temperature=temperature, 37 | system_prompt=system_prompt, 38 | ) 39 | print(f"created ollama model for query: {model}") 40 | Settings.llm = llm 41 | return llm 42 | except Exception as e: 43 | print(f"An error occurred while creating Ollama LLM: {e}") 44 | return None 45 | -------------------------------------------------------------------------------- /server/models/reranker.py: -------------------------------------------------------------------------------- 1 | # Create Rerank model 2 | # https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/SentenceTransformerRerank/ 3 | import os 4 | from llama_index.core.postprocessor import SentenceTransformerRerank 5 | from config import DEFAULT_RERANKER_MODEL, RERANKER_MODEL_TOP_N, RERANKER_MODEL_PATH, MODEL_DIR 6 | from server.utils.hf_mirror import use_hf_mirror 7 | 8 | def create_reranker_model(model_name = DEFAULT_RERANKER_MODEL, top_n = RERANKER_MODEL_TOP_N) -> SentenceTransformerRerank: 9 | try: 10 | use_hf_mirror() 11 | model_path = RERANKER_MODEL_PATH[model_name] 12 | if MODEL_DIR is not None: 13 | path = f"./{MODEL_DIR}/{model_path}" 14 | if os.path.exists(path): # Use local models if the path exists 15 | model_path = path 16 | rerank_model = SentenceTransformerRerank(model=model_path, top_n=top_n) 17 | print(f"created rerank model: {model_name}") 18 | return rerank_model 19 | except Exception as e: 20 | return None -------------------------------------------------------------------------------- /server/prompt.py: -------------------------------------------------------------------------------- 1 | # https://docs.llamaindex.ai/en/stable/examples/customization/prompts/completion_prompts/ 2 | # https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/prompts/default_prompts.py 3 | # https://docs.llamaindex.ai/en/stable/module_guides/models/prompts/usage_pattern/ 4 | 5 | from llama_index.core import PromptTemplate 6 | 7 | text_qa_template_str = ( 8 | "以下为上下文信息\n" 9 | "---------------------\n" 10 | "{context_str}\n" 11 | "---------------------\n" 12 | "请根据上下文信息回答我的问题或回复我的指令。前面的上下文信息可能有用,也可能没用,你需要从我给出的上下文信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。我的问题或指令是什么语种,你就用什么语种回复。\n" 13 | "问题:{query_str}\n" 14 | "你的回复: " 15 | ) 16 | 17 | 18 | text_qa_template = PromptTemplate(text_qa_template_str) 19 | 20 | refine_template_str = ( 21 | "这是原本的问题: {query_str}\n" 22 | "我们已经提供了回答: {existing_answer}\n" 23 | "现在我们有机会改进这个回答 " 24 | "使用以下更多上下文(仅当需要用时)\n" 25 | "------------\n" 26 | "{context_msg}\n" 27 | "------------\n" 28 | "根据新的上下文, 请改进原来的回答。" 29 | "如果新的上下文没有用, 直接返回原本的回答。\n" 30 | "改进的回答: " 31 | ) 32 | refine_template = PromptTemplate(refine_template_str) 33 | -------------------------------------------------------------------------------- /server/readers/beautiful_soup_web.py: -------------------------------------------------------------------------------- 1 | """Beautiful Soup Web scraper.""" 2 | 3 | import logging 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | from urllib.parse import urljoin 6 | from datetime import datetime 7 | 8 | from llama_index.core.bridge.pydantic import PrivateAttr 9 | from llama_index.core.readers.base import BasePydanticReader 10 | from llama_index.core.schema import Document 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def _mpweixin_reader(soup: Any, **kwargs) -> Tuple[str, Dict[str, Any]]: 16 | """Extract text from Substack blog post.""" 17 | meta_tag_title = soup.find('meta', attrs={'property': 'og:title'}) 18 | title = meta_tag_title['content'] 19 | extra_info = { 20 | "title": title, 21 | #"Author": soup.select_one("span #js_author_name").getText(), 22 | } 23 | text = soup.select_one("div #page-content").getText() 24 | return text, extra_info 25 | 26 | 27 | DEFAULT_WEBSITE_EXTRACTOR: Dict[ 28 | str, Callable[[Any, str], Tuple[str, Dict[str, Any]]] 29 | ] = { 30 | "mp.weixin.qq.com": _mpweixin_reader, 31 | } 32 | 33 | 34 | class BeautifulSoupWebReader(BasePydanticReader): 35 | """BeautifulSoup web page reader. 36 | 37 | Reads pages from the web. 38 | Requires the `bs4` and `urllib` packages. 39 | 40 | Args: 41 | website_extractor (Optional[Dict[str, Callable]]): A mapping of website 42 | hostname (e.g. google.com) to a function that specifies how to 43 | extract text from the BeautifulSoup obj. See DEFAULT_WEBSITE_EXTRACTOR. 44 | """ 45 | 46 | is_remote: bool = True 47 | _website_extractor: Dict[str, Callable] = PrivateAttr() 48 | 49 | def __init__(self, website_extractor: Optional[Dict[str, Callable]] = None) -> None: 50 | super().__init__() 51 | self._website_extractor = website_extractor or DEFAULT_WEBSITE_EXTRACTOR 52 | 53 | @classmethod 54 | def class_name(cls) -> str: 55 | """Get the name identifier of the class.""" 56 | return "BeautifulSoupWebReader" 57 | 58 | def load_data( 59 | self, 60 | urls: List[str], 61 | custom_hostname: Optional[str] = None, 62 | include_url_in_text: Optional[bool] = True, 63 | ) -> List[Document]: 64 | """Load data from the urls. 65 | 66 | Args: 67 | urls (List[str]): List of URLs to scrape. 68 | custom_hostname (Optional[str]): Force a certain hostname in the case 69 | a website is displayed under custom URLs (e.g. Substack blogs) 70 | include_url_in_text (Optional[bool]): Include the reference url in the text of the document 71 | 72 | Returns: 73 | List[Document]: List of documents. 74 | 75 | """ 76 | from urllib.parse import urlparse 77 | 78 | import requests 79 | from bs4 import BeautifulSoup 80 | 81 | documents = [] 82 | for url in urls: 83 | try: 84 | page = requests.get(url) 85 | hostname = custom_hostname or urlparse(url).hostname or "" 86 | 87 | soup = BeautifulSoup(page.content, "html.parser") 88 | 89 | data = "" 90 | extra_info = { 91 | "title": soup.select_one("title"), 92 | "url_source": url, 93 | "creation_date": datetime.now().date().isoformat(), # Convert datetime to ISO format string 94 | } 95 | if hostname in self._website_extractor: 96 | data, metadata = self._website_extractor[hostname]( 97 | soup=soup, url=url, include_url_in_text=include_url_in_text 98 | ) 99 | extra_info.update(metadata) 100 | 101 | else: 102 | data = soup.getText() 103 | 104 | documents.append(Document(text=data, id_=url, extra_info=extra_info)) 105 | except Exception: 106 | print(f"Could not scrape {url}") 107 | raise ValueError(f"One of the inputs is not a valid url: {url}") 108 | 109 | return documents 110 | -------------------------------------------------------------------------------- /server/readers/jina_web.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict, Callable 2 | from datetime import datetime 3 | 4 | import requests, re 5 | from llama_index.core.readers.base import BasePydanticReader 6 | from llama_index.core.schema import Document 7 | 8 | 9 | class JinaWebReader(BasePydanticReader): 10 | """Jina web page reader. 11 | 12 | Reads pages from the web. 13 | 14 | """ 15 | 16 | def __init__(self) -> None: 17 | """Initialize with parameters.""" 18 | 19 | def load_data(self, urls: List[str]) -> List[Document]: 20 | """Load data from the input directory. 21 | 22 | Args: 23 | urls (List[str]): List of URLs to scrape. 24 | 25 | Returns: 26 | List[Document]: List of documents. 27 | 28 | """ 29 | if not isinstance(urls, list): 30 | raise ValueError("urls must be a list of strings.") 31 | 32 | documents = [] 33 | for url in urls: 34 | new_url = "https://r.jina.ai/" + url 35 | response = requests.get(new_url) 36 | text = response.text 37 | 38 | # Extract Title 39 | title_match = re.search(r"Title:\s*(.*)", text) 40 | title = title_match.group(1) if title_match else None 41 | 42 | # Extract URL Source 43 | url_match = re.search(r"URL Source:\s*(.*)", text) 44 | url_source = url_match.group(1) if url_match else None 45 | 46 | # Extract Markdown Content 47 | markdown_match = re.search(r"Markdown Content:\s*(.*)", text, re.DOTALL) 48 | markdown_content = markdown_match.group(1).strip() if markdown_match else None 49 | 50 | # Compose metadata 51 | metadata: Dict = { 52 | "title": title, 53 | "url_source": url_source, 54 | "creation_date": datetime.now().date().isoformat(), # Convert datetime to ISO format string 55 | } 56 | 57 | documents.append(Document(text=markdown_content, id_=url, metadata=metadata or {})) 58 | 59 | return documents 60 | -------------------------------------------------------------------------------- /server/retriever.py: -------------------------------------------------------------------------------- 1 | # Retriever method 2 | 3 | from llama_index.core.retrievers import BaseRetriever 4 | from llama_index.core.retrievers import VectorIndexRetriever 5 | from llama_index.retrievers.bm25 import BM25Retriever 6 | 7 | # A simple BM25 retrieval method, customized for document storage and tokenization 8 | 9 | # BM25Retriever's default tokenizer does not support Chinese 10 | # Reference:https://github.com/run-llama/llama_index/issues/13866 11 | 12 | import jieba 13 | from typing import List 14 | def chinese_tokenizer(text: str) -> List[str]: 15 | return list(jieba.cut(text)) 16 | 17 | class SimpleBM25Retriever(BM25Retriever): 18 | @classmethod 19 | def from_defaults(cls, index, similarity_top_k, **kwargs) -> "BM25Retriever": 20 | docstore = index.docstore 21 | return BM25Retriever.from_defaults( 22 | docstore=docstore, similarity_top_k=similarity_top_k, verbose=True, 23 | tokenizer=chinese_tokenizer, **kwargs 24 | ) 25 | 26 | # A simple hybrid retriever method 27 | # Reference:https://docs.llamaindex.ai/en/stable/examples/retrievers/bm25_retriever/ 28 | 29 | class SimpleHybridRetriever(BaseRetriever): 30 | def __init__(self, vector_index, top_k=2): 31 | self.top_k = top_k 32 | 33 | # Build vector retriever from vector index 34 | self.vector_retriever = VectorIndexRetriever( 35 | index=vector_index, similarity_top_k=top_k, verbose=True, 36 | ) 37 | 38 | # Build BM25 retriever from document storage 39 | self.bm25_retriever = SimpleBM25Retriever.from_defaults( 40 | index=vector_index, similarity_top_k=top_k, 41 | ) 42 | 43 | super().__init__() 44 | 45 | def _retrieve(self, query, **kwargs): 46 | bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs) 47 | 48 | # the score is related to the query and may exceed 1, thus normalization is required 49 | # calculate min and max value 50 | min_score = min(item.score for item in bm25_nodes) 51 | max_score = max(item.score for item in bm25_nodes) 52 | 53 | # normalize score 54 | normalized_data = [(item.score - min_score) / (max_score - min_score) for item in bm25_nodes] 55 | 56 | # Assign normalized score back to the original object 57 | for item, normalized_score in zip(bm25_nodes, normalized_data): 58 | item.score = normalized_score 59 | 60 | vector_nodes = self.vector_retriever.retrieve(query, **kwargs) 61 | 62 | # Merge two retrieval results, remove duplicates, and return only the Top_K results 63 | all_nodes = [] 64 | node_ids = set() 65 | count = 0 66 | for n in vector_nodes + bm25_nodes: 67 | if n.node.node_id not in node_ids: 68 | all_nodes.append(n) 69 | node_ids.add(n.node.node_id) 70 | count += 1 71 | if count >= self.top_k: 72 | break 73 | for node in all_nodes: 74 | print(f"Hybrid Retrieved Node: {node.node_id} - Score: {node.score:.2f} - {node.text[:10]}...\n-----") 75 | return all_nodes 76 | 77 | # Fusion retriever method 78 | # Reference: https://docs.llamaindex.ai/en/stable/examples/retrievers/relative_score_dist_fusion/ 79 | # https://medium.com/plain-simple-software/distribution-based-score-fusion-dbsf-a-new-approach-to-vector-search-ranking-f87c37488b18 80 | # https://docs.llamaindex.ai/en/stable/examples/low_level/fusion_retriever/?h=retrieverqueryengine 81 | from llama_index.core.retrievers import QueryFusionRetriever 82 | from enum import Enum 83 | 84 | # Three different modes, from LlamaIndex's source code 85 | class FUSION_MODES(str, Enum): 86 | RECIPROCAL_RANK = "reciprocal_rerank" # apply reciprocal rank fusion 87 | RELATIVE_SCORE = "relative_score" # apply relative score fusion 88 | DIST_BASED_SCORE = "dist_based_score" # apply distance-based score fusion 89 | SIMPLE = "simple" # simple re-ordering of results based on original scores 90 | 91 | class SimpleFusionRetriever(QueryFusionRetriever): 92 | def __init__(self, vector_index, top_k=2, mode=FUSION_MODES.DIST_BASED_SCORE): 93 | self.top_k = top_k 94 | self.mode = mode 95 | 96 | # Build vector retriever from vector index 97 | self.vector_retriever = VectorIndexRetriever( 98 | index=vector_index, similarity_top_k=top_k, verbose=True, 99 | ) 100 | 101 | # Build BM25 retriever from document storage 102 | self.bm25_retriever = SimpleBM25Retriever.from_defaults( 103 | index=vector_index, similarity_top_k=top_k, 104 | ) 105 | 106 | super().__init__( 107 | [self.vector_retriever, self.bm25_retriever], 108 | retriever_weights=[0.6, 0.4], 109 | similarity_top_k=top_k, 110 | num_queries=1, # set this to 1 to disable query generation 111 | mode=mode, 112 | use_async=True, 113 | verbose=True, 114 | ) -------------------------------------------------------------------------------- /server/splitters/__init__.py: -------------------------------------------------------------------------------- 1 | from .chinese_text_splitter import ChineseTextSplitter 2 | from .zh_title_enhance import ChineseTitleExtractor 3 | from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter -------------------------------------------------------------------------------- /server/splitters/chinese_recursive_text_splitter.py: -------------------------------------------------------------------------------- 1 | # Chinese recursive text splitter 2 | # Source:LangchainChatChat, QAnything 3 | 4 | import re 5 | from typing import List, Optional, Any 6 | from langchain.text_splitter import RecursiveCharacterTextSplitter 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def _split_text_with_regex_from_end( 13 | text: str, separator: str, keep_separator: bool 14 | ) -> List[str]: 15 | # Now that we have the separator, split the text 16 | if separator: 17 | if keep_separator: 18 | # The parentheses in the pattern keep the delimiters in the result. 19 | _splits = re.split(f"({separator})", text) 20 | splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] 21 | if len(_splits) % 2 == 1: 22 | splits += _splits[-1:] 23 | # splits = [_splits[0]] + splits 24 | else: 25 | splits = re.split(separator, text) 26 | else: 27 | splits = list(text) 28 | return [s for s in splits if s != ""] 29 | 30 | 31 | class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): 32 | def __init__( 33 | self, 34 | separators: Optional[List[str]] = None, 35 | keep_separator: bool = True, 36 | is_separator_regex: bool = True, 37 | **kwargs: Any, 38 | ) -> None: 39 | """Create a new TextSplitter.""" 40 | super().__init__(keep_separator=keep_separator, **kwargs) 41 | self._separators = separators or [ 42 | "\n\n", 43 | "\n", 44 | "。|!|?", 45 | "\.\s|\!\s|\?\s", 46 | ";|;\s", 47 | ",|,\s" 48 | ] 49 | self._is_separator_regex = is_separator_regex 50 | 51 | def _split_text(self, text: str, separators: List[str]) -> List[str]: 52 | """Split incoming text and return chunks.""" 53 | final_chunks = [] 54 | # Get appropriate separator to use 55 | separator = separators[-1] 56 | new_separators = [] 57 | for i, _s in enumerate(separators): 58 | _separator = _s if self._is_separator_regex else re.escape(_s) 59 | if _s == "": 60 | separator = _s 61 | break 62 | if re.search(_separator, text): 63 | separator = _s 64 | new_separators = separators[i + 1:] 65 | break 66 | 67 | _separator = separator if self._is_separator_regex else re.escape(separator) 68 | splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) 69 | 70 | # Now go merging things, recursively splitting longer texts. 71 | _good_splits = [] 72 | _separator = "" if self._keep_separator else separator 73 | for s in splits: 74 | if self._length_function(s) < self._chunk_size: 75 | _good_splits.append(s) 76 | else: 77 | if _good_splits: 78 | merged_text = self._merge_splits(_good_splits, _separator) 79 | final_chunks.extend(merged_text) 80 | _good_splits = [] 81 | if not new_separators: 82 | final_chunks.append(s) 83 | else: 84 | other_info = self._split_text(s, new_separators) 85 | final_chunks.extend(other_info) 86 | if _good_splits: 87 | merged_text = self._merge_splits(_good_splits, _separator) 88 | final_chunks.extend(merged_text) 89 | return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""] 90 | 91 | 92 | if __name__ == "__main__": 93 | text_splitter = ChineseRecursiveTextSplitter( 94 | keep_separator=True, 95 | is_separator_regex=True, 96 | chunk_size=50, 97 | chunk_overlap=0 98 | ) 99 | ls = [ 100 | """中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""", 101 | ] 102 | # text = """""" 103 | for inum, text in enumerate(ls): 104 | print(inum) 105 | chunks = text_splitter.split_text(text) 106 | for chunk in chunks: 107 | print(chunk) 108 | -------------------------------------------------------------------------------- /server/splitters/chinese_text_splitter.py: -------------------------------------------------------------------------------- 1 | # Chinese text splitter 2 | # Source:LangchainChatChat, QAnything 3 | 4 | from langchain.text_splitter import CharacterTextSplitter 5 | import re 6 | from typing import List 7 | 8 | class ChineseTextSplitter(CharacterTextSplitter): 9 | def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): 10 | super().__init__(**kwargs) 11 | self.pdf = pdf 12 | self.sentence_size = sentence_size 13 | 14 | def split_text1(self, text: str) -> List[str]: 15 | if self.pdf: 16 | text = re.sub(r"\n{3,}", "\n", text) 17 | text = re.sub('\s', ' ', text) 18 | text = text.replace("\n\n", "") 19 | sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; 20 | sent_list = [] 21 | for ele in sent_sep_pattern.split(text): 22 | if sent_sep_pattern.match(ele) and sent_list: 23 | sent_list[-1] += ele 24 | elif ele: 25 | sent_list.append(ele) 26 | return sent_list 27 | 28 | def split_text(self, text: str) -> List[str]: ## Need further logical optimization here 29 | if self.pdf: 30 | text = re.sub(r"\n{3,}", r"\n", text) 31 | text = re.sub('\s', " ", text) 32 | text = re.sub("\n\n", "", text) 33 | 34 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # Single-character delimiter 35 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # English ellipsis 36 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # Chinese ellipsis 37 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) 38 | # If there is an ending punctuation before the double quotes, then the double quotes are considered to be the end of the sentence. 39 | # Place the sentence delimiter \n after the double quotes, and be aware that the double quotes in the previous sentences are preserved. 40 | text = text.rstrip() # Remove the extra \n at the end of the paragraph(if any) 41 | # Semicolons was not considered in this case, along with dashes and English double quotes. If needed, all we need is some simple adjustment. 42 | ls = [i for i in text.split("\n") if i] 43 | for ele in ls: 44 | if len(ele) > self.sentence_size: 45 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) 46 | ele1_ls = ele1.split("\n") 47 | for ele_ele1 in ele1_ls: 48 | if len(ele_ele1) > self.sentence_size: 49 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) 50 | ele2_ls = ele_ele2.split("\n") 51 | for ele_ele2 in ele2_ls: 52 | if len(ele_ele2) > self.sentence_size: 53 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) 54 | ele2_id = ele2_ls.index(ele_ele2) 55 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ 56 | ele2_id + 1:] 57 | ele_id = ele1_ls.index(ele_ele1) 58 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] 59 | 60 | id = ls.index(ele) 61 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] 62 | return ls 63 | -------------------------------------------------------------------------------- /server/splitters/zh_title_enhance.py: -------------------------------------------------------------------------------- 1 | # Chinese title enhance 2 | # Source:LangchainChatChat, QAnything 3 | 4 | from llama_index.core.schema import BaseNode # modified based on Document in Langchain 5 | from typing import List 6 | import re 7 | 8 | 9 | def under_non_alpha_ratio(text: str, threshold: float = 0.5): 10 | """Checks if the proportion of non-alpha characters in the text snippet exceeds a given 11 | threshold. This helps prevent text like "-----------BREAK---------" from being tagged 12 | as a title or narrative text. The ratio does not count spaces. 13 | 14 | Parameters 15 | ---------- 16 | text 17 | The input string to test 18 | threshold 19 | If the proportion of non-alpha characters exceeds this threshold, the function 20 | returns False 21 | """ 22 | if len(text) == 0: 23 | return False 24 | 25 | alpha_count = len([char for char in text if char.strip() and char.isalpha()]) 26 | total_count = len([char for char in text if char.strip()]) 27 | try: 28 | ratio = alpha_count / total_count 29 | return ratio < threshold 30 | except: 31 | return False 32 | 33 | 34 | def is_possible_title( 35 | text: str, 36 | title_max_word_length: int = 20, 37 | non_alpha_threshold: float = 0.5, 38 | ) -> bool: 39 | """Checks to see if the text passes all of the checks for a valid title. 40 | 41 | Parameters 42 | ---------- 43 | text 44 | The input text to check 45 | title_max_word_length 46 | The maximum number of words a title can contain 47 | non_alpha_threshold 48 | The minimum number of alpha characters the text needs to be considered a title 49 | """ 50 | 51 | # If the text length is zero, it is not a title 52 | if len(text) == 0: 53 | print("Not a title. Text is empty.") 54 | return False 55 | 56 | # If the text has punctuation, it is not a title 57 | ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z" 58 | ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN) 59 | if ENDS_IN_PUNCT_RE.search(text) is not None: 60 | return False 61 | 62 | # The text length must not exceed the set value, which is set to be 20 by default. 63 | # NOTE(robinson) - splitting on spaces here instead of word tokenizing because it 64 | # is less expensive and actual tokenization doesn't add much value for the length check 65 | if len(text) > title_max_word_length: 66 | return False 67 | 68 | # The ratio of numbers in the text should not be too high, otherwise it is not a title. 69 | if under_non_alpha_ratio(text, threshold=non_alpha_threshold): 70 | return False 71 | 72 | # NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles 73 | if text.endswith((",", ".", ",", "。")): 74 | return False 75 | 76 | if text.isnumeric(): 77 | print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore 78 | return False 79 | 80 | # "The initial characters should contain numbers, typically within the first 5 characters by default." 81 | if len(text) < 5: 82 | text_5 = text 83 | else: 84 | text_5 = text[:5] 85 | alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5)))) 86 | if not alpha_in_text_5: 87 | return False 88 | 89 | return True 90 | 91 | 92 | def zh_title_enhance(docs: List[BaseNode]) -> List[BaseNode]: # modified based on Document in Langchain 93 | title = None 94 | if len(docs) > 0: 95 | for doc in docs: 96 | if is_possible_title(doc.text): # modified based on doc.page_content in Langchain 97 | doc.metadata['category'] = 'cn_Title' 98 | title = doc.text 99 | elif title: 100 | doc.text = f"下文与({title})有关。{doc.text}" 101 | return docs 102 | else: 103 | print("文件不存在") 104 | 105 | # The following is an encapsulation based on LlamaIndex 106 | 107 | import re 108 | from llama_index.core.schema import TransformComponent 109 | 110 | class ChineseTitleExtractor(TransformComponent): 111 | def __call__(self, nodes, **kwargs): 112 | nodes = zh_title_enhance(nodes) 113 | return nodes -------------------------------------------------------------------------------- /server/stores/chat_store.py: -------------------------------------------------------------------------------- 1 | # Chat Store 2 | 3 | from config import DEV_MODE, REDIS_URI, CHAT_STORE_KEY 4 | 5 | def create_chat_memory(): 6 | 7 | if DEV_MODE: 8 | # Development environment: SimpleChatStore 9 | # https://docs.llamaindex.ai/en/stable/module_guides/storing/chat_stores/ 10 | from llama_index.core.storage.chat_store import SimpleChatStore 11 | from llama_index.core.memory import ChatMemoryBuffer 12 | 13 | simple_chat_store = SimpleChatStore() 14 | 15 | simple_chat_memory = ChatMemoryBuffer.from_defaults( 16 | token_limit=3000, 17 | chat_store=simple_chat_store, 18 | chat_store_key=CHAT_STORE_KEY, 19 | ) 20 | return simple_chat_memory 21 | else: 22 | # Production environment: Redis 23 | # https://docs.llamaindex.ai/en/stable/examples/vector_stores/RedisIndexDemo/ 24 | 25 | # Start redis locally: 26 | # docker run --name redis-vecdb -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest 27 | 28 | from llama_index.core.memory import ChatMemoryBuffer 29 | from llama_index.storage.chat_store.redis import RedisChatStore 30 | 31 | redis_chat_store = RedisChatStore(redis_url=REDIS_URI, ttl=3600) 32 | 33 | redis_chat_memory = ChatMemoryBuffer.from_defaults( 34 | token_limit=3000, 35 | chat_store=redis_chat_store, 36 | chat_store_key=CHAT_STORE_KEY, 37 | ) 38 | return redis_chat_memory 39 | 40 | CHAT_MEMORY = create_chat_memory() -------------------------------------------------------------------------------- /server/stores/config_store.py: -------------------------------------------------------------------------------- 1 | # Config Store 2 | # Save configuration in local kv store or database 3 | 4 | import os 5 | from typing import Optional, Dict 6 | from llama_index.core.storage.kvstore import SimpleKVStore 7 | from config import STORAGE_DIR, CONFIG_STORE_FILE 8 | 9 | DATA_TYPE = Dict[str, Dict[str, dict]] 10 | 11 | PERSISIT_PATH = "./" + STORAGE_DIR + "/" + CONFIG_STORE_FILE 12 | 13 | class LocalKVStore(SimpleKVStore): 14 | #Simple Key-Value store with local persistent. 15 | 16 | def __init__( 17 | self, 18 | data: Optional[DATA_TYPE] = None, 19 | ) -> None: 20 | """Init a SimpleKVStore.""" 21 | super().__init__(data) 22 | 23 | def put(self, key: str, val: dict) -> None: 24 | """Put a key-value pair into the store.""" 25 | super().put(key=key, val=val) 26 | super().persist(persist_path=self.persist_path) 27 | 28 | def delete(self, key: str) -> bool: 29 | """Delete a value from the store.""" 30 | try: 31 | super().delete(key) 32 | super().persist(persist_path=self.persist_path) 33 | return True 34 | except KeyError: 35 | return False 36 | 37 | @classmethod 38 | def from_persist_path( 39 | cls, persist_path: str = PERSISIT_PATH 40 | ) -> "LocalKVStore": 41 | """Load a SimpleKVStore from a persist path and filesystem.""" 42 | cls.persist_path = persist_path 43 | if (os.path.exists(persist_path)): 44 | return super().from_persist_path(persist_path=persist_path) 45 | else: 46 | return cls({}) 47 | 48 | CONFIG_STORE = LocalKVStore.from_persist_path() -------------------------------------------------------------------------------- /server/stores/doc_store.py: -------------------------------------------------------------------------------- 1 | # Document Store 2 | # https://docs.llamaindex.ai/en/stable/examples/docstore/MongoDocstoreDemo/ 3 | # https://docs.llamaindex.ai/en/stable/examples/docstore/RedisDocstoreIndexStoreDemo/ 4 | import config 5 | 6 | if config.THINKRAG_ENV == "production": 7 | from llama_index.storage.docstore.redis import RedisDocumentStore 8 | DOC_STORE = RedisDocumentStore.from_host_and_port( 9 | host=config.REDIS_HOST, port=config.REDIS_PORT, namespace="think" 10 | ) 11 | elif config.THINKRAG_ENV == "development": 12 | from llama_index.core.storage.docstore import SimpleDocumentStore 13 | DOC_STORE = SimpleDocumentStore() -------------------------------------------------------------------------------- /server/stores/index_store.py: -------------------------------------------------------------------------------- 1 | # Index Store 2 | import config 3 | 4 | if config.THINKRAG_ENV == "production": 5 | from llama_index.storage.index_store.redis import RedisIndexStore 6 | INDEX_STORE = RedisIndexStore.from_host_and_port( 7 | host=config.REDIS_HOST, port=config.REDIS_PORT, namespace="think" 8 | ) 9 | elif config.THINKRAG_ENV == "development": 10 | from llama_index.core.storage.index_store import SimpleIndexStore 11 | INDEX_STORE = SimpleIndexStore() -------------------------------------------------------------------------------- /server/stores/ingestion_cache.py: -------------------------------------------------------------------------------- 1 | from llama_index.core.ingestion import IngestionCache 2 | from llama_index.storage.kvstore.redis import RedisKVStore as RedisCache 3 | from config import REDIS_URI, DEV_MODE 4 | 5 | redis_cache=IngestionCache( 6 | cache=RedisCache(redis_uri=REDIS_URI), 7 | collection="redis_pipeline_cache", 8 | ) 9 | 10 | INGESTION_CACHE = redis_cache if not DEV_MODE else None -------------------------------------------------------------------------------- /server/stores/strage_context.py: -------------------------------------------------------------------------------- 1 | # Store context 2 | # https://docs.llamaindex.ai/en/stable/module_guides/storing/customization/ 3 | 4 | from llama_index.core import StorageContext 5 | from config import THINKRAG_ENV 6 | from server.stores.doc_store import DOC_STORE 7 | from server.stores.vector_store import VECTOR_STORE 8 | from server.stores.index_store import INDEX_STORE 9 | 10 | def create_storage_context(): 11 | if THINKRAG_ENV == "development": 12 | # Development environment 13 | import os 14 | from config import STORAGE_DIR 15 | persist_dir = "./" + STORAGE_DIR 16 | if os.path.exists(STORAGE_DIR + "/docstore.json"): 17 | dev_storage_context = StorageContext.from_defaults( 18 | persist_dir=persist_dir # Load from the persist directory 19 | ) 20 | print(f"Loaded storage context from {persist_dir}") 21 | return dev_storage_context 22 | else: 23 | dev_storage_context = StorageContext.from_defaults() # Created new storage context, need persistence 24 | print(f"Created new storage context") 25 | return dev_storage_context 26 | elif THINKRAG_ENV == "production": 27 | pro_storage_context = StorageContext.from_defaults( 28 | docstore=DOC_STORE, 29 | index_store=INDEX_STORE, 30 | vector_store=VECTOR_STORE, 31 | ) 32 | return pro_storage_context 33 | 34 | STORAGE_CONTEXT = create_storage_context() -------------------------------------------------------------------------------- /server/stores/vector_store.py: -------------------------------------------------------------------------------- 1 | # Vector database 2 | 3 | # https://docs.llamaindex.ai/en/stable/examples/vector_stores/ChromaIndexDemo/ 4 | # https://docs.llamaindex.ai/en/stable/module_guides/storing/customization/ 5 | 6 | import config 7 | 8 | def create_vector_store(type=config.DEFAULT_VS_TYPE): 9 | if type == "chroma": 10 | # Vector database Chroma 11 | 12 | # Install Chroma vector database 13 | """ pip install chromadb """ 14 | 15 | import chromadb 16 | from llama_index.vector_stores.chroma import ChromaVectorStore 17 | 18 | db = chromadb.PersistentClient(path=".chroma") 19 | chroma_collection = db.get_or_create_collection("think") 20 | chroma_vector_store = ChromaVectorStore(chroma_collection=chroma_collection) 21 | return chroma_vector_store 22 | elif type == "es": 23 | # Todo: use Metadata Filters 24 | 25 | # Vector database ES 26 | # https://docs.llamaindex.ai/en/stable/examples/vector_stores/ElasticsearchIndexDemo/ 27 | 28 | # Run ES locally 29 | """ docker run -p 9200:9200 \ 30 | -e "discovery.type=single-node" \ 31 | -e "xpack.security.enabled=false" \ 32 | -e "xpack.license.self_generated.type=trial" \ 33 | docker.elastic.co/elasticsearch/elasticsearch:8.13.2 """ 34 | 35 | from llama_index.vector_stores.elasticsearch import ElasticsearchStore 36 | from llama_index.vector_stores.elasticsearch import AsyncDenseVectorStrategy 37 | 38 | es_vector_store = ElasticsearchStore( 39 | es_url="http://localhost:9200", 40 | index_name="think", 41 | retrieval_strategy=AsyncDenseVectorStrategy(hybrid=False), 42 | ) 43 | return es_vector_store 44 | elif type == "lancedb": 45 | # Vector database LanceDB 46 | # https://docs.llamaindex.ai/en/stable/examples/vector_stores/LanceDBIndexDemo/ 47 | # https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/ 48 | from llama_index.vector_stores.lancedb import LanceDBVectorStore 49 | from lancedb.rerankers import LinearCombinationReranker 50 | reranker = LinearCombinationReranker(weight=0.9) 51 | 52 | lance_vector_store = LanceDBVectorStore( 53 | uri=".lancedb", mode="overwrite", query_type="vector", reranker=reranker 54 | ) 55 | return lance_vector_store 56 | elif type == "simple": 57 | from llama_index.core.vector_stores import SimpleVectorStore 58 | return SimpleVectorStore() 59 | else: 60 | raise ValueError(f"Invalid vector store type: {type}") 61 | 62 | if config.THINKRAG_ENV == "production": 63 | VECTOR_STORE = create_vector_store(type="chroma") 64 | else: 65 | VECTOR_STORE = create_vector_store(type="simple") -------------------------------------------------------------------------------- /server/text_splitter.py: -------------------------------------------------------------------------------- 1 | # Text splitter 2 | 3 | from config import DEV_MODE 4 | from llama_index.core import Settings 5 | 6 | def create_text_splitter(chunk_size=2048, chunk_overlap=512): 7 | if DEV_MODE: 8 | # Development environment 9 | # SentenceSplitter 10 | from llama_index.core.node_parser import SentenceSplitter 11 | 12 | sentence_splitter = SentenceSplitter( 13 | chunk_size=chunk_size, 14 | chunk_overlap=chunk_overlap, 15 | ) 16 | 17 | return sentence_splitter 18 | 19 | else: 20 | # Production environment 21 | # SpacyTextSplitter 22 | # https://zhuanlan.zhihu.com/p/638827267 23 | # pip install spacy 24 | # spacy download zh_core_web_sm 25 | from langchain.text_splitter import SpacyTextSplitter 26 | from llama_index.core.node_parser import LangchainNodeParser 27 | 28 | spacy_text_splitter = LangchainNodeParser(SpacyTextSplitter( 29 | pipeline="zh_core_web_sm", 30 | chunk_size=chunk_size, 31 | chunk_overlap=chunk_overlap, 32 | )) 33 | 34 | return spacy_text_splitter 35 | 36 | Settings.text_splitter = create_text_splitter() -------------------------------------------------------------------------------- /server/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import DATA_DIR 3 | 4 | def get_save_dir(): 5 | save_dir = os.getcwd() + "/" + DATA_DIR 6 | return save_dir 7 | 8 | def save_uploaded_file(uploaded_file: bytes, save_dir: str): 9 | try: 10 | if not os.path.exists(save_dir): 11 | os.makedirs(save_dir) 12 | path = os.path.join(save_dir, uploaded_file.name) 13 | with open(path, "wb") as f: 14 | f.write(uploaded_file.getbuffer()) 15 | print(f"已保存 {path}") 16 | except Exception as e: 17 | print(f"Error saving upload to disk: {e}") -------------------------------------------------------------------------------- /server/utils/hf_mirror.py: -------------------------------------------------------------------------------- 1 | # Setting up a HuggingFace mirror 2 | 3 | def use_hf_mirror(): 4 | import os 5 | from config import HF_ENDPOINT 6 | os.environ['HF_ENDPOINT'] = HF_ENDPOINT 7 | print(f"Use HF mirror: {os.environ['HF_ENDPOINT']}") 8 | return os.environ['HF_ENDPOINT'] 9 | --------------------------------------------------------------------------------