├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── LICENSE ├── README.md ├── configparser.ini ├── convo_qa_chain.py ├── data ├── ABPI Code of Practice for the Pharmaceutical Industry 2021.pdf ├── Attention Is All You Need.pdf ├── Gradient Descent The Ultimate Optimizer.pdf ├── JP Morgan 2022 Environmental Social Governance Report.pdf ├── Language Models are Few-Shot Learners.pdf ├── Language Models are Unsupervised Multitask Learners.pdf └── United Nations 2022 Annual Report.pdf ├── docs2db.py ├── figs ├── High_Level_Architecture.png └── Sliding_Window_Chunking.png ├── main.py ├── requirements.txt └── toolkit ├── ___init__.py ├── local_llm.py ├── prompts.py ├── retrivers.py ├── together_api_llm.py └── utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .history 3 | .vscode 4 | __pycache__ 5 | Archieve 6 | database_store 7 | IncarnaMind.log 8 | experiments.ipynb 9 | .pylintrc 10 | .flake8 11 | models/ 12 | model/ 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🧠 IncarnaMind 2 | 3 | ## 👀 In a Nutshell 4 | 5 | IncarnaMind enables you to chat with your personal documents 📁 (PDF, TXT) using Large Language Models (LLMs) like GPT ([architecture overview](#high-level-architecture)). While OpenAI has recently launched a fine-tuning API for GPT models, it doesn't enable the base pretrained models to learn new data, and the responses can be prone to factual hallucinations. Utilize our [Sliding Window Chunking](#sliding-window-chunking) mechanism and Ensemble Retriever enables efficient querying of both fine-grained and coarse-grained information within your ground truth documents to augment the LLMs. 6 | 7 | Feel free to use it and we welcome any feedback and new feature suggestions 🙌. 8 | 9 | ## ✨ New Updates 10 | 11 | ### Open-Source and Local LLMs Support 12 | 13 | - **Recommended Model:** We've primarily tested with the Llama2 series models and recommend using [llama2-70b-chat](https://huggingface.co/TheBloke/Llama-2-70B-chat-GGUF) (either full or GGUF version) for optimal performance. Feel free to experiment with other LLMs. 14 | - **System Requirements:** It requires more than 35GB of GPU RAM to run the GGUF quantized version. 15 | 16 | ### Alternative Open-Source LLMs Options 17 | 18 | - **Insufficient RAM:** If you're limited by GPU RAM, consider using the [Together.ai](https://api.together.xyz/playground) API. It supports llama2-70b-chat and most other open-source LLMs. Plus, you get $25 in free usage. 19 | - **Upcoming:** Smaller and cost-effecitive, fine-tuned models will be released in the future. 20 | 21 | ### How to use GGUF models 22 | 23 | - For instructions on acquiring and using quantized GGUF LLM (similar to GGML), please refer to this [video](https://www.youtube.com/watch?v=lbFmceo4D5E) (from 10:45 to 12:30).. 24 | 25 | Here is a comparison table of the different models I tested, for reference only: 26 | 27 | | Metrics | GPT-4 | GPT-3.5 | Claude 2.0 | Llama2-70b | Llama2-70b-gguf | Llama2-70b-api | 28 | |-----------|--------|---------|------------|------------|-----------------|----------------| 29 | | Reasoning | High | Medium | High | Medium | Medium | Medium | 30 | | Speed | Medium | High | Medium | Very Low | Low | Medium | 31 | | GPU RAM | N/A | N/A | N/A | Very High | High | N/A | 32 | | Safety | Low | Low | Low | High | High | Low | 33 | 34 | ## 💻 Demo 35 | 36 | https://github.com/junruxiong/IncarnaMind/assets/44308338/89d479fb-de90-4f7c-b166-e54f7bc7344c 37 | 38 | ## 💡 Challenges Addressed 39 | 40 | - **Fixed Chunking**: Traditional RAG tools rely on fixed chunk sizes, limiting their adaptability in handling varying data complexity and context. 41 | 42 | - **Precision vs. Semantics**: Current retrieval methods usually focus either on semantic understanding or precise retrieval, but rarely both. 43 | 44 | - **Single-Document Limitation**: Many solutions can only query one document at a time, restricting multi-document information retrieval. 45 | 46 | - **Stability**: IncarnaMind is compatible with OpenAI GPT, Anthropic Claude, Llama2, and other open-source LLMs, ensuring stable parsing. 47 | 48 | ## 🎯 Key Features 49 | 50 | - **Adaptive Chunking**: Our Sliding Window Chunking technique dynamically adjusts window size and position for RAG, balancing fine-grained and coarse-grained data access based on data complexity and context. 51 | 52 | - **Multi-Document Conversational QA**: Supports simple and multi-hop queries across multiple documents simultaneously, breaking the single-document limitation. 53 | 54 | - **File Compatibility**: Supports both PDF and TXT file formats. 55 | 56 | - **LLM Model Compatibility**: Supports OpenAI GPT, Anthropic Claude, Llama2 and other open-source LLMs. 57 | 58 | ## 🏗 Architecture 59 | 60 | ### High Level Architecture 61 | 62 | ![image](figs/High_Level_Architecture.png) 63 | 64 | ### Sliding Window Chunking 65 | 66 | ![image](figs/Sliding_Window_Chunking.png) 67 | 68 | ## 🚀 Getting Started 69 | 70 | ### 1. Installation 71 | 72 | The installation is simple, you just need to run few commands. 73 | 74 | #### 1.0. Prerequisites 75 | 76 | - 3.8 ≤ Python < 3.11 with [Conda](https://www.anaconda.com/download) 77 | - One/All of [OpenAI API Key](https://beta.openai.com/signup), [Anthropic Claude API Key](https://console.anthropic.com/account/keys), [Together.ai API KEY](https://api.together.xyz/settings/api-keys) or [HuggingFace toekn for Meta Llama models](https://huggingface.co/settings/tokens) 78 | - And of course, your own documents. 79 | 80 | #### 1.1. Clone the repository 81 | 82 | ```shell 83 | git clone https://github.com/junruxiong/IncarnaMind 84 | cd IncarnaMind 85 | ``` 86 | 87 | #### 1.2. Setup 88 | 89 | Create Conda virtual environment: 90 | 91 | ```shell 92 | conda create -n IncarnaMind python=3.10 93 | ``` 94 | 95 | Activate: 96 | 97 | ```shell 98 | conda activate IncarnaMind 99 | ``` 100 | 101 | Install all requirements: 102 | 103 | ```shell 104 | pip install -r requirements.txt 105 | ``` 106 | 107 | Install [llama-cpp](https://github.com/abetlen/llama-cpp-python) seperatly if you want to run quantized local LLMs: 108 | 109 | - For `NVIDIA` GPUs support, use `cuBLAS` 110 | 111 | ```shell 112 | CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.1.83 --no-cache-dir 113 | ``` 114 | 115 | - For Apple Metal (`M1/M2`) support, use 116 | 117 | ```shell 118 | CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.1.83 --no-cache-dir 119 | ``` 120 | 121 | Setup your one/all of API keys in **configparser.ini** file: 122 | 123 | ```shell 124 | [tokens] 125 | OPENAI_API_KEY = (replace_me) 126 | ANTHROPIC_API_KEY = (replace_me) 127 | TOGETHER_API_KEY = (replace_me) 128 | # if you use full Meta-Llama models, you may need Huggingface token to access. 129 | HUGGINGFACE_TOKEN = (replace_me) 130 | ``` 131 | 132 | (Optional) Setup your custom parameters in **configparser.ini** file: 133 | 134 | ```shell 135 | [parameters] 136 | PARAMETERS 1 = (replace_me) 137 | PARAMETERS 2 = (replace_me) 138 | ... 139 | PARAMETERS n = (replace_me) 140 | ``` 141 | 142 | ### 2. Usage 143 | 144 | #### 2.1. Upload and process your files 145 | 146 | Put all your files (please name each file correctly to maximize the performance) into the **/data** directory and run the following command to ingest all data: 147 | (You can delete example files in the **/data** directory before running the command) 148 | 149 | ```shell 150 | python docs2db.py 151 | ``` 152 | 153 | #### 2.2. Run 154 | 155 | In order to start the conversation, run a command like: 156 | 157 | ```shell 158 | python main.py 159 | ``` 160 | 161 | #### 2.3. Chat and ask any questions 162 | 163 | Wait for the script to require your input like the below. 164 | 165 | ```shell 166 | Human: 167 | ``` 168 | 169 | #### 2.4. Others 170 | 171 | When you start a chat, the system will automatically generate a **IncarnaMind.log** file. 172 | If you want to edit the logging, please edit in the **configparser.ini** file. 173 | 174 | ```shell 175 | [logging] 176 | enabled = True 177 | level = INFO 178 | filename = IncarnaMind.log 179 | format = %(asctime)s [%(levelname)s] %(name)s: %(message)s 180 | ``` 181 | 182 | ## 🚫 Limitations 183 | 184 | - Citation is not supported for current version, but will release soon. 185 | - Limited asynchronous capabilities. 186 | 187 | ## 📝 Upcoming Features 188 | 189 | - Frontend UI interface 190 | - Fine-tuned small size open-source LLMs 191 | - OCR support 192 | - Asynchronous optimization 193 | - Support more document formats 194 | 195 | ## 🙌 Acknowledgements 196 | 197 | Special thanks to [Langchain](https://github.com/langchain-ai/langchain), [Chroma DB](https://github.com/chroma-core/chroma), [LocalGPT](https://github.com/PromtEngineer/localGPT), [Llama-cpp](https://github.com/abetlen/llama-cpp-python) for their invaluable contributions to the open-source community. Their work has been instrumental in making the IncarnaMind project a reality. 198 | 199 | ## 🖋 Citation 200 | 201 | If you want to cite our work, please use the following bibtex entry: 202 | 203 | ```bibtex 204 | @misc{IncarnaMind2023, 205 | author = {Junru Xiong}, 206 | title = {IncarnaMind}, 207 | year = {2023}, 208 | publisher = {GitHub}, 209 | journal = {GitHub Repository}, 210 | howpublished = {\url{https://github.com/junruxiong/IncarnaMind}} 211 | } 212 | ``` 213 | 214 | ## 📑 License 215 | 216 | [Apache 2.0 License](LICENSE) -------------------------------------------------------------------------------- /configparser.ini: -------------------------------------------------------------------------------- 1 | [tokens] 2 | ; Enter one/all of your API key here. 3 | ; E.g., OPENAI_API_KEY = sk-xxxxxxx 4 | OPENAI_API_KEY = xxxxx 5 | ANTHROPIC_API_KEY = xxxxx 6 | TOGETHER_API_KEY = xxxxx 7 | ; if you use Meta-Llama models, you may need Huggingface token to access. 8 | HUGGINGFACE_TOKEN = xxxxx 9 | VERSION = 1.0.1 10 | 11 | 12 | [directory] 13 | ; Directory for source files. 14 | DOCS_DIR = ./data 15 | ; Directory to store embeddings and Langchain documents. 16 | DB_DIR = ./database_store 17 | LOCAL_MODEL_DIR = ./models 18 | 19 | 20 | ; The below parameters are optional to modify: 21 | ; -------------------------------------------- 22 | [parameters] 23 | ; Model name schema: Model Provider|Model Name|Model File. Model File is only valid for GGUF format, set None for other format. 24 | 25 | ; For example: 26 | ; OpenAI|gpt-3.5-turbo|None 27 | ; OpenAI|gpt-4|None 28 | ; Anthropic|claude-2.0|None 29 | ; Together|togethercomputer/llama-2-70b-chat|None 30 | ; HuggingFace|TheBloke/Llama-2-70B-chat-GGUF|llama-2-70b-chat.q4_K_M.gguf 31 | ; HuggingFace|meta-llama/Llama-2-70b-chat-hf|None 32 | 33 | ; The full Together.AI model list can be found in the end of this file; We currently only support quantized gguf and the full huggingface local LLMs. 34 | MODEL_NAME = OpenAI|gpt-4-1106-preview|None 35 | ; LLM temperature 36 | TEMPURATURE = 0 37 | ; Maximum tokens for storing chat history. 38 | MAX_CHAT_HISTORY = 800 39 | ; Maximum tokens for LLM context for retrieved information. 40 | MAX_LLM_CONTEXT = 1200 41 | ; Maximum tokens for LLM generation. 42 | MAX_LLM_GENERATION = 1000 43 | ; Supported embeddings: openAIEmbeddings and hkunlpInstructorLarge. 44 | EMBEDDING_NAME = openAIEmbeddings 45 | 46 | ; This is dependent on your GPU type. 47 | N_GPU_LAYERS = 100 48 | ; this is depend on your GPU and CPU ram when using open source LLMs. 49 | N_BATCH = 512 50 | 51 | 52 | ; The base (small) chunk size for first stage document retrieval. 53 | BASE_CHUNK_SIZE = 100 54 | ; Set to 0 for no overlap. 55 | CHUNK_OVERLAP = 0 56 | ; The final retrieval (medium) chunk size will be BASE_CHUNK_SIZE * CHUNK_SCALE. 57 | CHUNK_SCALE = 3 58 | WINDOW_STEPS = 3 59 | ; The # tokens of window chunk will be BASE_CHUNK_SIZE * WINDOW_SCALE. 60 | WINDOW_SCALE = 18 61 | 62 | ; Ratio of BM25 retriever to Chroma Vectorstore retriever. 63 | RETRIEVER_WEIGHTS = 0.5, 0.5 64 | ; Number of retrieved chunks will range from FIRST_RETRIEVAL_K to 2*FIRST_RETRIEVAL_K due to the ensemble retriever. 65 | FIRST_RETRIEVAL_K = 3 66 | ; Number of retrieved chunks will range from SECOND_RETRIEVAL_K to 2*SECOND_RETRIEVAL_K due to the ensemble retriever. 67 | SECOND_RETRIEVAL_K = 3 68 | ; Number of windows (large chunks) for the third retriever. 69 | NUM_WINDOWS = 2 70 | ; (The third retrieval gets the final chunks passed to the LLM QA chain. The 'k' value is dynamic (based on MAX_LLM_CONTEXT), depending on the number of rephrased questions and retrieved documents.) 71 | 72 | 73 | [logging] 74 | ; If you do not want to enable logging, set enabled to False. 75 | enabled = True 76 | level = INFO 77 | filename = IncarnaMind.log 78 | format = %(asctime)s [%(levelname)s] %(name)s: %(message)s 79 | 80 | 81 | ; Together.AI supported models: 82 | 83 | ; 0 Austism/chronos-hermes-13b 84 | ; 1 EleutherAI/pythia-12b-v0 85 | ; 2 EleutherAI/pythia-1b-v0 86 | ; 3 EleutherAI/pythia-2.8b-v0 87 | ; 4 EleutherAI/pythia-6.9b 88 | ; 5 Gryphe/MythoMax-L2-13b 89 | ; 6 HuggingFaceH4/starchat-alpha 90 | ; 7 NousResearch/Nous-Hermes-13b 91 | ; 8 NousResearch/Nous-Hermes-Llama2-13b 92 | ; 9 NumbersStation/nsql-llama-2-7B 93 | ; 10 OpenAssistant/llama2-70b-oasst-sft-v10 94 | ; 11 OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5 95 | ; 12 OpenAssistant/stablelm-7b-sft-v7-epoch-3 96 | ; 13 Phind/Phind-CodeLlama-34B-Python-v1 97 | ; 14 Phind/Phind-CodeLlama-34B-v2 98 | ; 15 SG161222/Realistic_Vision_V3.0_VAE 99 | ; 16 WizardLM/WizardCoder-15B-V1.0 100 | ; 17 WizardLM/WizardCoder-Python-34B-V1.0 101 | ; 18 WizardLM/WizardLM-70B-V1.0 102 | ; 19 bigcode/starcoder 103 | ; 20 databricks/dolly-v2-12b 104 | ; 21 databricks/dolly-v2-3b 105 | ; 22 databricks/dolly-v2-7b 106 | ; 23 defog/sqlcoder 107 | ; 24 garage-bAInd/Platypus2-70B-instruct 108 | ; 25 huggyllama/llama-13b 109 | ; 26 huggyllama/llama-30b 110 | ; 27 huggyllama/llama-65b 111 | ; 28 huggyllama/llama-7b 112 | ; 29 lmsys/fastchat-t5-3b-v1.0 113 | ; 30 lmsys/vicuna-13b-v1.3 114 | ; 31 lmsys/vicuna-13b-v1.5-16k 115 | ; 32 lmsys/vicuna-13b-v1.5 116 | ; 33 lmsys/vicuna-7b-v1.3 117 | ; 34 prompthero/openjourney 118 | ; 35 runwayml/stable-diffusion-v1-5 119 | ; 36 stabilityai/stable-diffusion-2-1 120 | ; 37 stabilityai/stable-diffusion-xl-base-1.0 121 | ; 38 togethercomputer/CodeLlama-13b-Instruct 122 | ; 39 togethercomputer/CodeLlama-13b-Python 123 | ; 40 togethercomputer/CodeLlama-13b 124 | ; 41 togethercomputer/CodeLlama-34b-Instruct 125 | ; 42 togethercomputer/CodeLlama-34b-Python 126 | ; 43 togethercomputer/CodeLlama-34b 127 | ; 44 togethercomputer/CodeLlama-7b-Instruct 128 | ; 45 togethercomputer/CodeLlama-7b-Python 129 | ; 46 togethercomputer/CodeLlama-7b 130 | ; 47 togethercomputer/GPT-JT-6B-v1 131 | ; 48 togethercomputer/GPT-JT-Moderation-6B 132 | ; 49 togethercomputer/GPT-NeoXT-Chat-Base-20B 133 | ; 50 togethercomputer/Koala-13B 134 | ; 51 togethercomputer/LLaMA-2-7B-32K 135 | ; 52 togethercomputer/Llama-2-7B-32K-Instruct 136 | ; 53 togethercomputer/Pythia-Chat-Base-7B-v0.16 137 | ; 54 togethercomputer/Qwen-7B-Chat 138 | ; 55 togethercomputer/Qwen-7B 139 | ; 56 togethercomputer/RedPajama-INCITE-7B-Base 140 | ; 57 togethercomputer/RedPajama-INCITE-7B-Chat 141 | ; 58 togethercomputer/RedPajama-INCITE-7B-Instruct 142 | ; 59 togethercomputer/RedPajama-INCITE-Base-3B-v1 143 | ; 60 togethercomputer/RedPajama-INCITE-Chat-3B-v1 144 | ; 61 togethercomputer/RedPajama-INCITE-Instruct-3B-v1 145 | ; 62 togethercomputer/alpaca-7b 146 | ; 63 togethercomputer/codegen2-16B 147 | ; 64 togethercomputer/codegen2-7B 148 | ; 65 togethercomputer/falcon-40b-instruct 149 | ; 66 togethercomputer/falcon-40b 150 | ; 67 togethercomputer/falcon-7b-instruct 151 | ; 68 togethercomputer/falcon-7b 152 | ; 69 togethercomputer/guanaco-13b 153 | ; 70 togethercomputer/guanaco-33b 154 | ; 71 togethercomputer/guanaco-65b 155 | ; 72 togethercomputer/guanaco-7b 156 | ; 73 togethercomputer/llama-2-13b-chat 157 | ; 74 togethercomputer/llama-2-13b 158 | ; 75 togethercomputer/llama-2-70b-chat 159 | ; 76 togethercomputer/llama-2-70b 160 | ; 77 togethercomputer/llama-2-7b-chat 161 | ; 78 togethercomputer/llama-2-7b 162 | ; 79 togethercomputer/mpt-30b-chat 163 | ; 80 togethercomputer/mpt-30b-instruct 164 | ; 81 togethercomputer/mpt-30b 165 | ; 82 togethercomputer/mpt-7b-chat 166 | ; 83 togethercomputer/mpt-7b 167 | ; 84 togethercomputer/replit-code-v1-3b 168 | ; 85 upstage/SOLAR-0-70b-16bit 169 | ; 86 wavymulder/Analog-Diffusion -------------------------------------------------------------------------------- /convo_qa_chain.py: -------------------------------------------------------------------------------- 1 | """Conversational QA Chain""" 2 | from __future__ import annotations 3 | import inspect 4 | import logging 5 | from typing import Any, Dict, List, Optional 6 | from pydantic import Field 7 | 8 | from langchain.schema import BasePromptTemplate, BaseRetriever, Document 9 | from langchain.schema.language_model import BaseLanguageModel 10 | from langchain.chains import LLMChain 11 | from langchain.chains.question_answering import load_qa_chain 12 | from langchain.chains.conversational_retrieval.base import ( 13 | BaseConversationalRetrievalChain, 14 | ) 15 | from langchain.callbacks.manager import ( 16 | AsyncCallbackManagerForChainRun, 17 | CallbackManagerForChainRun, 18 | Callbacks, 19 | ) 20 | 21 | from toolkit.utils import ( 22 | Config, 23 | _get_chat_history, 24 | _get_standalone_questions_list, 25 | ) 26 | from toolkit.retrivers import MyRetriever 27 | from toolkit.prompts import PromptTemplates 28 | 29 | configs = Config("configparser.ini") 30 | logger = logging.getLogger(__name__) 31 | 32 | prompt_templates = PromptTemplates() 33 | 34 | 35 | class ConvoRetrievalChain(BaseConversationalRetrievalChain): 36 | """Chain for having a conversation based on retrieved documents. 37 | 38 | This chain takes in chat history (a list of messages) and new questions, 39 | and then returns an answer to that question. 40 | The algorithm for this chain consists of three parts: 41 | 42 | 1. Use the chat history and the new question to create a "standalone question". 43 | This is done so that this question can be passed into the retrieval step to fetch 44 | relevant documents. If only the new question was passed in, then relevant context 45 | may be lacking. If the whole conversation was passed into retrieval, there may 46 | be unnecessary information there that would distract from retrieval. 47 | 48 | 2. This new question is passed to the retriever and relevant documents are 49 | returned. 50 | 51 | 3. The retrieved documents are passed to an LLM along with either the new question 52 | (default behavior) or the original question and chat history to generate a final 53 | response. 54 | 55 | Example: 56 | .. code-block:: python 57 | 58 | from langchain.chains import ( 59 | StuffDocumentsChain, LLMChain, ConversationalRetrievalChain 60 | ) 61 | from langchain.prompts import PromptTemplate 62 | from langchain.llms import OpenAI 63 | 64 | combine_docs_chain = StuffDocumentsChain(...) 65 | vectorstore = ... 66 | retriever = vectorstore.as_retriever() 67 | 68 | # This controls how the standalone question is generated. 69 | # Should take `chat_history` and `question` as input variables. 70 | template = ( 71 | "Combine the chat history and follow up question into " 72 | "a standalone question. Chat History: {chat_history}" 73 | "Follow up question: {question}" 74 | ) 75 | prompt = PromptTemplate.from_template(template) 76 | llm = OpenAI() 77 | question_generator_chain = LLMChain(llm=llm, prompt=prompt) 78 | chain = ConversationalRetrievalChain( 79 | combine_docs_chain=combine_docs_chain, 80 | retriever=retriever, 81 | question_generator=question_generator_chain, 82 | ) 83 | """ 84 | 85 | retriever: MyRetriever = Field(exclude=True) 86 | """Retriever to use to fetch documents.""" 87 | file_names: List = Field(exclude=True) 88 | """file_names (List): List of file names used for retrieval.""" 89 | 90 | def _get_docs( 91 | self, 92 | question: str, 93 | inputs: Dict[str, Any], 94 | num_query: int, 95 | *, 96 | run_manager: Optional[CallbackManagerForChainRun] = None, 97 | ) -> List[Document]: 98 | """Get docs.""" 99 | try: 100 | docs = self.retriever.get_relevant_documents( 101 | question, num_query=num_query, run_manager=run_manager 102 | ) 103 | return docs 104 | except (IOError, FileNotFoundError) as error: 105 | logger.error("An error occurred in _get_docs: %s", error) 106 | return [] 107 | 108 | def _retrieve( 109 | self, 110 | question_list: List[str], 111 | inputs: Dict[str, Any], 112 | run_manager: Optional[CallbackManagerForChainRun] = None, 113 | ) -> List[str]: 114 | num_query = len(question_list) 115 | accepts_run_manager = ( 116 | "run_manager" in inspect.signature(self._get_docs).parameters 117 | ) 118 | 119 | total_results = {} 120 | for question in question_list: 121 | docs_dict = ( 122 | self._get_docs( 123 | question, inputs, num_query=num_query, run_manager=run_manager 124 | ) 125 | if accepts_run_manager 126 | else self._get_docs(question, inputs, num_query=num_query) 127 | ) 128 | 129 | for file_name, docs in docs_dict.items(): 130 | if file_name not in total_results: 131 | total_results[file_name] = docs 132 | else: 133 | total_results[file_name].extend(docs) 134 | 135 | logger.info( 136 | "-----step_done--------------------------------------------------", 137 | ) 138 | 139 | snippets = "" 140 | redundancy = set() 141 | for file_name, docs in total_results.items(): 142 | sorted_docs = sorted(docs, key=lambda x: x.metadata["medium_chunk_idx"]) 143 | temp = "\n".join( 144 | doc.page_content 145 | for doc in sorted_docs 146 | if doc.metadata["page_content_md5"] not in redundancy 147 | ) 148 | redundancy.update(doc.metadata["page_content_md5"] for doc in sorted_docs) 149 | snippets += f"\nContext about {file_name}:\n{{{temp}}}\n" 150 | 151 | return snippets, docs_dict 152 | 153 | def _call( 154 | self, 155 | inputs: Dict[str, Any], 156 | run_manager: Optional[CallbackManagerForChainRun] = None, 157 | ) -> Dict[str, Any]: 158 | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() 159 | question = inputs["question"] 160 | get_chat_history = self.get_chat_history or _get_chat_history 161 | chat_history_str = get_chat_history(inputs["chat_history"]) 162 | 163 | callbacks = _run_manager.get_child() 164 | new_questions = self.question_generator.run( 165 | question=question, 166 | chat_history=chat_history_str, 167 | database=self.file_names, 168 | callbacks=callbacks, 169 | ) 170 | logger.info("new_questions: %s", new_questions) 171 | new_question_list = _get_standalone_questions_list(new_questions, question)[:3] 172 | # print("new_question_list:", new_question_list) 173 | logger.info("user_input: %s", question) 174 | logger.info("new_question_list: %s", new_question_list) 175 | 176 | snippets, source_docs = self._retrieve( 177 | new_question_list, inputs, run_manager=_run_manager 178 | ) 179 | 180 | docs = [ 181 | Document( 182 | page_content=snippets, 183 | metadata={}, 184 | ) 185 | ] 186 | 187 | new_inputs = inputs.copy() 188 | new_inputs["chat_history"] = chat_history_str 189 | answer = self.combine_docs_chain.run( 190 | input_documents=docs, 191 | database=self.file_names, 192 | callbacks=_run_manager.get_child(), 193 | **new_inputs, 194 | ) 195 | output: Dict[str, Any] = {self.output_key: answer} 196 | if self.return_source_documents: 197 | output["source_documents"] = source_docs 198 | if self.return_generated_question: 199 | output["generated_question"] = new_questions 200 | 201 | logger.info("*****response*****: %s", output["answer"]) 202 | logger.info( 203 | "=====epoch_done============================================================", 204 | ) 205 | return output 206 | 207 | async def _aget_docs( 208 | self, 209 | question: str, 210 | inputs: Dict[str, Any], 211 | num_query: int, 212 | *, 213 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, 214 | ) -> List[Document]: 215 | """Get docs.""" 216 | try: 217 | docs = await self.retriever.aget_relevant_documents( 218 | question, num_query=num_query, run_manager=run_manager 219 | ) 220 | return docs 221 | except (IOError, FileNotFoundError) as error: 222 | logger.error("An error occurred in _get_docs: %s", error) 223 | return [] 224 | 225 | async def _aretrieve( 226 | self, 227 | question_list: List[str], 228 | inputs: Dict[str, Any], 229 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, 230 | ) -> Dict[str, Any]: 231 | num_query = len(question_list) 232 | accepts_run_manager = ( 233 | "run_manager" in inspect.signature(self._get_docs).parameters 234 | ) 235 | 236 | total_results = {} 237 | for question in question_list: 238 | docs_dict = ( 239 | await self._aget_docs( 240 | question, inputs, num_query=num_query, run_manager=run_manager 241 | ) 242 | if accepts_run_manager 243 | else await self._aget_docs(question, inputs, num_query=num_query) 244 | ) 245 | 246 | for file_name, docs in docs_dict.items(): 247 | if file_name not in total_results: 248 | total_results[file_name] = docs 249 | else: 250 | total_results[file_name].extend(docs) 251 | 252 | logger.info( 253 | "-----step_done--------------------------------------------------", 254 | ) 255 | 256 | snippets = "" 257 | redundancy = set() 258 | for file_name, docs in total_results.items(): 259 | sorted_docs = sorted(docs, key=lambda x: x.metadata["medium_chunk_idx"]) 260 | temp = "\n".join( 261 | doc.page_content 262 | for doc in sorted_docs 263 | if doc.metadata["page_content_md5"] not in redundancy 264 | ) 265 | redundancy.update(doc.metadata["page_content_md5"] for doc in sorted_docs) 266 | snippets += f"\nContext about {file_name}:\n{{{temp}}}\n" 267 | 268 | return snippets, docs_dict 269 | 270 | async def _acall( 271 | self, 272 | inputs: Dict[str, Any], 273 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, 274 | ) -> Dict[str, Any]: 275 | _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() 276 | question = inputs["question"] 277 | get_chat_history = self.get_chat_history or _get_chat_history 278 | chat_history_str = get_chat_history(inputs["chat_history"]) 279 | 280 | callbacks = _run_manager.get_child() 281 | new_questions = await self.question_generator.arun( 282 | question=question, 283 | chat_history=chat_history_str, 284 | database=self.file_names, 285 | callbacks=callbacks, 286 | ) 287 | new_question_list = _get_standalone_questions_list(new_questions, question)[:3] 288 | logger.info("new_questions: %s", new_questions) 289 | logger.info("new_question_list: %s", new_question_list) 290 | 291 | snippets, source_docs = await self._aretrieve( 292 | new_question_list, inputs, run_manager=_run_manager 293 | ) 294 | 295 | docs = [ 296 | Document( 297 | page_content=snippets, 298 | metadata={}, 299 | ) 300 | ] 301 | 302 | new_inputs = inputs.copy() 303 | new_inputs["chat_history"] = chat_history_str 304 | answer = await self.combine_docs_chain.arun( 305 | input_documents=docs, 306 | database=self.file_names, 307 | callbacks=_run_manager.get_child(), 308 | **new_inputs, 309 | ) 310 | output: Dict[str, Any] = {self.output_key: answer} 311 | if self.return_source_documents: 312 | output["source_documents"] = source_docs 313 | if self.return_generated_question: 314 | output["generated_question"] = new_questions 315 | 316 | logger.info("*****response*****: %s", output["answer"]) 317 | logger.info( 318 | "=====epoch_done============================================================", 319 | ) 320 | 321 | return output 322 | 323 | @classmethod 324 | def from_llm( 325 | cls, 326 | llm: BaseLanguageModel, 327 | retriever: BaseRetriever, 328 | condense_question_prompt: BasePromptTemplate = prompt_templates.get_refine_qa_template( 329 | configs.model_name 330 | ), 331 | chain_type: str = "stuff", # only support stuff chain now 332 | verbose: bool = False, 333 | condense_question_llm: Optional[BaseLanguageModel] = None, 334 | combine_docs_chain_kwargs: Optional[Dict] = None, 335 | callbacks: Callbacks = None, 336 | **kwargs: Any, 337 | ) -> BaseConversationalRetrievalChain: 338 | """Convenience method to load chain from LLM and retriever. 339 | 340 | This provides some logic to create the `question_generator` chain 341 | as well as the combine_docs_chain. 342 | 343 | Args: 344 | llm: The default language model to use at every part of this chain 345 | (eg in both the question generation and the answering) 346 | retriever: The retriever to use to fetch relevant documents from. 347 | condense_question_prompt: The prompt to use to condense the chat history 348 | and new question into standalone question(s). 349 | chain_type: The chain type to use to create the combine_docs_chain, will 350 | be sent to `load_qa_chain`. 351 | verbose: Verbosity flag for logging to stdout. 352 | condense_question_llm: The language model to use for condensing the chat 353 | history and new question into standalone question(s). If none is 354 | provided, will default to `llm`. 355 | combine_docs_chain_kwargs: Parameters to pass as kwargs to `load_qa_chain` 356 | when constructing the combine_docs_chain. 357 | callbacks: Callbacks to pass to all subchains. 358 | **kwargs: Additional parameters to pass when initializing 359 | ConversationalRetrievalChain 360 | """ 361 | combine_docs_chain_kwargs = combine_docs_chain_kwargs or { 362 | "prompt": prompt_templates.get_retrieval_qa_template_selector( 363 | configs.model_name 364 | ).get_prompt(llm) 365 | } 366 | doc_chain = load_qa_chain( 367 | llm, 368 | chain_type=chain_type, 369 | verbose=verbose, 370 | callbacks=callbacks, 371 | **combine_docs_chain_kwargs, 372 | ) 373 | 374 | _llm = condense_question_llm or llm 375 | condense_question_chain = LLMChain( 376 | llm=_llm, 377 | prompt=condense_question_prompt, 378 | verbose=verbose, 379 | callbacks=callbacks, 380 | ) 381 | return cls( 382 | retriever=retriever, 383 | combine_docs_chain=doc_chain, 384 | question_generator=condense_question_chain, 385 | callbacks=callbacks, 386 | **kwargs, 387 | ) 388 | -------------------------------------------------------------------------------- /data/ABPI Code of Practice for the Pharmaceutical Industry 2021.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/ABPI Code of Practice for the Pharmaceutical Industry 2021.pdf -------------------------------------------------------------------------------- /data/Attention Is All You Need.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/Attention Is All You Need.pdf -------------------------------------------------------------------------------- /data/Gradient Descent The Ultimate Optimizer.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/Gradient Descent The Ultimate Optimizer.pdf -------------------------------------------------------------------------------- /data/JP Morgan 2022 Environmental Social Governance Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/JP Morgan 2022 Environmental Social Governance Report.pdf -------------------------------------------------------------------------------- /data/Language Models are Few-Shot Learners.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/Language Models are Few-Shot Learners.pdf -------------------------------------------------------------------------------- /data/Language Models are Unsupervised Multitask Learners.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/Language Models are Unsupervised Multitask Learners.pdf -------------------------------------------------------------------------------- /data/United Nations 2022 Annual Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/United Nations 2022 Annual Report.pdf -------------------------------------------------------------------------------- /docs2db.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module save documents to embeddings and langchain Documents. 3 | """ 4 | import os 5 | import glob 6 | import pickle 7 | from typing import List 8 | from multiprocessing import Pool 9 | from collections import deque 10 | import hashlib 11 | import tiktoken 12 | 13 | from tqdm import tqdm 14 | 15 | from langchain.schema import Document 16 | from langchain.vectorstores import Chroma 17 | from langchain.text_splitter import ( 18 | RecursiveCharacterTextSplitter, 19 | ) 20 | from langchain.document_loaders import ( 21 | PyPDFLoader, 22 | TextLoader, 23 | ) 24 | 25 | from toolkit.utils import Config, choose_embeddings, clean_text 26 | 27 | 28 | # Load the config file 29 | configs = Config("configparser.ini") 30 | 31 | os.environ["OPENAI_API_KEY"] = configs.openai_api_key 32 | os.environ["ANTHROPIC_API_KEY"] = configs.anthropic_api_key 33 | 34 | embedding_store_path = configs.db_dir 35 | files_path = glob.glob(configs.docs_dir + "/*") 36 | 37 | tokenizer_name = tiktoken.encoding_for_model("gpt-3.5-turbo") 38 | tokenizer = tiktoken.get_encoding(tokenizer_name.name) 39 | 40 | loaders = { 41 | "pdf": (PyPDFLoader, {}), 42 | "txt": (TextLoader, {}), 43 | } 44 | 45 | 46 | def tiktoken_len(text: str): 47 | """Calculate the token length of a given text string using TikToken. 48 | 49 | Args: 50 | text (str): The text to be tokenized. 51 | 52 | Returns: 53 | int: The length of the tokenized text. 54 | """ 55 | tokens = tokenizer.encode(text, disallowed_special=()) 56 | 57 | return len(tokens) 58 | 59 | 60 | def string2md5(text: str): 61 | """Convert a string to its MD5 hash. 62 | 63 | Args: 64 | text (str): The text to be hashed. 65 | 66 | Returns: 67 | str: The MD5 hash of the input string. 68 | """ 69 | hash_md5 = hashlib.md5() 70 | hash_md5.update(text.encode("utf-8")) 71 | 72 | return hash_md5.hexdigest() 73 | 74 | 75 | def load_file(file_path): 76 | """Load a file and return its content as a Document object. 77 | 78 | Args: 79 | file_path (str): The path to the file. 80 | 81 | Returns: 82 | Document: The loaded document. 83 | """ 84 | ext = file_path.split(".")[-1] 85 | 86 | if ext in loaders: 87 | loader_type, args = loaders[ext] 88 | loader = loader_type(file_path, **args) 89 | doc = loader.load() 90 | 91 | return doc 92 | 93 | raise ValueError(f"Extension {ext} not supported") 94 | 95 | 96 | def docs2vectorstore(docs: List[Document], embedding_name: str, suffix: str = ""): 97 | """Convert a list of Documents into a Chroma vector store. 98 | 99 | Args: 100 | docs (Document): The list of Documents. 101 | suffix (str, optional): Suffix for the embedding. Defaults to "". 102 | """ 103 | embedding = choose_embeddings(embedding_name) 104 | name = f"{embedding_name}_{suffix}" 105 | # if embedding_store_path is not existing, create it 106 | if not os.path.exists(embedding_store_path): 107 | os.makedirs(embedding_store_path) 108 | Chroma.from_documents( 109 | docs, 110 | embedding, 111 | persist_directory=f"{embedding_store_path}/chroma_{name}", 112 | ) 113 | 114 | 115 | def file_names2pickle(file_names: list, save_name: str = ""): 116 | """Save the list of file names to a pickle file. 117 | 118 | Args: 119 | file_names (list): The list of file names. 120 | save_name (str, optional): The name for the saved pickle file. Defaults to "". 121 | """ 122 | name = f"{save_name}" 123 | if not os.path.exists(embedding_store_path): 124 | os.makedirs(embedding_store_path) 125 | with open(f"{embedding_store_path}/{name}.pkl", "wb") as file: 126 | pickle.dump(file_names, file) 127 | 128 | 129 | def docs2pickle(docs: List[Document], suffix: str = ""): 130 | """Serializes a list of Document objects to a pickle file. 131 | 132 | Args: 133 | docs (Document): List of Document objects. 134 | suffix (str, optional): Suffix for the pickle file. Defaults to "". 135 | """ 136 | for doc in docs: 137 | doc.page_content = clean_text(doc.page_content) 138 | name = f"pickle_{suffix}" 139 | if not os.path.exists(embedding_store_path): 140 | os.makedirs(embedding_store_path) 141 | with open(f"{embedding_store_path}/docs_{name}.pkl", "wb") as file: 142 | pickle.dump(docs, file) 143 | 144 | 145 | def split_doc( 146 | doc: List[Document], chunk_size: int, chunk_overlap: int, chunk_idx_name: str 147 | ): 148 | """Splits a document into smaller chunks based on the provided size and overlap. 149 | 150 | Args: 151 | doc (Document): Document to be split. 152 | chunk_size (int): Size of each chunk. 153 | chunk_overlap (int): Overlap between adjacent chunks. 154 | chunk_idx_name (str): Metadata key for storing chunk indices. 155 | 156 | Returns: 157 | list: List of Document objects representing the chunks. 158 | """ 159 | data_splitter = RecursiveCharacterTextSplitter( 160 | chunk_size=chunk_size, 161 | chunk_overlap=chunk_overlap, 162 | length_function=tiktoken_len, 163 | ) 164 | doc_split = data_splitter.split_documents(doc) 165 | chunk_idx = 0 166 | 167 | for d_split in doc_split: 168 | d_split.metadata[chunk_idx_name] = chunk_idx 169 | chunk_idx += 1 170 | 171 | return doc_split 172 | 173 | 174 | def process_metadata(doc: List[Document]): 175 | """Processes and updates the metadata for a list of Document objects. 176 | 177 | Args: 178 | doc (list): List of Document objects. 179 | """ 180 | # get file name and remove extension 181 | file_name_with_extension = os.path.basename(doc[0].metadata["source"]) 182 | file_name, _ = os.path.splitext(file_name_with_extension) 183 | 184 | for _, item in enumerate(doc): 185 | for key, value in item.metadata.items(): 186 | if isinstance(value, list): 187 | item.metadata[key] = str(value) 188 | item.metadata["page_content"] = item.page_content 189 | item.metadata["page_content_md5"] = string2md5(item.page_content) 190 | item.metadata["source_md5"] = string2md5(item.metadata["source"]) 191 | item.page_content = f"{file_name}\n{item.page_content}" 192 | 193 | 194 | def add_window( 195 | doc: Document, window_steps: int, window_size: int, window_idx_name: str 196 | ): 197 | """Adds windowing information to the metadata of each document in the list. 198 | 199 | Args: 200 | doc (Document): List of Document objects. 201 | window_steps (int): Step size for windowing. 202 | window_size (int): Size of each window. 203 | window_idx_name (str): Metadata key for storing window indices. 204 | """ 205 | window_id = 0 206 | window_deque = deque() 207 | 208 | for idx, item in enumerate(doc): 209 | if idx % window_steps == 0 and idx != 0 and idx < len(doc) - window_size: 210 | window_id += 1 211 | window_deque.append(window_id) 212 | 213 | if len(window_deque) > window_size: 214 | for _ in range(window_steps): 215 | window_deque.popleft() 216 | 217 | window = set(window_deque) 218 | item.metadata[f"{window_idx_name}_lower_bound"] = min(window) 219 | item.metadata[f"{window_idx_name}_upper_bound"] = max(window) 220 | 221 | 222 | def merge_metadata(dicts_list: dict): 223 | """Merges a list of metadata dictionaries into a single dictionary. 224 | 225 | Args: 226 | dicts_list (list): List of metadata dictionaries. 227 | 228 | Returns: 229 | dict: Merged metadata dictionary. 230 | """ 231 | merged_dict = {} 232 | bounds_dict = {} 233 | keys_to_remove = set() 234 | 235 | for dic in dicts_list: 236 | for key, value in dic.items(): 237 | if key in merged_dict: 238 | if value not in merged_dict[key]: 239 | merged_dict[key].append(value) 240 | else: 241 | merged_dict[key] = [value] 242 | 243 | for key, values in merged_dict.items(): 244 | if len(values) > 1 and all(isinstance(x, (int, float)) for x in values): 245 | bounds_dict[f"{key}_lower_bound"] = min(values) 246 | bounds_dict[f"{key}_upper_bound"] = max(values) 247 | keys_to_remove.add(key) 248 | 249 | merged_dict.update(bounds_dict) 250 | 251 | for key in keys_to_remove: 252 | del merged_dict[key] 253 | 254 | return { 255 | k: v[0] if isinstance(v, list) and len(v) == 1 else v 256 | for k, v in merged_dict.items() 257 | } 258 | 259 | 260 | def merge_chunks(doc: Document, scale_factor: int, chunk_idx_name: str): 261 | """Merges adjacent chunks into larger chunks based on a scaling factor. 262 | 263 | Args: 264 | doc (Document): List of Document objects. 265 | scale_factor (int): The number of small chunks to merge into a larger chunk. 266 | chunk_idx_name (str): Metadata key for storing chunk indices. 267 | 268 | Returns: 269 | list: List of Document objects representing the merged chunks. 270 | """ 271 | merged_doc = [] 272 | page_content = "" 273 | metadata_list = [] 274 | chunk_idx = 0 275 | 276 | for idx, item in enumerate(doc): 277 | page_content += item.page_content 278 | metadata_list.append(item.metadata) 279 | 280 | if (idx + 1) % scale_factor == 0 or idx == len(doc) - 1: 281 | metadata = merge_metadata(metadata_list) 282 | metadata[chunk_idx_name] = chunk_idx 283 | merged_doc.append( 284 | Document( 285 | page_content=page_content, 286 | metadata=metadata, 287 | ) 288 | ) 289 | chunk_idx += 1 290 | page_content = "" 291 | metadata_list = [] 292 | 293 | return merged_doc 294 | 295 | 296 | def process_files(): 297 | """Main function for processing files. Loads, tokenizes, and saves document data.""" 298 | with Pool() as pool: 299 | chunks_small = [] 300 | chunks_medium = [] 301 | file_names = [] 302 | 303 | with tqdm(total=len(files_path), desc="Processing files", ncols=80) as pbar: 304 | for doc in pool.imap_unordered(load_file, files_path): 305 | file_name_with_extension = os.path.basename(doc[0].metadata["source"]) 306 | # file_name, _ = os.path.splitext(file_name_with_extension) 307 | 308 | chunk_split_small = split_doc( 309 | doc=doc, 310 | chunk_size=configs.base_chunk_size, 311 | chunk_overlap=configs.chunk_overlap, 312 | chunk_idx_name="small_chunk_idx", 313 | ) 314 | add_window( 315 | doc=chunk_split_small, 316 | window_steps=configs.window_steps, 317 | window_size=configs.window_scale, 318 | window_idx_name="large_chunks_idx", 319 | ) 320 | 321 | chunk_split_medium = merge_chunks( 322 | doc=chunk_split_small, 323 | scale_factor=configs.chunk_scale, 324 | chunk_idx_name="medium_chunk_idx", 325 | ) 326 | 327 | process_metadata(chunk_split_small) 328 | process_metadata(chunk_split_medium) 329 | 330 | file_names.append(file_name_with_extension) 331 | chunks_small.extend(chunk_split_small) 332 | chunks_medium.extend(chunk_split_medium) 333 | 334 | pbar.update() 335 | 336 | file_names2pickle(file_names, save_name="file_names") 337 | 338 | docs2vectorstore(chunks_small, configs.embedding_name, suffix="chunks_small") 339 | docs2vectorstore(chunks_medium, configs.embedding_name, suffix="chunks_medium") 340 | 341 | docs2pickle(chunks_small, suffix="chunks_small") 342 | docs2pickle(chunks_medium, suffix="chunks_medium") 343 | 344 | 345 | if __name__ == "__main__": 346 | process_files() 347 | -------------------------------------------------------------------------------- /figs/High_Level_Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/figs/High_Level_Architecture.png -------------------------------------------------------------------------------- /figs/Sliding_Window_Chunking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/figs/Sliding_Window_Chunking.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Conversational QA Chain""" 2 | from __future__ import annotations 3 | import os 4 | import re 5 | import time 6 | import logging 7 | 8 | from langchain.chat_models import ChatOpenAI, ChatAnthropic 9 | from langchain.memory import ConversationTokenBufferMemory 10 | from convo_qa_chain import ConvoRetrievalChain 11 | 12 | from toolkit.together_api_llm import TogetherLLM 13 | from toolkit.retrivers import MyRetriever 14 | from toolkit.local_llm import load_local_llm 15 | from toolkit.utils import ( 16 | Config, 17 | choose_embeddings, 18 | load_embedding, 19 | load_pickle, 20 | check_device, 21 | ) 22 | 23 | 24 | # Load the config file 25 | configs = Config("configparser.ini") 26 | logger = logging.getLogger(__name__) 27 | 28 | os.environ["OPENAI_API_KEY"] = configs.openai_api_key 29 | os.environ["ANTHROPIC_API_KEY"] = configs.anthropic_api_key 30 | 31 | embedding = choose_embeddings(configs.embedding_name) 32 | db_store_path = configs.db_dir 33 | 34 | 35 | # get models 36 | def get_llm(llm_name: str, temperature: float, max_tokens: int): 37 | """Get the LLM model from the model name.""" 38 | 39 | if not os.path.exists(configs.local_model_dir): 40 | os.makedirs(configs.local_model_dir) 41 | 42 | splits = llm_name.split("|") # [provider, model_name, model_file] 43 | 44 | if "openai" in splits[0].lower(): 45 | llm_model = ChatOpenAI( 46 | model=splits[1], 47 | temperature=temperature, 48 | max_tokens=max_tokens, 49 | ) 50 | 51 | elif "anthropic" in splits[0].lower(): 52 | llm_model = ChatAnthropic( 53 | model=splits[1], 54 | temperature=temperature, 55 | max_tokens_to_sample=max_tokens, 56 | ) 57 | 58 | elif "together" in splits[0].lower(): 59 | llm_model = TogetherLLM( 60 | model=splits[1], 61 | temperature=temperature, 62 | max_tokens=max_tokens, 63 | ) 64 | elif "huggingface" in splits[0].lower(): 65 | llm_model = load_local_llm( 66 | model_id=splits[1], 67 | model_basename=splits[-1], 68 | temperature=temperature, 69 | max_tokens=max_tokens, 70 | device_type=check_device(), 71 | ) 72 | else: 73 | raise ValueError("Invalid Model Name") 74 | 75 | return llm_model 76 | 77 | 78 | llm = get_llm(configs.model_name, configs.temperature, configs.max_llm_generation) 79 | 80 | 81 | # load retrieval database 82 | db_embedding_chunks_small = load_embedding( 83 | store_name=configs.embedding_name, 84 | embedding=embedding, 85 | suffix="chunks_small", 86 | path=db_store_path, 87 | ) 88 | db_embedding_chunks_medium = load_embedding( 89 | store_name=configs.embedding_name, 90 | embedding=embedding, 91 | suffix="chunks_medium", 92 | path=db_store_path, 93 | ) 94 | 95 | db_docs_chunks_small = load_pickle( 96 | prefix="docs_pickle", suffix="chunks_small", path=db_store_path 97 | ) 98 | db_docs_chunks_medium = load_pickle( 99 | prefix="docs_pickle", suffix="chunks_medium", path=db_store_path 100 | ) 101 | file_names = load_pickle(prefix="file", suffix="names", path=db_store_path) 102 | 103 | 104 | # Initialize the retriever 105 | my_retriever = MyRetriever( 106 | llm=llm, 107 | embedding_chunks_small=db_embedding_chunks_small, 108 | embedding_chunks_medium=db_embedding_chunks_medium, 109 | docs_chunks_small=db_docs_chunks_small, 110 | docs_chunks_medium=db_docs_chunks_medium, 111 | first_retrieval_k=configs.first_retrieval_k, 112 | second_retrieval_k=configs.second_retrieval_k, 113 | num_windows=configs.num_windows, 114 | retriever_weights=configs.retriever_weights, 115 | ) 116 | 117 | 118 | # Initialize the memory 119 | memory = ConversationTokenBufferMemory( 120 | llm=llm, 121 | memory_key="chat_history", 122 | input_key="question", 123 | output_key="answer", 124 | return_messages=True, 125 | max_token_limit=configs.max_chat_history, 126 | ) 127 | 128 | 129 | # Initialize the QA chain 130 | qa = ConvoRetrievalChain.from_llm( 131 | llm, 132 | my_retriever, 133 | file_names=file_names, 134 | memory=memory, 135 | return_source_documents=False, 136 | return_generated_question=False, 137 | ) 138 | 139 | 140 | if __name__ == "__main__": 141 | while True: 142 | user_input = input("Human: ") 143 | start_time = time.time() 144 | user_input_ = re.sub(r"^Human: ", "", user_input) 145 | print("*" * 6) 146 | resp = qa({"question": user_input_}) 147 | print() 148 | print(f"AI:{resp['answer']}") 149 | print(f"Time used: {time.time() - start_time}") 150 | print("-" * 60) 151 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chromadb==0.4.13 2 | InstructorEmbedding==1.0.1 3 | langchain==0.0.308 4 | openai==0.28.1 5 | pypdf==3.16.2 6 | rank-bm25==0.2.2 7 | sentence-transformers==2.2.2 8 | tiktoken==0.5.1 9 | torch==2.0.1 10 | torchaudio==2.0.2 11 | torchvision==0.15.2 12 | together==0.2.4 13 | tqdm==4.66.1 -------------------------------------------------------------------------------- /toolkit/___init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/toolkit/___init__.py -------------------------------------------------------------------------------- /toolkit/local_llm.py: -------------------------------------------------------------------------------- 1 | """The below code is borrowed from: https://github.com/PromtEngineer/localGPT 2 | The reason to use gguf/ggml models: https://huggingface.co/TheBloke/wizardLM-7B-GGML/discussions/3""" 3 | import logging 4 | import torch 5 | from huggingface_hub import hf_hub_download 6 | from huggingface_hub import login 7 | from langchain.llms import LlamaCpp, HuggingFacePipeline 8 | from transformers import ( 9 | AutoModelForCausalLM, 10 | AutoTokenizer, 11 | LlamaForCausalLM, 12 | LlamaTokenizer, 13 | GenerationConfig, 14 | pipeline, 15 | ) 16 | from toolkit.utils import Config 17 | 18 | 19 | configs = Config("configparser.ini") 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def load_gguf_hf_model( 24 | model_id: str, 25 | model_basename: str, 26 | max_tokens: int, 27 | temperature: float, 28 | device_type: str, 29 | ): 30 | """ 31 | Load a GGUF/GGML quantized model using LlamaCpp. 32 | 33 | This function attempts to load a GGUF/GGML quantized model using the LlamaCpp library. 34 | If the model is of type GGML, and newer version of LLAMA-CPP is used which does not support GGML, 35 | it logs a message indicating that LLAMA-CPP has dropped support for GGML. 36 | 37 | Parameters: 38 | - model_id (str): The identifier for the model on HuggingFace Hub. 39 | - model_basename (str): The base name of the model file. 40 | - max_tokens (int): The maximum number of tokens to generate in the completion. 41 | - temperature (float): The temperature of LLM. 42 | - device_type (str): The type of device where the model will run, e.g., 'mps', 'cuda', etc. 43 | 44 | Returns: 45 | - LlamaCpp: An instance of the LlamaCpp model if successful, otherwise None. 46 | 47 | Notes: 48 | - The function uses the `hf_hub_download` function to download the model from the HuggingFace Hub. 49 | - The number of GPU layers is set based on the device type. 50 | """ 51 | 52 | try: 53 | logger.info("Using Llamacpp for GGUF/GGML quantized models") 54 | model_path = hf_hub_download( 55 | repo_id=model_id, 56 | filename=model_basename, 57 | resume_download=True, 58 | cache_dir=configs.local_model_dir, 59 | ) 60 | kwargs = { 61 | "model_path": model_path, 62 | "n_ctx": configs.max_llm_context, 63 | "max_tokens": max_tokens, 64 | "temperature": temperature, 65 | "n_batch": configs.n_batch, # set this based on your GPU & CPU RAM 66 | "verbose": False, 67 | } 68 | if device_type.lower() == "mps": 69 | kwargs["n_gpu_layers"] = 1 70 | if device_type.lower() == "cuda": 71 | kwargs["n_gpu_layers"] = configs.n_gpu_layers # set this based on your GPU 72 | 73 | return LlamaCpp(**kwargs) 74 | except: 75 | if "ggml" in model_basename: 76 | logger.info( 77 | "If you were using GGML model, LLAMA-CPP Dropped Support, Use GGUF Instead" 78 | ) 79 | return None 80 | 81 | 82 | def load_full_hf_model(model_id: str, model_basename: str, device_type: str): 83 | """ 84 | Load a full model using either LlamaTokenizer or AutoModelForCausalLM. 85 | 86 | This function loads a full model based on the specified device type. 87 | If the device type is 'mps' or 'cpu', it uses LlamaTokenizer and LlamaForCausalLM. 88 | Otherwise, it uses AutoModelForCausalLM. 89 | 90 | Parameters: 91 | - model_id (str): The identifier for the model on HuggingFace Hub. 92 | - model_basename (str): The base name of the model file. 93 | - device_type (str): The type of device where the model will run. 94 | 95 | Returns: 96 | - model (Union[LlamaForCausalLM, AutoModelForCausalLM]): The loaded model. 97 | - tokenizer (Union[LlamaTokenizer, AutoTokenizer]): The tokenizer associated with the model. 98 | 99 | Notes: 100 | - The function uses the `from_pretrained` method to load both the model and the tokenizer. 101 | - Additional settings are provided for NVIDIA GPUs, such as loading in 4-bit and setting the compute dtype. 102 | """ 103 | if "meta-llama" in model_id.lower(): 104 | login(token=configs.huggingface_token) 105 | 106 | if device_type.lower() in ["mps", "cpu"]: 107 | logger.info("Using LlamaTokenizer") 108 | tokenizer = LlamaTokenizer.from_pretrained( 109 | model_id, 110 | cache_dir=configs.local_model_dir, 111 | ) 112 | model = LlamaForCausalLM.from_pretrained( 113 | model_id, 114 | cache_dir=configs.local_model_dir, 115 | ) 116 | else: 117 | logger.info("Using AutoModelForCausalLM for full models") 118 | tokenizer = AutoTokenizer.from_pretrained( 119 | model_id, cache_dir=configs.local_model_dir 120 | ) 121 | logger.info("Tokenizer loaded") 122 | model = AutoModelForCausalLM.from_pretrained( 123 | model_id, 124 | device_map="auto", 125 | torch_dtype=torch.float16, 126 | low_cpu_mem_usage=True, 127 | cache_dir=configs.local_model_dir, 128 | # trust_remote_code=True, # set these if you are using NVIDIA GPU 129 | # load_in_4bit=True, 130 | # bnb_4bit_quant_type="nf4", 131 | # bnb_4bit_compute_dtype=torch.float16, 132 | # max_memory={0: "15GB"} # Uncomment this line with you encounter CUDA out of memory errors 133 | ) 134 | model.tie_weights() 135 | return model, tokenizer 136 | 137 | 138 | def load_local_llm( 139 | model_id: str, 140 | model_basename: str, 141 | temperature: float, 142 | max_tokens: int, 143 | device_type: str, 144 | ): 145 | """ 146 | Select a model for text generation using the HuggingFace library. 147 | If you are running this for the first time, it will download a model for you. 148 | subsequent runs will use the model from the disk. 149 | 150 | Args: 151 | device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU. 152 | model_id (str): Identifier of the model to load from HuggingFace's model hub. 153 | model_basename (str, optional): Basename of the model if using quantized models. 154 | Defaults to None. 155 | 156 | Returns: 157 | HuggingFacePipeline: A pipeline object for text generation using the loaded model. 158 | 159 | Raises: 160 | ValueError: If an unsupported model or device type is provided. 161 | """ 162 | logger.info(f"Loading Model: {model_id}, on: {device_type}") 163 | logger.info("This action can take a few minutes!") 164 | 165 | if model_basename.lower() != "none": 166 | if ".gguf" in model_basename.lower(): 167 | llm = load_gguf_hf_model( 168 | model_id, model_basename, max_tokens, temperature, device_type 169 | ) 170 | return llm 171 | 172 | model, tokenizer = load_full_hf_model(model_id, None, device_type) 173 | # Load configuration from the model to avoid warnings 174 | generation_config = GenerationConfig.from_pretrained(model_id) 175 | # see here for details: 176 | # https://huggingface.co/docs/transformers/ 177 | # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns 178 | 179 | # Create a pipeline for text generation 180 | pipe = pipeline( 181 | "text-generation", 182 | model=model, 183 | tokenizer=tokenizer, 184 | max_length=max_tokens, 185 | temperature=temperature, 186 | # top_p=0.95, 187 | repetition_penalty=1.15, 188 | generation_config=generation_config, 189 | ) 190 | local_llm = HuggingFacePipeline(pipeline=pipe) 191 | logger.info("Local LLM Loaded") 192 | 193 | return local_llm 194 | -------------------------------------------------------------------------------- /toolkit/prompts.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import PromptTemplate 2 | from langchain.prompts.chat import ( 3 | ChatPromptTemplate, 4 | HumanMessagePromptTemplate, 5 | SystemMessagePromptTemplate, 6 | ) 7 | from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model 8 | 9 | # ================================================================================ 10 | 11 | REFINE_QA_TEMPLATE = """Break down or rephrase the follow up input into fewer than 3 heterogeneous one-hop queries to be the input of a retrieval tool, if the follow up inout is multi-hop, multi-step, complex or comparative queries and relevant to Chat History and Document Names. Otherwise keep the follow up input as it is. 12 | 13 | 14 | The output format should strictly follow the following, and each query can only conatain 1 document name: 15 | ``` 16 | 1. One-hop standalone query 17 | ... 18 | 3. One-hop standalone query 19 | ... 20 | ``` 21 | 22 | 23 | Document Names in the database: 24 | ``` 25 | {database} 26 | ``` 27 | 28 | 29 | Chat History: 30 | ``` 31 | {chat_history} 32 | ``` 33 | 34 | 35 | Begin: 36 | 37 | Follow Up Input: {question} 38 | 39 | One-hop standalone queries(s): 40 | """ 41 | 42 | 43 | # ================================================================================ 44 | 45 | DOCS_SELECTION_TEMPLATE = """Below are some verified sources and a human input. If you think any of them are relevant to the human input, then list all possible context numbers. 46 | 47 | ``` 48 | {snippets} 49 | ``` 50 | 51 | The output format must be like the following, nothing else. If not, you will output []: 52 | [0, ..., n] 53 | 54 | Human Input: {query} 55 | """ 56 | 57 | 58 | # ================================================================================ 59 | 60 | RETRIEVAL_QA_SYS = """You are a helpful assistant designed by IncarnaMind. 61 | If you think the below below information are relevant to the human input, please respond to the human based on the relevant retrieved sources; otherwise, respond in your own words only about the human input.""" 62 | 63 | 64 | RETRIEVAL_QA_TEMPLATE = """ 65 | File Names in the database: 66 | ``` 67 | {database} 68 | ``` 69 | 70 | 71 | Chat History: 72 | ``` 73 | {chat_history} 74 | ``` 75 | 76 | 77 | Verified Sources: 78 | ``` 79 | {context} 80 | ``` 81 | 82 | 83 | User: {question} 84 | """ 85 | 86 | 87 | RETRIEVAL_QA_CHAT_TEMPLATE = """ 88 | File Names in the database: 89 | ``` 90 | {database} 91 | ``` 92 | 93 | 94 | Chat History: 95 | ``` 96 | {chat_history} 97 | ``` 98 | 99 | 100 | Verified Sources: 101 | ``` 102 | {context} 103 | ``` 104 | """ 105 | 106 | 107 | class PromptTemplates: 108 | """_summary_""" 109 | 110 | def __init__(self): 111 | self.refine_qa_prompt = REFINE_QA_TEMPLATE 112 | self.docs_selection_prompt = DOCS_SELECTION_TEMPLATE 113 | self.retrieval_qa_sys = RETRIEVAL_QA_SYS 114 | self.retrieval_qa_prompt = RETRIEVAL_QA_TEMPLATE 115 | self.retrieval_qa_chat_prompt = RETRIEVAL_QA_CHAT_TEMPLATE 116 | 117 | def get_refine_qa_template(self, llm: str): 118 | """get the refine qa prompt template""" 119 | if "llama" in llm.lower(): 120 | temp = f"[INST] {self.refine_qa_prompt} [/INST]" 121 | else: 122 | temp = self.refine_qa_prompt 123 | 124 | return PromptTemplate( 125 | input_variables=["database", "chat_history", "question"], 126 | template=temp, 127 | ) 128 | 129 | def get_docs_selection_template(self, llm: str): 130 | """get the docs selection prompt template""" 131 | if "llama" in llm.lower(): 132 | temp = f"[INST] {self.docs_selection_prompt} [/INST]" 133 | else: 134 | temp = self.docs_selection_prompt 135 | 136 | return PromptTemplate( 137 | input_variables=["snippets", "query"], 138 | template=temp, 139 | ) 140 | 141 | def get_retrieval_qa_template_selector(self, llm: str): 142 | """get the retrieval qa prompt template""" 143 | if "llama" in llm.lower(): 144 | temp = f"[INST] <>\n{self.retrieval_qa_sys}\n<>\n\n{self.retrieval_qa_prompt} [/INST]" 145 | messages = [ 146 | SystemMessagePromptTemplate.from_template( 147 | f"[INST] <>\n{self.retrieval_qa_sys}\n<>\n\n{self.retrieval_qa_chat_prompt} [/INST]" 148 | ), 149 | HumanMessagePromptTemplate.from_template("{question}"), 150 | ] 151 | else: 152 | temp = f"{self.retrieval_qa_sys}\n{self.retrieval_qa_prompt}" 153 | messages = [ 154 | SystemMessagePromptTemplate.from_template( 155 | f"{self.retrieval_qa_sys}\n{self.retrieval_qa_chat_prompt}" 156 | ), 157 | HumanMessagePromptTemplate.from_template("{question}"), 158 | ] 159 | 160 | prompt_temp = PromptTemplate( 161 | template=temp, 162 | input_variables=["database", "chat_history", "context", "question"], 163 | ) 164 | prompt_temp_chat = ChatPromptTemplate.from_messages(messages) 165 | 166 | return ConditionalPromptSelector( 167 | default_prompt=prompt_temp, 168 | conditionals=[(is_chat_model, prompt_temp_chat)], 169 | ) 170 | -------------------------------------------------------------------------------- /toolkit/retrivers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides custom implementation of a document retriever, designed for multi-stage retrieval. 3 | The system uses ensemble methods combining BM25 and Chroma Embeddings to retrieve relevant documents for a given query. 4 | It also utilizes various optimizations like rank fusion and weighted reciprocal rank by Langchain. 5 | 6 | Classes: 7 | -------- 8 | - MyEnsembleRetriever: Custom retriever for BM25 and Chroma Embeddings. 9 | - MyRetriever: Handles multi-stage retrieval. 10 | 11 | """ 12 | import re 13 | import ast 14 | import copy 15 | import math 16 | import logging 17 | from typing import Dict, List, Optional 18 | from langchain.chains import LLMChain 19 | from langchain.schema import BaseRetriever, Document 20 | from langchain.retrievers import BM25Retriever, EnsembleRetriever 21 | from langchain.callbacks.manager import ( 22 | AsyncCallbackManagerForRetrieverRun, 23 | CallbackManagerForRetrieverRun, 24 | AsyncCallbackManagerForChainRun, 25 | CallbackManagerForChainRun, 26 | ) 27 | 28 | from toolkit.utils import Config, clean_text, DocIndexer, IndexerOperator 29 | from toolkit.prompts import PromptTemplates 30 | 31 | prompt_templates = PromptTemplates() 32 | 33 | configs = Config("configparser.ini") 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | class MyEnsembleRetriever(EnsembleRetriever): 38 | """ 39 | Custom retriever for BM24 and Chroma Embeddings 40 | """ 41 | 42 | retrievers: Dict[str, BaseRetriever] 43 | 44 | def rank_fusion( 45 | self, query: str, run_manager: CallbackManagerForRetrieverRun 46 | ) -> List[Document]: 47 | """ 48 | Retrieve the results of the retrievers and use rank_fusion_func to get 49 | the final result. 50 | 51 | Args: 52 | query: The query to search for. 53 | 54 | Returns: 55 | A list of reranked documents. 56 | """ 57 | # Get the results of all retrievers. 58 | retriever_docs = [] 59 | for key, retriever in self.retrievers.items(): 60 | if key == "bm25": 61 | res = retriever.get_relevant_documents( 62 | clean_text(query), 63 | callbacks=run_manager.get_child(tag=f"retriever_{key}"), 64 | ) 65 | retriever_docs.append(res) 66 | else: 67 | res = retriever.get_relevant_documents( 68 | query, callbacks=run_manager.get_child(tag=f"retriever_{key}") 69 | ) 70 | retriever_docs.append(res) 71 | 72 | # apply rank fusion 73 | fused_documents = self.weighted_reciprocal_rank(retriever_docs) 74 | 75 | return fused_documents 76 | 77 | async def arank_fusion( 78 | self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun 79 | ) -> List[Document]: 80 | """ 81 | Asynchronously retrieve the results of the retrievers 82 | and use rank_fusion_func to get the final result. 83 | 84 | Args: 85 | query: The query to search for. 86 | 87 | Returns: 88 | A list of reranked documents. 89 | """ 90 | 91 | # Get the results of all retrievers. 92 | retriever_docs = [] 93 | for key, retriever in self.retrievers.items(): 94 | if key == "bm25": 95 | res = retriever.get_relevant_documents( 96 | clean_text(query), 97 | callbacks=run_manager.get_child(tag=f"retriever_{key}"), 98 | ) 99 | retriever_docs.append(res) 100 | # print("retriever_docs 1:", res) 101 | else: 102 | res = await retriever.aget_relevant_documents( 103 | query, callbacks=run_manager.get_child(tag=f"retriever_{key}") 104 | ) 105 | retriever_docs.append(res) 106 | 107 | # apply rank fusion 108 | fused_documents = self.weighted_reciprocal_rank(retriever_docs) 109 | 110 | return fused_documents 111 | 112 | def weighted_reciprocal_rank( 113 | self, doc_lists: List[List[Document]] 114 | ) -> List[Document]: 115 | """ 116 | Perform weighted Reciprocal Rank Fusion on multiple rank lists. 117 | You can find more details about RRF here: 118 | https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf 119 | 120 | Args: 121 | doc_lists: A list of rank lists, where each rank list contains unique items. 122 | 123 | Returns: 124 | list: The final aggregated list of items sorted by their weighted RRF 125 | scores in descending order. 126 | """ 127 | if len(doc_lists) != len(self.weights): 128 | raise ValueError( 129 | "Number of rank lists must be equal to the number of weights." 130 | ) 131 | 132 | # replace the page_content with the original uncleaned page_content 133 | doc_lists_ = copy.copy(doc_lists) 134 | for doc_list in doc_lists_: 135 | for doc in doc_list: 136 | doc.page_content = doc.metadata["page_content"] 137 | # doc.metadata["page_content"] = None 138 | 139 | # Create a union of all unique documents in the input doc_lists 140 | all_documents = set() 141 | for doc_list in doc_lists_: 142 | for doc in doc_list: 143 | all_documents.add(doc.page_content) 144 | 145 | # Initialize the RRF score dictionary for each document 146 | rrf_score_dic = {doc: 0.0 for doc in all_documents} 147 | 148 | # Calculate RRF scores for each document 149 | for doc_list, weight in zip(doc_lists_, self.weights): 150 | for rank, doc in enumerate(doc_list, start=1): 151 | rrf_score = weight * (1 / (rank + self.c)) 152 | rrf_score_dic[doc.page_content] += rrf_score 153 | 154 | # Sort documents by their RRF scores in descending order 155 | sorted_documents = sorted( 156 | rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True 157 | ) 158 | 159 | # Map the sorted page_content back to the original document objects 160 | page_content_to_doc_map = { 161 | doc.page_content: doc for doc_list in doc_lists_ for doc in doc_list 162 | } 163 | sorted_docs = [ 164 | page_content_to_doc_map[page_content] for page_content in sorted_documents 165 | ] 166 | 167 | return sorted_docs 168 | 169 | 170 | class MyRetriever: 171 | """ 172 | Retriever class to handle multi-stage retrieval. 173 | """ 174 | 175 | def __init__( 176 | self, 177 | llm, 178 | embedding_chunks_small: List[Document], 179 | embedding_chunks_medium: List[Document], 180 | docs_chunks_small: DocIndexer, 181 | docs_chunks_medium: DocIndexer, 182 | first_retrieval_k: int, 183 | second_retrieval_k: int, 184 | num_windows: int, 185 | retriever_weights: List[float], 186 | ): 187 | """ 188 | Initialize the MyRetriever class. 189 | 190 | Args: 191 | llm: Language model for retrieval. 192 | embedding_chunks_small (List[Document]): List of small embedding chunks. 193 | embedding_chunks_medium (List[Document]): List of medium embedding chunks. 194 | docs_chunks_small (DocIndexer): Document indexer for small chunks. 195 | docs_chunks_medium (DocIndexer): Document indexer for medium chunks. 196 | first_retrieval_k (int): Number of top documents to retrieve in first retrieval. 197 | second_retrieval_k (int): Number of top documents to retrieve in second retrieval. 198 | num_windows (int): Number of overlapping windows to consider. 199 | retriever_weights (List[float]): Weights for ensemble retrieval. 200 | """ 201 | self.llm = llm 202 | self.embedding_chunks_small = embedding_chunks_small 203 | self.embedding_chunks_medium = embedding_chunks_medium 204 | self.docs_index_small = DocIndexer(docs_chunks_small) 205 | self.docs_index_medium = DocIndexer(docs_chunks_medium) 206 | 207 | self.first_retrieval_k = first_retrieval_k 208 | self.second_retrieval_k = second_retrieval_k 209 | self.num_windows = num_windows 210 | self.retriever_weights = retriever_weights 211 | 212 | def get_retriever( 213 | self, 214 | docs_chunks, 215 | emb_chunks, 216 | emb_filter=None, 217 | k=2, 218 | weights=(0.5, 0.5), 219 | ): 220 | """ 221 | Initialize and return a retriever instance with specified parameters. 222 | 223 | Args: 224 | docs_chunks: The document chunks for the BM25 retriever. 225 | emb_chunks: The document chunks for the Embedding retriever. 226 | emb_filter: A filter for embedding retriever. 227 | k (int): The number of top documents to return. 228 | weights (list): Weights for ensemble retrieval. 229 | 230 | Returns: 231 | MyEnsembleRetriever: An instance of MyEnsembleRetriever. 232 | """ 233 | bm25_retriever = BM25Retriever.from_documents(docs_chunks) 234 | bm25_retriever.k = k 235 | 236 | emb_retriever = emb_chunks.as_retriever( 237 | search_kwargs={ 238 | "filter": emb_filter, 239 | "k": k, 240 | "search_type": "mmr", 241 | } 242 | ) 243 | return MyEnsembleRetriever( 244 | retrievers={"bm25": bm25_retriever, "chroma": emb_retriever}, 245 | weights=weights, 246 | ) 247 | 248 | def find_overlaps(self, doc: List[Document]): 249 | """ 250 | Find overlapping intervals of windows. 251 | 252 | Args: 253 | doc (Document): A document object to find overlaps in. 254 | 255 | Returns: 256 | list: A list of overlapping intervals. 257 | """ 258 | intervals = [] 259 | for item in doc: 260 | intervals.append( 261 | ( 262 | item.metadata["large_chunks_idx_lower_bound"], 263 | item.metadata["large_chunks_idx_upper_bound"], 264 | ) 265 | ) 266 | remaining_intervals, grouped_intervals, centroids = intervals.copy(), [], [] 267 | 268 | while remaining_intervals: 269 | curr_interval = remaining_intervals.pop(0) 270 | curr_group = [curr_interval] 271 | subset_interval = None 272 | 273 | for start, end in remaining_intervals.copy(): 274 | for s, e in curr_group: 275 | overlap = set(range(s, e + 1)) & set(range(start, end + 1)) 276 | if overlap: 277 | curr_group.append((start, end)) 278 | remaining_intervals.remove((start, end)) 279 | if set(range(start, end + 1)).issubset(set(range(s, e + 1))): 280 | subset_interval = (start, end) 281 | break 282 | 283 | if subset_interval: 284 | centroid = [math.ceil((subset_interval[0] + subset_interval[1]) / 2)] 285 | elif len(curr_group) > 2: 286 | first_overlap = max( 287 | set(range(curr_group[0][0], curr_group[0][1] + 1)) 288 | & set(range(curr_group[1][0], curr_group[1][1] + 1)) 289 | ) 290 | last_overlap_set = set( 291 | range(curr_group[-1][0], curr_group[-1][1] + 1) 292 | ) & set(range(curr_group[-2][0], curr_group[-2][1] + 1)) 293 | 294 | if not last_overlap_set: 295 | last_overlap = first_overlap # Fallback if no overlap 296 | else: 297 | last_overlap = min(last_overlap_set) 298 | 299 | step = 1 if first_overlap <= last_overlap else -1 300 | centroid = list(range(first_overlap, last_overlap + step, step)) 301 | else: 302 | centroid = [ 303 | round( 304 | sum([math.ceil((s + e) / 2) for s, e in curr_group]) 305 | / len(curr_group) 306 | ) 307 | ] 308 | 309 | grouped_intervals.append( 310 | curr_group if len(curr_group) > 1 else curr_group[0] 311 | ) 312 | centroids.extend(centroid) 313 | 314 | return centroids 315 | 316 | def get_filter(self, top_k: int, file_md5: str, doc: List[Document]): 317 | """ 318 | Create a filter for retrievers based on overlapping intervals. 319 | 320 | Args: 321 | top_k (int): Number of top intervals to consider. 322 | file_md5 (str): MD5 hash of the file to filter. 323 | doc (List[Document]): List of document objects. 324 | 325 | Returns: 326 | tuple: A tuple of containing dictionary filters for DocIndexer and Chroma retrievers. 327 | """ 328 | overlaps = self.find_overlaps(doc) 329 | if len(overlaps) < 1: 330 | raise ValueError("No overlapping intervals found.") 331 | 332 | overlaps_k = overlaps[:top_k] 333 | logger.info("windows_at_2nd_retrieval: %s", overlaps_k) 334 | search_dict_docindexer = {"OR": []} 335 | search_dict_chroma = {"$or": []} 336 | 337 | for chunk_idx in overlaps_k: 338 | search_dict_docindexer["OR"].append( 339 | { 340 | "large_chunks_idx_lower_bound": ( 341 | IndexerOperator.LTE, 342 | chunk_idx, 343 | ), 344 | "large_chunks_idx_upper_bound": ( 345 | IndexerOperator.GTE, 346 | chunk_idx, 347 | ), 348 | "source_md5": (IndexerOperator.EQ, file_md5), 349 | } 350 | ) 351 | 352 | if len(overlaps_k) == 1: 353 | search_dict_chroma = { 354 | "$and": [ 355 | {"large_chunks_idx_lower_bound": {"$lte": overlaps_k[0]}}, 356 | {"large_chunks_idx_upper_bound": {"$gte": overlaps_k[0]}}, 357 | {"source_md5": {"$eq": file_md5}}, 358 | ] 359 | } 360 | else: 361 | search_dict_chroma["$or"].append( 362 | { 363 | "$and": [ 364 | {"large_chunks_idx_lower_bound": {"$lte": chunk_idx}}, 365 | {"large_chunks_idx_upper_bound": {"$gte": chunk_idx}}, 366 | {"source_md5": {"$eq": file_md5}}, 367 | ] 368 | } 369 | ) 370 | 371 | return search_dict_docindexer, search_dict_chroma 372 | 373 | def get_relevant_doc_ids(self, docs: List[Document], query: str): 374 | """ 375 | Get relevant document IDs given a query using an LLM. 376 | 377 | Args: 378 | docs (List[Document]): List of document objects to find relevant IDs in. 379 | query (str): The query string. 380 | 381 | Returns: 382 | list: A list of relevant document IDs. 383 | """ 384 | snippets = "\n\n\n".join( 385 | [ 386 | f"Context {idx}:\n{{{doc.page_content}}}. {{source: {doc.metadata['source']}}}" 387 | for idx, doc in enumerate(docs) 388 | ] 389 | ) 390 | id_chain = LLMChain( 391 | llm=self.llm, 392 | prompt=prompt_templates.get_docs_selection_template(configs.model_name), 393 | output_key="IDs", 394 | ) 395 | ids = id_chain.run({"query": query, "snippets": snippets}) 396 | logger.info("relevant doc ids: %s", ids) 397 | pattern = r"\[\s*\d+\s*(?:,\s*\d+\s*)*\]" 398 | match = re.search(pattern, ids) 399 | if match: 400 | return ast.literal_eval(match.group(0)) 401 | else: 402 | return [] 403 | 404 | def get_relevant_documents( 405 | self, 406 | query: str, 407 | num_query: int, 408 | *, 409 | run_manager: Optional[CallbackManagerForChainRun] = None, 410 | ) -> List[Document]: 411 | """ 412 | Perform multi-stage retrieval to get relevant documents. 413 | 414 | Args: 415 | query (str): The query string. 416 | num_query (int): Number of queries. 417 | run_manager (Optional[CallbackManagerForChainRun], optional): Callback manager for chain run. 418 | 419 | Returns: 420 | List[Document]: A list of relevant documents. 421 | """ 422 | # ! First retrieval 423 | first_retriever = self.get_retriever( 424 | docs_chunks=self.docs_index_small.documents, 425 | emb_chunks=self.embedding_chunks_small, 426 | emb_filter=None, 427 | k=self.first_retrieval_k, 428 | weights=self.retriever_weights, 429 | ) 430 | first = first_retriever.get_relevant_documents( 431 | query, callbacks=run_manager.get_child() 432 | ) 433 | for doc in first: 434 | logger.info("----1st retrieval----: %s", doc) 435 | ids_clean = self.get_relevant_doc_ids(first, query) 436 | # ids_clean = [0, 1, 2] 437 | logger.info("relevant cleaned doc ids: %s", ids_clean) 438 | qa_chunks = {} # key is file name, value is a list of relevant documents 439 | # res_chunks = [] 440 | if ids_clean and isinstance(ids_clean, list): 441 | source_md5_dict = {} 442 | for ids_c in ids_clean: 443 | if ids_c < len(first): 444 | if ids_c not in source_md5_dict: 445 | source_md5_dict[first[ids_c].metadata["source_md5"]] = [ 446 | first[ids_c] 447 | ] 448 | # else: 449 | # source_md5_dict[first[ids_c].metadata["source_md5"]].append( 450 | # ids_clean[ids_c] 451 | # ) 452 | if len(source_md5_dict) == 0: 453 | source_md5_dict[first[0].metadata["source_md5"]] = [first[0]] 454 | num_docs = len(source_md5_dict.keys()) 455 | third_num_k = max( 456 | 1, 457 | ( 458 | int( 459 | ( 460 | configs.max_llm_context 461 | / (configs.base_chunk_size * configs.chunk_scale) 462 | ) 463 | // (num_docs * num_query) 464 | ) 465 | ), 466 | ) 467 | 468 | for source_md5, docs in source_md5_dict.items(): 469 | logger.info( 470 | "selected_docs_at_1st_retrieval: %s", docs[0].metadata["source"] 471 | ) 472 | second_docs_chunks = self.docs_index_small.retrieve_metadata( 473 | { 474 | "source_md5": (IndexerOperator.EQ, source_md5), 475 | } 476 | ) 477 | second_retriever = self.get_retriever( 478 | docs_chunks=second_docs_chunks, 479 | emb_chunks=self.embedding_chunks_small, 480 | emb_filter={"source_md5": source_md5}, 481 | k=self.second_retrieval_k, 482 | weights=self.retriever_weights, 483 | ) 484 | # ! Second retrieval 485 | second = second_retriever.get_relevant_documents( 486 | query, callbacks=run_manager.get_child() 487 | ) 488 | for doc in second: 489 | logger.info("----2nd retrieval----: %s", doc) 490 | docs.extend(second) 491 | docindexer_filter, chroma_filter = self.get_filter( 492 | self.num_windows, source_md5, docs 493 | ) 494 | third_docs_chunks = self.docs_index_medium.retrieve_metadata( 495 | docindexer_filter 496 | ) 497 | third_retriever = self.get_retriever( 498 | docs_chunks=third_docs_chunks, 499 | emb_chunks=self.embedding_chunks_medium, 500 | emb_filter=chroma_filter, 501 | k=third_num_k, 502 | weights=self.retriever_weights, 503 | ) 504 | # ! Third retrieval 505 | third_temp = third_retriever.get_relevant_documents( 506 | query, callbacks=run_manager.get_child() 507 | ) 508 | third = third_temp[:third_num_k] 509 | # chunks = sorted(third, key=lambda x: x.metadata["medium_chunk_idx"]) 510 | for doc in third: 511 | logger.info( 512 | "----3rd retrieval----page_content: %s", [doc.page_content] 513 | ) 514 | mtdata = doc.metadata 515 | mtdata["page_content"] = None 516 | logger.info("----3rd retrieval----metadata: %s", mtdata) 517 | file_name = third[0].metadata["source"].split("/")[-1] 518 | if file_name not in qa_chunks: 519 | qa_chunks[file_name] = third 520 | else: 521 | qa_chunks[file_name].extend(third) 522 | 523 | return qa_chunks 524 | 525 | async def aget_relevant_documents( 526 | self, 527 | query: str, 528 | num_query: int, 529 | *, 530 | run_manager: AsyncCallbackManagerForChainRun, 531 | ) -> List[Document]: 532 | """ 533 | Asynchronous version of get_relevant_documents method. 534 | 535 | Args: 536 | query (str): The query string. 537 | num_query (int): Number of queries. 538 | run_manager (AsyncCallbackManagerForChainRun): Callback manager for asynchronous chain run. 539 | 540 | Returns: 541 | List[Document]: A list of relevant documents. 542 | """ 543 | # ! First retrieval 544 | first_retriever = self.get_retriever( 545 | docs_chunks=self.docs_index_small.documents, 546 | emb_chunks=self.embedding_chunks_small, 547 | emb_filter=None, 548 | k=self.first_retrieval_k, 549 | weights=self.retriever_weights, 550 | ) 551 | first = await first_retriever.aget_relevant_documents( 552 | query, callbacks=run_manager.get_child() 553 | ) 554 | for doc in first: 555 | logger.info("----1st retrieval----: %s", doc) 556 | ids_clean = self.get_relevant_doc_ids(first, query) 557 | logger.info("relevant doc ids: %s", ids_clean) 558 | qa_chunks = {} # key is file name, value is a list of relevant documents 559 | # res_chunks = [] 560 | if ids_clean and isinstance(ids_clean, list): 561 | source_md5_dict = {} 562 | for ids_c in ids_clean: 563 | if ids_c < len(first): 564 | if ids_c not in source_md5_dict: 565 | source_md5_dict[first[ids_c].metadata["source_md5"]] = [ 566 | first[ids_c] 567 | ] 568 | # else: 569 | # source_md5_dict[first[ids_c].metadata["source_md5"]].append( 570 | # ids_clean[ids_c] 571 | # ) 572 | if len(source_md5_dict) == 0: 573 | source_md5_dict[first[0].metadata["source_md5"]] = [first[0]] 574 | num_docs = len(source_md5_dict.keys()) 575 | third_num_k = max( 576 | 1, 577 | ( 578 | int( 579 | ( 580 | configs.max_llm_context 581 | / (configs.base_chunk_size * configs.chunk_scale) 582 | ) 583 | // (num_docs * num_query) 584 | ) 585 | ), 586 | ) 587 | 588 | for source_md5, docs in source_md5_dict.items(): 589 | logger.info( 590 | "selected_docs_at_1st_retrieval: %s", docs[0].metadata["source"] 591 | ) 592 | second_docs_chunks = self.docs_index_small.retrieve_metadata( 593 | { 594 | "source_md5": (IndexerOperator.EQ, source_md5), 595 | } 596 | ) 597 | second_retriever = self.get_retriever( 598 | docs_chunks=second_docs_chunks, 599 | emb_chunks=self.embedding_chunks_small, 600 | emb_filter={"source_md5": source_md5}, 601 | k=self.second_retrieval_k, 602 | weights=self.retriever_weights, 603 | ) 604 | # ! Second retrieval 605 | second = await second_retriever.aget_relevant_documents( 606 | query, callbacks=run_manager.get_child() 607 | ) 608 | for doc in second: 609 | logger.info("----2nd retrieval----: %s", doc) 610 | docs.extend(second) 611 | docindexer_filter, chroma_filter = self.get_filter( 612 | self.num_windows, source_md5, docs 613 | ) 614 | third_docs_chunks = self.docs_index_medium.retrieve_metadata( 615 | docindexer_filter 616 | ) 617 | third_retriever = self.get_retriever( 618 | docs_chunks=third_docs_chunks, 619 | emb_chunks=self.embedding_chunks_medium, 620 | emb_filter=chroma_filter, 621 | k=third_num_k, 622 | weights=self.retriever_weights, 623 | ) 624 | # ! Third retrieval 625 | third_temp = await third_retriever.aget_relevant_documents( 626 | query, callbacks=run_manager.get_child() 627 | ) 628 | third = third_temp[:third_num_k] 629 | # chunks = sorted(third, key=lambda x: x.metadata["medium_chunk_idx"]) 630 | for doc in third: 631 | logger.info( 632 | "----3rd retrieval----page_content: %s", [doc.page_content] 633 | ) 634 | mtdata = doc.metadata 635 | mtdata["page_content"] = None 636 | logger.info("----3rd retrieval----metadata: %s", mtdata) 637 | file_name = third[0].metadata["source"].split("/")[-1] 638 | if file_name not in qa_chunks: 639 | qa_chunks[file_name] = third 640 | else: 641 | qa_chunks[file_name].extend(third) 642 | 643 | return qa_chunks 644 | -------------------------------------------------------------------------------- /toolkit/together_api_llm.py: -------------------------------------------------------------------------------- 1 | """The code borrowed from https://colab.research.google.com/drive/1RW2yTxh5b9w7F3IrK00Iz51FTO5W01Rx?usp=sharing#scrollTo=RgbLVmf-o4j7""" 2 | import os 3 | from typing import Any, Dict 4 | import together 5 | from pydantic import Extra, root_validator 6 | 7 | from langchain.llms.base import LLM 8 | from langchain.utils import get_from_dict_or_env 9 | from toolkit.utils import Config 10 | 11 | configs = Config("configparser.ini") 12 | os.environ["TOGETHER_API_KEY"] = configs.together_api_key 13 | 14 | # together.api_key = configs.together_api_key 15 | # models = together.Models.list() 16 | # for idx, model in enumerate(models): 17 | # print(idx, model["name"]) 18 | 19 | 20 | class TogetherLLM(LLM): 21 | """Together large language models.""" 22 | 23 | model: str = "togethercomputer/llama-2-70b-chat" 24 | """model endpoint to use""" 25 | 26 | together_api_key: str = os.environ["TOGETHER_API_KEY"] 27 | """Together API key""" 28 | 29 | temperature: float = 0 30 | """What sampling temperature to use.""" 31 | 32 | max_tokens: int = 512 33 | """The maximum number of tokens to generate in the completion.""" 34 | 35 | class Config: 36 | extra = "forbid" 37 | 38 | # @root_validator() 39 | # def validate_environment(cls, values: Dict) -> Dict: 40 | # """Validate that the API key is set.""" 41 | # api_key = get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY") 42 | # values["together_api_key"] = api_key 43 | # return values 44 | 45 | @property 46 | def _llm_type(self) -> str: 47 | """Return type of LLM.""" 48 | return "together" 49 | 50 | def _call( 51 | self, 52 | prompt: str, 53 | **kwargs: Any, 54 | ) -> str: 55 | """Call to Together endpoint.""" 56 | together.api_key = self.together_api_key 57 | output = together.Complete.create( 58 | prompt, 59 | model=self.model, 60 | max_tokens=self.max_tokens, 61 | temperature=self.temperature, 62 | ) 63 | text = output["output"]["choices"][0]["text"] 64 | return text 65 | 66 | 67 | # if __name__ == "__main__": 68 | # test_llm = TogetherLLM( 69 | # model="togethercomputer/llama-2-70b-chat", temperature=0, max_tokens=1000 70 | # ) 71 | 72 | # print(test_llm("What are the olympics? ")) 73 | -------------------------------------------------------------------------------- /toolkit/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The widgets defines utility functions for loading data, text cleaning, 3 | and indexing documents, as well as classes for handling document queries 4 | and formatting chat history. 5 | """ 6 | import re 7 | import pickle 8 | import string 9 | import logging 10 | import configparser 11 | from enum import Enum 12 | from typing import List, Tuple, Union 13 | import nltk 14 | from nltk.stem import WordNetLemmatizer 15 | from nltk.tokenize import word_tokenize 16 | from nltk.corpus import stopwords 17 | import torch 18 | import tiktoken 19 | from langchain.vectorstores import Chroma 20 | 21 | from langchain.schema import Document, BaseMessage 22 | from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings 23 | from langchain.embeddings.openai import OpenAIEmbeddings 24 | 25 | 26 | tokenizer_name = tiktoken.encoding_for_model("gpt-3.5-turbo") 27 | tokenizer = tiktoken.get_encoding(tokenizer_name.name) 28 | 29 | # if nltk stopwords, punkt and wordnet are not downloaded, download it 30 | try: 31 | nltk.data.find("corpora/stopwords") 32 | except LookupError: 33 | nltk.download("stopwords") 34 | try: 35 | nltk.data.find("tokenizers/punkt") 36 | except LookupError: 37 | nltk.download("punkt") 38 | try: 39 | nltk.data.find("corpora/wordnet") 40 | except LookupError: 41 | nltk.download("wordnet") 42 | 43 | ChatTurnType = Union[Tuple[str, str], BaseMessage] 44 | _ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "} 45 | 46 | 47 | class Config: 48 | """Initializes configs.""" 49 | 50 | def __init__(self, config_file): 51 | self.config = configparser.ConfigParser(interpolation=None) 52 | self.config.read(config_file) 53 | 54 | # Tokens 55 | self.openai_api_key = self.config.get("tokens", "OPENAI_API_KEY") 56 | self.anthropic_api_key = self.config.get("tokens", "ANTHROPIC_API_KEY") 57 | self.together_api_key = self.config.get("tokens", "TOGETHER_API_KEY") 58 | self.huggingface_token = self.config.get("tokens", "HUGGINGFACE_TOKEN") 59 | self.version = self.config.get("tokens", "VERSION") 60 | 61 | # Directory 62 | self.docs_dir = self.config.get("directory", "DOCS_DIR") 63 | self.db_dir = self.config.get("directory", "db_DIR") 64 | self.local_model_dir = self.config.get("directory", "LOCAL_MODEL_DIR") 65 | 66 | # Parameters 67 | self.model_name = self.config.get("parameters", "MODEL_NAME") 68 | self.temperature = self.config.getfloat("parameters", "TEMPURATURE") 69 | self.max_chat_history = self.config.getint("parameters", "MAX_CHAT_HISTORY") 70 | self.max_llm_context = self.config.getint("parameters", "MAX_LLM_CONTEXT") 71 | self.max_llm_generation = self.config.getint("parameters", "MAX_LLM_GENERATION") 72 | self.embedding_name = self.config.get("parameters", "EMBEDDING_NAME") 73 | 74 | self.n_gpu_layers = self.config.getint("parameters", "N_GPU_LAYERS") 75 | self.n_batch = self.config.getint("parameters", "N_BATCH") 76 | 77 | self.base_chunk_size = self.config.getint("parameters", "BASE_CHUNK_SIZE") 78 | self.chunk_overlap = self.config.getint("parameters", "CHUNK_OVERLAP") 79 | self.chunk_scale = self.config.getint("parameters", "CHUNK_SCALE") 80 | self.window_steps = self.config.getint("parameters", "WINDOW_STEPS") 81 | self.window_scale = self.config.getint("parameters", "WINDOW_SCALE") 82 | 83 | self.retriever_weights = [ 84 | float(x.strip()) 85 | for x in self.config.get("parameters", "RETRIEVER_WEIGHTS").split(",") 86 | ] 87 | self.first_retrieval_k = self.config.getint("parameters", "FIRST_RETRIEVAL_K") 88 | self.second_retrieval_k = self.config.getint("parameters", "SECOND_RETRIEVAL_K") 89 | self.num_windows = self.config.getint("parameters", "NUM_WINDOWS") 90 | 91 | # Logging 92 | self.logging_enabled = self.config.getboolean("logging", "enabled") 93 | self.logging_level = self.config.get("logging", "level") 94 | self.logging_filename = self.config.get("logging", "filename") 95 | self.logging_format = self.config.get("logging", "format") 96 | 97 | self.configure_logging() 98 | 99 | def configure_logging(self): 100 | """ 101 | Configure the logger for each .py files. 102 | """ 103 | 104 | if not self.logging_enabled: 105 | logging.disable(logging.CRITICAL + 1) 106 | return 107 | 108 | log_level = self.config.get("logging", "level") 109 | log_filename = self.config.get("logging", "filename") 110 | log_format = self.config.get("logging", "format") 111 | 112 | logging.basicConfig(level=log_level, filename=log_filename, format=log_format) 113 | 114 | 115 | def configure_logger(): 116 | """ 117 | Configure the logger for each .py files. 118 | """ 119 | config = configparser.ConfigParser(interpolation=None) 120 | config.read("configparser.ini") 121 | 122 | enabled = config.getboolean("logging", "enabled") 123 | 124 | if not enabled: 125 | logging.disable(logging.CRITICAL + 1) 126 | return 127 | 128 | log_level = config.get("logging", "level") 129 | log_filename = config.get("logging", "filename") 130 | log_format = config.get("logging", "format") 131 | 132 | logging.basicConfig(level=log_level, filename=log_filename, format=log_format) 133 | 134 | 135 | def tiktoken_len(text): 136 | """token length function""" 137 | tokens = tokenizer.encode(text, disallowed_special=()) 138 | return len(tokens) 139 | 140 | 141 | def check_device(): 142 | """Check if cuda or MPS is available, else fallback to CPU""" 143 | if torch.cuda.is_available(): 144 | device = "cuda" 145 | elif torch.backends.mps.is_available(): 146 | device = "mps" 147 | else: 148 | device = "cpu" 149 | return device 150 | 151 | 152 | def choose_embeddings(embedding_name): 153 | """Choose embeddings for a given model's name""" 154 | try: 155 | if embedding_name == "openAIEmbeddings": 156 | return OpenAIEmbeddings() 157 | elif embedding_name == "hkunlpInstructorLarge": 158 | device = check_device() 159 | return HuggingFaceInstructEmbeddings( 160 | model_name="hkunlp/instructor-large", model_kwargs={"device": device} 161 | ) 162 | else: 163 | device = check_device() 164 | return HuggingFaceEmbeddings(model_name=embedding_name, device=device) 165 | except Exception as error: 166 | raise ValueError(f"Embedding {embedding_name} not supported") from error 167 | 168 | 169 | def load_embedding(store_name, embedding, suffix, path): 170 | """Load chroma embeddings""" 171 | vector_store = Chroma( 172 | persist_directory=f"{path}/chroma_{store_name}_{suffix}", 173 | embedding_function=embedding, 174 | ) 175 | return vector_store 176 | 177 | 178 | def load_pickle(prefix, suffix, path): 179 | """Load langchain documents from a pickle file. 180 | 181 | Args: 182 | store_name (str): The name of the store where data is saved. 183 | suffix (str): Suffix to append to the store name. 184 | path (str): The path where the pickle file is stored. 185 | 186 | Returns: 187 | Document: documents from the pickle file 188 | """ 189 | with open(f"{path}/{prefix}_{suffix}.pkl", "rb") as file: 190 | return pickle.load(file) 191 | 192 | 193 | def clean_text(text): 194 | """ 195 | Converts text to lowercase, removes punctuation, stopwords, and lemmatizes it 196 | for BM25 retriever. 197 | 198 | Parameters: 199 | text (str): The text to be cleaned. 200 | 201 | Returns: 202 | str: The cleaned and lemmatized text. 203 | """ 204 | # remove [SEP] in the text 205 | text = text.replace("[SEP]", "") 206 | # Tokenization 207 | tokens = word_tokenize(text) 208 | # Lowercasing 209 | tokens = [w.lower() for w in tokens] 210 | # Remove punctuation 211 | table = str.maketrans("", "", string.punctuation) 212 | stripped = [w.translate(table) for w in tokens] 213 | # Keep tokens that are alphabetic, numeric, or contain both. 214 | words = [ 215 | word 216 | for word in stripped 217 | if word.isalpha() 218 | or word.isdigit() 219 | or (re.search("\d", word) and re.search("[a-zA-Z]", word)) 220 | ] 221 | # Remove stopwords 222 | stop_words = set(stopwords.words("english")) 223 | words = [w for w in words if w not in stop_words] 224 | # Lemmatization (or you could use stemming instead) 225 | lemmatizer = WordNetLemmatizer() 226 | lemmatized = [lemmatizer.lemmatize(w) for w in words] 227 | # Convert list of words to a string 228 | lemmatized_ = " ".join(lemmatized) 229 | 230 | return lemmatized_ 231 | 232 | 233 | class IndexerOperator(Enum): 234 | """ 235 | Enumeration for different query operators used in indexing. 236 | """ 237 | 238 | EQ = "==" 239 | GT = ">" 240 | GTE = ">=" 241 | LT = "<" 242 | LTE = "<=" 243 | 244 | 245 | class DocIndexer: 246 | """ 247 | A class to handle indexing and searching of documents. 248 | 249 | Attributes: 250 | documents (List[Document]): List of documents to be indexed. 251 | """ 252 | 253 | def __init__(self, documents): 254 | self.documents = documents 255 | self.index = self.build_index(documents) 256 | 257 | def build_index(self, documents): 258 | """ 259 | Build an index for the given list of documents. 260 | 261 | Parameters: 262 | documents (List[Document]): The list of documents to be indexed. 263 | 264 | Returns: 265 | dict: The built index. 266 | """ 267 | index = {} 268 | for doc in documents: 269 | for key, value in doc.metadata.items(): 270 | if key not in index: 271 | index[key] = {} 272 | if value not in index[key]: 273 | index[key][value] = [] 274 | index[key][value].append(doc) 275 | return index 276 | 277 | def retrieve_metadata(self, search_dict): 278 | """ 279 | Retrieve documents based on the search criteria provided in search_dict. 280 | 281 | Parameters: 282 | search_dict (dict): Dictionary specifying the search criteria. 283 | It can contain "AND" or "OR" operators for 284 | complex queries. 285 | 286 | Returns: 287 | List[Document]: List of documents that match the search criteria. 288 | """ 289 | if "AND" in search_dict: 290 | return self._handle_and(search_dict["AND"]) 291 | elif "OR" in search_dict: 292 | return self._handle_or(search_dict["OR"]) 293 | else: 294 | return self._handle_single(search_dict) 295 | 296 | def _handle_and(self, search_dicts): 297 | results = [self.retrieve_metadata(sd) for sd in search_dicts] 298 | if results: 299 | intersection = set.intersection( 300 | *[set(map(self._hash_doc, r)) for r in results] 301 | ) 302 | return [self._unhash_doc(h) for h in intersection] 303 | else: 304 | return [] 305 | 306 | def _handle_or(self, search_dicts): 307 | results = [self.retrieve_metadata(sd) for sd in search_dicts] 308 | union = set.union(*[set(map(self._hash_doc, r)) for r in results]) 309 | return [self._unhash_doc(h) for h in union] 310 | 311 | def _handle_single(self, search_dict): 312 | unions = [] 313 | for key, query in search_dict.items(): 314 | operator, value = query 315 | union = set() 316 | if operator == IndexerOperator.EQ: 317 | if key in self.index and value in self.index[key]: 318 | union.update(map(self._hash_doc, self.index[key][value])) 319 | else: 320 | if key in self.index: 321 | for k, v in self.index[key].items(): 322 | if ( 323 | (operator == IndexerOperator.GT and k > value) 324 | or (operator == IndexerOperator.GTE and k >= value) 325 | or (operator == IndexerOperator.LT and k < value) 326 | or (operator == IndexerOperator.LTE and k <= value) 327 | ): 328 | union.update(map(self._hash_doc, v)) 329 | if union: 330 | unions.append(union) 331 | 332 | if unions: 333 | intersection = set.intersection(*unions) 334 | return [self._unhash_doc(h) for h in intersection] 335 | else: 336 | return [] 337 | 338 | def _hash_doc(self, doc): 339 | return (doc.page_content, frozenset(doc.metadata.items())) 340 | 341 | def _unhash_doc(self, hashed_doc): 342 | page_content, metadata = hashed_doc 343 | return Document(page_content=page_content, metadata=dict(metadata)) 344 | 345 | 346 | def _get_chat_history(chat_history: List[ChatTurnType]) -> str: 347 | buffer = "" 348 | for dialogue_turn in chat_history: 349 | if isinstance(dialogue_turn, BaseMessage): 350 | role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ") 351 | buffer += f"\n{role_prefix}{dialogue_turn.content}" 352 | elif isinstance(dialogue_turn, tuple): 353 | human = "Human: " + dialogue_turn[0] 354 | ai = "Assistant: " + dialogue_turn[1] 355 | buffer += "\n" + "\n".join([human, ai]) 356 | else: 357 | raise ValueError( 358 | f"Unsupported chat history format: {type(dialogue_turn)}." 359 | f" Full chat history: {chat_history} " 360 | ) 361 | return buffer 362 | 363 | 364 | def _get_standalone_questions_list( 365 | standalone_questions_str: str, original_question: str 366 | ) -> List[str]: 367 | pattern = r"\d+\.\s(.*?)(?=\n\d+\.|\n|$)" 368 | 369 | matches = [ 370 | match.group(1) for match in re.finditer(pattern, standalone_questions_str) 371 | ] 372 | if matches: 373 | return matches 374 | 375 | match = re.search( 376 | r"(?i)standalone[^\n]*:[^\n](.*)", standalone_questions_str, re.DOTALL 377 | ) 378 | sentence_source = match.group(1).strip() if match else standalone_questions_str 379 | sentences = sentence_source.split("\n") 380 | 381 | return [ 382 | re.sub( 383 | r"^\((\d+)\)\.? ?|^\d+\.? ?\)?|^(\d+)\) ?|^(\d+)\) ?|^[Qq]uery \d+: ?|^[Qq]uery: ?", 384 | "", 385 | sentence.strip(), 386 | ) 387 | for sentence in sentences 388 | if sentence.strip() 389 | ] 390 | --------------------------------------------------------------------------------