├── LICENSE
├── README.md
└── finetune_embedding_lora.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # llamaindex-embedding-lora
2 |
3 | Example notebook accompanying the a blog post on `LoRA fine-tuning of embedding models using LlamaIndex` available at https://medium.com/@diagnosta/lora-fine-tuning-of-embedding-models-using-llamaindex-a60b823a2c94
4 |
--------------------------------------------------------------------------------
/finetune_embedding_lora.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "44986dbb",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "03f5ac7e-d36d-4879-959a-1af414fe4c02",
14 | "metadata": {},
15 | "source": [
16 | "# LoRA finetuning of any Black-Box Embedding Model\n",
17 | "\n",
18 | "This notebook is based on https://github.com/run-llama/llama_index/blob/3e5d0a146fcda01a984818d381f31a19287aead8/docs/examples/finetuning/embeddings/finetune_embedding_adapter.ipynb and demonstrates how to:\n",
19 | "\n",
20 | "- Generate a fine-tuning corpus using a local LLM\n",
21 | "- Fine-tune a local embedding model using LoRA\n",
22 | "\n",
23 | "The latter is achieved by subclassing the `EmbeddingAdapterFinetuneEngine` and a few tricks in order to make it behave (in the way we want it to)."
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "id": "9ab6c5cc-8b31-41cd-95aa-6d60fbefff9b",
29 | "metadata": {},
30 | "source": [
31 | "## Generate Corpus\n",
32 | "\n",
33 | "We use our helper abstractions, `generate_qa_embedding_pairs`, to generate our training and evaluation dataset. This function takes in any set of text nodes (chunks) and generates a structured dataset containing (question, context) pairs."
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 1,
39 | "id": "9b36f73f-83b1-4715-bd4d-7ce1353d1a19",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "import torch\n",
44 | "from typing import Any, List, Optional, Tuple#, Union\n",
45 | "from llama_index.core import SimpleDirectoryReader\n",
46 | "from llama_index.core.base.embeddings.base import BaseEmbedding\n",
47 | "from llama_index.core.node_parser import SentenceSplitter\n",
48 | "from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding\n",
49 | "from llama_index.embeddings.huggingface.pooling import Pooling\n",
50 | "from llama_index.finetuning import EmbeddingAdapterFinetuneEngine\n",
51 | "from llama_index.finetuning.embeddings.adapter_utils import BaseAdapter"
52 | ]
53 | },
54 | {
55 | "attachments": {},
56 | "cell_type": "markdown",
57 | "id": "2fc4bd24",
58 | "metadata": {},
59 | "source": [
60 | "Download Data"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 2,
66 | "id": "6ae97522",
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "name": "stdout",
71 | "output_type": "stream",
72 | "text": [
73 | "--2024-03-18 14:51:34-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf\n",
74 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...\n",
75 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
76 | "HTTP request sent, awaiting response... 200 OK\n",
77 | "Length: 1880483 (1.8M) [application/octet-stream]\n",
78 | "Saving to: ‘data/10k/uber_2021.pdf’\n",
79 | "\n",
80 | "data/10k/uber_2021. 100%[===================>] 1.79M --.-KB/s in 0.04s \n",
81 | "\n",
82 | "2024-03-18 14:51:34 (41.6 MB/s) - ‘data/10k/uber_2021.pdf’ saved [1880483/1880483]\n",
83 | "\n",
84 | "--2024-03-18 14:51:34-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf\n",
85 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...\n",
86 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
87 | "HTTP request sent, awaiting response... 200 OK\n",
88 | "Length: 1440303 (1.4M) [application/octet-stream]\n",
89 | "Saving to: ‘data/10k/lyft_2021.pdf’\n",
90 | "\n",
91 | "data/10k/lyft_2021. 100%[===================>] 1.37M --.-KB/s in 0.03s \n",
92 | "\n",
93 | "2024-03-18 14:51:35 (43.5 MB/s) - ‘data/10k/lyft_2021.pdf’ saved [1440303/1440303]\n",
94 | "\n"
95 | ]
96 | }
97 | ],
98 | "source": [
99 | "!mkdir -p 'data/10k/'\n",
100 | "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf' -O 'data/10k/uber_2021.pdf'\n",
101 | "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 3,
107 | "id": "58c43042-2ed1-4ab7-a53d-7f65dd856f83",
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "TRAIN_FILES = [\"./data/10k/lyft_2021.pdf\"]\n",
112 | "VAL_FILES = [\"./data/10k/uber_2021.pdf\"]\n",
113 | "\n",
114 | "TRAIN_CORPUS_FPATH = \"./data/train_corpus.json\"\n",
115 | "VAL_CORPUS_FPATH = \"./data/val_corpus.json\""
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 4,
121 | "id": "3c7e38d0-39ff-44e2-ab7f-fded56dcd707",
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "def load_corpus(files, verbose=False):\n",
126 | " if verbose: print(f\"Loading files {files}\")\n",
127 | "\n",
128 | " reader = SimpleDirectoryReader(input_files=files)\n",
129 | " docs = reader.load_data()\n",
130 | " if verbose: print(f\"Loaded {len(docs)} docs\")\n",
131 | "\n",
132 | " parser = SentenceSplitter()\n",
133 | " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n",
134 | " if verbose: print(f\"Parsed {len(nodes)} nodes\")\n",
135 | "\n",
136 | " return nodes"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "id": "d1257dce-0be1-42c4-9346-a1fe68505fdd",
142 | "metadata": {},
143 | "source": [
144 | "We do a very naive train/val split by having the Lyft corpus as the train dataset, and the Uber corpus as the val dataset."
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 5,
150 | "id": "ffd6d8af-5382-48b8-8a7d-98a03d2f150d",
151 | "metadata": {},
152 | "outputs": [
153 | {
154 | "name": "stdout",
155 | "output_type": "stream",
156 | "text": [
157 | "Loading files ['./data/10k/lyft_2021.pdf']\n",
158 | "Loaded 238 docs\n"
159 | ]
160 | },
161 | {
162 | "data": {
163 | "application/vnd.jupyter.widget-view+json": {
164 | "model_id": "00e3746a01034f1387e563606519788d",
165 | "version_major": 2,
166 | "version_minor": 0
167 | },
168 | "text/plain": [
169 | "Parsing nodes: 0%| | 0/238 [00:00, ?it/s]"
170 | ]
171 | },
172 | "metadata": {},
173 | "output_type": "display_data"
174 | },
175 | {
176 | "name": "stdout",
177 | "output_type": "stream",
178 | "text": [
179 | "Parsed 344 nodes\n",
180 | "Loading files ['./data/10k/uber_2021.pdf']\n",
181 | "Loaded 307 docs\n"
182 | ]
183 | },
184 | {
185 | "data": {
186 | "application/vnd.jupyter.widget-view+json": {
187 | "model_id": "3980765635d540168c0c12f90f4cba92",
188 | "version_major": 2,
189 | "version_minor": 0
190 | },
191 | "text/plain": [
192 | "Parsing nodes: 0%| | 0/307 [00:00, ?it/s]"
193 | ]
194 | },
195 | "metadata": {},
196 | "output_type": "display_data"
197 | },
198 | {
199 | "name": "stdout",
200 | "output_type": "stream",
201 | "text": [
202 | "Parsed 410 nodes\n"
203 | ]
204 | }
205 | ],
206 | "source": [
207 | "train_nodes = load_corpus(TRAIN_FILES, verbose=True)\n",
208 | "val_nodes = load_corpus(VAL_FILES, verbose=True)"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "id": "1893a5f1-6fdf-473b-80ea-5ea3df5681a7",
214 | "metadata": {},
215 | "source": [
216 | "### Generate synthetic queries\n",
217 | "\n",
218 | "Now, we use an LLM (Mixtral) to generate questions using each text chunk in the corpus as context.\n",
219 | "\n",
220 | "Each pair of (generated question, text chunk used as context) becomes a datapoint in the finetuning dataset (either for training or evaluation)."
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 6,
226 | "id": "ee1c892e-e27d-49f6-96d4-b99af330aed8",
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "from llama_index.finetuning import generate_qa_embedding_pairs\n",
231 | "from llama_index.core.evaluation import EmbeddingQAFinetuneDataset"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": 7,
237 | "id": "f9eddecf",
238 | "metadata": {},
239 | "outputs": [
240 | {
241 | "name": "stderr",
242 | "output_type": "stream",
243 | "text": [
244 | "/opt/conda/envs/llama-index/lib/python3.11/site-packages/transformers/modeling_utils.py:4193: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n",
245 | " warnings.warn(\n"
246 | ]
247 | }
248 | ],
249 | "source": [
250 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
251 | "from llama_index.llms.huggingface import HuggingFaceLLM\n",
252 | "from llama_index.core.prompts import PromptTemplate\n",
253 | "\n",
254 | "model_id = 'TheBloke/Mixtral-8x7B-v0.1-GPTQ'\n",
255 | "code_revision = 'gptq-4bit-32g-actorder_True'\n",
256 | "tokenizer = AutoTokenizer.from_pretrained(model_id, attn_implementation='flash_attention_2')\n",
257 | "model = AutoModelForCausalLM.from_pretrained(model_id, code_revision=code_revision, device_map='auto')\n",
258 | "\n",
259 | "llm = HuggingFaceLLM(\n",
260 | " model=model,\n",
261 | " tokenizer=tokenizer,\n",
262 | " query_wrapper_prompt=PromptTemplate('[INST] {query_str} [/INST]'),\n",
263 | " context_window=16*1024,\n",
264 | " max_new_tokens=1024,\n",
265 | ")"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "id": "7330fb1f-cfb4-4b9b-b614-06910d5330b3",
272 | "metadata": {},
273 | "outputs": [],
274 | "source": [
275 | "train_dataset = generate_qa_embedding_pairs(train_nodes, llm=llm)\n",
276 | "train_dataset.save_json(\"train_dataset.json\")\n",
277 | "\n",
278 | "val_dataset = generate_qa_embedding_pairs(val_nodes, llm=llm)\n",
279 | "val_dataset.save_json(\"val_dataset.json\")"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": 2,
285 | "id": "909ca757-bf02-4304-a59e-7d61a12a67df",
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "# release cuda memory - at this point it's probably a good idea to restart the kernel and load the data\n",
290 | "from llama_index.finetuning import generate_qa_embedding_pairs\n",
291 | "from llama_index.core.evaluation import EmbeddingQAFinetuneDataset\n",
292 | "\n",
293 | "train_dataset = EmbeddingQAFinetuneDataset.from_json(\"train_dataset.json\")\n",
294 | "val_dataset = EmbeddingQAFinetuneDataset.from_json(\"val_dataset.json\")"
295 | ]
296 | },
297 | {
298 | "cell_type": "markdown",
299 | "id": "b619e9a6-4795-4ff5-bb48-ae2c50324eb2",
300 | "metadata": {},
301 | "source": [
302 | "## Run Embedding Finetuning\n",
303 | "\n",
304 | "Here we first define the subclasses needed for LoRA finetuning."
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 3,
310 | "id": "ea49d4c3",
311 | "metadata": {},
312 | "outputs": [],
313 | "source": [
314 | "class UniversalAdapter(torch.nn.Identity, BaseAdapter):\n",
315 | " \"\"\"Adapter model that does nothing, but includes trainable parameters \n",
316 | " (e.g. LoRAs) of the embedding model, which the FinetuneEngine actually trains.\"\"\"\n",
317 | " def __init__(self, embed_model):\n",
318 | " super().__init__()\n",
319 | " self.embed_model = embed_model\n",
320 | "\n",
321 | " def save(self, output_path):\n",
322 | " self.embed_model.save_pretrained(output_path, save_adapter=True, save_config=True)"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 4,
328 | "id": "a5234bae",
329 | "metadata": {},
330 | "outputs": [],
331 | "source": [
332 | "class UniversalEmbeddingFinetuneEngine(EmbeddingAdapterFinetuneEngine):\n",
333 | " \"\"\"Fintune any parameters of embed_model with requires_grad set to True, e.g. LoRA adapaters.\"\"\"\n",
334 | " def __init__(\n",
335 | " self,\n",
336 | " dataset: EmbeddingQAFinetuneDataset,\n",
337 | " embed_model: BaseEmbedding,\n",
338 | " batch_size: int = 10,\n",
339 | " epochs: int = 1,\n",
340 | " dim: Optional[int] = None,\n",
341 | " device: Optional[str] = None,\n",
342 | " model_output_path: str = \"model_output\",\n",
343 | " model_checkpoint_path: Optional[str] = None,\n",
344 | " checkpoint_save_steps: int = 100,\n",
345 | " verbose: bool = False,\n",
346 | " bias: bool = False,\n",
347 | " **train_kwargs: Any,\n",
348 | " ) -> None:\n",
349 | " super().__init__(\n",
350 | " dataset=dataset,\n",
351 | " embed_model=embed_model,\n",
352 | " batch_size=batch_size,\n",
353 | " epochs=epochs,\n",
354 | " adapter_model=UniversalAdapter(embed_model._model),\n",
355 | " dim=dim,\n",
356 | " device=device,\n",
357 | " model_output_path=model_output_path,\n",
358 | " model_checkpoint_path=model_checkpoint_path,\n",
359 | " checkpoint_save_steps=checkpoint_save_steps,\n",
360 | " verbose=verbose,\n",
361 | " bias=bias,\n",
362 | " **train_kwargs,\n",
363 | " )\n",
364 | "\n",
365 | " def smart_batching_collate(self, batch: List) -> Tuple[Any, Any]:\n",
366 | " \"\"\"Smart batching collate.\"\"\"\n",
367 | " import torch\n",
368 | " from torch import Tensor\n",
369 | "\n",
370 | " query_embeddings: List[Tensor] = []\n",
371 | " text_embeddings: List[Tensor] = []\n",
372 | "\n",
373 | " for query, text in batch:\n",
374 | " query_embedding = self.embed_model.get_query_embedding(query)\n",
375 | " text_embedding = self.embed_model.get_text_embedding(text)\n",
376 | "\n",
377 | " query_embeddings.append(query_embedding) # was stripping gradients: query_embeddings.append(torch.tensor(query_embedding))\n",
378 | " text_embeddings.append(text_embedding) # was stripping gradients: text_embeddings.append(torch.tensor(text_embedding))\n",
379 | "\n",
380 | " query_embeddings_t = torch.stack(query_embeddings)\n",
381 | " text_embeddings_t = torch.stack(text_embeddings)\n",
382 | "\n",
383 | " return query_embeddings_t, text_embeddings_t"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": 5,
389 | "id": "fc9837a9",
390 | "metadata": {},
391 | "outputs": [],
392 | "source": [
393 | "class HuggingFaceEmbeddingWithGrad(HuggingFaceEmbedding):\n",
394 | " \"\"\"HuggingFaceEmbedding with gradient support.\"\"\"\n",
395 | "\n",
396 | " def __getattr__(self, name: str) -> Any:\n",
397 | " return getattr(self._model, name)\n",
398 | " \n",
399 | " def _embed(self, sentences: List[str]) -> torch.Tensor:\n",
400 | " \"\"\"Embed sentences.\"\"\"\n",
401 | " encoded_input = self._tokenizer(\n",
402 | " sentences,\n",
403 | " padding=True,\n",
404 | " max_length=self.max_length,\n",
405 | " truncation=True,\n",
406 | " return_tensors=\"pt\",\n",
407 | " )\n",
408 | "\n",
409 | " # pop token_type_ids\n",
410 | " encoded_input.pop(\"token_type_ids\", None)\n",
411 | "\n",
412 | " # move tokenizer inputs to device\n",
413 | " encoded_input = {\n",
414 | " key: val.to(self._device) for key, val in encoded_input.items()\n",
415 | " }\n",
416 | "\n",
417 | " model_output = self._model(**encoded_input)\n",
418 | "\n",
419 | " context_layer: \"torch.Tensor\" = model_output[0]\n",
420 | " if self.pooling == Pooling.CLS:\n",
421 | " embeddings = self.pooling.cls_pooling(context_layer)\n",
422 | " elif self.pooling == Pooling.LAST:\n",
423 | " embeddings = self.pooling.last_pooling(context_layer) \n",
424 | " else:\n",
425 | " embeddings = self._mean_pooling(\n",
426 | " token_embeddings=context_layer,\n",
427 | " attention_mask=encoded_input[\"attention_mask\"],\n",
428 | " )\n",
429 | "\n",
430 | " if self.normalize:\n",
431 | " import torch\n",
432 | " embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)\n",
433 | "\n",
434 | " return embeddings # was embeddings.tolist()"
435 | ]
436 | },
437 | {
438 | "cell_type": "code",
439 | "execution_count": 25,
440 | "id": "837cb16f",
441 | "metadata": {},
442 | "outputs": [],
443 | "source": [
444 | "from pydantic import fields as pydantic_fields\n",
445 | "\n",
446 | "class disable_pydantic:\n",
447 | " \"\"\"Context manager to disable pydantic validation.\"\"\"\n",
448 | "\n",
449 | " def __enter__(self) -> None:\n",
450 | " self.validate = pydantic_fields.ModelField.validate\n",
451 | " pydantic_fields.ModelField.validate = lambda *args, **kwargs: (args[1], None)\n",
452 | "\n",
453 | " def __exit__(self, *args) -> None:\n",
454 | " pydantic_fields.ModelField.validate = self.validate"
455 | ]
456 | },
457 | {
458 | "cell_type": "markdown",
459 | "id": "1aeb224e",
460 | "metadata": {},
461 | "source": [
462 | "### Fine-tune sfr-embedding-mistral\n",
463 | "\n",
464 | "As of March 2024 SFR-Embedding-Mistral is at the top of the Massive Text Embedding Benchmark (MTEB) Leaderboard: https://huggingface.co/spaces/mteb/leaderboard"
465 | ]
466 | },
467 | {
468 | "cell_type": "markdown",
469 | "id": "20ef7b91",
470 | "metadata": {},
471 | "source": [
472 | "We quantize the model to 4-bit first:"
473 | ]
474 | },
475 | {
476 | "cell_type": "code",
477 | "execution_count": 7,
478 | "id": "86b9b422",
479 | "metadata": {},
480 | "outputs": [],
481 | "source": [
482 | "model_id = 'Salesforce/SFR-Embedding-Mistral'\n",
483 | "quant_path = f'/tmp/models/{model_id.replace(\"/\",\"-\")}-quant'"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": 8,
489 | "id": "1f42ec76",
490 | "metadata": {},
491 | "outputs": [
492 | {
493 | "data": {
494 | "application/vnd.jupyter.widget-view+json": {
495 | "model_id": "5717387749124b5e9248cbeee6c428f9",
496 | "version_major": 2,
497 | "version_minor": 0
498 | },
499 | "text/plain": [
500 | "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
501 | ]
502 | },
503 | "metadata": {},
504 | "output_type": "display_data"
505 | },
506 | {
507 | "name": "stdout",
508 | "output_type": "stream",
509 | "text": [
510 | "Quantized model saved to /tmp/models/Salesforce-SFR-Embedding-Mistral-quant\n"
511 | ]
512 | }
513 | ],
514 | "source": [
515 | "from transformers import BitsAndBytesConfig, AutoModel, AutoTokenizer\n",
516 | "\n",
517 | "bnb_config = BitsAndBytesConfig(\n",
518 | " load_in_4bit=True,\n",
519 | " bnb_4bit_use_double_quant=True,\n",
520 | " bnb_4bit_quant_type=\"nf4\",\n",
521 | " bnb_4bit_compute_dtype=torch.bfloat16\n",
522 | ")\n",
523 | "\n",
524 | "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
525 | "tokenizer.save_pretrained(quant_path)\n",
526 | "\n",
527 | "model = AutoModel.from_pretrained(\n",
528 | " model_id,\n",
529 | " trust_remote_code=True,\n",
530 | " device_map='auto',\n",
531 | " torch_dtype=torch.bfloat16,\n",
532 | " quantization_config=bnb_config\n",
533 | ")\n",
534 | "\n",
535 | "# freeze the model before saving just as a precaution\n",
536 | "for param in model.parameters():\n",
537 | " param.requires_grad = False\n",
538 | "\n",
539 | "model.save_pretrained(quant_path, low_cpu_mem_usage=False)\n",
540 | "print(f'Quantized model saved to {quant_path}')"
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": 9,
546 | "id": "c09683b4",
547 | "metadata": {},
548 | "outputs": [],
549 | "source": [
550 | "# release cuda memory\n",
551 | "del model, tokenizer, bnb_config\n",
552 | "import gc; gc.collect()\n",
553 | "with torch.no_grad(): torch.cuda.empty_cache()"
554 | ]
555 | },
556 | {
557 | "cell_type": "code",
558 | "execution_count": 10,
559 | "id": "811fbdb2",
560 | "metadata": {},
561 | "outputs": [],
562 | "source": [
563 | "lora_adapters_path = '/tmp/whatever'"
564 | ]
565 | },
566 | {
567 | "cell_type": "code",
568 | "execution_count": 11,
569 | "id": "15bed4c1",
570 | "metadata": {},
571 | "outputs": [],
572 | "source": [
573 | "from transformers import AutoModel, AutoTokenizer\n",
574 | "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
575 | "\n",
576 | "embed_tokenizer = AutoTokenizer.from_pretrained(quant_path)\n",
577 | "embed_model = AutoModel.from_pretrained(quant_path, low_cpu_mem_usage=True)\n",
578 | "embed_model.to = lambda _: embed_model # quantized model does not have .to() method\n",
579 | "for param in embed_model.parameters():\n",
580 | " param.requires_grad = False"
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "execution_count": 12,
586 | "id": "e437c606",
587 | "metadata": {},
588 | "outputs": [],
589 | "source": [
590 | "hf_base_model = HuggingFaceEmbedding(\n",
591 | " model=embed_model, \n",
592 | " tokenizer=embed_tokenizer, \n",
593 | " query_instruction='Instruct: Given a web search query, retrieve relevant passages that answer the query\\nQuery:',\n",
594 | " pooling='last',\n",
595 | " embed_batch_size=1\n",
596 | ")"
597 | ]
598 | },
599 | {
600 | "cell_type": "markdown",
601 | "id": "bf592ed6",
602 | "metadata": {},
603 | "source": [
604 | "Evaluate the base model:"
605 | ]
606 | },
607 | {
608 | "cell_type": "code",
609 | "execution_count": 13,
610 | "id": "a3068379",
611 | "metadata": {},
612 | "outputs": [
613 | {
614 | "data": {
615 | "application/vnd.jupyter.widget-view+json": {
616 | "model_id": "e164db6a4bb74df0bb81c3818a022984",
617 | "version_major": 2,
618 | "version_minor": 0
619 | },
620 | "text/plain": [
621 | "Generating embeddings: 0%| | 0/410 [00:00, ?it/s]"
622 | ]
623 | },
624 | "metadata": {},
625 | "output_type": "display_data"
626 | },
627 | {
628 | "name": "stderr",
629 | "output_type": "stream",
630 | "text": [
631 | "100%|██████████| 861/861 [01:56<00:00, 7.39it/s]\n"
632 | ]
633 | },
634 | {
635 | "data": {
636 | "text/html": [
637 | "
\n",
638 | "\n",
651 | "
\n",
652 | " \n",
653 | " \n",
654 | " | \n",
655 | " retrievers | \n",
656 | " hit_rate | \n",
657 | " mrr | \n",
658 | "
\n",
659 | " \n",
660 | " \n",
661 | " \n",
662 | " 0 | \n",
663 | " base_sfr | \n",
664 | " 0.872242 | \n",
665 | " 0.68494 | \n",
666 | "
\n",
667 | " \n",
668 | "
\n",
669 | "
"
670 | ],
671 | "text/plain": [
672 | " retrievers hit_rate mrr\n",
673 | "0 base_sfr 0.872242 0.68494"
674 | ]
675 | },
676 | "metadata": {},
677 | "output_type": "display_data"
678 | }
679 | ],
680 | "source": [
681 | "from eval_utils import evaluate, display_results\n",
682 | "\n",
683 | "with torch.no_grad():\n",
684 | " base_sfr_val_results = evaluate(val_dataset, hf_base_model)\n",
685 | "display_results([\"base_sfr\"], [base_sfr_val_results])"
686 | ]
687 | },
688 | {
689 | "cell_type": "code",
690 | "execution_count": 16,
691 | "id": "605768ca",
692 | "metadata": {},
693 | "outputs": [],
694 | "source": [
695 | "# create the peft model\n",
696 | "peft_config = LoraConfig(\n",
697 | " r=8,\n",
698 | " lora_alpha=16,\n",
699 | " lora_dropout=0.05,\n",
700 | " target_modules=[\"q_proj\", \"v_proj\"],\n",
701 | " task_type=\"FEATURE_EXTRACTION\",\n",
702 | ")\n",
703 | "\n",
704 | "kbit_model = prepare_model_for_kbit_training(embed_model)\n",
705 | "peft_model = get_peft_model(kbit_model, peft_config)"
706 | ]
707 | },
708 | {
709 | "cell_type": "code",
710 | "execution_count": null,
711 | "id": "afb8d02f",
712 | "metadata": {},
713 | "outputs": [],
714 | "source": [
715 | "# ...or load trained adapters\n",
716 | "from peft import PeftModel\n",
717 | "peft_model = PeftModel.from_pretrained(embed_model, lora_adapters_path)"
718 | ]
719 | },
720 | {
721 | "cell_type": "code",
722 | "execution_count": 17,
723 | "id": "b31b0c71",
724 | "metadata": {},
725 | "outputs": [],
726 | "source": [
727 | "hf_qlora_model = HuggingFaceEmbeddingWithGrad(\n",
728 | " model=peft_model, \n",
729 | " tokenizer=embed_tokenizer, \n",
730 | " query_instruction='Instruct: Given a web search query, retrieve relevant passages that answer the query\\nQuery:',\n",
731 | " pooling='last',\n",
732 | " embed_batch_size=1\n",
733 | ")"
734 | ]
735 | },
736 | {
737 | "cell_type": "code",
738 | "execution_count": 26,
739 | "id": "6f1bf68e",
740 | "metadata": {},
741 | "outputs": [
742 | {
743 | "data": {
744 | "application/vnd.jupyter.widget-view+json": {
745 | "model_id": "59fa371cb6a24eab86b86b8b15eb2838",
746 | "version_major": 2,
747 | "version_minor": 0
748 | },
749 | "text/plain": [
750 | "Epoch: 0%| | 0/1 [00:00, ?it/s]"
751 | ]
752 | },
753 | "metadata": {},
754 | "output_type": "display_data"
755 | },
756 | {
757 | "data": {
758 | "application/vnd.jupyter.widget-view+json": {
759 | "model_id": "39ef9d7d89974bf5a03a7a3869d6fc51",
760 | "version_major": 2,
761 | "version_minor": 0
762 | },
763 | "text/plain": [
764 | "Iteration: 0%| | 0/77 [00:00, ?it/s]"
765 | ]
766 | },
767 | "metadata": {},
768 | "output_type": "display_data"
769 | },
770 | {
771 | "name": "stderr",
772 | "output_type": "stream",
773 | "text": [
774 | "/opt/conda/envs/llama-index/lib/python3.11/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
775 | " warnings.warn(\n",
776 | "/opt/conda/envs/llama-index/lib/python3.11/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /tmp/models/Salesforce-SFR-Embedding-Mistral-quant - will assume that the vocabulary was not modified.\n",
777 | " warnings.warn(\n"
778 | ]
779 | }
780 | ],
781 | "source": [
782 | "finetune_engine = UniversalEmbeddingFinetuneEngine(\n",
783 | " train_dataset,\n",
784 | " embed_model=hf_qlora_model,\n",
785 | " dim=4096,\n",
786 | " model_output_path=lora_adapters_path,\n",
787 | " epochs=5,\n",
788 | " verbose=False,\n",
789 | ")\n",
790 | "\n",
791 | "with disable_pydantic():\n",
792 | " finetune_engine.finetune()"
793 | ]
794 | },
795 | {
796 | "cell_type": "code",
797 | "execution_count": 19,
798 | "id": "14dd5c85",
799 | "metadata": {},
800 | "outputs": [],
801 | "source": [
802 | "# repackage as HuggingFaceEmbedding to avoid grief from pydantic which wants embeddings to be lists not tensors\n",
803 | "hf_embeddig_model = HuggingFaceEmbedding(\n",
804 | " model=hf_qlora_model.model, \n",
805 | " tokenizer=hf_qlora_model._tokenizer, \n",
806 | " query_instruction=hf_qlora_model.query_instruction,\n",
807 | " pooling=hf_qlora_model.pooling,\n",
808 | " embed_batch_size=hf_qlora_model.embed_batch_size\n",
809 | ")"
810 | ]
811 | },
812 | {
813 | "cell_type": "markdown",
814 | "id": "0d25a60e",
815 | "metadata": {},
816 | "source": [
817 | "Evaluate the fine-tuned model:"
818 | ]
819 | },
820 | {
821 | "cell_type": "code",
822 | "execution_count": 20,
823 | "id": "bda7f46a",
824 | "metadata": {},
825 | "outputs": [
826 | {
827 | "data": {
828 | "application/vnd.jupyter.widget-view+json": {
829 | "model_id": "63255805c566460bb3cd003ed029fc56",
830 | "version_major": 2,
831 | "version_minor": 0
832 | },
833 | "text/plain": [
834 | "Generating embeddings: 0%| | 0/410 [00:00, ?it/s]"
835 | ]
836 | },
837 | "metadata": {},
838 | "output_type": "display_data"
839 | },
840 | {
841 | "name": "stderr",
842 | "output_type": "stream",
843 | "text": [
844 | "/opt/conda/envs/llama-index/lib/python3.11/site-packages/torch/utils/checkpoint.py:90: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
845 | " warnings.warn(\n",
846 | "100%|██████████| 861/861 [01:59<00:00, 7.19it/s]\n"
847 | ]
848 | },
849 | {
850 | "data": {
851 | "text/html": [
852 | "\n",
853 | "\n",
866 | "
\n",
867 | " \n",
868 | " \n",
869 | " | \n",
870 | " retrievers | \n",
871 | " hit_rate | \n",
872 | " mrr | \n",
873 | "
\n",
874 | " \n",
875 | " \n",
876 | " \n",
877 | " 0 | \n",
878 | " lora_sfr | \n",
879 | " 0.941928 | \n",
880 | " 0.803949 | \n",
881 | "
\n",
882 | " \n",
883 | "
\n",
884 | "
"
885 | ],
886 | "text/plain": [
887 | " retrievers hit_rate mrr\n",
888 | "0 lora_sfr 0.941928 0.803949"
889 | ]
890 | },
891 | "metadata": {},
892 | "output_type": "display_data"
893 | }
894 | ],
895 | "source": [
896 | "from eval_utils import evaluate, display_results\n",
897 | "\n",
898 | "with torch.no_grad():\n",
899 | " lora_sfr_val_results = evaluate(val_dataset, hf_embeddig_model)\n",
900 | "display_results([\"lora_sfr\"], [lora_sfr_val_results])"
901 | ]
902 | },
903 | {
904 | "cell_type": "code",
905 | "execution_count": null,
906 | "id": "c7c5e9b5",
907 | "metadata": {},
908 | "outputs": [],
909 | "source": []
910 | }
911 | ],
912 | "metadata": {
913 | "kernelspec": {
914 | "display_name": "llama-index",
915 | "language": "python",
916 | "name": "python3"
917 | },
918 | "language_info": {
919 | "codemirror_mode": {
920 | "name": "ipython",
921 | "version": 3
922 | },
923 | "file_extension": ".py",
924 | "mimetype": "text/x-python",
925 | "name": "python",
926 | "nbconvert_exporter": "python",
927 | "pygments_lexer": "ipython3",
928 | "version": "3.11.8"
929 | }
930 | },
931 | "nbformat": 4,
932 | "nbformat_minor": 5
933 | }
934 |
--------------------------------------------------------------------------------