├── Linear_Adapter.ipynb ├── README.md ├── adapters ├── linear_adapter_10epochs.pth ├── linear_adapter_20epochs.pth ├── linear_adapter_30epochs.pth └── linear_adapter_40epochs.pth ├── data ├── apple_QA_dataset.json └── nvidia_10k.pdf └── media ├── adapter_diagram.png ├── adapters_explainer.png ├── linear_layer.png ├── negative_sampling.png ├── negative_sampling_2.png ├── training_fit.png ├── triplet_loss.png ├── tripletdataset.png ├── validation_chart.png └── vid_screenshot.png /Linear_Adapter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e569c2f3-69bd-488c-8b9a-aae8610e2aae", 6 | "metadata": {}, 7 | "source": [ 8 | "# Fine-Tuning Embedding Models for RAG Pipeline Optimization\n", 9 | "\n", 10 | "Base embedding models used for both knowledge base embedding and query embedding for context retrieval in RAG-based applications generally work well, but there are ways to optimize their performance to improve correct information retrieval based on historical user queries and more efficiently retrieve domain-specific information.\n", 11 | "\n", 12 | "**Essentially, fine-tuning embedding models on your data to improve your RAG application!**\n", 13 | "\n", 14 | "I've gone through various papers and implementations of embedding model fine-tuning techniques and determined that the most efficient way to get this improvement is through a **query-only linear adapter**, or training a simple linear layer transformation to better represent user queries in embedding space for improved retrieval.\n", 15 | "\n", 16 | "This allows us to very easily plug into existing RAG pipelines and optimize for our specific task without needing to completely re-embed our knowledge base or use a lot of resources training larger models, making this a simple, cost/compute-effective way to improve retrieval performance.\n", 17 | "\n", 18 | "\n", 19 | "\n", 20 | "Additionally, while it's preferred to use existing labeled data gathered through something like RAG Question Answering Chatbot logs, it is possible to also improve embedding representations with synthetically generated labels.\n", 21 | "\n", 22 | "In this notebook we will:\n", 23 | "1. Define a RAG application to optimize\n", 24 | "2. Generate a synthetic dataset with gpt-4o-mini\n", 25 | "3. Test retrieval metrics to gather a baseline for all-MiniLM-L6-v2\n", 26 | "4. Create and train a linear adapter\n", 27 | "5. Plug the adapter onto all-MiniLM-L6-v2 and assess performance\n", 28 | "\n", 29 | "Along the way, we'll be implementing many methodologies from [ChromaDB's Research](https://research.trychroma.com/embedding-adapters) on a small scale for task-specific performance increases, specifically their recommendations for triplet loss, random negative sampling, and linear query-only transformation.\n", 30 | "\n", 31 | "\n", 32 | "\n", 33 | "Model adapters trained in this notebook published to [AdamLucek/all-MiniLM-L6-v2-query-only-linear-adapter-AppleQA](https://huggingface.co/AdamLucek/all-MiniLM-L6-v2-query-only-linear-adapter-AppleQA)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "9fdd5485-51da-4d17-a1f8-8d6054dcc9eb", 39 | "metadata": {}, 40 | "source": [ 41 | "---\n", 42 | "# Synthetic Dataset Creation\n", 43 | "\n", 44 | "To optimize document retrieval effectively, a crucial component is having access to high-quality labeled data. This data typically consists of pairs matching user queries with their most relevant documents. While the ideal scenario would involve collecting and labeling this data from real-world user interactions and testing, for demonstration purposes we can simulate this data by generating potential queries for each chunk of our knowledgebase.\n", 45 | "\n", 46 | "As a plus [it's been researched](https://arxiv.org/pdf/2401.00368) that using LLMs to generate synthetic data for text embedding improvement can lead to gains!" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 1, 52 | "id": "48b5b048-33d0-4716-96c5-ae17f46510ed", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import os\n", 57 | "import json\n", 58 | "from langchain_openai import ChatOpenAI\n", 59 | "from langchain_core.prompts import ChatPromptTemplate\n", 60 | "from langchain_core.output_parsers import JsonOutputParser\n", 61 | "from langchain_community.document_loaders import PyPDFLoader\n", 62 | "from langchain_text_splitters import RecursiveCharacterTextSplitter" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "803302c9-c013-46fd-89ff-a892a73670d3", 68 | "metadata": {}, 69 | "source": [ 70 | "### Loading [Apple's 2024 Environmental Report](https://www.apple.com/environment/pdf/Apple_Environmental_Progress_Report_2024.pdf)\n", 71 | "\n", 72 | "This will be our main document and candidate for optimization. The concept here is that our \"RAG Application\" we're optimizing would be some sort of question and answering chat flow over this document. Our end goal is to improve accurate document retrieval based on the user's questions." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 2, 78 | "id": "f6beeade-4d01-438c-9637-078b51846568", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# Will be used as Training Data\n", 83 | "apple_loader = PyPDFLoader(\"/Users/alucek/Documents/Jupyter_Notebooks/ft_emb/data/Apple_Environmental_Progress_Report_2024.pdf\")\n", 84 | "apple_pages = apple_loader.load()\n", 85 | "\n", 86 | "apple_document = \"\"\n", 87 | "for i in range(len(apple_pages)):\n", 88 | " apple_document += apple_pages[i].page_content" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "d7078328-d29c-4d71-9759-a665ac0adbac", 94 | "metadata": {}, 95 | "source": [ 96 | "### Chunking Documents\n", 97 | "\n", 98 | "Using a token based chunker with the same parameters that [OpenAI's file search tool](https://platform.openai.com/docs/assistants/tools/file-search/how-it-works) uses for chunk size and overlap. This will split the documents into manageable chunks to be embedded and retreived. " 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "id": "71af17fc-1457-439d-9f4a-1253015b2e45", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "# Split PDFs\n", 109 | "text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n", 110 | " model_name=\"gpt-4\",\n", 111 | " chunk_size=800,\n", 112 | " chunk_overlap=400,\n", 113 | ")\n", 114 | "\n", 115 | "apple_chunks = text_splitter.split_text(apple_document)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "38d9b70d-c95a-42ca-90d7-5038ec37125b", 121 | "metadata": {}, 122 | "source": [ 123 | "### Defining Chunk Question Generation Chain\n", 124 | "\n", 125 | "As mentioned, it is best to use real testing data and labeling from your RAG application query/retrieval pairs- but for demonstration we will use some synthetic QA pair generation via an LLM.\n", 126 | "\n", 127 | "The hope is that for each chunk of text that we have, we can create a possible user query that would most likely retrieve that chunk of text. This will allow us to further on down the line test the same user query, and assess based on retrieval/rank of the expected chunk." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 4, 133 | "id": "769bad4e-77a7-48ba-88a0-39a67dfb174d", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "label_template = \"\"\"\n", 138 | "You are an AI assistant tasked with generating a single, realistic question-answer pair based on a given document. The question should be something a user might naturally ask when seeking information contained in the document.\n", 139 | "\n", 140 | "Given: {chunk}\n", 141 | "\n", 142 | "Instructions:\n", 143 | "1. Analyze the key topics, facts, and concepts in the given document, choose one to focus on.\n", 144 | "2. Generate twenty similar questions that a user might ask to find the information in this document that does NOT contain any company name.\n", 145 | "3. Use natural language and occasionally include typos or colloquialisms to mimic real user behavior in the question.\n", 146 | "4. Ensure the question is semantically related to the document content WITHOUT directly copying phrases.\n", 147 | "5. Make sure that all of the questions are similar to eachother. I.E. All asking about a similar topic/requesting the same information.\n", 148 | "\n", 149 | "Output Format:\n", 150 | "Return a JSON object with the following structure:\n", 151 | "```json\n", 152 | "{{\n", 153 | " \"question_1\": \"Generated question text\",\n", 154 | " \"question_2\": \"Generated question text\",\n", 155 | " ...\n", 156 | "}}\n", 157 | "```\n", 158 | "\n", 159 | "Be creative, think like a curious user, and generate your 20 similar questions that would naturally lead to the given document in a semantic search. Ensure your response is a valid JSON object containing only the questions.\n", 160 | "\n", 161 | "\"\"\"\n", 162 | "\n", 163 | "label_prompt = ChatPromptTemplate.from_template(label_template)\n", 164 | "llm = ChatOpenAI(temperature=1.0, model=\"gpt-4o-mini\")\n", 165 | "\n", 166 | "label_chain = label_prompt | llm | JsonOutputParser()" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 5, 172 | "id": "5de7345d-3ff5-4fff-9c1a-e8472f8df8eb", 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "data": { 177 | "text/plain": [ 178 | "{'question_1': 'How much less energy does the iMac use compared to ENERGY STAR requirements?',\n", 179 | " 'question_2': 'What makes the iMac energy efficient?',\n", 180 | " 'question_3': \"Can you tell me about the iMac's energy consumption compared to standards?\",\n", 181 | " 'question_4': 'What percentage less energy does the iMac consume than required?',\n", 182 | " 'question_5': 'How does the iMac compare in energy usage to the ENERGY STAR benchmark?',\n", 183 | " 'question_6': \"Is there any info on the iMac's energy efficiency ratings?\",\n", 184 | " 'question_7': 'What are the energy savings on the iMac compared to ENERGY STAR?',\n", 185 | " 'question_8': 'How efficiently does the iMac use energy relative to environmental standards?',\n", 186 | " 'question_9': 'What’s the energy reduction percentage of the new iMac models?',\n", 187 | " 'question_10': 'How energy-efficient is the latest iMac according to ENERGY STAR?',\n", 188 | " 'question_11': \"What details can you provide about the iMac's energy savings?\",\n", 189 | " 'question_12': 'Does the iMac meet or exceed ENERGY STAR energy requirements?',\n", 190 | " 'question_13': 'How much energy does the iMac save compared to the required standards?',\n", 191 | " 'question_14': 'What are the efficiency metrics for the iMac regarding energy use?',\n", 192 | " 'question_15': 'Is the iMac compliant with ENERGY STAR for energy usage?',\n", 193 | " 'question_16': 'What’s the energy consumption percentage of the iMac compared to its requirement?',\n", 194 | " 'question_17': 'How does the iMac help in reducing energy costs?',\n", 195 | " 'question_18': 'What advantage does the iMac have in energy efficiency?',\n", 196 | " 'question_19': \"Can you explain the iMac's performance in terms of energy consumption?\",\n", 197 | " 'question_20': \"What are the specifics of the iMac's energy consumption relative to industry standards?\"}" 198 | ] 199 | }, 200 | "execution_count": 5, 201 | "metadata": {}, 202 | "output_type": "execute_result" 203 | } 204 | ], 205 | "source": [ 206 | "label_chain.invoke({\"chunk\": apple_chunks[20]})" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "id": "b53781f3-f686-4294-a73b-28d7993172b2", 212 | "metadata": {}, 213 | "source": [ 214 | "### Training & Validation Split\n", 215 | "\n", 216 | "With 215 chunks and 20 questions each, we have 4300 Question+Chunk pairs. These were shuffled into an 80/20 train/validation set resulting in: \n", 217 | "- Training set size: 3440 \n", 218 | "- Validation set size: 860\n", 219 | "\n", 220 | "The dataset has been uploaded to [AdamLucek/apple-environmental-report-QA-retrieval](https://huggingface.co/datasets/AdamLucek/apple-environmental-report-QA-retrieval)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 6, 226 | "id": "f40574d7-a18c-487e-aa7b-7c02787c0f21", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "with open('/Users/alucek/Documents/Jupyter_Notebooks/ft_emb/data/train.json', 'r') as f:\n", 231 | " train_data = json.load(f)\n", 232 | "\n", 233 | "with open('/Users/alucek/Documents/Jupyter_Notebooks/ft_emb/data/validation.json', 'r') as f:\n", 234 | " validation_data = json.load(f)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "id": "52b37f01-be6f-488f-977e-b696953aacfa", 240 | "metadata": {}, 241 | "source": [ 242 | "---\n", 243 | "# Setting Up Vector Database\n", 244 | "\n", 245 | "We'll be using my go-to open source VDB [ChromaDB](https://www.trychroma.com/) as our application database. You would want to use a testing environment of whatever vector database your application is currently using and the same retrieval parameters to ensure it's optimized for your specific use case and data store. \n", 246 | "\n", 247 | "By default, ChromaDB uses [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) as their embedding model, which is the embedding model we will be using as our foundation for training!" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 7, 253 | "id": "d0179c78-2a4f-412a-ae27-f9ec333f6f3f", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "import chromadb\n", 258 | "\n", 259 | "# Create chroma client\n", 260 | "path = \"/Users/alucek/Documents/Jupyter_Notebooks/ft_emb/chromadb\"\n", 261 | "client = chromadb.PersistentClient(path=path)\n", 262 | "\n", 263 | "# Create collections for both our specific simulated RAG pipelines\n", 264 | "apple_collection = client.get_or_create_collection(name='apple_collection', metadata={\"hnsw:space\": \"cosine\"})" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "id": "bcebbe13-86d1-4be5-84d1-d89103af618b", 270 | "metadata": {}, 271 | "source": [ 272 | "### Adding Chunks to VDB\n", 273 | "\n", 274 | "Simply embedding all chunks into the database." 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "id": "bbba8937-da46-4c9d-82a5-94d0fb4258eb", 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "# Add apple chunks to vdb\n", 285 | "i = 0\n", 286 | "for chunk in apple_chunks:\n", 287 | " apple_collection.add(\n", 288 | " documents=[chunk],\n", 289 | " ids=[f\"chunk_{i}\"]\n", 290 | " )\n", 291 | " i += 1" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "id": "d39c8f7f-0070-4eaf-9543-4b4c25be95d6", 297 | "metadata": {}, 298 | "source": [ 299 | "### Function for Document Retrieval\n", 300 | "\n", 301 | "Takes in the embedding, and retrieves the top 10 similar results from the database." 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 8, 307 | "id": "a5911fb2-d1bb-4227-aca9-b7b334b2dbcf", 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "def retrieve_documents_embeddings(query_embedding, k=10):\n", 312 | " query_embedding_list = query_embedding.tolist()\n", 313 | " \n", 314 | " results = apple_collection.query(\n", 315 | " query_embeddings=[query_embedding_list],\n", 316 | " n_results=k)\n", 317 | " return results['documents'][0]" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "id": "de4d49e7-258a-435a-b8be-6c593edc5dfd", 323 | "metadata": {}, 324 | "source": [ 325 | "---\n", 326 | "# Base Model Evaluation\n", 327 | "\n", 328 | "As mentioned, we will be using [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) as our base model for optimization. This technique is possible to use for any embedding model, so you'd want to assess your embedding model of choice here.\n", 329 | "\n", 330 | "We'll be focusing on two specific metrics to optimize towards:\n", 331 | "- **Mean Reciprocal Rank**\n", 332 | "- **Recall@k**" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "id": "8e3dff31-61b7-4c7d-8adf-b93a8b6f802e", 338 | "metadata": {}, 339 | "source": [ 340 | "### Metric: Mean Reciprocal Rank (MRR)\n", 341 | "\n", 342 | "$MRR = \\frac{1}{|Q|} \\sum_{i=1}^{|Q|} \\frac{1}{rank_i}$\n", 343 | "\n", 344 | "Where:\n", 345 | "\n", 346 | "- $|Q|$ is the number of queries \n", 347 | "- $rank_i$ is the rank of the first correct answer for the $i$-th query\n", 348 | "\n", 349 | "MRR (Mean Reciprocal Rank) measures how high the first correct answer appears in the list, on average. MRR is particularly useful when there's only one relevant item in the list or when we're primarily interested in the position of the first correct result." 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 9, 355 | "id": "8012bdb2-e62d-4a6c-9cd0-2fc2f0ee8575", 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "def reciprocal_rank(retrieved_docs, ground_truth, k):\n", 360 | " try:\n", 361 | " rank = retrieved_docs.index(ground_truth) + 1\n", 362 | " return 1.0 / rank if rank <= k else 0.0\n", 363 | " except ValueError:\n", 364 | " return 0.0" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "id": "f8788377-076d-4fce-b05b-907e43181688", 370 | "metadata": {}, 371 | "source": [ 372 | "### Metric: Recall@K\n", 373 | "\n", 374 | "$\\text{Recall@k} = \\frac{1}{|Q|} \\sum_{i=1}^{|Q|} \\mathbb{1}(rank_i \\leq k)$\n", 375 | "\n", 376 | "Where:\n", 377 | "\n", 378 | "- $|Q|$ is the number of queries\n", 379 | "- $rank_i$ is the rank of the ground truth item for the $i$-th query\n", 380 | "- $\\mathbb{1}()$ is the indicator function, which equals 1 if the condition inside the parentheses is true, and 0 otherwise\n", 381 | "- $k$ is the cut-off rank\n", 382 | "\n", 383 | "Recall@k, also known as hit rate, measures the proportion of relevant items that are successfully retrieved within the top k results. In this context with one ground truth document, it's a binary measure that checks if the ground truth (correct item) is present in the top k retrieved documents." 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 52, 389 | "id": "dd7b7f11-a083-4485-8c76-8a8afc78a524", 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "def hit_rate(retrieved_docs, ground_truth, k):\n", 394 | " return 1.0 if ground_truth in retrieved_docs[:k] else 0.0" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "id": "eda54c5d-3196-421b-9ec4-7b68e11f9a85", 400 | "metadata": {}, 401 | "source": [ 402 | "### Load the base model\n", 403 | "\n", 404 | "We are using [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2), a 22.7M parameter embedding model which creates 384 dimensional dense vector representations of text content. Our goal then is to create an adapter than can better map the generated embedding of our query to the original vectorspace representation of our knowledgebase." 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 11, 410 | "id": "cabb0bd1-70c0-4cab-b9ab-e8cb2e5961f4", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "from sentence_transformers import SentenceTransformer\n", 415 | "\n", 416 | "# Load the base model\n", 417 | "base_model = SentenceTransformer('all-MiniLM-L6-v2')" 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "id": "f202a58d-d780-4b33-ba50-d379ade847b9", 423 | "metadata": {}, 424 | "source": [ 425 | "### Evaluation Function\n", 426 | "\n", 427 | "This will run our metric calculations within our vector database setup and return an average MRR and R@K" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 14, 433 | "id": "1934db48-a9ef-46be-b673-a1daaedb5399", 434 | "metadata": {}, 435 | "outputs": [ 436 | { 437 | "name": "stdout", 438 | "output_type": "stream", 439 | "text": [ 440 | "Average Hit Rate @10: 0.6116279069767442\n", 441 | "Mean Reciprocal Rank @10: 0.31327104097452935\n" 442 | ] 443 | } 444 | ], 445 | "source": [ 446 | "import numpy as np\n", 447 | "\n", 448 | "def validate_embedding_model(validation_data, base_model, k=10):\n", 449 | " hit_rates = []\n", 450 | " reciprocal_ranks = []\n", 451 | " \n", 452 | " for data_point in validation_data:\n", 453 | " question = data_point['question']\n", 454 | " ground_truth = data_point['chunk']\n", 455 | " \n", 456 | " # Generate embedding for the question\n", 457 | " question_embedding = base_model.encode(question)\n", 458 | " \n", 459 | " # Retrieve documents using the embedding\n", 460 | " retrieved_docs = retrieve_documents_embeddings(question_embedding, k)\n", 461 | " \n", 462 | " # Calculate metrics\n", 463 | " hr = hit_rate(retrieved_docs, ground_truth, k)\n", 464 | " rr = reciprocal_rank(retrieved_docs, ground_truth, k)\n", 465 | " \n", 466 | " hit_rates.append(hr)\n", 467 | " reciprocal_ranks.append(rr)\n", 468 | " \n", 469 | " # Calculate average metrics\n", 470 | " avg_hit_rate = np.mean(hit_rates)\n", 471 | " avg_reciprocal_rank = np.mean(reciprocal_ranks)\n", 472 | " \n", 473 | " return {\n", 474 | " 'average_hit_rate': avg_hit_rate,\n", 475 | " 'average_reciprocal_rank': avg_reciprocal_rank\n", 476 | " }\n", 477 | "\n", 478 | "results = validate_embedding_model(validation_data, base_model)\n", 479 | "print(f\"Average Hit Rate @10: {results['average_hit_rate']}\")\n", 480 | "print(f\"Mean Reciprocal Rank @10: {results['average_reciprocal_rank']}\")" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "id": "90eacd06-c6c8-4656-80ab-e8bb72e94d88", 486 | "metadata": {}, 487 | "source": [ 488 | "### Baseline Interpretation\n", 489 | "\n", 490 | "**R@K: 0.6116**\n", 491 | "- With k=10, for 61.16% of the queries, the correct answer was found within the top 10 results.\n", 492 | "\n", 493 | "**MRR: 0.3133**\n", 494 | "- The reciprocal of 0.3133 is about 3.2, indicating that on average, the first correct result appears at about position 3.2 of 10.\n", 495 | "\n", 496 | "Our goal then is to increase both of these so our correct document is placed higher in the retrieval ranking (calculated as more relevant), and shows up more often within our number of retrieved documents." 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "id": "2b0a723a-6de4-4797-aa85-46480fa7c69a", 502 | "metadata": {}, 503 | "source": [ 504 | "---\n", 505 | "# Linear Adapter Training\n", 506 | "\n", 507 | "\n", 508 | "\n", 509 | "As stated in the introduction, we will be training a query-only linear adapter.\n", 510 | "\n", 511 | "This has the added benefits of:\n", 512 | "- Super lightweight, only one single linear transformation layer from the embedding\n", 513 | "- Minimal added compute at run time, can be trained quickly on minimal hardware \n", 514 | "- No need to re-embed your knowledgebase\n", 515 | "- Proven to be almost as effective as full embedding model fine tuning \n", 516 | "\n", 517 | "We'll be using some techniques from the [ChromaDB embedding adapters paper](https://research.trychroma.com/embedding-adapters) findings like triplet loss, random negative sampling, and some of their hyperparameters. These are outlined below:" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 15, 523 | "id": "9a629c7f-5159-4028-949b-969808d7f9c3", 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "import random\n", 528 | "import torch\n", 529 | "from torch import nn\n", 530 | "from torch.utils.data import Dataset, DataLoader\n", 531 | "from torch.optim import AdamW\n", 532 | "from torch.optim.lr_scheduler import LambdaLR\n", 533 | "from torch.nn.utils import clip_grad_norm_" 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "id": "18ba855d-049e-4452-adf0-9c210a3f01a6", 539 | "metadata": {}, 540 | "source": [ 541 | "### Random Negative Sampling\n", 542 | "\n", 543 | "\n", 544 | "\n", 545 | "Negative sampling involves randomly selecting unrelated or irrelevant examples (called \"negative samples\") during the training process.\n", 546 | "\n", 547 | "1. **Purpose**: It helps the model learn to distinguish between relevant and irrelevant information more effectively.\n", 548 | "\n", 549 | "2. **Process**: Along with the correct (positive) matches for a query, the model is also shown randomly selected incorrect (negative) matches.\n", 550 | "\n", 551 | "3. **Benefit**: This exposes the model to a wider range of examples, helping it develop a better understanding of what makes a good match versus a poor one.\n", 552 | "\n", 553 | "4. **Efficiency**: It's a computationally cheap way to improve performance, as it doesn't require carefully curated negative examples.\n", 554 | "\n", 555 | "By introducing these random negative samples, the model adapter better learns to create embeddings that not only bring relevant items closer together in the vector space, but also push irrelevant items further apart. This leads to more robust and discriminative embeddings, ultimately improving the model's ability to retrieve relevant information accurately.\n", 556 | "\n", 557 | "" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 17, 563 | "id": "e995e05b-b6c6-476b-bbea-d06b2e1027b3", 564 | "metadata": {}, 565 | "outputs": [], 566 | "source": [ 567 | "# Load NVIDIA 10K Document\n", 568 | "# Will be used for random negative sampling\n", 569 | "nvidia_loader = PyPDFLoader(\"/Users/alucek/Documents/Jupyter_Notebooks/ft_emb/data/nvidia_10k.pdf\")\n", 570 | "nvidia_pages = nvidia_loader.load()\n", 571 | "\n", 572 | "nvidia_document = \"\"\n", 573 | "for i in range(len(nvidia_pages)):\n", 574 | " nvidia_document += nvidia_pages[i].page_content\n", 575 | "\n", 576 | "nvidia_chunks = text_splitter.split_text(nvidia_document)" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 18, 582 | "id": "0a666ef9-38a6-4fb7-8de2-76d6856df336", 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [ 586 | "def random_negative():\n", 587 | " random_sample = random.choice(nvidia_chunks)\n", 588 | " return random_sample" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": 20, 594 | "id": "a18d437a-56d9-4c69-b4c1-0b22d2a3f813", 595 | "metadata": {}, 596 | "outputs": [ 597 | { 598 | "data": { 599 | "text/plain": [ 600 | "'price basis. In certain cases, we can establish standalone selling price based on directly observable prices of products or services sold separately in comparable\\ncircumstances to similar customers. If standalone selling price is not directly observable, such as when we do not sell a product or service separately , we\\ndetermine standalone selling price based on market data and other observable inputs.\\nChange in Accounting Estimate\\nIn February 2023, we assessed the useful lives of our property , plant, and equipment. Based on advances in technology and usage rate, we increased the\\nestimated useful life of a majority of the server , storage, and network equipment from three years to a range of four to five years, and assembly and test\\nequipment from five years to seven years. The estimated effect of this change for fiscal year 2024 was a benefit of $33 million and $102 million for cost of\\nrevenue and operating expenses, respectively , which resulted in an increase in operating income of $135 million and net income of $114 million after tax, or\\n$0.05 per both basic and diluted share.\\nResults of Operations\\nA discussion regarding our financial condition and results of operations for fiscal year 2024 compared to fiscal year 2023 is presented below . A discussion\\nregarding our financial condition and results of operations for fiscal year 2023 compared to fiscal year 2022 can be found under Item 7 in our Annual Report on\\nForm 10-K for the fiscal year ended January 29, 2023, filed with the SEC on February 24, 2023, which is available free of charge on the SEC’ s website at\\nhttp://www .sec.gov and at our investor relations website, http://investor .nvidia.com.\\n38Table of Contents\\nThe following table sets forth, for the periods indicated, certain items in our Consolidated Statements of Income expressed as a percentage of revenue. \\n Year Ended\\n Jan 28, 2024 Jan 29, 2023\\nRevenue 100.0 % 100.0 %\\nCost of revenue 27.3 43.1 \\nGross profit 72.7 56.9 \\nOperating expenses \\nResearch and development 14.2 27.2 \\nSales, general and administrative 4.4 9.1 \\nAcquisition termination cost — 5.0 \\nTotal operating expenses 18.6 41.3 \\nOperating income 54.1 15.6 \\nInterest income 1.4 1.0 \\nInterest expense (0.4) (1.0)\\nOther , net 0.4 (0.1)\\nOther income (expense), net 1.4 (0.1)\\nIncome before income tax 55.5 15.5 \\nIncome tax expense (benefit) 6.6 (0.7)\\nNet income 48.9 % 16.2 %\\nReportable Segments\\nRevenue by Reportable Segments\\nYear Ended\\nJan 28, 2024 Jan 29, 2023$\\nChange%\\nChange\\n($ in millions)\\nCompute & Networking $ 47,405 $ 15,068 $ 32,337 215 %\\nGraphics 13,517 11,906 1,611 14 %\\nTotal $ 60,922 $ 26,974 $ 33,948 126 %\\nOperating Income by Reportable Segments\\nYear Ended\\nJan 28, 2024 Jan 29, 2023$\\nChange%\\nChange\\n($ in millions)\\nCompute & Networking $ 32,016 $ 5,083 $ 26,933 530 %\\nGraphics 5,846 4,552 1,294 28 %'" 601 | ] 602 | }, 603 | "execution_count": 20, 604 | "metadata": {}, 605 | "output_type": "execute_result" 606 | } 607 | ], 608 | "source": [ 609 | "random_negative()" 610 | ] 611 | }, 612 | { 613 | "cell_type": "markdown", 614 | "id": "667182f4-a800-4e10-9ec5-a65fa0966bea", 615 | "metadata": {}, 616 | "source": [ 617 | "### Torch Module and Linear Transformations\n", 618 | "\n", 619 | "\n", 620 | "\n", 621 | "A linear transformation is a mathematical operation that takes an input vector and produces an output vector while preserving the operations of vector addition and scalar multiplication.\n", 622 | "\n", 623 | "In the context of machine learning and neural networks, a linear transformation is typically represented as:\n", 624 | "\n", 625 | "$f(x) = Wx + b$\n", 626 | "\n", 627 | "Where:\n", 628 | "- $x$ is the input vector\n", 629 | "- $W$ is a matrix of weights\n", 630 | "- $b$ is a bias vector\n", 631 | "\n", 632 | "Internally, `nn.Linear` creates:\n", 633 | "- A weight matrix $W$ of shape (output_features, input_features)\n", 634 | "- A bias vector $b$ of shape (output_features)\n", 635 | "\n", 636 | "This is all saved using PyTorch and their `Module` class, and these are the trainable parameters that we will be optimizing." 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": 21, 642 | "id": "7a61a937-fb73-4764-abbd-e5d9752781d3", 643 | "metadata": {}, 644 | "outputs": [], 645 | "source": [ 646 | "class LinearAdapter(nn.Module):\n", 647 | " def __init__(self, input_dim):\n", 648 | " super().__init__()\n", 649 | " self.linear = nn.Linear(input_dim, input_dim)\n", 650 | " \n", 651 | " def forward(self, x):\n", 652 | " return self.linear(x)" 653 | ] 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "id": "e1b000b5-3794-4436-b7ef-ea7c4e36b938", 658 | "metadata": {}, 659 | "source": [ 660 | "### Triplet Dataset Preparation\n", 661 | "\n", 662 | "\n", 663 | "\n", 664 | "The `TripletDataset` class is a custom dataset class that inherits from PyTorch's `Dataset` designed to work with triplet loss where each data point consists of three parts: a query, a positive example, and a negative example.\n", 665 | "\n", 666 | "1. It retrieves the item at index `idx` from the training data.\n", 667 | "2. Extracts the query and positive example from the item.\n", 668 | "3. Uses the `negative_sampler` to generate a negative example.\n", 669 | "4. Encodes the query, positive, and negative examples into embeddings using the `base_model`.\n", 670 | "5. Returns the triplet of embeddings (query, positive, negative)." 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 22, 676 | "id": "51e28cdd-f54c-47aa-9426-e426fb330bb2", 677 | "metadata": {}, 678 | "outputs": [], 679 | "source": [ 680 | "class TripletDataset(Dataset):\n", 681 | " def __init__(self, data, base_model, negative_sampler):\n", 682 | " self.data = data\n", 683 | " self.base_model = base_model\n", 684 | " self.negative_sampler = negative_sampler\n", 685 | "\n", 686 | " def __len__(self):\n", 687 | " return len(self.data)\n", 688 | "\n", 689 | " def __getitem__(self, idx):\n", 690 | " item = self.data[idx]\n", 691 | " query = item['question']\n", 692 | " positive = item['chunk']\n", 693 | " negative = self.negative_sampler()\n", 694 | " \n", 695 | " query_emb = self.base_model.encode(query, convert_to_tensor=True)\n", 696 | " positive_emb = self.base_model.encode(positive, convert_to_tensor=True)\n", 697 | " negative_emb = self.base_model.encode(negative, convert_to_tensor=True)\n", 698 | " \n", 699 | " return query_emb, positive_emb, negative_emb" 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "id": "23b21840-ae60-4d29-b5b1-e2c569fc90e8", 705 | "metadata": {}, 706 | "source": [ 707 | "### Trainer Script\n", 708 | "\n", 709 | "This script follows the classic machine learning optimization flow in this way:\n", 710 | "\n", 711 | "1. Learning rate scheduler setup with warmup and decay phases\n", 712 | "2. Initialization of LinearAdapter, loss function, optimizer, and dataloader\n", 713 | "3. Calculation of total training steps and setup of learning rate scheduler\n", 714 | "4. Batch generation from the TripletDataset dataloader\n", 715 | "5. Forward pass through the LinearAdapter\n", 716 | "6. Triplet Margin Loss calculation\n", 717 | "7. Backpropagation and gradient clipping\n", 718 | "8. Optimization parameter updates and learning rate adjustment\n", 719 | "9. Epoch-wise loss reporting\n", 720 | "10. Return of the trained adapter\n", 721 | "\n", 722 | "Let's talk briefly about how random negative sampling and triplet loss are used:\n", 723 | "\n", 724 | "\n", 725 | "\n", 726 | "Triplet loss is a type of loss function used in various machine learning tasks, particularly in metric learning and embedding learning. Its primary goal is to learn embeddings such that similar examples are closer together in the embedding space while dissimilar examples are farther apart.\n", 727 | "\n", 728 | "The triplet loss operates on triplets of data points:\n", 729 | "\n", 730 | "1. Anchor (A): The reference sample\n", 731 | "2. Positive (P): A sample similar to the anchor\n", 732 | "3. Negative (N): A sample dissimilar to the anchor\n", 733 | "\n", 734 | "defined as: $L = max(d(A, P) - d(A, N) + margin, 0)$\n", 735 | "\n", 736 | "Where:\n", 737 | "- $d(x, y)$ is the distance function (Euclidean)\n", 738 | "- margin is a hyperparameter that enforces a minimum distance between the positive and negative pairs\n", 739 | "\n", 740 | "The loss encourages the model to learn embeddings where:\n", 741 | "\n", 742 | "$d(A, P) < d(A, N) - margin$\n", 743 | "\n", 744 | "This means the distance between the anchor and positive should be smaller than the distance between the anchor and negative, by at least the margin. The Negative document is dynamically randomly sampled from our NVIDIA form 10-K " 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 23, 750 | "id": "591caa00-410f-4149-ad5a-6a83e79a6838", 751 | "metadata": {}, 752 | "outputs": [], 753 | "source": [ 754 | "def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):\n", 755 | " def lr_lambda(current_step):\n", 756 | " if current_step < num_warmup_steps:\n", 757 | " return float(current_step) / float(max(1, num_warmup_steps))\n", 758 | " return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))\n", 759 | " return LambdaLR(optimizer, lr_lambda)\n", 760 | "\n", 761 | "def train_linear_adapter(base_model, train_data, negative_sampler, num_epochs=10, batch_size=32, \n", 762 | " learning_rate=2e-5, warmup_steps=100, max_grad_norm=1.0, margin=1.0):\n", 763 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 764 | " \n", 765 | " # Initialize the LinearAdapter\n", 766 | " adapter = LinearAdapter(base_model.get_sentence_embedding_dimension()).to(device)\n", 767 | " \n", 768 | " # Define loss function and optimizer\n", 769 | " triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)\n", 770 | " optimizer = AdamW(adapter.parameters(), lr=learning_rate)\n", 771 | " \n", 772 | " # Create dataset and dataloader\n", 773 | " dataset = TripletDataset(train_data, base_model, negative_sampler)\n", 774 | " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", 775 | " \n", 776 | " # Calculate total number of training steps\n", 777 | " total_steps = len(dataloader) * num_epochs\n", 778 | " \n", 779 | " # Create learning rate scheduler\n", 780 | " scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)\n", 781 | " \n", 782 | " # Training loop\n", 783 | " for epoch in range(num_epochs):\n", 784 | " total_loss = 0\n", 785 | " for batch in dataloader:\n", 786 | " query_emb, positive_emb, negative_emb = [x.to(device) for x in batch]\n", 787 | " \n", 788 | " # Forward pass\n", 789 | " adapted_query_emb = adapter(query_emb)\n", 790 | " \n", 791 | " # Compute loss\n", 792 | " loss = triplet_loss(adapted_query_emb, positive_emb, negative_emb)\n", 793 | " \n", 794 | " # Backward pass and optimization\n", 795 | " optimizer.zero_grad()\n", 796 | " loss.backward()\n", 797 | " \n", 798 | " # Gradient clipping\n", 799 | " clip_grad_norm_(adapter.parameters(), max_grad_norm)\n", 800 | " \n", 801 | " optimizer.step()\n", 802 | " scheduler.step()\n", 803 | " \n", 804 | " total_loss += loss.item()\n", 805 | " \n", 806 | " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}\")\n", 807 | " \n", 808 | " return adapter" 809 | ] 810 | }, 811 | { 812 | "cell_type": "markdown", 813 | "id": "db53c604-03f3-4ab3-b251-e97a49867155", 814 | "metadata": {}, 815 | "source": [ 816 | "### Training Function Execution!\n", 817 | "\n", 818 | "Saves the hyperparameters and linear adapter in the same file" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": 24, 824 | "id": "c1ca4974-ec5e-44db-b9b2-c5b58319abd4", 825 | "metadata": {}, 826 | "outputs": [ 827 | { 828 | "name": "stdout", 829 | "output_type": "stream", 830 | "text": [ 831 | "Epoch 1/1, Loss: 0.4718\n" 832 | ] 833 | } 834 | ], 835 | "source": [ 836 | "# Define the kwargs dictionary\n", 837 | "adapter_kwargs = {\n", 838 | " 'num_epochs': 1,\n", 839 | " 'batch_size': 32,\n", 840 | " 'learning_rate': 0.003,\n", 841 | " 'warmup_steps': 100,\n", 842 | " 'max_grad_norm': 1.0,\n", 843 | " 'margin': 1.0\n", 844 | "}\n", 845 | "\n", 846 | "# Train the adapter using the kwargs dictionary\n", 847 | "trained_adapter = train_linear_adapter(base_model, train_data, random_negative, **adapter_kwargs)\n", 848 | "\n", 849 | "# Create a dictionary to store both the adapter state_dict and the kwargs\n", 850 | "save_dict = {\n", 851 | " 'adapter_state_dict': trained_adapter.state_dict(),\n", 852 | " 'adapter_kwargs': adapter_kwargs\n", 853 | "}\n", 854 | "\n", 855 | "# Save the combined dictionary\n", 856 | "torch.save(save_dict, '/Users/alucek/Documents/Jupyter_Notebooks/ft_emb/adapters/linear_adapter_1epoch.pth')" 857 | ] 858 | }, 859 | { 860 | "cell_type": "markdown", 861 | "id": "0db51ae8-06d1-4a00-8ee1-db43262df725", 862 | "metadata": {}, 863 | "source": [ 864 | "---\n", 865 | "# Evaluate Adapter Performance\n", 866 | "\n", 867 | "Now that we have our trained query-only linear adapter, let's assess its performance on our knowledgebase compared to our baseline model." 868 | ] 869 | }, 870 | { 871 | "cell_type": "markdown", 872 | "id": "ee831ee7-ea22-4257-ae0b-7e0827c4c3bf", 873 | "metadata": {}, 874 | "source": [ 875 | "### Applying the Adapter\n", 876 | "\n", 877 | "Below is the function to get the original embedding for the query and run it through our trained adapter.\n", 878 | "\n", 879 | "This is essentially the new function for embedding your user query." 880 | ] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": 25, 885 | "id": "ba83ca83-6529-4e52-a4a8-d56eeea64ff3", 886 | "metadata": {}, 887 | "outputs": [], 888 | "source": [ 889 | "# Function to encode query using the adapter\n", 890 | "def encode_query(query, base_model, adapter):\n", 891 | " device = next(adapter.parameters()).device\n", 892 | " query_emb = base_model.encode(query, convert_to_tensor=True).to(device)\n", 893 | " adapted_query_emb = adapter(query_emb)\n", 894 | " return adapted_query_emb.cpu().detach().numpy()" 895 | ] 896 | }, 897 | { 898 | "cell_type": "markdown", 899 | "id": "7bdb45e7-d62b-4e87-b455-f26d8d768ef8", 900 | "metadata": {}, 901 | "source": [ 902 | "### Loading the Adapter" 903 | ] 904 | }, 905 | { 906 | "cell_type": "code", 907 | "execution_count": 50, 908 | "id": "2598f1cb-5844-4184-a102-f626c1a4ee96", 909 | "metadata": {}, 910 | "outputs": [ 911 | { 912 | "name": "stdout", 913 | "output_type": "stream", 914 | "text": [ 915 | "Adapter loaded successfully.\n", 916 | "Training parameters used:\n", 917 | "num_epochs: 30\n", 918 | "batch_size: 32\n", 919 | "learning_rate: 0.003\n", 920 | "warmup_steps: 100\n", 921 | "max_grad_norm: 1.0\n", 922 | "margin: 1.0\n" 923 | ] 924 | } 925 | ], 926 | "source": [ 927 | "# Later, loading and using the saved information\n", 928 | "loaded_dict = torch.load('/Users/alucek/Documents/Jupyter_Notebooks/ft_emb/adapters/linear_adapter_30epochs.pth')\n", 929 | "\n", 930 | "# Recreate the adapter\n", 931 | "loaded_adapter = LinearAdapter(base_model.get_sentence_embedding_dimension()) # Initialize with appropriate parameters\n", 932 | "loaded_adapter.load_state_dict(loaded_dict['adapter_state_dict'])\n", 933 | "\n", 934 | "# Access the training parameters\n", 935 | "training_params = loaded_dict['adapter_kwargs']\n", 936 | "\n", 937 | "print(\"Adapter loaded successfully.\")\n", 938 | "print(\"Training parameters used:\")\n", 939 | "for key, value in training_params.items():\n", 940 | " print(f\"{key}: {value}\")" 941 | ] 942 | }, 943 | { 944 | "cell_type": "markdown", 945 | "id": "848c1c99-e9dd-417b-9f0e-4e1af239bdd6", 946 | "metadata": {}, 947 | "source": [ 948 | "### Adapter Evaluation Function\n", 949 | "\n", 950 | "New evaluation function to replicate the original experiment, however this time with adapter support" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": 53, 956 | "id": "db5000e9-7abf-4731-ad3b-6c35472278bd", 957 | "metadata": {}, 958 | "outputs": [ 959 | { 960 | "name": "stdout", 961 | "output_type": "stream", 962 | "text": [ 963 | "Average Hit Rate @10: 0.6662790697674419\n", 964 | "Mean Reciprocal Rank @10: 0.33240956072351424\n" 965 | ] 966 | } 967 | ], 968 | "source": [ 969 | "def evaluate_adapter(validation_data, base_model, adapter, k=10):\n", 970 | " hit_rates = []\n", 971 | " reciprocal_ranks = []\n", 972 | " \n", 973 | " for data_point in validation_data:\n", 974 | " question = data_point['question']\n", 975 | " ground_truth = data_point['chunk']\n", 976 | " \n", 977 | " # Generate embedding for the question\n", 978 | " question_embedding = encode_query(question, base_model, adapter)\n", 979 | " # Retrieve documents using the embedding\n", 980 | " retrieved_docs = retrieve_documents_embeddings(question_embedding, k)\n", 981 | " \n", 982 | " # Calculate metrics\n", 983 | " hr = hit_rate(retrieved_docs, ground_truth, k)\n", 984 | " rr = reciprocal_rank(retrieved_docs, ground_truth, k)\n", 985 | " \n", 986 | " hit_rates.append(hr)\n", 987 | " reciprocal_ranks.append(rr)\n", 988 | " \n", 989 | " # Calculate average metrics\n", 990 | " avg_hit_rate = np.mean(hit_rates)\n", 991 | " avg_reciprocal_rank = np.mean(reciprocal_ranks)\n", 992 | " \n", 993 | " return {\n", 994 | " 'average_hit_rate': avg_hit_rate,\n", 995 | " 'average_reciprocal_rank': avg_reciprocal_rank\n", 996 | " }\n", 997 | "\n", 998 | "results = evaluate_adapter(validation_data, base_model, loaded_adapter, k=10)\n", 999 | "print(f\"Average Hit Rate @10: {results['average_hit_rate']}\")\n", 1000 | "print(f\"Mean Reciprocal Rank @10: {results['average_reciprocal_rank']}\")" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "markdown", 1005 | "id": "ece1888b-078e-4c24-8fef-0b4969e001a7", 1006 | "metadata": {}, 1007 | "source": [ 1008 | "Post training, our adapter gave us an average hit rate/recall of 66.7%, a percentage point increase of 4.8 over baseline of 61.9%- **a 7.8% improvement**. And a mean reciprocal rank of 0.332, so our expected document tends to be placed at place 3.0, compared to the baseline of 3.2 (0.3110)- **a 6.2% improvement**.\n", 1009 | "\n", 1010 | "### Validation Metrics Compared to Baseline\n", 1011 | "\n", 1012 | "\n", 1013 | "We can conclude that our 30 epoch trained version on these hyperparameters gave us the biggest improvement, and began to overfit and lose ability to generalize when increasing to 40. To see how our model is fitting to the data, we can run the metrics on our training data, visualized below:\n", 1014 | "\n", 1015 | "### Visualizing Model Fitting on Training Data\n", 1016 | "\n", 1017 | "\n", 1018 | "Decent fitting, not rampant overfitting. If user queries are the same as some of the frequent queries in the training data, they will definitely have a big boost in expected document retrieval accuracy." 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "code", 1023 | "execution_count": null, 1024 | "id": "fcf9e9f4-a9f3-44b9-8a6b-fd36f2ac3326", 1025 | "metadata": {}, 1026 | "outputs": [], 1027 | "source": [] 1028 | } 1029 | ], 1030 | "metadata": { 1031 | "kernelspec": { 1032 | "display_name": "Python 3 (ipykernel)", 1033 | "language": "python", 1034 | "name": "python3" 1035 | }, 1036 | "language_info": { 1037 | "codemirror_mode": { 1038 | "name": "ipython", 1039 | "version": 3 1040 | }, 1041 | "file_extension": ".py", 1042 | "mimetype": "text/x-python", 1043 | "name": "python", 1044 | "nbconvert_exporter": "python", 1045 | "pygments_lexer": "ipython3", 1046 | "version": "3.12.1" 1047 | } 1048 | }, 1049 | "nbformat": 4, 1050 | "nbformat_minor": 5 1051 | } 1052 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fine-Tuning Embedding Models for RAG Pipeline Optimization 2 | 3 | ## Check out my Video Walkthrough Here 4 | 5 | [![vid_screenshot.png](./media/vid_screenshot.png)](https://youtu.be/hztWQcoUbt0) 6 | Click play to watch :) 7 | 8 | Base embedding models used for both knowledge base embedding and query embedding for context retrieval in RAG-based applications generally work well, but there are ways to optimize their performance to improve correct information retrieval based on historical user queries and more efficiently retrieve domain-specific information. 9 | 10 | **Essentially, fine-tuning embedding models on your data to improve your RAG application!** 11 | 12 | I've gone through various papers and implementations of embedding model fine-tuning techniques and determined that the most efficient way to get this improvement is through a **query-only linear adapter**, or training a simple linear layer transformation to better represent user queries in embedding space for improved retrieval. 13 | 14 | This allows us to very easily plug into existing RAG pipelines and optimize for our specific task without needing to completely re-embed our knowledge base or use a lot of resources training larger models, making this a simple, cost/compute-effective way to improve retrieval performance. 15 | 16 | 17 | 18 | Additionally, while it's preferred to use existing labeled data gathered through something like RAG Question Answering Chatbot logs, it is possible to also improve embedding representations with synthetically generated labels. 19 | 20 | In this notebook we will: 21 | 1. Define a RAG application to optimize 22 | 2. Generate a synthetic dataset with gpt-4o-mini 23 | 3. Test retrieval metrics to gather a baseline for all-MiniLM-L6-v2 24 | 4. Create and train a linear adapter 25 | 5. Plug the adapter onto all-MiniLM-L6-v2 and assess performance 26 | 27 | Along the way, we'll be implementing many methodologies from [ChromaDB's Research](https://research.trychroma.com/embedding-adapters) on a small scale for task-specific performance increases, specifically their recommendations for triplet loss, random negative sampling, and linear query-only transformation. 28 | 29 | 30 | 31 | Model adapters trained in this notebook published to [AdamLucek/all-MiniLM-L6-v2-query-only-linear-adapter-AppleQA](https://huggingface.co/AdamLucek/all-MiniLM-L6-v2-query-only-linear-adapter-AppleQA) 32 | -------------------------------------------------------------------------------- /adapters/linear_adapter_10epochs.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/adapters/linear_adapter_10epochs.pth -------------------------------------------------------------------------------- /adapters/linear_adapter_20epochs.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/adapters/linear_adapter_20epochs.pth -------------------------------------------------------------------------------- /adapters/linear_adapter_30epochs.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/adapters/linear_adapter_30epochs.pth -------------------------------------------------------------------------------- /adapters/linear_adapter_40epochs.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/adapters/linear_adapter_40epochs.pth -------------------------------------------------------------------------------- /data/nvidia_10k.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/data/nvidia_10k.pdf -------------------------------------------------------------------------------- /media/adapter_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/adapter_diagram.png -------------------------------------------------------------------------------- /media/adapters_explainer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/adapters_explainer.png -------------------------------------------------------------------------------- /media/linear_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/linear_layer.png -------------------------------------------------------------------------------- /media/negative_sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/negative_sampling.png -------------------------------------------------------------------------------- /media/negative_sampling_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/negative_sampling_2.png -------------------------------------------------------------------------------- /media/training_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/training_fit.png -------------------------------------------------------------------------------- /media/triplet_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/triplet_loss.png -------------------------------------------------------------------------------- /media/tripletdataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/tripletdataset.png -------------------------------------------------------------------------------- /media/validation_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/validation_chart.png -------------------------------------------------------------------------------- /media/vid_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALucek/linear-adapter-embedding/f2ee2b497821cff60630fe75192ce49b753620fc/media/vid_screenshot.png --------------------------------------------------------------------------------