├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── LICENSE
├── README.md
├── configparser.ini
├── convo_qa_chain.py
├── data
├── ABPI Code of Practice for the Pharmaceutical Industry 2021.pdf
├── Attention Is All You Need.pdf
├── Gradient Descent The Ultimate Optimizer.pdf
├── JP Morgan 2022 Environmental Social Governance Report.pdf
├── Language Models are Few-Shot Learners.pdf
├── Language Models are Unsupervised Multitask Learners.pdf
└── United Nations 2022 Annual Report.pdf
├── docs2db.py
├── figs
├── High_Level_Architecture.png
└── Sliding_Window_Chunking.png
├── main.py
├── requirements.txt
└── toolkit
├── ___init__.py
├── local_llm.py
├── prompts.py
├── retrivers.py
├── together_api_llm.py
└── utils.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .history
3 | .vscode
4 | __pycache__
5 | Archieve
6 | database_store
7 | IncarnaMind.log
8 | experiments.ipynb
9 | .pylintrc
10 | .flake8
11 | models/
12 | model/
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🧠 IncarnaMind
2 |
3 | ## 👀 In a Nutshell
4 |
5 | IncarnaMind enables you to chat with your personal documents 📁 (PDF, TXT) using Large Language Models (LLMs) like GPT ([architecture overview](#high-level-architecture)). While OpenAI has recently launched a fine-tuning API for GPT models, it doesn't enable the base pretrained models to learn new data, and the responses can be prone to factual hallucinations. Utilize our [Sliding Window Chunking](#sliding-window-chunking) mechanism and Ensemble Retriever enables efficient querying of both fine-grained and coarse-grained information within your ground truth documents to augment the LLMs.
6 |
7 | Feel free to use it and we welcome any feedback and new feature suggestions 🙌.
8 |
9 | ## ✨ New Updates
10 |
11 | ### Open-Source and Local LLMs Support
12 |
13 | - **Recommended Model:** We've primarily tested with the Llama2 series models and recommend using [llama2-70b-chat](https://huggingface.co/TheBloke/Llama-2-70B-chat-GGUF) (either full or GGUF version) for optimal performance. Feel free to experiment with other LLMs.
14 | - **System Requirements:** It requires more than 35GB of GPU RAM to run the GGUF quantized version.
15 |
16 | ### Alternative Open-Source LLMs Options
17 |
18 | - **Insufficient RAM:** If you're limited by GPU RAM, consider using the [Together.ai](https://api.together.xyz/playground) API. It supports llama2-70b-chat and most other open-source LLMs. Plus, you get $25 in free usage.
19 | - **Upcoming:** Smaller and cost-effecitive, fine-tuned models will be released in the future.
20 |
21 | ### How to use GGUF models
22 |
23 | - For instructions on acquiring and using quantized GGUF LLM (similar to GGML), please refer to this [video](https://www.youtube.com/watch?v=lbFmceo4D5E) (from 10:45 to 12:30)..
24 |
25 | Here is a comparison table of the different models I tested, for reference only:
26 |
27 | | Metrics | GPT-4 | GPT-3.5 | Claude 2.0 | Llama2-70b | Llama2-70b-gguf | Llama2-70b-api |
28 | |-----------|--------|---------|------------|------------|-----------------|----------------|
29 | | Reasoning | High | Medium | High | Medium | Medium | Medium |
30 | | Speed | Medium | High | Medium | Very Low | Low | Medium |
31 | | GPU RAM | N/A | N/A | N/A | Very High | High | N/A |
32 | | Safety | Low | Low | Low | High | High | Low |
33 |
34 | ## 💻 Demo
35 |
36 | https://github.com/junruxiong/IncarnaMind/assets/44308338/89d479fb-de90-4f7c-b166-e54f7bc7344c
37 |
38 | ## 💡 Challenges Addressed
39 |
40 | - **Fixed Chunking**: Traditional RAG tools rely on fixed chunk sizes, limiting their adaptability in handling varying data complexity and context.
41 |
42 | - **Precision vs. Semantics**: Current retrieval methods usually focus either on semantic understanding or precise retrieval, but rarely both.
43 |
44 | - **Single-Document Limitation**: Many solutions can only query one document at a time, restricting multi-document information retrieval.
45 |
46 | - **Stability**: IncarnaMind is compatible with OpenAI GPT, Anthropic Claude, Llama2, and other open-source LLMs, ensuring stable parsing.
47 |
48 | ## 🎯 Key Features
49 |
50 | - **Adaptive Chunking**: Our Sliding Window Chunking technique dynamically adjusts window size and position for RAG, balancing fine-grained and coarse-grained data access based on data complexity and context.
51 |
52 | - **Multi-Document Conversational QA**: Supports simple and multi-hop queries across multiple documents simultaneously, breaking the single-document limitation.
53 |
54 | - **File Compatibility**: Supports both PDF and TXT file formats.
55 |
56 | - **LLM Model Compatibility**: Supports OpenAI GPT, Anthropic Claude, Llama2 and other open-source LLMs.
57 |
58 | ## 🏗 Architecture
59 |
60 | ### High Level Architecture
61 |
62 | 
63 |
64 | ### Sliding Window Chunking
65 |
66 | 
67 |
68 | ## 🚀 Getting Started
69 |
70 | ### 1. Installation
71 |
72 | The installation is simple, you just need to run few commands.
73 |
74 | #### 1.0. Prerequisites
75 |
76 | - 3.8 ≤ Python < 3.11 with [Conda](https://www.anaconda.com/download)
77 | - One/All of [OpenAI API Key](https://beta.openai.com/signup), [Anthropic Claude API Key](https://console.anthropic.com/account/keys), [Together.ai API KEY](https://api.together.xyz/settings/api-keys) or [HuggingFace toekn for Meta Llama models](https://huggingface.co/settings/tokens)
78 | - And of course, your own documents.
79 |
80 | #### 1.1. Clone the repository
81 |
82 | ```shell
83 | git clone https://github.com/junruxiong/IncarnaMind
84 | cd IncarnaMind
85 | ```
86 |
87 | #### 1.2. Setup
88 |
89 | Create Conda virtual environment:
90 |
91 | ```shell
92 | conda create -n IncarnaMind python=3.10
93 | ```
94 |
95 | Activate:
96 |
97 | ```shell
98 | conda activate IncarnaMind
99 | ```
100 |
101 | Install all requirements:
102 |
103 | ```shell
104 | pip install -r requirements.txt
105 | ```
106 |
107 | Install [llama-cpp](https://github.com/abetlen/llama-cpp-python) seperatly if you want to run quantized local LLMs:
108 |
109 | - For `NVIDIA` GPUs support, use `cuBLAS`
110 |
111 | ```shell
112 | CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.1.83 --no-cache-dir
113 | ```
114 |
115 | - For Apple Metal (`M1/M2`) support, use
116 |
117 | ```shell
118 | CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.1.83 --no-cache-dir
119 | ```
120 |
121 | Setup your one/all of API keys in **configparser.ini** file:
122 |
123 | ```shell
124 | [tokens]
125 | OPENAI_API_KEY = (replace_me)
126 | ANTHROPIC_API_KEY = (replace_me)
127 | TOGETHER_API_KEY = (replace_me)
128 | # if you use full Meta-Llama models, you may need Huggingface token to access.
129 | HUGGINGFACE_TOKEN = (replace_me)
130 | ```
131 |
132 | (Optional) Setup your custom parameters in **configparser.ini** file:
133 |
134 | ```shell
135 | [parameters]
136 | PARAMETERS 1 = (replace_me)
137 | PARAMETERS 2 = (replace_me)
138 | ...
139 | PARAMETERS n = (replace_me)
140 | ```
141 |
142 | ### 2. Usage
143 |
144 | #### 2.1. Upload and process your files
145 |
146 | Put all your files (please name each file correctly to maximize the performance) into the **/data** directory and run the following command to ingest all data:
147 | (You can delete example files in the **/data** directory before running the command)
148 |
149 | ```shell
150 | python docs2db.py
151 | ```
152 |
153 | #### 2.2. Run
154 |
155 | In order to start the conversation, run a command like:
156 |
157 | ```shell
158 | python main.py
159 | ```
160 |
161 | #### 2.3. Chat and ask any questions
162 |
163 | Wait for the script to require your input like the below.
164 |
165 | ```shell
166 | Human:
167 | ```
168 |
169 | #### 2.4. Others
170 |
171 | When you start a chat, the system will automatically generate a **IncarnaMind.log** file.
172 | If you want to edit the logging, please edit in the **configparser.ini** file.
173 |
174 | ```shell
175 | [logging]
176 | enabled = True
177 | level = INFO
178 | filename = IncarnaMind.log
179 | format = %(asctime)s [%(levelname)s] %(name)s: %(message)s
180 | ```
181 |
182 | ## 🚫 Limitations
183 |
184 | - Citation is not supported for current version, but will release soon.
185 | - Limited asynchronous capabilities.
186 |
187 | ## 📝 Upcoming Features
188 |
189 | - Frontend UI interface
190 | - Fine-tuned small size open-source LLMs
191 | - OCR support
192 | - Asynchronous optimization
193 | - Support more document formats
194 |
195 | ## 🙌 Acknowledgements
196 |
197 | Special thanks to [Langchain](https://github.com/langchain-ai/langchain), [Chroma DB](https://github.com/chroma-core/chroma), [LocalGPT](https://github.com/PromtEngineer/localGPT), [Llama-cpp](https://github.com/abetlen/llama-cpp-python) for their invaluable contributions to the open-source community. Their work has been instrumental in making the IncarnaMind project a reality.
198 |
199 | ## 🖋 Citation
200 |
201 | If you want to cite our work, please use the following bibtex entry:
202 |
203 | ```bibtex
204 | @misc{IncarnaMind2023,
205 | author = {Junru Xiong},
206 | title = {IncarnaMind},
207 | year = {2023},
208 | publisher = {GitHub},
209 | journal = {GitHub Repository},
210 | howpublished = {\url{https://github.com/junruxiong/IncarnaMind}}
211 | }
212 | ```
213 |
214 | ## 📑 License
215 |
216 | [Apache 2.0 License](LICENSE)
--------------------------------------------------------------------------------
/configparser.ini:
--------------------------------------------------------------------------------
1 | [tokens]
2 | ; Enter one/all of your API key here.
3 | ; E.g., OPENAI_API_KEY = sk-xxxxxxx
4 | OPENAI_API_KEY = xxxxx
5 | ANTHROPIC_API_KEY = xxxxx
6 | TOGETHER_API_KEY = xxxxx
7 | ; if you use Meta-Llama models, you may need Huggingface token to access.
8 | HUGGINGFACE_TOKEN = xxxxx
9 | VERSION = 1.0.1
10 |
11 |
12 | [directory]
13 | ; Directory for source files.
14 | DOCS_DIR = ./data
15 | ; Directory to store embeddings and Langchain documents.
16 | DB_DIR = ./database_store
17 | LOCAL_MODEL_DIR = ./models
18 |
19 |
20 | ; The below parameters are optional to modify:
21 | ; --------------------------------------------
22 | [parameters]
23 | ; Model name schema: Model Provider|Model Name|Model File. Model File is only valid for GGUF format, set None for other format.
24 |
25 | ; For example:
26 | ; OpenAI|gpt-3.5-turbo|None
27 | ; OpenAI|gpt-4|None
28 | ; Anthropic|claude-2.0|None
29 | ; Together|togethercomputer/llama-2-70b-chat|None
30 | ; HuggingFace|TheBloke/Llama-2-70B-chat-GGUF|llama-2-70b-chat.q4_K_M.gguf
31 | ; HuggingFace|meta-llama/Llama-2-70b-chat-hf|None
32 |
33 | ; The full Together.AI model list can be found in the end of this file; We currently only support quantized gguf and the full huggingface local LLMs.
34 | MODEL_NAME = OpenAI|gpt-4-1106-preview|None
35 | ; LLM temperature
36 | TEMPURATURE = 0
37 | ; Maximum tokens for storing chat history.
38 | MAX_CHAT_HISTORY = 800
39 | ; Maximum tokens for LLM context for retrieved information.
40 | MAX_LLM_CONTEXT = 1200
41 | ; Maximum tokens for LLM generation.
42 | MAX_LLM_GENERATION = 1000
43 | ; Supported embeddings: openAIEmbeddings and hkunlpInstructorLarge.
44 | EMBEDDING_NAME = openAIEmbeddings
45 |
46 | ; This is dependent on your GPU type.
47 | N_GPU_LAYERS = 100
48 | ; this is depend on your GPU and CPU ram when using open source LLMs.
49 | N_BATCH = 512
50 |
51 |
52 | ; The base (small) chunk size for first stage document retrieval.
53 | BASE_CHUNK_SIZE = 100
54 | ; Set to 0 for no overlap.
55 | CHUNK_OVERLAP = 0
56 | ; The final retrieval (medium) chunk size will be BASE_CHUNK_SIZE * CHUNK_SCALE.
57 | CHUNK_SCALE = 3
58 | WINDOW_STEPS = 3
59 | ; The # tokens of window chunk will be BASE_CHUNK_SIZE * WINDOW_SCALE.
60 | WINDOW_SCALE = 18
61 |
62 | ; Ratio of BM25 retriever to Chroma Vectorstore retriever.
63 | RETRIEVER_WEIGHTS = 0.5, 0.5
64 | ; Number of retrieved chunks will range from FIRST_RETRIEVAL_K to 2*FIRST_RETRIEVAL_K due to the ensemble retriever.
65 | FIRST_RETRIEVAL_K = 3
66 | ; Number of retrieved chunks will range from SECOND_RETRIEVAL_K to 2*SECOND_RETRIEVAL_K due to the ensemble retriever.
67 | SECOND_RETRIEVAL_K = 3
68 | ; Number of windows (large chunks) for the third retriever.
69 | NUM_WINDOWS = 2
70 | ; (The third retrieval gets the final chunks passed to the LLM QA chain. The 'k' value is dynamic (based on MAX_LLM_CONTEXT), depending on the number of rephrased questions and retrieved documents.)
71 |
72 |
73 | [logging]
74 | ; If you do not want to enable logging, set enabled to False.
75 | enabled = True
76 | level = INFO
77 | filename = IncarnaMind.log
78 | format = %(asctime)s [%(levelname)s] %(name)s: %(message)s
79 |
80 |
81 | ; Together.AI supported models:
82 |
83 | ; 0 Austism/chronos-hermes-13b
84 | ; 1 EleutherAI/pythia-12b-v0
85 | ; 2 EleutherAI/pythia-1b-v0
86 | ; 3 EleutherAI/pythia-2.8b-v0
87 | ; 4 EleutherAI/pythia-6.9b
88 | ; 5 Gryphe/MythoMax-L2-13b
89 | ; 6 HuggingFaceH4/starchat-alpha
90 | ; 7 NousResearch/Nous-Hermes-13b
91 | ; 8 NousResearch/Nous-Hermes-Llama2-13b
92 | ; 9 NumbersStation/nsql-llama-2-7B
93 | ; 10 OpenAssistant/llama2-70b-oasst-sft-v10
94 | ; 11 OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5
95 | ; 12 OpenAssistant/stablelm-7b-sft-v7-epoch-3
96 | ; 13 Phind/Phind-CodeLlama-34B-Python-v1
97 | ; 14 Phind/Phind-CodeLlama-34B-v2
98 | ; 15 SG161222/Realistic_Vision_V3.0_VAE
99 | ; 16 WizardLM/WizardCoder-15B-V1.0
100 | ; 17 WizardLM/WizardCoder-Python-34B-V1.0
101 | ; 18 WizardLM/WizardLM-70B-V1.0
102 | ; 19 bigcode/starcoder
103 | ; 20 databricks/dolly-v2-12b
104 | ; 21 databricks/dolly-v2-3b
105 | ; 22 databricks/dolly-v2-7b
106 | ; 23 defog/sqlcoder
107 | ; 24 garage-bAInd/Platypus2-70B-instruct
108 | ; 25 huggyllama/llama-13b
109 | ; 26 huggyllama/llama-30b
110 | ; 27 huggyllama/llama-65b
111 | ; 28 huggyllama/llama-7b
112 | ; 29 lmsys/fastchat-t5-3b-v1.0
113 | ; 30 lmsys/vicuna-13b-v1.3
114 | ; 31 lmsys/vicuna-13b-v1.5-16k
115 | ; 32 lmsys/vicuna-13b-v1.5
116 | ; 33 lmsys/vicuna-7b-v1.3
117 | ; 34 prompthero/openjourney
118 | ; 35 runwayml/stable-diffusion-v1-5
119 | ; 36 stabilityai/stable-diffusion-2-1
120 | ; 37 stabilityai/stable-diffusion-xl-base-1.0
121 | ; 38 togethercomputer/CodeLlama-13b-Instruct
122 | ; 39 togethercomputer/CodeLlama-13b-Python
123 | ; 40 togethercomputer/CodeLlama-13b
124 | ; 41 togethercomputer/CodeLlama-34b-Instruct
125 | ; 42 togethercomputer/CodeLlama-34b-Python
126 | ; 43 togethercomputer/CodeLlama-34b
127 | ; 44 togethercomputer/CodeLlama-7b-Instruct
128 | ; 45 togethercomputer/CodeLlama-7b-Python
129 | ; 46 togethercomputer/CodeLlama-7b
130 | ; 47 togethercomputer/GPT-JT-6B-v1
131 | ; 48 togethercomputer/GPT-JT-Moderation-6B
132 | ; 49 togethercomputer/GPT-NeoXT-Chat-Base-20B
133 | ; 50 togethercomputer/Koala-13B
134 | ; 51 togethercomputer/LLaMA-2-7B-32K
135 | ; 52 togethercomputer/Llama-2-7B-32K-Instruct
136 | ; 53 togethercomputer/Pythia-Chat-Base-7B-v0.16
137 | ; 54 togethercomputer/Qwen-7B-Chat
138 | ; 55 togethercomputer/Qwen-7B
139 | ; 56 togethercomputer/RedPajama-INCITE-7B-Base
140 | ; 57 togethercomputer/RedPajama-INCITE-7B-Chat
141 | ; 58 togethercomputer/RedPajama-INCITE-7B-Instruct
142 | ; 59 togethercomputer/RedPajama-INCITE-Base-3B-v1
143 | ; 60 togethercomputer/RedPajama-INCITE-Chat-3B-v1
144 | ; 61 togethercomputer/RedPajama-INCITE-Instruct-3B-v1
145 | ; 62 togethercomputer/alpaca-7b
146 | ; 63 togethercomputer/codegen2-16B
147 | ; 64 togethercomputer/codegen2-7B
148 | ; 65 togethercomputer/falcon-40b-instruct
149 | ; 66 togethercomputer/falcon-40b
150 | ; 67 togethercomputer/falcon-7b-instruct
151 | ; 68 togethercomputer/falcon-7b
152 | ; 69 togethercomputer/guanaco-13b
153 | ; 70 togethercomputer/guanaco-33b
154 | ; 71 togethercomputer/guanaco-65b
155 | ; 72 togethercomputer/guanaco-7b
156 | ; 73 togethercomputer/llama-2-13b-chat
157 | ; 74 togethercomputer/llama-2-13b
158 | ; 75 togethercomputer/llama-2-70b-chat
159 | ; 76 togethercomputer/llama-2-70b
160 | ; 77 togethercomputer/llama-2-7b-chat
161 | ; 78 togethercomputer/llama-2-7b
162 | ; 79 togethercomputer/mpt-30b-chat
163 | ; 80 togethercomputer/mpt-30b-instruct
164 | ; 81 togethercomputer/mpt-30b
165 | ; 82 togethercomputer/mpt-7b-chat
166 | ; 83 togethercomputer/mpt-7b
167 | ; 84 togethercomputer/replit-code-v1-3b
168 | ; 85 upstage/SOLAR-0-70b-16bit
169 | ; 86 wavymulder/Analog-Diffusion
--------------------------------------------------------------------------------
/convo_qa_chain.py:
--------------------------------------------------------------------------------
1 | """Conversational QA Chain"""
2 | from __future__ import annotations
3 | import inspect
4 | import logging
5 | from typing import Any, Dict, List, Optional
6 | from pydantic import Field
7 |
8 | from langchain.schema import BasePromptTemplate, BaseRetriever, Document
9 | from langchain.schema.language_model import BaseLanguageModel
10 | from langchain.chains import LLMChain
11 | from langchain.chains.question_answering import load_qa_chain
12 | from langchain.chains.conversational_retrieval.base import (
13 | BaseConversationalRetrievalChain,
14 | )
15 | from langchain.callbacks.manager import (
16 | AsyncCallbackManagerForChainRun,
17 | CallbackManagerForChainRun,
18 | Callbacks,
19 | )
20 |
21 | from toolkit.utils import (
22 | Config,
23 | _get_chat_history,
24 | _get_standalone_questions_list,
25 | )
26 | from toolkit.retrivers import MyRetriever
27 | from toolkit.prompts import PromptTemplates
28 |
29 | configs = Config("configparser.ini")
30 | logger = logging.getLogger(__name__)
31 |
32 | prompt_templates = PromptTemplates()
33 |
34 |
35 | class ConvoRetrievalChain(BaseConversationalRetrievalChain):
36 | """Chain for having a conversation based on retrieved documents.
37 |
38 | This chain takes in chat history (a list of messages) and new questions,
39 | and then returns an answer to that question.
40 | The algorithm for this chain consists of three parts:
41 |
42 | 1. Use the chat history and the new question to create a "standalone question".
43 | This is done so that this question can be passed into the retrieval step to fetch
44 | relevant documents. If only the new question was passed in, then relevant context
45 | may be lacking. If the whole conversation was passed into retrieval, there may
46 | be unnecessary information there that would distract from retrieval.
47 |
48 | 2. This new question is passed to the retriever and relevant documents are
49 | returned.
50 |
51 | 3. The retrieved documents are passed to an LLM along with either the new question
52 | (default behavior) or the original question and chat history to generate a final
53 | response.
54 |
55 | Example:
56 | .. code-block:: python
57 |
58 | from langchain.chains import (
59 | StuffDocumentsChain, LLMChain, ConversationalRetrievalChain
60 | )
61 | from langchain.prompts import PromptTemplate
62 | from langchain.llms import OpenAI
63 |
64 | combine_docs_chain = StuffDocumentsChain(...)
65 | vectorstore = ...
66 | retriever = vectorstore.as_retriever()
67 |
68 | # This controls how the standalone question is generated.
69 | # Should take `chat_history` and `question` as input variables.
70 | template = (
71 | "Combine the chat history and follow up question into "
72 | "a standalone question. Chat History: {chat_history}"
73 | "Follow up question: {question}"
74 | )
75 | prompt = PromptTemplate.from_template(template)
76 | llm = OpenAI()
77 | question_generator_chain = LLMChain(llm=llm, prompt=prompt)
78 | chain = ConversationalRetrievalChain(
79 | combine_docs_chain=combine_docs_chain,
80 | retriever=retriever,
81 | question_generator=question_generator_chain,
82 | )
83 | """
84 |
85 | retriever: MyRetriever = Field(exclude=True)
86 | """Retriever to use to fetch documents."""
87 | file_names: List = Field(exclude=True)
88 | """file_names (List): List of file names used for retrieval."""
89 |
90 | def _get_docs(
91 | self,
92 | question: str,
93 | inputs: Dict[str, Any],
94 | num_query: int,
95 | *,
96 | run_manager: Optional[CallbackManagerForChainRun] = None,
97 | ) -> List[Document]:
98 | """Get docs."""
99 | try:
100 | docs = self.retriever.get_relevant_documents(
101 | question, num_query=num_query, run_manager=run_manager
102 | )
103 | return docs
104 | except (IOError, FileNotFoundError) as error:
105 | logger.error("An error occurred in _get_docs: %s", error)
106 | return []
107 |
108 | def _retrieve(
109 | self,
110 | question_list: List[str],
111 | inputs: Dict[str, Any],
112 | run_manager: Optional[CallbackManagerForChainRun] = None,
113 | ) -> List[str]:
114 | num_query = len(question_list)
115 | accepts_run_manager = (
116 | "run_manager" in inspect.signature(self._get_docs).parameters
117 | )
118 |
119 | total_results = {}
120 | for question in question_list:
121 | docs_dict = (
122 | self._get_docs(
123 | question, inputs, num_query=num_query, run_manager=run_manager
124 | )
125 | if accepts_run_manager
126 | else self._get_docs(question, inputs, num_query=num_query)
127 | )
128 |
129 | for file_name, docs in docs_dict.items():
130 | if file_name not in total_results:
131 | total_results[file_name] = docs
132 | else:
133 | total_results[file_name].extend(docs)
134 |
135 | logger.info(
136 | "-----step_done--------------------------------------------------",
137 | )
138 |
139 | snippets = ""
140 | redundancy = set()
141 | for file_name, docs in total_results.items():
142 | sorted_docs = sorted(docs, key=lambda x: x.metadata["medium_chunk_idx"])
143 | temp = "\n".join(
144 | doc.page_content
145 | for doc in sorted_docs
146 | if doc.metadata["page_content_md5"] not in redundancy
147 | )
148 | redundancy.update(doc.metadata["page_content_md5"] for doc in sorted_docs)
149 | snippets += f"\nContext about {file_name}:\n{{{temp}}}\n"
150 |
151 | return snippets, docs_dict
152 |
153 | def _call(
154 | self,
155 | inputs: Dict[str, Any],
156 | run_manager: Optional[CallbackManagerForChainRun] = None,
157 | ) -> Dict[str, Any]:
158 | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
159 | question = inputs["question"]
160 | get_chat_history = self.get_chat_history or _get_chat_history
161 | chat_history_str = get_chat_history(inputs["chat_history"])
162 |
163 | callbacks = _run_manager.get_child()
164 | new_questions = self.question_generator.run(
165 | question=question,
166 | chat_history=chat_history_str,
167 | database=self.file_names,
168 | callbacks=callbacks,
169 | )
170 | logger.info("new_questions: %s", new_questions)
171 | new_question_list = _get_standalone_questions_list(new_questions, question)[:3]
172 | # print("new_question_list:", new_question_list)
173 | logger.info("user_input: %s", question)
174 | logger.info("new_question_list: %s", new_question_list)
175 |
176 | snippets, source_docs = self._retrieve(
177 | new_question_list, inputs, run_manager=_run_manager
178 | )
179 |
180 | docs = [
181 | Document(
182 | page_content=snippets,
183 | metadata={},
184 | )
185 | ]
186 |
187 | new_inputs = inputs.copy()
188 | new_inputs["chat_history"] = chat_history_str
189 | answer = self.combine_docs_chain.run(
190 | input_documents=docs,
191 | database=self.file_names,
192 | callbacks=_run_manager.get_child(),
193 | **new_inputs,
194 | )
195 | output: Dict[str, Any] = {self.output_key: answer}
196 | if self.return_source_documents:
197 | output["source_documents"] = source_docs
198 | if self.return_generated_question:
199 | output["generated_question"] = new_questions
200 |
201 | logger.info("*****response*****: %s", output["answer"])
202 | logger.info(
203 | "=====epoch_done============================================================",
204 | )
205 | return output
206 |
207 | async def _aget_docs(
208 | self,
209 | question: str,
210 | inputs: Dict[str, Any],
211 | num_query: int,
212 | *,
213 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
214 | ) -> List[Document]:
215 | """Get docs."""
216 | try:
217 | docs = await self.retriever.aget_relevant_documents(
218 | question, num_query=num_query, run_manager=run_manager
219 | )
220 | return docs
221 | except (IOError, FileNotFoundError) as error:
222 | logger.error("An error occurred in _get_docs: %s", error)
223 | return []
224 |
225 | async def _aretrieve(
226 | self,
227 | question_list: List[str],
228 | inputs: Dict[str, Any],
229 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
230 | ) -> Dict[str, Any]:
231 | num_query = len(question_list)
232 | accepts_run_manager = (
233 | "run_manager" in inspect.signature(self._get_docs).parameters
234 | )
235 |
236 | total_results = {}
237 | for question in question_list:
238 | docs_dict = (
239 | await self._aget_docs(
240 | question, inputs, num_query=num_query, run_manager=run_manager
241 | )
242 | if accepts_run_manager
243 | else await self._aget_docs(question, inputs, num_query=num_query)
244 | )
245 |
246 | for file_name, docs in docs_dict.items():
247 | if file_name not in total_results:
248 | total_results[file_name] = docs
249 | else:
250 | total_results[file_name].extend(docs)
251 |
252 | logger.info(
253 | "-----step_done--------------------------------------------------",
254 | )
255 |
256 | snippets = ""
257 | redundancy = set()
258 | for file_name, docs in total_results.items():
259 | sorted_docs = sorted(docs, key=lambda x: x.metadata["medium_chunk_idx"])
260 | temp = "\n".join(
261 | doc.page_content
262 | for doc in sorted_docs
263 | if doc.metadata["page_content_md5"] not in redundancy
264 | )
265 | redundancy.update(doc.metadata["page_content_md5"] for doc in sorted_docs)
266 | snippets += f"\nContext about {file_name}:\n{{{temp}}}\n"
267 |
268 | return snippets, docs_dict
269 |
270 | async def _acall(
271 | self,
272 | inputs: Dict[str, Any],
273 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
274 | ) -> Dict[str, Any]:
275 | _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
276 | question = inputs["question"]
277 | get_chat_history = self.get_chat_history or _get_chat_history
278 | chat_history_str = get_chat_history(inputs["chat_history"])
279 |
280 | callbacks = _run_manager.get_child()
281 | new_questions = await self.question_generator.arun(
282 | question=question,
283 | chat_history=chat_history_str,
284 | database=self.file_names,
285 | callbacks=callbacks,
286 | )
287 | new_question_list = _get_standalone_questions_list(new_questions, question)[:3]
288 | logger.info("new_questions: %s", new_questions)
289 | logger.info("new_question_list: %s", new_question_list)
290 |
291 | snippets, source_docs = await self._aretrieve(
292 | new_question_list, inputs, run_manager=_run_manager
293 | )
294 |
295 | docs = [
296 | Document(
297 | page_content=snippets,
298 | metadata={},
299 | )
300 | ]
301 |
302 | new_inputs = inputs.copy()
303 | new_inputs["chat_history"] = chat_history_str
304 | answer = await self.combine_docs_chain.arun(
305 | input_documents=docs,
306 | database=self.file_names,
307 | callbacks=_run_manager.get_child(),
308 | **new_inputs,
309 | )
310 | output: Dict[str, Any] = {self.output_key: answer}
311 | if self.return_source_documents:
312 | output["source_documents"] = source_docs
313 | if self.return_generated_question:
314 | output["generated_question"] = new_questions
315 |
316 | logger.info("*****response*****: %s", output["answer"])
317 | logger.info(
318 | "=====epoch_done============================================================",
319 | )
320 |
321 | return output
322 |
323 | @classmethod
324 | def from_llm(
325 | cls,
326 | llm: BaseLanguageModel,
327 | retriever: BaseRetriever,
328 | condense_question_prompt: BasePromptTemplate = prompt_templates.get_refine_qa_template(
329 | configs.model_name
330 | ),
331 | chain_type: str = "stuff", # only support stuff chain now
332 | verbose: bool = False,
333 | condense_question_llm: Optional[BaseLanguageModel] = None,
334 | combine_docs_chain_kwargs: Optional[Dict] = None,
335 | callbacks: Callbacks = None,
336 | **kwargs: Any,
337 | ) -> BaseConversationalRetrievalChain:
338 | """Convenience method to load chain from LLM and retriever.
339 |
340 | This provides some logic to create the `question_generator` chain
341 | as well as the combine_docs_chain.
342 |
343 | Args:
344 | llm: The default language model to use at every part of this chain
345 | (eg in both the question generation and the answering)
346 | retriever: The retriever to use to fetch relevant documents from.
347 | condense_question_prompt: The prompt to use to condense the chat history
348 | and new question into standalone question(s).
349 | chain_type: The chain type to use to create the combine_docs_chain, will
350 | be sent to `load_qa_chain`.
351 | verbose: Verbosity flag for logging to stdout.
352 | condense_question_llm: The language model to use for condensing the chat
353 | history and new question into standalone question(s). If none is
354 | provided, will default to `llm`.
355 | combine_docs_chain_kwargs: Parameters to pass as kwargs to `load_qa_chain`
356 | when constructing the combine_docs_chain.
357 | callbacks: Callbacks to pass to all subchains.
358 | **kwargs: Additional parameters to pass when initializing
359 | ConversationalRetrievalChain
360 | """
361 | combine_docs_chain_kwargs = combine_docs_chain_kwargs or {
362 | "prompt": prompt_templates.get_retrieval_qa_template_selector(
363 | configs.model_name
364 | ).get_prompt(llm)
365 | }
366 | doc_chain = load_qa_chain(
367 | llm,
368 | chain_type=chain_type,
369 | verbose=verbose,
370 | callbacks=callbacks,
371 | **combine_docs_chain_kwargs,
372 | )
373 |
374 | _llm = condense_question_llm or llm
375 | condense_question_chain = LLMChain(
376 | llm=_llm,
377 | prompt=condense_question_prompt,
378 | verbose=verbose,
379 | callbacks=callbacks,
380 | )
381 | return cls(
382 | retriever=retriever,
383 | combine_docs_chain=doc_chain,
384 | question_generator=condense_question_chain,
385 | callbacks=callbacks,
386 | **kwargs,
387 | )
388 |
--------------------------------------------------------------------------------
/data/ABPI Code of Practice for the Pharmaceutical Industry 2021.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/ABPI Code of Practice for the Pharmaceutical Industry 2021.pdf
--------------------------------------------------------------------------------
/data/Attention Is All You Need.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/Attention Is All You Need.pdf
--------------------------------------------------------------------------------
/data/Gradient Descent The Ultimate Optimizer.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/Gradient Descent The Ultimate Optimizer.pdf
--------------------------------------------------------------------------------
/data/JP Morgan 2022 Environmental Social Governance Report.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/JP Morgan 2022 Environmental Social Governance Report.pdf
--------------------------------------------------------------------------------
/data/Language Models are Few-Shot Learners.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/Language Models are Few-Shot Learners.pdf
--------------------------------------------------------------------------------
/data/Language Models are Unsupervised Multitask Learners.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/Language Models are Unsupervised Multitask Learners.pdf
--------------------------------------------------------------------------------
/data/United Nations 2022 Annual Report.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/data/United Nations 2022 Annual Report.pdf
--------------------------------------------------------------------------------
/docs2db.py:
--------------------------------------------------------------------------------
1 | """
2 | This module save documents to embeddings and langchain Documents.
3 | """
4 | import os
5 | import glob
6 | import pickle
7 | from typing import List
8 | from multiprocessing import Pool
9 | from collections import deque
10 | import hashlib
11 | import tiktoken
12 |
13 | from tqdm import tqdm
14 |
15 | from langchain.schema import Document
16 | from langchain.vectorstores import Chroma
17 | from langchain.text_splitter import (
18 | RecursiveCharacterTextSplitter,
19 | )
20 | from langchain.document_loaders import (
21 | PyPDFLoader,
22 | TextLoader,
23 | )
24 |
25 | from toolkit.utils import Config, choose_embeddings, clean_text
26 |
27 |
28 | # Load the config file
29 | configs = Config("configparser.ini")
30 |
31 | os.environ["OPENAI_API_KEY"] = configs.openai_api_key
32 | os.environ["ANTHROPIC_API_KEY"] = configs.anthropic_api_key
33 |
34 | embedding_store_path = configs.db_dir
35 | files_path = glob.glob(configs.docs_dir + "/*")
36 |
37 | tokenizer_name = tiktoken.encoding_for_model("gpt-3.5-turbo")
38 | tokenizer = tiktoken.get_encoding(tokenizer_name.name)
39 |
40 | loaders = {
41 | "pdf": (PyPDFLoader, {}),
42 | "txt": (TextLoader, {}),
43 | }
44 |
45 |
46 | def tiktoken_len(text: str):
47 | """Calculate the token length of a given text string using TikToken.
48 |
49 | Args:
50 | text (str): The text to be tokenized.
51 |
52 | Returns:
53 | int: The length of the tokenized text.
54 | """
55 | tokens = tokenizer.encode(text, disallowed_special=())
56 |
57 | return len(tokens)
58 |
59 |
60 | def string2md5(text: str):
61 | """Convert a string to its MD5 hash.
62 |
63 | Args:
64 | text (str): The text to be hashed.
65 |
66 | Returns:
67 | str: The MD5 hash of the input string.
68 | """
69 | hash_md5 = hashlib.md5()
70 | hash_md5.update(text.encode("utf-8"))
71 |
72 | return hash_md5.hexdigest()
73 |
74 |
75 | def load_file(file_path):
76 | """Load a file and return its content as a Document object.
77 |
78 | Args:
79 | file_path (str): The path to the file.
80 |
81 | Returns:
82 | Document: The loaded document.
83 | """
84 | ext = file_path.split(".")[-1]
85 |
86 | if ext in loaders:
87 | loader_type, args = loaders[ext]
88 | loader = loader_type(file_path, **args)
89 | doc = loader.load()
90 |
91 | return doc
92 |
93 | raise ValueError(f"Extension {ext} not supported")
94 |
95 |
96 | def docs2vectorstore(docs: List[Document], embedding_name: str, suffix: str = ""):
97 | """Convert a list of Documents into a Chroma vector store.
98 |
99 | Args:
100 | docs (Document): The list of Documents.
101 | suffix (str, optional): Suffix for the embedding. Defaults to "".
102 | """
103 | embedding = choose_embeddings(embedding_name)
104 | name = f"{embedding_name}_{suffix}"
105 | # if embedding_store_path is not existing, create it
106 | if not os.path.exists(embedding_store_path):
107 | os.makedirs(embedding_store_path)
108 | Chroma.from_documents(
109 | docs,
110 | embedding,
111 | persist_directory=f"{embedding_store_path}/chroma_{name}",
112 | )
113 |
114 |
115 | def file_names2pickle(file_names: list, save_name: str = ""):
116 | """Save the list of file names to a pickle file.
117 |
118 | Args:
119 | file_names (list): The list of file names.
120 | save_name (str, optional): The name for the saved pickle file. Defaults to "".
121 | """
122 | name = f"{save_name}"
123 | if not os.path.exists(embedding_store_path):
124 | os.makedirs(embedding_store_path)
125 | with open(f"{embedding_store_path}/{name}.pkl", "wb") as file:
126 | pickle.dump(file_names, file)
127 |
128 |
129 | def docs2pickle(docs: List[Document], suffix: str = ""):
130 | """Serializes a list of Document objects to a pickle file.
131 |
132 | Args:
133 | docs (Document): List of Document objects.
134 | suffix (str, optional): Suffix for the pickle file. Defaults to "".
135 | """
136 | for doc in docs:
137 | doc.page_content = clean_text(doc.page_content)
138 | name = f"pickle_{suffix}"
139 | if not os.path.exists(embedding_store_path):
140 | os.makedirs(embedding_store_path)
141 | with open(f"{embedding_store_path}/docs_{name}.pkl", "wb") as file:
142 | pickle.dump(docs, file)
143 |
144 |
145 | def split_doc(
146 | doc: List[Document], chunk_size: int, chunk_overlap: int, chunk_idx_name: str
147 | ):
148 | """Splits a document into smaller chunks based on the provided size and overlap.
149 |
150 | Args:
151 | doc (Document): Document to be split.
152 | chunk_size (int): Size of each chunk.
153 | chunk_overlap (int): Overlap between adjacent chunks.
154 | chunk_idx_name (str): Metadata key for storing chunk indices.
155 |
156 | Returns:
157 | list: List of Document objects representing the chunks.
158 | """
159 | data_splitter = RecursiveCharacterTextSplitter(
160 | chunk_size=chunk_size,
161 | chunk_overlap=chunk_overlap,
162 | length_function=tiktoken_len,
163 | )
164 | doc_split = data_splitter.split_documents(doc)
165 | chunk_idx = 0
166 |
167 | for d_split in doc_split:
168 | d_split.metadata[chunk_idx_name] = chunk_idx
169 | chunk_idx += 1
170 |
171 | return doc_split
172 |
173 |
174 | def process_metadata(doc: List[Document]):
175 | """Processes and updates the metadata for a list of Document objects.
176 |
177 | Args:
178 | doc (list): List of Document objects.
179 | """
180 | # get file name and remove extension
181 | file_name_with_extension = os.path.basename(doc[0].metadata["source"])
182 | file_name, _ = os.path.splitext(file_name_with_extension)
183 |
184 | for _, item in enumerate(doc):
185 | for key, value in item.metadata.items():
186 | if isinstance(value, list):
187 | item.metadata[key] = str(value)
188 | item.metadata["page_content"] = item.page_content
189 | item.metadata["page_content_md5"] = string2md5(item.page_content)
190 | item.metadata["source_md5"] = string2md5(item.metadata["source"])
191 | item.page_content = f"{file_name}\n{item.page_content}"
192 |
193 |
194 | def add_window(
195 | doc: Document, window_steps: int, window_size: int, window_idx_name: str
196 | ):
197 | """Adds windowing information to the metadata of each document in the list.
198 |
199 | Args:
200 | doc (Document): List of Document objects.
201 | window_steps (int): Step size for windowing.
202 | window_size (int): Size of each window.
203 | window_idx_name (str): Metadata key for storing window indices.
204 | """
205 | window_id = 0
206 | window_deque = deque()
207 |
208 | for idx, item in enumerate(doc):
209 | if idx % window_steps == 0 and idx != 0 and idx < len(doc) - window_size:
210 | window_id += 1
211 | window_deque.append(window_id)
212 |
213 | if len(window_deque) > window_size:
214 | for _ in range(window_steps):
215 | window_deque.popleft()
216 |
217 | window = set(window_deque)
218 | item.metadata[f"{window_idx_name}_lower_bound"] = min(window)
219 | item.metadata[f"{window_idx_name}_upper_bound"] = max(window)
220 |
221 |
222 | def merge_metadata(dicts_list: dict):
223 | """Merges a list of metadata dictionaries into a single dictionary.
224 |
225 | Args:
226 | dicts_list (list): List of metadata dictionaries.
227 |
228 | Returns:
229 | dict: Merged metadata dictionary.
230 | """
231 | merged_dict = {}
232 | bounds_dict = {}
233 | keys_to_remove = set()
234 |
235 | for dic in dicts_list:
236 | for key, value in dic.items():
237 | if key in merged_dict:
238 | if value not in merged_dict[key]:
239 | merged_dict[key].append(value)
240 | else:
241 | merged_dict[key] = [value]
242 |
243 | for key, values in merged_dict.items():
244 | if len(values) > 1 and all(isinstance(x, (int, float)) for x in values):
245 | bounds_dict[f"{key}_lower_bound"] = min(values)
246 | bounds_dict[f"{key}_upper_bound"] = max(values)
247 | keys_to_remove.add(key)
248 |
249 | merged_dict.update(bounds_dict)
250 |
251 | for key in keys_to_remove:
252 | del merged_dict[key]
253 |
254 | return {
255 | k: v[0] if isinstance(v, list) and len(v) == 1 else v
256 | for k, v in merged_dict.items()
257 | }
258 |
259 |
260 | def merge_chunks(doc: Document, scale_factor: int, chunk_idx_name: str):
261 | """Merges adjacent chunks into larger chunks based on a scaling factor.
262 |
263 | Args:
264 | doc (Document): List of Document objects.
265 | scale_factor (int): The number of small chunks to merge into a larger chunk.
266 | chunk_idx_name (str): Metadata key for storing chunk indices.
267 |
268 | Returns:
269 | list: List of Document objects representing the merged chunks.
270 | """
271 | merged_doc = []
272 | page_content = ""
273 | metadata_list = []
274 | chunk_idx = 0
275 |
276 | for idx, item in enumerate(doc):
277 | page_content += item.page_content
278 | metadata_list.append(item.metadata)
279 |
280 | if (idx + 1) % scale_factor == 0 or idx == len(doc) - 1:
281 | metadata = merge_metadata(metadata_list)
282 | metadata[chunk_idx_name] = chunk_idx
283 | merged_doc.append(
284 | Document(
285 | page_content=page_content,
286 | metadata=metadata,
287 | )
288 | )
289 | chunk_idx += 1
290 | page_content = ""
291 | metadata_list = []
292 |
293 | return merged_doc
294 |
295 |
296 | def process_files():
297 | """Main function for processing files. Loads, tokenizes, and saves document data."""
298 | with Pool() as pool:
299 | chunks_small = []
300 | chunks_medium = []
301 | file_names = []
302 |
303 | with tqdm(total=len(files_path), desc="Processing files", ncols=80) as pbar:
304 | for doc in pool.imap_unordered(load_file, files_path):
305 | file_name_with_extension = os.path.basename(doc[0].metadata["source"])
306 | # file_name, _ = os.path.splitext(file_name_with_extension)
307 |
308 | chunk_split_small = split_doc(
309 | doc=doc,
310 | chunk_size=configs.base_chunk_size,
311 | chunk_overlap=configs.chunk_overlap,
312 | chunk_idx_name="small_chunk_idx",
313 | )
314 | add_window(
315 | doc=chunk_split_small,
316 | window_steps=configs.window_steps,
317 | window_size=configs.window_scale,
318 | window_idx_name="large_chunks_idx",
319 | )
320 |
321 | chunk_split_medium = merge_chunks(
322 | doc=chunk_split_small,
323 | scale_factor=configs.chunk_scale,
324 | chunk_idx_name="medium_chunk_idx",
325 | )
326 |
327 | process_metadata(chunk_split_small)
328 | process_metadata(chunk_split_medium)
329 |
330 | file_names.append(file_name_with_extension)
331 | chunks_small.extend(chunk_split_small)
332 | chunks_medium.extend(chunk_split_medium)
333 |
334 | pbar.update()
335 |
336 | file_names2pickle(file_names, save_name="file_names")
337 |
338 | docs2vectorstore(chunks_small, configs.embedding_name, suffix="chunks_small")
339 | docs2vectorstore(chunks_medium, configs.embedding_name, suffix="chunks_medium")
340 |
341 | docs2pickle(chunks_small, suffix="chunks_small")
342 | docs2pickle(chunks_medium, suffix="chunks_medium")
343 |
344 |
345 | if __name__ == "__main__":
346 | process_files()
347 |
--------------------------------------------------------------------------------
/figs/High_Level_Architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/figs/High_Level_Architecture.png
--------------------------------------------------------------------------------
/figs/Sliding_Window_Chunking.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/figs/Sliding_Window_Chunking.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """Conversational QA Chain"""
2 | from __future__ import annotations
3 | import os
4 | import re
5 | import time
6 | import logging
7 |
8 | from langchain.chat_models import ChatOpenAI, ChatAnthropic
9 | from langchain.memory import ConversationTokenBufferMemory
10 | from convo_qa_chain import ConvoRetrievalChain
11 |
12 | from toolkit.together_api_llm import TogetherLLM
13 | from toolkit.retrivers import MyRetriever
14 | from toolkit.local_llm import load_local_llm
15 | from toolkit.utils import (
16 | Config,
17 | choose_embeddings,
18 | load_embedding,
19 | load_pickle,
20 | check_device,
21 | )
22 |
23 |
24 | # Load the config file
25 | configs = Config("configparser.ini")
26 | logger = logging.getLogger(__name__)
27 |
28 | os.environ["OPENAI_API_KEY"] = configs.openai_api_key
29 | os.environ["ANTHROPIC_API_KEY"] = configs.anthropic_api_key
30 |
31 | embedding = choose_embeddings(configs.embedding_name)
32 | db_store_path = configs.db_dir
33 |
34 |
35 | # get models
36 | def get_llm(llm_name: str, temperature: float, max_tokens: int):
37 | """Get the LLM model from the model name."""
38 |
39 | if not os.path.exists(configs.local_model_dir):
40 | os.makedirs(configs.local_model_dir)
41 |
42 | splits = llm_name.split("|") # [provider, model_name, model_file]
43 |
44 | if "openai" in splits[0].lower():
45 | llm_model = ChatOpenAI(
46 | model=splits[1],
47 | temperature=temperature,
48 | max_tokens=max_tokens,
49 | )
50 |
51 | elif "anthropic" in splits[0].lower():
52 | llm_model = ChatAnthropic(
53 | model=splits[1],
54 | temperature=temperature,
55 | max_tokens_to_sample=max_tokens,
56 | )
57 |
58 | elif "together" in splits[0].lower():
59 | llm_model = TogetherLLM(
60 | model=splits[1],
61 | temperature=temperature,
62 | max_tokens=max_tokens,
63 | )
64 | elif "huggingface" in splits[0].lower():
65 | llm_model = load_local_llm(
66 | model_id=splits[1],
67 | model_basename=splits[-1],
68 | temperature=temperature,
69 | max_tokens=max_tokens,
70 | device_type=check_device(),
71 | )
72 | else:
73 | raise ValueError("Invalid Model Name")
74 |
75 | return llm_model
76 |
77 |
78 | llm = get_llm(configs.model_name, configs.temperature, configs.max_llm_generation)
79 |
80 |
81 | # load retrieval database
82 | db_embedding_chunks_small = load_embedding(
83 | store_name=configs.embedding_name,
84 | embedding=embedding,
85 | suffix="chunks_small",
86 | path=db_store_path,
87 | )
88 | db_embedding_chunks_medium = load_embedding(
89 | store_name=configs.embedding_name,
90 | embedding=embedding,
91 | suffix="chunks_medium",
92 | path=db_store_path,
93 | )
94 |
95 | db_docs_chunks_small = load_pickle(
96 | prefix="docs_pickle", suffix="chunks_small", path=db_store_path
97 | )
98 | db_docs_chunks_medium = load_pickle(
99 | prefix="docs_pickle", suffix="chunks_medium", path=db_store_path
100 | )
101 | file_names = load_pickle(prefix="file", suffix="names", path=db_store_path)
102 |
103 |
104 | # Initialize the retriever
105 | my_retriever = MyRetriever(
106 | llm=llm,
107 | embedding_chunks_small=db_embedding_chunks_small,
108 | embedding_chunks_medium=db_embedding_chunks_medium,
109 | docs_chunks_small=db_docs_chunks_small,
110 | docs_chunks_medium=db_docs_chunks_medium,
111 | first_retrieval_k=configs.first_retrieval_k,
112 | second_retrieval_k=configs.second_retrieval_k,
113 | num_windows=configs.num_windows,
114 | retriever_weights=configs.retriever_weights,
115 | )
116 |
117 |
118 | # Initialize the memory
119 | memory = ConversationTokenBufferMemory(
120 | llm=llm,
121 | memory_key="chat_history",
122 | input_key="question",
123 | output_key="answer",
124 | return_messages=True,
125 | max_token_limit=configs.max_chat_history,
126 | )
127 |
128 |
129 | # Initialize the QA chain
130 | qa = ConvoRetrievalChain.from_llm(
131 | llm,
132 | my_retriever,
133 | file_names=file_names,
134 | memory=memory,
135 | return_source_documents=False,
136 | return_generated_question=False,
137 | )
138 |
139 |
140 | if __name__ == "__main__":
141 | while True:
142 | user_input = input("Human: ")
143 | start_time = time.time()
144 | user_input_ = re.sub(r"^Human: ", "", user_input)
145 | print("*" * 6)
146 | resp = qa({"question": user_input_})
147 | print()
148 | print(f"AI:{resp['answer']}")
149 | print(f"Time used: {time.time() - start_time}")
150 | print("-" * 60)
151 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | chromadb==0.4.13
2 | InstructorEmbedding==1.0.1
3 | langchain==0.0.308
4 | openai==0.28.1
5 | pypdf==3.16.2
6 | rank-bm25==0.2.2
7 | sentence-transformers==2.2.2
8 | tiktoken==0.5.1
9 | torch==2.0.1
10 | torchaudio==2.0.2
11 | torchvision==0.15.2
12 | together==0.2.4
13 | tqdm==4.66.1
--------------------------------------------------------------------------------
/toolkit/___init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junruxiong/IncarnaMind/0866e4ce0cfebdec0d4de722ab843780ccb61826/toolkit/___init__.py
--------------------------------------------------------------------------------
/toolkit/local_llm.py:
--------------------------------------------------------------------------------
1 | """The below code is borrowed from: https://github.com/PromtEngineer/localGPT
2 | The reason to use gguf/ggml models: https://huggingface.co/TheBloke/wizardLM-7B-GGML/discussions/3"""
3 | import logging
4 | import torch
5 | from huggingface_hub import hf_hub_download
6 | from huggingface_hub import login
7 | from langchain.llms import LlamaCpp, HuggingFacePipeline
8 | from transformers import (
9 | AutoModelForCausalLM,
10 | AutoTokenizer,
11 | LlamaForCausalLM,
12 | LlamaTokenizer,
13 | GenerationConfig,
14 | pipeline,
15 | )
16 | from toolkit.utils import Config
17 |
18 |
19 | configs = Config("configparser.ini")
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | def load_gguf_hf_model(
24 | model_id: str,
25 | model_basename: str,
26 | max_tokens: int,
27 | temperature: float,
28 | device_type: str,
29 | ):
30 | """
31 | Load a GGUF/GGML quantized model using LlamaCpp.
32 |
33 | This function attempts to load a GGUF/GGML quantized model using the LlamaCpp library.
34 | If the model is of type GGML, and newer version of LLAMA-CPP is used which does not support GGML,
35 | it logs a message indicating that LLAMA-CPP has dropped support for GGML.
36 |
37 | Parameters:
38 | - model_id (str): The identifier for the model on HuggingFace Hub.
39 | - model_basename (str): The base name of the model file.
40 | - max_tokens (int): The maximum number of tokens to generate in the completion.
41 | - temperature (float): The temperature of LLM.
42 | - device_type (str): The type of device where the model will run, e.g., 'mps', 'cuda', etc.
43 |
44 | Returns:
45 | - LlamaCpp: An instance of the LlamaCpp model if successful, otherwise None.
46 |
47 | Notes:
48 | - The function uses the `hf_hub_download` function to download the model from the HuggingFace Hub.
49 | - The number of GPU layers is set based on the device type.
50 | """
51 |
52 | try:
53 | logger.info("Using Llamacpp for GGUF/GGML quantized models")
54 | model_path = hf_hub_download(
55 | repo_id=model_id,
56 | filename=model_basename,
57 | resume_download=True,
58 | cache_dir=configs.local_model_dir,
59 | )
60 | kwargs = {
61 | "model_path": model_path,
62 | "n_ctx": configs.max_llm_context,
63 | "max_tokens": max_tokens,
64 | "temperature": temperature,
65 | "n_batch": configs.n_batch, # set this based on your GPU & CPU RAM
66 | "verbose": False,
67 | }
68 | if device_type.lower() == "mps":
69 | kwargs["n_gpu_layers"] = 1
70 | if device_type.lower() == "cuda":
71 | kwargs["n_gpu_layers"] = configs.n_gpu_layers # set this based on your GPU
72 |
73 | return LlamaCpp(**kwargs)
74 | except:
75 | if "ggml" in model_basename:
76 | logger.info(
77 | "If you were using GGML model, LLAMA-CPP Dropped Support, Use GGUF Instead"
78 | )
79 | return None
80 |
81 |
82 | def load_full_hf_model(model_id: str, model_basename: str, device_type: str):
83 | """
84 | Load a full model using either LlamaTokenizer or AutoModelForCausalLM.
85 |
86 | This function loads a full model based on the specified device type.
87 | If the device type is 'mps' or 'cpu', it uses LlamaTokenizer and LlamaForCausalLM.
88 | Otherwise, it uses AutoModelForCausalLM.
89 |
90 | Parameters:
91 | - model_id (str): The identifier for the model on HuggingFace Hub.
92 | - model_basename (str): The base name of the model file.
93 | - device_type (str): The type of device where the model will run.
94 |
95 | Returns:
96 | - model (Union[LlamaForCausalLM, AutoModelForCausalLM]): The loaded model.
97 | - tokenizer (Union[LlamaTokenizer, AutoTokenizer]): The tokenizer associated with the model.
98 |
99 | Notes:
100 | - The function uses the `from_pretrained` method to load both the model and the tokenizer.
101 | - Additional settings are provided for NVIDIA GPUs, such as loading in 4-bit and setting the compute dtype.
102 | """
103 | if "meta-llama" in model_id.lower():
104 | login(token=configs.huggingface_token)
105 |
106 | if device_type.lower() in ["mps", "cpu"]:
107 | logger.info("Using LlamaTokenizer")
108 | tokenizer = LlamaTokenizer.from_pretrained(
109 | model_id,
110 | cache_dir=configs.local_model_dir,
111 | )
112 | model = LlamaForCausalLM.from_pretrained(
113 | model_id,
114 | cache_dir=configs.local_model_dir,
115 | )
116 | else:
117 | logger.info("Using AutoModelForCausalLM for full models")
118 | tokenizer = AutoTokenizer.from_pretrained(
119 | model_id, cache_dir=configs.local_model_dir
120 | )
121 | logger.info("Tokenizer loaded")
122 | model = AutoModelForCausalLM.from_pretrained(
123 | model_id,
124 | device_map="auto",
125 | torch_dtype=torch.float16,
126 | low_cpu_mem_usage=True,
127 | cache_dir=configs.local_model_dir,
128 | # trust_remote_code=True, # set these if you are using NVIDIA GPU
129 | # load_in_4bit=True,
130 | # bnb_4bit_quant_type="nf4",
131 | # bnb_4bit_compute_dtype=torch.float16,
132 | # max_memory={0: "15GB"} # Uncomment this line with you encounter CUDA out of memory errors
133 | )
134 | model.tie_weights()
135 | return model, tokenizer
136 |
137 |
138 | def load_local_llm(
139 | model_id: str,
140 | model_basename: str,
141 | temperature: float,
142 | max_tokens: int,
143 | device_type: str,
144 | ):
145 | """
146 | Select a model for text generation using the HuggingFace library.
147 | If you are running this for the first time, it will download a model for you.
148 | subsequent runs will use the model from the disk.
149 |
150 | Args:
151 | device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU.
152 | model_id (str): Identifier of the model to load from HuggingFace's model hub.
153 | model_basename (str, optional): Basename of the model if using quantized models.
154 | Defaults to None.
155 |
156 | Returns:
157 | HuggingFacePipeline: A pipeline object for text generation using the loaded model.
158 |
159 | Raises:
160 | ValueError: If an unsupported model or device type is provided.
161 | """
162 | logger.info(f"Loading Model: {model_id}, on: {device_type}")
163 | logger.info("This action can take a few minutes!")
164 |
165 | if model_basename.lower() != "none":
166 | if ".gguf" in model_basename.lower():
167 | llm = load_gguf_hf_model(
168 | model_id, model_basename, max_tokens, temperature, device_type
169 | )
170 | return llm
171 |
172 | model, tokenizer = load_full_hf_model(model_id, None, device_type)
173 | # Load configuration from the model to avoid warnings
174 | generation_config = GenerationConfig.from_pretrained(model_id)
175 | # see here for details:
176 | # https://huggingface.co/docs/transformers/
177 | # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
178 |
179 | # Create a pipeline for text generation
180 | pipe = pipeline(
181 | "text-generation",
182 | model=model,
183 | tokenizer=tokenizer,
184 | max_length=max_tokens,
185 | temperature=temperature,
186 | # top_p=0.95,
187 | repetition_penalty=1.15,
188 | generation_config=generation_config,
189 | )
190 | local_llm = HuggingFacePipeline(pipeline=pipe)
191 | logger.info("Local LLM Loaded")
192 |
193 | return local_llm
194 |
--------------------------------------------------------------------------------
/toolkit/prompts.py:
--------------------------------------------------------------------------------
1 | from langchain.prompts import PromptTemplate
2 | from langchain.prompts.chat import (
3 | ChatPromptTemplate,
4 | HumanMessagePromptTemplate,
5 | SystemMessagePromptTemplate,
6 | )
7 | from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
8 |
9 | # ================================================================================
10 |
11 | REFINE_QA_TEMPLATE = """Break down or rephrase the follow up input into fewer than 3 heterogeneous one-hop queries to be the input of a retrieval tool, if the follow up inout is multi-hop, multi-step, complex or comparative queries and relevant to Chat History and Document Names. Otherwise keep the follow up input as it is.
12 |
13 |
14 | The output format should strictly follow the following, and each query can only conatain 1 document name:
15 | ```
16 | 1. One-hop standalone query
17 | ...
18 | 3. One-hop standalone query
19 | ...
20 | ```
21 |
22 |
23 | Document Names in the database:
24 | ```
25 | {database}
26 | ```
27 |
28 |
29 | Chat History:
30 | ```
31 | {chat_history}
32 | ```
33 |
34 |
35 | Begin:
36 |
37 | Follow Up Input: {question}
38 |
39 | One-hop standalone queries(s):
40 | """
41 |
42 |
43 | # ================================================================================
44 |
45 | DOCS_SELECTION_TEMPLATE = """Below are some verified sources and a human input. If you think any of them are relevant to the human input, then list all possible context numbers.
46 |
47 | ```
48 | {snippets}
49 | ```
50 |
51 | The output format must be like the following, nothing else. If not, you will output []:
52 | [0, ..., n]
53 |
54 | Human Input: {query}
55 | """
56 |
57 |
58 | # ================================================================================
59 |
60 | RETRIEVAL_QA_SYS = """You are a helpful assistant designed by IncarnaMind.
61 | If you think the below below information are relevant to the human input, please respond to the human based on the relevant retrieved sources; otherwise, respond in your own words only about the human input."""
62 |
63 |
64 | RETRIEVAL_QA_TEMPLATE = """
65 | File Names in the database:
66 | ```
67 | {database}
68 | ```
69 |
70 |
71 | Chat History:
72 | ```
73 | {chat_history}
74 | ```
75 |
76 |
77 | Verified Sources:
78 | ```
79 | {context}
80 | ```
81 |
82 |
83 | User: {question}
84 | """
85 |
86 |
87 | RETRIEVAL_QA_CHAT_TEMPLATE = """
88 | File Names in the database:
89 | ```
90 | {database}
91 | ```
92 |
93 |
94 | Chat History:
95 | ```
96 | {chat_history}
97 | ```
98 |
99 |
100 | Verified Sources:
101 | ```
102 | {context}
103 | ```
104 | """
105 |
106 |
107 | class PromptTemplates:
108 | """_summary_"""
109 |
110 | def __init__(self):
111 | self.refine_qa_prompt = REFINE_QA_TEMPLATE
112 | self.docs_selection_prompt = DOCS_SELECTION_TEMPLATE
113 | self.retrieval_qa_sys = RETRIEVAL_QA_SYS
114 | self.retrieval_qa_prompt = RETRIEVAL_QA_TEMPLATE
115 | self.retrieval_qa_chat_prompt = RETRIEVAL_QA_CHAT_TEMPLATE
116 |
117 | def get_refine_qa_template(self, llm: str):
118 | """get the refine qa prompt template"""
119 | if "llama" in llm.lower():
120 | temp = f"[INST] {self.refine_qa_prompt} [/INST]"
121 | else:
122 | temp = self.refine_qa_prompt
123 |
124 | return PromptTemplate(
125 | input_variables=["database", "chat_history", "question"],
126 | template=temp,
127 | )
128 |
129 | def get_docs_selection_template(self, llm: str):
130 | """get the docs selection prompt template"""
131 | if "llama" in llm.lower():
132 | temp = f"[INST] {self.docs_selection_prompt} [/INST]"
133 | else:
134 | temp = self.docs_selection_prompt
135 |
136 | return PromptTemplate(
137 | input_variables=["snippets", "query"],
138 | template=temp,
139 | )
140 |
141 | def get_retrieval_qa_template_selector(self, llm: str):
142 | """get the retrieval qa prompt template"""
143 | if "llama" in llm.lower():
144 | temp = f"[INST] <>\n{self.retrieval_qa_sys}\n<>\n\n{self.retrieval_qa_prompt} [/INST]"
145 | messages = [
146 | SystemMessagePromptTemplate.from_template(
147 | f"[INST] <>\n{self.retrieval_qa_sys}\n<>\n\n{self.retrieval_qa_chat_prompt} [/INST]"
148 | ),
149 | HumanMessagePromptTemplate.from_template("{question}"),
150 | ]
151 | else:
152 | temp = f"{self.retrieval_qa_sys}\n{self.retrieval_qa_prompt}"
153 | messages = [
154 | SystemMessagePromptTemplate.from_template(
155 | f"{self.retrieval_qa_sys}\n{self.retrieval_qa_chat_prompt}"
156 | ),
157 | HumanMessagePromptTemplate.from_template("{question}"),
158 | ]
159 |
160 | prompt_temp = PromptTemplate(
161 | template=temp,
162 | input_variables=["database", "chat_history", "context", "question"],
163 | )
164 | prompt_temp_chat = ChatPromptTemplate.from_messages(messages)
165 |
166 | return ConditionalPromptSelector(
167 | default_prompt=prompt_temp,
168 | conditionals=[(is_chat_model, prompt_temp_chat)],
169 | )
170 |
--------------------------------------------------------------------------------
/toolkit/retrivers.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides custom implementation of a document retriever, designed for multi-stage retrieval.
3 | The system uses ensemble methods combining BM25 and Chroma Embeddings to retrieve relevant documents for a given query.
4 | It also utilizes various optimizations like rank fusion and weighted reciprocal rank by Langchain.
5 |
6 | Classes:
7 | --------
8 | - MyEnsembleRetriever: Custom retriever for BM25 and Chroma Embeddings.
9 | - MyRetriever: Handles multi-stage retrieval.
10 |
11 | """
12 | import re
13 | import ast
14 | import copy
15 | import math
16 | import logging
17 | from typing import Dict, List, Optional
18 | from langchain.chains import LLMChain
19 | from langchain.schema import BaseRetriever, Document
20 | from langchain.retrievers import BM25Retriever, EnsembleRetriever
21 | from langchain.callbacks.manager import (
22 | AsyncCallbackManagerForRetrieverRun,
23 | CallbackManagerForRetrieverRun,
24 | AsyncCallbackManagerForChainRun,
25 | CallbackManagerForChainRun,
26 | )
27 |
28 | from toolkit.utils import Config, clean_text, DocIndexer, IndexerOperator
29 | from toolkit.prompts import PromptTemplates
30 |
31 | prompt_templates = PromptTemplates()
32 |
33 | configs = Config("configparser.ini")
34 | logger = logging.getLogger(__name__)
35 |
36 |
37 | class MyEnsembleRetriever(EnsembleRetriever):
38 | """
39 | Custom retriever for BM24 and Chroma Embeddings
40 | """
41 |
42 | retrievers: Dict[str, BaseRetriever]
43 |
44 | def rank_fusion(
45 | self, query: str, run_manager: CallbackManagerForRetrieverRun
46 | ) -> List[Document]:
47 | """
48 | Retrieve the results of the retrievers and use rank_fusion_func to get
49 | the final result.
50 |
51 | Args:
52 | query: The query to search for.
53 |
54 | Returns:
55 | A list of reranked documents.
56 | """
57 | # Get the results of all retrievers.
58 | retriever_docs = []
59 | for key, retriever in self.retrievers.items():
60 | if key == "bm25":
61 | res = retriever.get_relevant_documents(
62 | clean_text(query),
63 | callbacks=run_manager.get_child(tag=f"retriever_{key}"),
64 | )
65 | retriever_docs.append(res)
66 | else:
67 | res = retriever.get_relevant_documents(
68 | query, callbacks=run_manager.get_child(tag=f"retriever_{key}")
69 | )
70 | retriever_docs.append(res)
71 |
72 | # apply rank fusion
73 | fused_documents = self.weighted_reciprocal_rank(retriever_docs)
74 |
75 | return fused_documents
76 |
77 | async def arank_fusion(
78 | self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
79 | ) -> List[Document]:
80 | """
81 | Asynchronously retrieve the results of the retrievers
82 | and use rank_fusion_func to get the final result.
83 |
84 | Args:
85 | query: The query to search for.
86 |
87 | Returns:
88 | A list of reranked documents.
89 | """
90 |
91 | # Get the results of all retrievers.
92 | retriever_docs = []
93 | for key, retriever in self.retrievers.items():
94 | if key == "bm25":
95 | res = retriever.get_relevant_documents(
96 | clean_text(query),
97 | callbacks=run_manager.get_child(tag=f"retriever_{key}"),
98 | )
99 | retriever_docs.append(res)
100 | # print("retriever_docs 1:", res)
101 | else:
102 | res = await retriever.aget_relevant_documents(
103 | query, callbacks=run_manager.get_child(tag=f"retriever_{key}")
104 | )
105 | retriever_docs.append(res)
106 |
107 | # apply rank fusion
108 | fused_documents = self.weighted_reciprocal_rank(retriever_docs)
109 |
110 | return fused_documents
111 |
112 | def weighted_reciprocal_rank(
113 | self, doc_lists: List[List[Document]]
114 | ) -> List[Document]:
115 | """
116 | Perform weighted Reciprocal Rank Fusion on multiple rank lists.
117 | You can find more details about RRF here:
118 | https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
119 |
120 | Args:
121 | doc_lists: A list of rank lists, where each rank list contains unique items.
122 |
123 | Returns:
124 | list: The final aggregated list of items sorted by their weighted RRF
125 | scores in descending order.
126 | """
127 | if len(doc_lists) != len(self.weights):
128 | raise ValueError(
129 | "Number of rank lists must be equal to the number of weights."
130 | )
131 |
132 | # replace the page_content with the original uncleaned page_content
133 | doc_lists_ = copy.copy(doc_lists)
134 | for doc_list in doc_lists_:
135 | for doc in doc_list:
136 | doc.page_content = doc.metadata["page_content"]
137 | # doc.metadata["page_content"] = None
138 |
139 | # Create a union of all unique documents in the input doc_lists
140 | all_documents = set()
141 | for doc_list in doc_lists_:
142 | for doc in doc_list:
143 | all_documents.add(doc.page_content)
144 |
145 | # Initialize the RRF score dictionary for each document
146 | rrf_score_dic = {doc: 0.0 for doc in all_documents}
147 |
148 | # Calculate RRF scores for each document
149 | for doc_list, weight in zip(doc_lists_, self.weights):
150 | for rank, doc in enumerate(doc_list, start=1):
151 | rrf_score = weight * (1 / (rank + self.c))
152 | rrf_score_dic[doc.page_content] += rrf_score
153 |
154 | # Sort documents by their RRF scores in descending order
155 | sorted_documents = sorted(
156 | rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
157 | )
158 |
159 | # Map the sorted page_content back to the original document objects
160 | page_content_to_doc_map = {
161 | doc.page_content: doc for doc_list in doc_lists_ for doc in doc_list
162 | }
163 | sorted_docs = [
164 | page_content_to_doc_map[page_content] for page_content in sorted_documents
165 | ]
166 |
167 | return sorted_docs
168 |
169 |
170 | class MyRetriever:
171 | """
172 | Retriever class to handle multi-stage retrieval.
173 | """
174 |
175 | def __init__(
176 | self,
177 | llm,
178 | embedding_chunks_small: List[Document],
179 | embedding_chunks_medium: List[Document],
180 | docs_chunks_small: DocIndexer,
181 | docs_chunks_medium: DocIndexer,
182 | first_retrieval_k: int,
183 | second_retrieval_k: int,
184 | num_windows: int,
185 | retriever_weights: List[float],
186 | ):
187 | """
188 | Initialize the MyRetriever class.
189 |
190 | Args:
191 | llm: Language model for retrieval.
192 | embedding_chunks_small (List[Document]): List of small embedding chunks.
193 | embedding_chunks_medium (List[Document]): List of medium embedding chunks.
194 | docs_chunks_small (DocIndexer): Document indexer for small chunks.
195 | docs_chunks_medium (DocIndexer): Document indexer for medium chunks.
196 | first_retrieval_k (int): Number of top documents to retrieve in first retrieval.
197 | second_retrieval_k (int): Number of top documents to retrieve in second retrieval.
198 | num_windows (int): Number of overlapping windows to consider.
199 | retriever_weights (List[float]): Weights for ensemble retrieval.
200 | """
201 | self.llm = llm
202 | self.embedding_chunks_small = embedding_chunks_small
203 | self.embedding_chunks_medium = embedding_chunks_medium
204 | self.docs_index_small = DocIndexer(docs_chunks_small)
205 | self.docs_index_medium = DocIndexer(docs_chunks_medium)
206 |
207 | self.first_retrieval_k = first_retrieval_k
208 | self.second_retrieval_k = second_retrieval_k
209 | self.num_windows = num_windows
210 | self.retriever_weights = retriever_weights
211 |
212 | def get_retriever(
213 | self,
214 | docs_chunks,
215 | emb_chunks,
216 | emb_filter=None,
217 | k=2,
218 | weights=(0.5, 0.5),
219 | ):
220 | """
221 | Initialize and return a retriever instance with specified parameters.
222 |
223 | Args:
224 | docs_chunks: The document chunks for the BM25 retriever.
225 | emb_chunks: The document chunks for the Embedding retriever.
226 | emb_filter: A filter for embedding retriever.
227 | k (int): The number of top documents to return.
228 | weights (list): Weights for ensemble retrieval.
229 |
230 | Returns:
231 | MyEnsembleRetriever: An instance of MyEnsembleRetriever.
232 | """
233 | bm25_retriever = BM25Retriever.from_documents(docs_chunks)
234 | bm25_retriever.k = k
235 |
236 | emb_retriever = emb_chunks.as_retriever(
237 | search_kwargs={
238 | "filter": emb_filter,
239 | "k": k,
240 | "search_type": "mmr",
241 | }
242 | )
243 | return MyEnsembleRetriever(
244 | retrievers={"bm25": bm25_retriever, "chroma": emb_retriever},
245 | weights=weights,
246 | )
247 |
248 | def find_overlaps(self, doc: List[Document]):
249 | """
250 | Find overlapping intervals of windows.
251 |
252 | Args:
253 | doc (Document): A document object to find overlaps in.
254 |
255 | Returns:
256 | list: A list of overlapping intervals.
257 | """
258 | intervals = []
259 | for item in doc:
260 | intervals.append(
261 | (
262 | item.metadata["large_chunks_idx_lower_bound"],
263 | item.metadata["large_chunks_idx_upper_bound"],
264 | )
265 | )
266 | remaining_intervals, grouped_intervals, centroids = intervals.copy(), [], []
267 |
268 | while remaining_intervals:
269 | curr_interval = remaining_intervals.pop(0)
270 | curr_group = [curr_interval]
271 | subset_interval = None
272 |
273 | for start, end in remaining_intervals.copy():
274 | for s, e in curr_group:
275 | overlap = set(range(s, e + 1)) & set(range(start, end + 1))
276 | if overlap:
277 | curr_group.append((start, end))
278 | remaining_intervals.remove((start, end))
279 | if set(range(start, end + 1)).issubset(set(range(s, e + 1))):
280 | subset_interval = (start, end)
281 | break
282 |
283 | if subset_interval:
284 | centroid = [math.ceil((subset_interval[0] + subset_interval[1]) / 2)]
285 | elif len(curr_group) > 2:
286 | first_overlap = max(
287 | set(range(curr_group[0][0], curr_group[0][1] + 1))
288 | & set(range(curr_group[1][0], curr_group[1][1] + 1))
289 | )
290 | last_overlap_set = set(
291 | range(curr_group[-1][0], curr_group[-1][1] + 1)
292 | ) & set(range(curr_group[-2][0], curr_group[-2][1] + 1))
293 |
294 | if not last_overlap_set:
295 | last_overlap = first_overlap # Fallback if no overlap
296 | else:
297 | last_overlap = min(last_overlap_set)
298 |
299 | step = 1 if first_overlap <= last_overlap else -1
300 | centroid = list(range(first_overlap, last_overlap + step, step))
301 | else:
302 | centroid = [
303 | round(
304 | sum([math.ceil((s + e) / 2) for s, e in curr_group])
305 | / len(curr_group)
306 | )
307 | ]
308 |
309 | grouped_intervals.append(
310 | curr_group if len(curr_group) > 1 else curr_group[0]
311 | )
312 | centroids.extend(centroid)
313 |
314 | return centroids
315 |
316 | def get_filter(self, top_k: int, file_md5: str, doc: List[Document]):
317 | """
318 | Create a filter for retrievers based on overlapping intervals.
319 |
320 | Args:
321 | top_k (int): Number of top intervals to consider.
322 | file_md5 (str): MD5 hash of the file to filter.
323 | doc (List[Document]): List of document objects.
324 |
325 | Returns:
326 | tuple: A tuple of containing dictionary filters for DocIndexer and Chroma retrievers.
327 | """
328 | overlaps = self.find_overlaps(doc)
329 | if len(overlaps) < 1:
330 | raise ValueError("No overlapping intervals found.")
331 |
332 | overlaps_k = overlaps[:top_k]
333 | logger.info("windows_at_2nd_retrieval: %s", overlaps_k)
334 | search_dict_docindexer = {"OR": []}
335 | search_dict_chroma = {"$or": []}
336 |
337 | for chunk_idx in overlaps_k:
338 | search_dict_docindexer["OR"].append(
339 | {
340 | "large_chunks_idx_lower_bound": (
341 | IndexerOperator.LTE,
342 | chunk_idx,
343 | ),
344 | "large_chunks_idx_upper_bound": (
345 | IndexerOperator.GTE,
346 | chunk_idx,
347 | ),
348 | "source_md5": (IndexerOperator.EQ, file_md5),
349 | }
350 | )
351 |
352 | if len(overlaps_k) == 1:
353 | search_dict_chroma = {
354 | "$and": [
355 | {"large_chunks_idx_lower_bound": {"$lte": overlaps_k[0]}},
356 | {"large_chunks_idx_upper_bound": {"$gte": overlaps_k[0]}},
357 | {"source_md5": {"$eq": file_md5}},
358 | ]
359 | }
360 | else:
361 | search_dict_chroma["$or"].append(
362 | {
363 | "$and": [
364 | {"large_chunks_idx_lower_bound": {"$lte": chunk_idx}},
365 | {"large_chunks_idx_upper_bound": {"$gte": chunk_idx}},
366 | {"source_md5": {"$eq": file_md5}},
367 | ]
368 | }
369 | )
370 |
371 | return search_dict_docindexer, search_dict_chroma
372 |
373 | def get_relevant_doc_ids(self, docs: List[Document], query: str):
374 | """
375 | Get relevant document IDs given a query using an LLM.
376 |
377 | Args:
378 | docs (List[Document]): List of document objects to find relevant IDs in.
379 | query (str): The query string.
380 |
381 | Returns:
382 | list: A list of relevant document IDs.
383 | """
384 | snippets = "\n\n\n".join(
385 | [
386 | f"Context {idx}:\n{{{doc.page_content}}}. {{source: {doc.metadata['source']}}}"
387 | for idx, doc in enumerate(docs)
388 | ]
389 | )
390 | id_chain = LLMChain(
391 | llm=self.llm,
392 | prompt=prompt_templates.get_docs_selection_template(configs.model_name),
393 | output_key="IDs",
394 | )
395 | ids = id_chain.run({"query": query, "snippets": snippets})
396 | logger.info("relevant doc ids: %s", ids)
397 | pattern = r"\[\s*\d+\s*(?:,\s*\d+\s*)*\]"
398 | match = re.search(pattern, ids)
399 | if match:
400 | return ast.literal_eval(match.group(0))
401 | else:
402 | return []
403 |
404 | def get_relevant_documents(
405 | self,
406 | query: str,
407 | num_query: int,
408 | *,
409 | run_manager: Optional[CallbackManagerForChainRun] = None,
410 | ) -> List[Document]:
411 | """
412 | Perform multi-stage retrieval to get relevant documents.
413 |
414 | Args:
415 | query (str): The query string.
416 | num_query (int): Number of queries.
417 | run_manager (Optional[CallbackManagerForChainRun], optional): Callback manager for chain run.
418 |
419 | Returns:
420 | List[Document]: A list of relevant documents.
421 | """
422 | # ! First retrieval
423 | first_retriever = self.get_retriever(
424 | docs_chunks=self.docs_index_small.documents,
425 | emb_chunks=self.embedding_chunks_small,
426 | emb_filter=None,
427 | k=self.first_retrieval_k,
428 | weights=self.retriever_weights,
429 | )
430 | first = first_retriever.get_relevant_documents(
431 | query, callbacks=run_manager.get_child()
432 | )
433 | for doc in first:
434 | logger.info("----1st retrieval----: %s", doc)
435 | ids_clean = self.get_relevant_doc_ids(first, query)
436 | # ids_clean = [0, 1, 2]
437 | logger.info("relevant cleaned doc ids: %s", ids_clean)
438 | qa_chunks = {} # key is file name, value is a list of relevant documents
439 | # res_chunks = []
440 | if ids_clean and isinstance(ids_clean, list):
441 | source_md5_dict = {}
442 | for ids_c in ids_clean:
443 | if ids_c < len(first):
444 | if ids_c not in source_md5_dict:
445 | source_md5_dict[first[ids_c].metadata["source_md5"]] = [
446 | first[ids_c]
447 | ]
448 | # else:
449 | # source_md5_dict[first[ids_c].metadata["source_md5"]].append(
450 | # ids_clean[ids_c]
451 | # )
452 | if len(source_md5_dict) == 0:
453 | source_md5_dict[first[0].metadata["source_md5"]] = [first[0]]
454 | num_docs = len(source_md5_dict.keys())
455 | third_num_k = max(
456 | 1,
457 | (
458 | int(
459 | (
460 | configs.max_llm_context
461 | / (configs.base_chunk_size * configs.chunk_scale)
462 | )
463 | // (num_docs * num_query)
464 | )
465 | ),
466 | )
467 |
468 | for source_md5, docs in source_md5_dict.items():
469 | logger.info(
470 | "selected_docs_at_1st_retrieval: %s", docs[0].metadata["source"]
471 | )
472 | second_docs_chunks = self.docs_index_small.retrieve_metadata(
473 | {
474 | "source_md5": (IndexerOperator.EQ, source_md5),
475 | }
476 | )
477 | second_retriever = self.get_retriever(
478 | docs_chunks=second_docs_chunks,
479 | emb_chunks=self.embedding_chunks_small,
480 | emb_filter={"source_md5": source_md5},
481 | k=self.second_retrieval_k,
482 | weights=self.retriever_weights,
483 | )
484 | # ! Second retrieval
485 | second = second_retriever.get_relevant_documents(
486 | query, callbacks=run_manager.get_child()
487 | )
488 | for doc in second:
489 | logger.info("----2nd retrieval----: %s", doc)
490 | docs.extend(second)
491 | docindexer_filter, chroma_filter = self.get_filter(
492 | self.num_windows, source_md5, docs
493 | )
494 | third_docs_chunks = self.docs_index_medium.retrieve_metadata(
495 | docindexer_filter
496 | )
497 | third_retriever = self.get_retriever(
498 | docs_chunks=third_docs_chunks,
499 | emb_chunks=self.embedding_chunks_medium,
500 | emb_filter=chroma_filter,
501 | k=third_num_k,
502 | weights=self.retriever_weights,
503 | )
504 | # ! Third retrieval
505 | third_temp = third_retriever.get_relevant_documents(
506 | query, callbacks=run_manager.get_child()
507 | )
508 | third = third_temp[:third_num_k]
509 | # chunks = sorted(third, key=lambda x: x.metadata["medium_chunk_idx"])
510 | for doc in third:
511 | logger.info(
512 | "----3rd retrieval----page_content: %s", [doc.page_content]
513 | )
514 | mtdata = doc.metadata
515 | mtdata["page_content"] = None
516 | logger.info("----3rd retrieval----metadata: %s", mtdata)
517 | file_name = third[0].metadata["source"].split("/")[-1]
518 | if file_name not in qa_chunks:
519 | qa_chunks[file_name] = third
520 | else:
521 | qa_chunks[file_name].extend(third)
522 |
523 | return qa_chunks
524 |
525 | async def aget_relevant_documents(
526 | self,
527 | query: str,
528 | num_query: int,
529 | *,
530 | run_manager: AsyncCallbackManagerForChainRun,
531 | ) -> List[Document]:
532 | """
533 | Asynchronous version of get_relevant_documents method.
534 |
535 | Args:
536 | query (str): The query string.
537 | num_query (int): Number of queries.
538 | run_manager (AsyncCallbackManagerForChainRun): Callback manager for asynchronous chain run.
539 |
540 | Returns:
541 | List[Document]: A list of relevant documents.
542 | """
543 | # ! First retrieval
544 | first_retriever = self.get_retriever(
545 | docs_chunks=self.docs_index_small.documents,
546 | emb_chunks=self.embedding_chunks_small,
547 | emb_filter=None,
548 | k=self.first_retrieval_k,
549 | weights=self.retriever_weights,
550 | )
551 | first = await first_retriever.aget_relevant_documents(
552 | query, callbacks=run_manager.get_child()
553 | )
554 | for doc in first:
555 | logger.info("----1st retrieval----: %s", doc)
556 | ids_clean = self.get_relevant_doc_ids(first, query)
557 | logger.info("relevant doc ids: %s", ids_clean)
558 | qa_chunks = {} # key is file name, value is a list of relevant documents
559 | # res_chunks = []
560 | if ids_clean and isinstance(ids_clean, list):
561 | source_md5_dict = {}
562 | for ids_c in ids_clean:
563 | if ids_c < len(first):
564 | if ids_c not in source_md5_dict:
565 | source_md5_dict[first[ids_c].metadata["source_md5"]] = [
566 | first[ids_c]
567 | ]
568 | # else:
569 | # source_md5_dict[first[ids_c].metadata["source_md5"]].append(
570 | # ids_clean[ids_c]
571 | # )
572 | if len(source_md5_dict) == 0:
573 | source_md5_dict[first[0].metadata["source_md5"]] = [first[0]]
574 | num_docs = len(source_md5_dict.keys())
575 | third_num_k = max(
576 | 1,
577 | (
578 | int(
579 | (
580 | configs.max_llm_context
581 | / (configs.base_chunk_size * configs.chunk_scale)
582 | )
583 | // (num_docs * num_query)
584 | )
585 | ),
586 | )
587 |
588 | for source_md5, docs in source_md5_dict.items():
589 | logger.info(
590 | "selected_docs_at_1st_retrieval: %s", docs[0].metadata["source"]
591 | )
592 | second_docs_chunks = self.docs_index_small.retrieve_metadata(
593 | {
594 | "source_md5": (IndexerOperator.EQ, source_md5),
595 | }
596 | )
597 | second_retriever = self.get_retriever(
598 | docs_chunks=second_docs_chunks,
599 | emb_chunks=self.embedding_chunks_small,
600 | emb_filter={"source_md5": source_md5},
601 | k=self.second_retrieval_k,
602 | weights=self.retriever_weights,
603 | )
604 | # ! Second retrieval
605 | second = await second_retriever.aget_relevant_documents(
606 | query, callbacks=run_manager.get_child()
607 | )
608 | for doc in second:
609 | logger.info("----2nd retrieval----: %s", doc)
610 | docs.extend(second)
611 | docindexer_filter, chroma_filter = self.get_filter(
612 | self.num_windows, source_md5, docs
613 | )
614 | third_docs_chunks = self.docs_index_medium.retrieve_metadata(
615 | docindexer_filter
616 | )
617 | third_retriever = self.get_retriever(
618 | docs_chunks=third_docs_chunks,
619 | emb_chunks=self.embedding_chunks_medium,
620 | emb_filter=chroma_filter,
621 | k=third_num_k,
622 | weights=self.retriever_weights,
623 | )
624 | # ! Third retrieval
625 | third_temp = await third_retriever.aget_relevant_documents(
626 | query, callbacks=run_manager.get_child()
627 | )
628 | third = third_temp[:third_num_k]
629 | # chunks = sorted(third, key=lambda x: x.metadata["medium_chunk_idx"])
630 | for doc in third:
631 | logger.info(
632 | "----3rd retrieval----page_content: %s", [doc.page_content]
633 | )
634 | mtdata = doc.metadata
635 | mtdata["page_content"] = None
636 | logger.info("----3rd retrieval----metadata: %s", mtdata)
637 | file_name = third[0].metadata["source"].split("/")[-1]
638 | if file_name not in qa_chunks:
639 | qa_chunks[file_name] = third
640 | else:
641 | qa_chunks[file_name].extend(third)
642 |
643 | return qa_chunks
644 |
--------------------------------------------------------------------------------
/toolkit/together_api_llm.py:
--------------------------------------------------------------------------------
1 | """The code borrowed from https://colab.research.google.com/drive/1RW2yTxh5b9w7F3IrK00Iz51FTO5W01Rx?usp=sharing#scrollTo=RgbLVmf-o4j7"""
2 | import os
3 | from typing import Any, Dict
4 | import together
5 | from pydantic import Extra, root_validator
6 |
7 | from langchain.llms.base import LLM
8 | from langchain.utils import get_from_dict_or_env
9 | from toolkit.utils import Config
10 |
11 | configs = Config("configparser.ini")
12 | os.environ["TOGETHER_API_KEY"] = configs.together_api_key
13 |
14 | # together.api_key = configs.together_api_key
15 | # models = together.Models.list()
16 | # for idx, model in enumerate(models):
17 | # print(idx, model["name"])
18 |
19 |
20 | class TogetherLLM(LLM):
21 | """Together large language models."""
22 |
23 | model: str = "togethercomputer/llama-2-70b-chat"
24 | """model endpoint to use"""
25 |
26 | together_api_key: str = os.environ["TOGETHER_API_KEY"]
27 | """Together API key"""
28 |
29 | temperature: float = 0
30 | """What sampling temperature to use."""
31 |
32 | max_tokens: int = 512
33 | """The maximum number of tokens to generate in the completion."""
34 |
35 | class Config:
36 | extra = "forbid"
37 |
38 | # @root_validator()
39 | # def validate_environment(cls, values: Dict) -> Dict:
40 | # """Validate that the API key is set."""
41 | # api_key = get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
42 | # values["together_api_key"] = api_key
43 | # return values
44 |
45 | @property
46 | def _llm_type(self) -> str:
47 | """Return type of LLM."""
48 | return "together"
49 |
50 | def _call(
51 | self,
52 | prompt: str,
53 | **kwargs: Any,
54 | ) -> str:
55 | """Call to Together endpoint."""
56 | together.api_key = self.together_api_key
57 | output = together.Complete.create(
58 | prompt,
59 | model=self.model,
60 | max_tokens=self.max_tokens,
61 | temperature=self.temperature,
62 | )
63 | text = output["output"]["choices"][0]["text"]
64 | return text
65 |
66 |
67 | # if __name__ == "__main__":
68 | # test_llm = TogetherLLM(
69 | # model="togethercomputer/llama-2-70b-chat", temperature=0, max_tokens=1000
70 | # )
71 |
72 | # print(test_llm("What are the olympics? "))
73 |
--------------------------------------------------------------------------------
/toolkit/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | The widgets defines utility functions for loading data, text cleaning,
3 | and indexing documents, as well as classes for handling document queries
4 | and formatting chat history.
5 | """
6 | import re
7 | import pickle
8 | import string
9 | import logging
10 | import configparser
11 | from enum import Enum
12 | from typing import List, Tuple, Union
13 | import nltk
14 | from nltk.stem import WordNetLemmatizer
15 | from nltk.tokenize import word_tokenize
16 | from nltk.corpus import stopwords
17 | import torch
18 | import tiktoken
19 | from langchain.vectorstores import Chroma
20 |
21 | from langchain.schema import Document, BaseMessage
22 | from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
23 | from langchain.embeddings.openai import OpenAIEmbeddings
24 |
25 |
26 | tokenizer_name = tiktoken.encoding_for_model("gpt-3.5-turbo")
27 | tokenizer = tiktoken.get_encoding(tokenizer_name.name)
28 |
29 | # if nltk stopwords, punkt and wordnet are not downloaded, download it
30 | try:
31 | nltk.data.find("corpora/stopwords")
32 | except LookupError:
33 | nltk.download("stopwords")
34 | try:
35 | nltk.data.find("tokenizers/punkt")
36 | except LookupError:
37 | nltk.download("punkt")
38 | try:
39 | nltk.data.find("corpora/wordnet")
40 | except LookupError:
41 | nltk.download("wordnet")
42 |
43 | ChatTurnType = Union[Tuple[str, str], BaseMessage]
44 | _ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
45 |
46 |
47 | class Config:
48 | """Initializes configs."""
49 |
50 | def __init__(self, config_file):
51 | self.config = configparser.ConfigParser(interpolation=None)
52 | self.config.read(config_file)
53 |
54 | # Tokens
55 | self.openai_api_key = self.config.get("tokens", "OPENAI_API_KEY")
56 | self.anthropic_api_key = self.config.get("tokens", "ANTHROPIC_API_KEY")
57 | self.together_api_key = self.config.get("tokens", "TOGETHER_API_KEY")
58 | self.huggingface_token = self.config.get("tokens", "HUGGINGFACE_TOKEN")
59 | self.version = self.config.get("tokens", "VERSION")
60 |
61 | # Directory
62 | self.docs_dir = self.config.get("directory", "DOCS_DIR")
63 | self.db_dir = self.config.get("directory", "db_DIR")
64 | self.local_model_dir = self.config.get("directory", "LOCAL_MODEL_DIR")
65 |
66 | # Parameters
67 | self.model_name = self.config.get("parameters", "MODEL_NAME")
68 | self.temperature = self.config.getfloat("parameters", "TEMPURATURE")
69 | self.max_chat_history = self.config.getint("parameters", "MAX_CHAT_HISTORY")
70 | self.max_llm_context = self.config.getint("parameters", "MAX_LLM_CONTEXT")
71 | self.max_llm_generation = self.config.getint("parameters", "MAX_LLM_GENERATION")
72 | self.embedding_name = self.config.get("parameters", "EMBEDDING_NAME")
73 |
74 | self.n_gpu_layers = self.config.getint("parameters", "N_GPU_LAYERS")
75 | self.n_batch = self.config.getint("parameters", "N_BATCH")
76 |
77 | self.base_chunk_size = self.config.getint("parameters", "BASE_CHUNK_SIZE")
78 | self.chunk_overlap = self.config.getint("parameters", "CHUNK_OVERLAP")
79 | self.chunk_scale = self.config.getint("parameters", "CHUNK_SCALE")
80 | self.window_steps = self.config.getint("parameters", "WINDOW_STEPS")
81 | self.window_scale = self.config.getint("parameters", "WINDOW_SCALE")
82 |
83 | self.retriever_weights = [
84 | float(x.strip())
85 | for x in self.config.get("parameters", "RETRIEVER_WEIGHTS").split(",")
86 | ]
87 | self.first_retrieval_k = self.config.getint("parameters", "FIRST_RETRIEVAL_K")
88 | self.second_retrieval_k = self.config.getint("parameters", "SECOND_RETRIEVAL_K")
89 | self.num_windows = self.config.getint("parameters", "NUM_WINDOWS")
90 |
91 | # Logging
92 | self.logging_enabled = self.config.getboolean("logging", "enabled")
93 | self.logging_level = self.config.get("logging", "level")
94 | self.logging_filename = self.config.get("logging", "filename")
95 | self.logging_format = self.config.get("logging", "format")
96 |
97 | self.configure_logging()
98 |
99 | def configure_logging(self):
100 | """
101 | Configure the logger for each .py files.
102 | """
103 |
104 | if not self.logging_enabled:
105 | logging.disable(logging.CRITICAL + 1)
106 | return
107 |
108 | log_level = self.config.get("logging", "level")
109 | log_filename = self.config.get("logging", "filename")
110 | log_format = self.config.get("logging", "format")
111 |
112 | logging.basicConfig(level=log_level, filename=log_filename, format=log_format)
113 |
114 |
115 | def configure_logger():
116 | """
117 | Configure the logger for each .py files.
118 | """
119 | config = configparser.ConfigParser(interpolation=None)
120 | config.read("configparser.ini")
121 |
122 | enabled = config.getboolean("logging", "enabled")
123 |
124 | if not enabled:
125 | logging.disable(logging.CRITICAL + 1)
126 | return
127 |
128 | log_level = config.get("logging", "level")
129 | log_filename = config.get("logging", "filename")
130 | log_format = config.get("logging", "format")
131 |
132 | logging.basicConfig(level=log_level, filename=log_filename, format=log_format)
133 |
134 |
135 | def tiktoken_len(text):
136 | """token length function"""
137 | tokens = tokenizer.encode(text, disallowed_special=())
138 | return len(tokens)
139 |
140 |
141 | def check_device():
142 | """Check if cuda or MPS is available, else fallback to CPU"""
143 | if torch.cuda.is_available():
144 | device = "cuda"
145 | elif torch.backends.mps.is_available():
146 | device = "mps"
147 | else:
148 | device = "cpu"
149 | return device
150 |
151 |
152 | def choose_embeddings(embedding_name):
153 | """Choose embeddings for a given model's name"""
154 | try:
155 | if embedding_name == "openAIEmbeddings":
156 | return OpenAIEmbeddings()
157 | elif embedding_name == "hkunlpInstructorLarge":
158 | device = check_device()
159 | return HuggingFaceInstructEmbeddings(
160 | model_name="hkunlp/instructor-large", model_kwargs={"device": device}
161 | )
162 | else:
163 | device = check_device()
164 | return HuggingFaceEmbeddings(model_name=embedding_name, device=device)
165 | except Exception as error:
166 | raise ValueError(f"Embedding {embedding_name} not supported") from error
167 |
168 |
169 | def load_embedding(store_name, embedding, suffix, path):
170 | """Load chroma embeddings"""
171 | vector_store = Chroma(
172 | persist_directory=f"{path}/chroma_{store_name}_{suffix}",
173 | embedding_function=embedding,
174 | )
175 | return vector_store
176 |
177 |
178 | def load_pickle(prefix, suffix, path):
179 | """Load langchain documents from a pickle file.
180 |
181 | Args:
182 | store_name (str): The name of the store where data is saved.
183 | suffix (str): Suffix to append to the store name.
184 | path (str): The path where the pickle file is stored.
185 |
186 | Returns:
187 | Document: documents from the pickle file
188 | """
189 | with open(f"{path}/{prefix}_{suffix}.pkl", "rb") as file:
190 | return pickle.load(file)
191 |
192 |
193 | def clean_text(text):
194 | """
195 | Converts text to lowercase, removes punctuation, stopwords, and lemmatizes it
196 | for BM25 retriever.
197 |
198 | Parameters:
199 | text (str): The text to be cleaned.
200 |
201 | Returns:
202 | str: The cleaned and lemmatized text.
203 | """
204 | # remove [SEP] in the text
205 | text = text.replace("[SEP]", "")
206 | # Tokenization
207 | tokens = word_tokenize(text)
208 | # Lowercasing
209 | tokens = [w.lower() for w in tokens]
210 | # Remove punctuation
211 | table = str.maketrans("", "", string.punctuation)
212 | stripped = [w.translate(table) for w in tokens]
213 | # Keep tokens that are alphabetic, numeric, or contain both.
214 | words = [
215 | word
216 | for word in stripped
217 | if word.isalpha()
218 | or word.isdigit()
219 | or (re.search("\d", word) and re.search("[a-zA-Z]", word))
220 | ]
221 | # Remove stopwords
222 | stop_words = set(stopwords.words("english"))
223 | words = [w for w in words if w not in stop_words]
224 | # Lemmatization (or you could use stemming instead)
225 | lemmatizer = WordNetLemmatizer()
226 | lemmatized = [lemmatizer.lemmatize(w) for w in words]
227 | # Convert list of words to a string
228 | lemmatized_ = " ".join(lemmatized)
229 |
230 | return lemmatized_
231 |
232 |
233 | class IndexerOperator(Enum):
234 | """
235 | Enumeration for different query operators used in indexing.
236 | """
237 |
238 | EQ = "=="
239 | GT = ">"
240 | GTE = ">="
241 | LT = "<"
242 | LTE = "<="
243 |
244 |
245 | class DocIndexer:
246 | """
247 | A class to handle indexing and searching of documents.
248 |
249 | Attributes:
250 | documents (List[Document]): List of documents to be indexed.
251 | """
252 |
253 | def __init__(self, documents):
254 | self.documents = documents
255 | self.index = self.build_index(documents)
256 |
257 | def build_index(self, documents):
258 | """
259 | Build an index for the given list of documents.
260 |
261 | Parameters:
262 | documents (List[Document]): The list of documents to be indexed.
263 |
264 | Returns:
265 | dict: The built index.
266 | """
267 | index = {}
268 | for doc in documents:
269 | for key, value in doc.metadata.items():
270 | if key not in index:
271 | index[key] = {}
272 | if value not in index[key]:
273 | index[key][value] = []
274 | index[key][value].append(doc)
275 | return index
276 |
277 | def retrieve_metadata(self, search_dict):
278 | """
279 | Retrieve documents based on the search criteria provided in search_dict.
280 |
281 | Parameters:
282 | search_dict (dict): Dictionary specifying the search criteria.
283 | It can contain "AND" or "OR" operators for
284 | complex queries.
285 |
286 | Returns:
287 | List[Document]: List of documents that match the search criteria.
288 | """
289 | if "AND" in search_dict:
290 | return self._handle_and(search_dict["AND"])
291 | elif "OR" in search_dict:
292 | return self._handle_or(search_dict["OR"])
293 | else:
294 | return self._handle_single(search_dict)
295 |
296 | def _handle_and(self, search_dicts):
297 | results = [self.retrieve_metadata(sd) for sd in search_dicts]
298 | if results:
299 | intersection = set.intersection(
300 | *[set(map(self._hash_doc, r)) for r in results]
301 | )
302 | return [self._unhash_doc(h) for h in intersection]
303 | else:
304 | return []
305 |
306 | def _handle_or(self, search_dicts):
307 | results = [self.retrieve_metadata(sd) for sd in search_dicts]
308 | union = set.union(*[set(map(self._hash_doc, r)) for r in results])
309 | return [self._unhash_doc(h) for h in union]
310 |
311 | def _handle_single(self, search_dict):
312 | unions = []
313 | for key, query in search_dict.items():
314 | operator, value = query
315 | union = set()
316 | if operator == IndexerOperator.EQ:
317 | if key in self.index and value in self.index[key]:
318 | union.update(map(self._hash_doc, self.index[key][value]))
319 | else:
320 | if key in self.index:
321 | for k, v in self.index[key].items():
322 | if (
323 | (operator == IndexerOperator.GT and k > value)
324 | or (operator == IndexerOperator.GTE and k >= value)
325 | or (operator == IndexerOperator.LT and k < value)
326 | or (operator == IndexerOperator.LTE and k <= value)
327 | ):
328 | union.update(map(self._hash_doc, v))
329 | if union:
330 | unions.append(union)
331 |
332 | if unions:
333 | intersection = set.intersection(*unions)
334 | return [self._unhash_doc(h) for h in intersection]
335 | else:
336 | return []
337 |
338 | def _hash_doc(self, doc):
339 | return (doc.page_content, frozenset(doc.metadata.items()))
340 |
341 | def _unhash_doc(self, hashed_doc):
342 | page_content, metadata = hashed_doc
343 | return Document(page_content=page_content, metadata=dict(metadata))
344 |
345 |
346 | def _get_chat_history(chat_history: List[ChatTurnType]) -> str:
347 | buffer = ""
348 | for dialogue_turn in chat_history:
349 | if isinstance(dialogue_turn, BaseMessage):
350 | role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ")
351 | buffer += f"\n{role_prefix}{dialogue_turn.content}"
352 | elif isinstance(dialogue_turn, tuple):
353 | human = "Human: " + dialogue_turn[0]
354 | ai = "Assistant: " + dialogue_turn[1]
355 | buffer += "\n" + "\n".join([human, ai])
356 | else:
357 | raise ValueError(
358 | f"Unsupported chat history format: {type(dialogue_turn)}."
359 | f" Full chat history: {chat_history} "
360 | )
361 | return buffer
362 |
363 |
364 | def _get_standalone_questions_list(
365 | standalone_questions_str: str, original_question: str
366 | ) -> List[str]:
367 | pattern = r"\d+\.\s(.*?)(?=\n\d+\.|\n|$)"
368 |
369 | matches = [
370 | match.group(1) for match in re.finditer(pattern, standalone_questions_str)
371 | ]
372 | if matches:
373 | return matches
374 |
375 | match = re.search(
376 | r"(?i)standalone[^\n]*:[^\n](.*)", standalone_questions_str, re.DOTALL
377 | )
378 | sentence_source = match.group(1).strip() if match else standalone_questions_str
379 | sentences = sentence_source.split("\n")
380 |
381 | return [
382 | re.sub(
383 | r"^\((\d+)\)\.? ?|^\d+\.? ?\)?|^(\d+)\) ?|^(\d+)\) ?|^[Qq]uery \d+: ?|^[Qq]uery: ?",
384 | "",
385 | sentence.strip(),
386 | )
387 | for sentence in sentences
388 | if sentence.strip()
389 | ]
390 |
--------------------------------------------------------------------------------