├── LICENSE ├── README.md └── finetune_embedding_lora.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llamaindex-embedding-lora 2 | 3 | Example notebook accompanying the a blog post on `LoRA fine-tuning of embedding models using LlamaIndex` available at https://medium.com/@diagnosta/lora-fine-tuning-of-embedding-models-using-llamaindex-a60b823a2c94 4 | -------------------------------------------------------------------------------- /finetune_embedding_lora.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "44986dbb", 6 | "metadata": {}, 7 | "source": [ 8 | "\"Open" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "03f5ac7e-d36d-4879-959a-1af414fe4c02", 14 | "metadata": {}, 15 | "source": [ 16 | "# LoRA finetuning of any Black-Box Embedding Model\n", 17 | "\n", 18 | "This notebook is based on https://github.com/run-llama/llama_index/blob/3e5d0a146fcda01a984818d381f31a19287aead8/docs/examples/finetuning/embeddings/finetune_embedding_adapter.ipynb and demonstrates how to:\n", 19 | "\n", 20 | "- Generate a fine-tuning corpus using a local LLM\n", 21 | "- Fine-tune a local embedding model using LoRA\n", 22 | "\n", 23 | "The latter is achieved by subclassing the `EmbeddingAdapterFinetuneEngine` and a few tricks in order to make it behave (in the way we want it to)." 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "9ab6c5cc-8b31-41cd-95aa-6d60fbefff9b", 29 | "metadata": {}, 30 | "source": [ 31 | "## Generate Corpus\n", 32 | "\n", 33 | "We use our helper abstractions, `generate_qa_embedding_pairs`, to generate our training and evaluation dataset. This function takes in any set of text nodes (chunks) and generates a structured dataset containing (question, context) pairs." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "id": "9b36f73f-83b1-4715-bd4d-7ce1353d1a19", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "import torch\n", 44 | "from typing import Any, List, Optional, Tuple#, Union\n", 45 | "from llama_index.core import SimpleDirectoryReader\n", 46 | "from llama_index.core.base.embeddings.base import BaseEmbedding\n", 47 | "from llama_index.core.node_parser import SentenceSplitter\n", 48 | "from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding\n", 49 | "from llama_index.embeddings.huggingface.pooling import Pooling\n", 50 | "from llama_index.finetuning import EmbeddingAdapterFinetuneEngine\n", 51 | "from llama_index.finetuning.embeddings.adapter_utils import BaseAdapter" 52 | ] 53 | }, 54 | { 55 | "attachments": {}, 56 | "cell_type": "markdown", 57 | "id": "2fc4bd24", 58 | "metadata": {}, 59 | "source": [ 60 | "Download Data" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 2, 66 | "id": "6ae97522", 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "--2024-03-18 14:51:34-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf\n", 74 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...\n", 75 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n", 76 | "HTTP request sent, awaiting response... 200 OK\n", 77 | "Length: 1880483 (1.8M) [application/octet-stream]\n", 78 | "Saving to: ‘data/10k/uber_2021.pdf’\n", 79 | "\n", 80 | "data/10k/uber_2021. 100%[===================>] 1.79M --.-KB/s in 0.04s \n", 81 | "\n", 82 | "2024-03-18 14:51:34 (41.6 MB/s) - ‘data/10k/uber_2021.pdf’ saved [1880483/1880483]\n", 83 | "\n", 84 | "--2024-03-18 14:51:34-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf\n", 85 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...\n", 86 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n", 87 | "HTTP request sent, awaiting response... 200 OK\n", 88 | "Length: 1440303 (1.4M) [application/octet-stream]\n", 89 | "Saving to: ‘data/10k/lyft_2021.pdf’\n", 90 | "\n", 91 | "data/10k/lyft_2021. 100%[===================>] 1.37M --.-KB/s in 0.03s \n", 92 | "\n", 93 | "2024-03-18 14:51:35 (43.5 MB/s) - ‘data/10k/lyft_2021.pdf’ saved [1440303/1440303]\n", 94 | "\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "!mkdir -p 'data/10k/'\n", 100 | "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf' -O 'data/10k/uber_2021.pdf'\n", 101 | "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 3, 107 | "id": "58c43042-2ed1-4ab7-a53d-7f65dd856f83", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "TRAIN_FILES = [\"./data/10k/lyft_2021.pdf\"]\n", 112 | "VAL_FILES = [\"./data/10k/uber_2021.pdf\"]\n", 113 | "\n", 114 | "TRAIN_CORPUS_FPATH = \"./data/train_corpus.json\"\n", 115 | "VAL_CORPUS_FPATH = \"./data/val_corpus.json\"" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "id": "3c7e38d0-39ff-44e2-ab7f-fded56dcd707", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "def load_corpus(files, verbose=False):\n", 126 | " if verbose: print(f\"Loading files {files}\")\n", 127 | "\n", 128 | " reader = SimpleDirectoryReader(input_files=files)\n", 129 | " docs = reader.load_data()\n", 130 | " if verbose: print(f\"Loaded {len(docs)} docs\")\n", 131 | "\n", 132 | " parser = SentenceSplitter()\n", 133 | " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n", 134 | " if verbose: print(f\"Parsed {len(nodes)} nodes\")\n", 135 | "\n", 136 | " return nodes" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "d1257dce-0be1-42c4-9346-a1fe68505fdd", 142 | "metadata": {}, 143 | "source": [ 144 | "We do a very naive train/val split by having the Lyft corpus as the train dataset, and the Uber corpus as the val dataset." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 5, 150 | "id": "ffd6d8af-5382-48b8-8a7d-98a03d2f150d", 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "Loading files ['./data/10k/lyft_2021.pdf']\n", 158 | "Loaded 238 docs\n" 159 | ] 160 | }, 161 | { 162 | "data": { 163 | "application/vnd.jupyter.widget-view+json": { 164 | "model_id": "00e3746a01034f1387e563606519788d", 165 | "version_major": 2, 166 | "version_minor": 0 167 | }, 168 | "text/plain": [ 169 | "Parsing nodes: 0%| | 0/238 [00:00 None:\n", 349 | " super().__init__(\n", 350 | " dataset=dataset,\n", 351 | " embed_model=embed_model,\n", 352 | " batch_size=batch_size,\n", 353 | " epochs=epochs,\n", 354 | " adapter_model=UniversalAdapter(embed_model._model),\n", 355 | " dim=dim,\n", 356 | " device=device,\n", 357 | " model_output_path=model_output_path,\n", 358 | " model_checkpoint_path=model_checkpoint_path,\n", 359 | " checkpoint_save_steps=checkpoint_save_steps,\n", 360 | " verbose=verbose,\n", 361 | " bias=bias,\n", 362 | " **train_kwargs,\n", 363 | " )\n", 364 | "\n", 365 | " def smart_batching_collate(self, batch: List) -> Tuple[Any, Any]:\n", 366 | " \"\"\"Smart batching collate.\"\"\"\n", 367 | " import torch\n", 368 | " from torch import Tensor\n", 369 | "\n", 370 | " query_embeddings: List[Tensor] = []\n", 371 | " text_embeddings: List[Tensor] = []\n", 372 | "\n", 373 | " for query, text in batch:\n", 374 | " query_embedding = self.embed_model.get_query_embedding(query)\n", 375 | " text_embedding = self.embed_model.get_text_embedding(text)\n", 376 | "\n", 377 | " query_embeddings.append(query_embedding) # was stripping gradients: query_embeddings.append(torch.tensor(query_embedding))\n", 378 | " text_embeddings.append(text_embedding) # was stripping gradients: text_embeddings.append(torch.tensor(text_embedding))\n", 379 | "\n", 380 | " query_embeddings_t = torch.stack(query_embeddings)\n", 381 | " text_embeddings_t = torch.stack(text_embeddings)\n", 382 | "\n", 383 | " return query_embeddings_t, text_embeddings_t" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 5, 389 | "id": "fc9837a9", 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "class HuggingFaceEmbeddingWithGrad(HuggingFaceEmbedding):\n", 394 | " \"\"\"HuggingFaceEmbedding with gradient support.\"\"\"\n", 395 | "\n", 396 | " def __getattr__(self, name: str) -> Any:\n", 397 | " return getattr(self._model, name)\n", 398 | " \n", 399 | " def _embed(self, sentences: List[str]) -> torch.Tensor:\n", 400 | " \"\"\"Embed sentences.\"\"\"\n", 401 | " encoded_input = self._tokenizer(\n", 402 | " sentences,\n", 403 | " padding=True,\n", 404 | " max_length=self.max_length,\n", 405 | " truncation=True,\n", 406 | " return_tensors=\"pt\",\n", 407 | " )\n", 408 | "\n", 409 | " # pop token_type_ids\n", 410 | " encoded_input.pop(\"token_type_ids\", None)\n", 411 | "\n", 412 | " # move tokenizer inputs to device\n", 413 | " encoded_input = {\n", 414 | " key: val.to(self._device) for key, val in encoded_input.items()\n", 415 | " }\n", 416 | "\n", 417 | " model_output = self._model(**encoded_input)\n", 418 | "\n", 419 | " context_layer: \"torch.Tensor\" = model_output[0]\n", 420 | " if self.pooling == Pooling.CLS:\n", 421 | " embeddings = self.pooling.cls_pooling(context_layer)\n", 422 | " elif self.pooling == Pooling.LAST:\n", 423 | " embeddings = self.pooling.last_pooling(context_layer) \n", 424 | " else:\n", 425 | " embeddings = self._mean_pooling(\n", 426 | " token_embeddings=context_layer,\n", 427 | " attention_mask=encoded_input[\"attention_mask\"],\n", 428 | " )\n", 429 | "\n", 430 | " if self.normalize:\n", 431 | " import torch\n", 432 | " embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)\n", 433 | "\n", 434 | " return embeddings # was embeddings.tolist()" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": 25, 440 | "id": "837cb16f", 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "from pydantic import fields as pydantic_fields\n", 445 | "\n", 446 | "class disable_pydantic:\n", 447 | " \"\"\"Context manager to disable pydantic validation.\"\"\"\n", 448 | "\n", 449 | " def __enter__(self) -> None:\n", 450 | " self.validate = pydantic_fields.ModelField.validate\n", 451 | " pydantic_fields.ModelField.validate = lambda *args, **kwargs: (args[1], None)\n", 452 | "\n", 453 | " def __exit__(self, *args) -> None:\n", 454 | " pydantic_fields.ModelField.validate = self.validate" 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "id": "1aeb224e", 460 | "metadata": {}, 461 | "source": [ 462 | "### Fine-tune sfr-embedding-mistral\n", 463 | "\n", 464 | "As of March 2024 SFR-Embedding-Mistral is at the top of the Massive Text Embedding Benchmark (MTEB) Leaderboard: https://huggingface.co/spaces/mteb/leaderboard" 465 | ] 466 | }, 467 | { 468 | "cell_type": "markdown", 469 | "id": "20ef7b91", 470 | "metadata": {}, 471 | "source": [ 472 | "We quantize the model to 4-bit first:" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": 7, 478 | "id": "86b9b422", 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [ 482 | "model_id = 'Salesforce/SFR-Embedding-Mistral'\n", 483 | "quant_path = f'/tmp/models/{model_id.replace(\"/\",\"-\")}-quant'" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 8, 489 | "id": "1f42ec76", 490 | "metadata": {}, 491 | "outputs": [ 492 | { 493 | "data": { 494 | "application/vnd.jupyter.widget-view+json": { 495 | "model_id": "5717387749124b5e9248cbeee6c428f9", 496 | "version_major": 2, 497 | "version_minor": 0 498 | }, 499 | "text/plain": [ 500 | "Loading checkpoint shards: 0%| | 0/3 [00:00\n", 638 | "\n", 651 | "\n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | "
retrievershit_ratemrr
0base_sfr0.8722420.68494
\n", 669 | "" 670 | ], 671 | "text/plain": [ 672 | " retrievers hit_rate mrr\n", 673 | "0 base_sfr 0.872242 0.68494" 674 | ] 675 | }, 676 | "metadata": {}, 677 | "output_type": "display_data" 678 | } 679 | ], 680 | "source": [ 681 | "from eval_utils import evaluate, display_results\n", 682 | "\n", 683 | "with torch.no_grad():\n", 684 | " base_sfr_val_results = evaluate(val_dataset, hf_base_model)\n", 685 | "display_results([\"base_sfr\"], [base_sfr_val_results])" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": 16, 691 | "id": "605768ca", 692 | "metadata": {}, 693 | "outputs": [], 694 | "source": [ 695 | "# create the peft model\n", 696 | "peft_config = LoraConfig(\n", 697 | " r=8,\n", 698 | " lora_alpha=16,\n", 699 | " lora_dropout=0.05,\n", 700 | " target_modules=[\"q_proj\", \"v_proj\"],\n", 701 | " task_type=\"FEATURE_EXTRACTION\",\n", 702 | ")\n", 703 | "\n", 704 | "kbit_model = prepare_model_for_kbit_training(embed_model)\n", 705 | "peft_model = get_peft_model(kbit_model, peft_config)" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": null, 711 | "id": "afb8d02f", 712 | "metadata": {}, 713 | "outputs": [], 714 | "source": [ 715 | "# ...or load trained adapters\n", 716 | "from peft import PeftModel\n", 717 | "peft_model = PeftModel.from_pretrained(embed_model, lora_adapters_path)" 718 | ] 719 | }, 720 | { 721 | "cell_type": "code", 722 | "execution_count": 17, 723 | "id": "b31b0c71", 724 | "metadata": {}, 725 | "outputs": [], 726 | "source": [ 727 | "hf_qlora_model = HuggingFaceEmbeddingWithGrad(\n", 728 | " model=peft_model, \n", 729 | " tokenizer=embed_tokenizer, \n", 730 | " query_instruction='Instruct: Given a web search query, retrieve relevant passages that answer the query\\nQuery:',\n", 731 | " pooling='last',\n", 732 | " embed_batch_size=1\n", 733 | ")" 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": 26, 739 | "id": "6f1bf68e", 740 | "metadata": {}, 741 | "outputs": [ 742 | { 743 | "data": { 744 | "application/vnd.jupyter.widget-view+json": { 745 | "model_id": "59fa371cb6a24eab86b86b8b15eb2838", 746 | "version_major": 2, 747 | "version_minor": 0 748 | }, 749 | "text/plain": [ 750 | "Epoch: 0%| | 0/1 [00:00\n", 853 | "\n", 866 | "\n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | "
retrievershit_ratemrr
0lora_sfr0.9419280.803949
\n", 884 | "" 885 | ], 886 | "text/plain": [ 887 | " retrievers hit_rate mrr\n", 888 | "0 lora_sfr 0.941928 0.803949" 889 | ] 890 | }, 891 | "metadata": {}, 892 | "output_type": "display_data" 893 | } 894 | ], 895 | "source": [ 896 | "from eval_utils import evaluate, display_results\n", 897 | "\n", 898 | "with torch.no_grad():\n", 899 | " lora_sfr_val_results = evaluate(val_dataset, hf_embeddig_model)\n", 900 | "display_results([\"lora_sfr\"], [lora_sfr_val_results])" 901 | ] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "execution_count": null, 906 | "id": "c7c5e9b5", 907 | "metadata": {}, 908 | "outputs": [], 909 | "source": [] 910 | } 911 | ], 912 | "metadata": { 913 | "kernelspec": { 914 | "display_name": "llama-index", 915 | "language": "python", 916 | "name": "python3" 917 | }, 918 | "language_info": { 919 | "codemirror_mode": { 920 | "name": "ipython", 921 | "version": 3 922 | }, 923 | "file_extension": ".py", 924 | "mimetype": "text/x-python", 925 | "name": "python", 926 | "nbconvert_exporter": "python", 927 | "pygments_lexer": "ipython3", 928 | "version": "3.11.8" 929 | } 930 | }, 931 | "nbformat": 4, 932 | "nbformat_minor": 5 933 | } 934 | --------------------------------------------------------------------------------