├── experimental_tech ├── .gitkeep ├── 2_compress_and_rerank.ipynb └── 1_estimating_k.ipynb ├── images ├── 1_estimating_k_diagram.png └── 2_compress_and_rerank_diagram.png ├── LICENSE ├── README.md └── .gitignore /experimental_tech/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/1_estimating_k_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucaStrano/Experimental_RAG_Tech/HEAD/images/1_estimating_k_diagram.png -------------------------------------------------------------------------------- /images/2_compress_and_rerank_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucaStrano/Experimental_RAG_Tech/HEAD/images/2_compress_and_rerank_diagram.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Luca Strano 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🧪 Experimental RAG Techniques 2 | 3 | Welcome to the **Experimental RAG Techniques** repository! This repo contains various experimental techniques for implementing and optimizing certain aspects of Retrieval Augmented Generation (RAG) systems. For each technique, you can find a dedicated notebook file that goes into detail about the intuition behind the technique, how to implement it, and its potential benefits. 4 | 5 | This repostiory is adjacent to Nir Diamant's [Advanced Rag Techniques Repo](https://github.com/NirDiamant/RAG_Techniques), which i highly recommend checking out. 6 | 7 | ### ⁉️ What kind of tech are we talking about? 8 | 9 | The techniques included in this repository are **experimental** in nature, meaning they may not have been extensively tested or validated in serious production environments. However, they represent innovative approaches to improving RAG systems and could lead to advancements in the field, which is why I wanted to share them with the community. _This repository is a place for experimentation and exploration, so please approach its contents with an open mind and a willingness to test and iterate_. 10 | 11 | The techniques implemented in this repository primarily rely on **Traditional NLP** methods, which offer a strong balance between quality and efficiency. While these methods may not match the raw power of LLMs, they can still produce highly satisfactory results, especially when considering the **quality/latency tradeoff**. This makes them particularly suitable for RAG environments, where low latency and high quality are often two critical requirements. 12 | 13 | ### ⁉️ Can I contribute? 14 | 15 | Absolutely! If you have a novel experimental technique that you've developed or even just thought about, please feel free to contact me. I would be happy to collaborate with you and credit you for your contribution. You can send me an email at **strano.lucass@gmail.com** or reach out to me on [LinkedIn](https://www.linkedin.com/in/strano-lucass/). 16 | 17 | ## 📑 Table of Contents 18 | 19 | > Techniques marked with a 🧪 emoji are original contributions derived from my research that, to the best of my knowledge, have not been published or widely discussed elsewhere. 20 | 21 | | # | Title | Type | Notebook | 22 | |---|-------|------|----------| 23 | | 1 | 🧪 **Dynamic K Estimation with Query Complexity Score** | 🎣 Retrieval | [![Github View](https://img.shields.io/badge/GitHub-View-blue)](https://github.com/LucaStrano/Experimental_RAG_Tech/blob/main/experimental_tech/1_estimating_k.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LucaStrano/Experimental_RAG_Tech/blob/main/experimental_tech/1_estimating_k.ipynb) | 24 | | 2 | 🧪 **Single Pass Rerank and Compression with Recursive Reranking** | 🎣 Retrieval | [![Github View](https://img.shields.io/badge/GitHub-View-blue)](https://github.com/LucaStrano/Experimental_RAG_Tech/blob/main/experimental_tech/2_compress_and_rerank.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LucaStrano/Experimental_RAG_Tech/blob/main/experimental_tech/2_compress_and_rerank.ipynb) | 25 | | 3 | 🧪 **Coming Soon!** | ❓ _Soon_ | _Coming Soon_ | 26 | 27 | ## 🔬 Techniques Overview 28 | 29 | ### 1. Dynamic K Estimation with Query Complexity Score 30 | 31 | **Type: 🎣 Retrieval (🧪)** 32 | 33 | This technique introduces a novel approach to dynamically estimate the optimal number of documents to retrieve (K) based on the complexity of the query. By using traditional NLP methods and by analyzing the query's structure and semantics, the (hyper)parameter K can be adjusted to ensure retrieval of the right amount of information needed for effective RAG. 34 | 35 | ### 2. Single Pass Rerank and Compression with Recursive Reranking 36 | 37 | **Type: 🎣 Retrieval (🧪)** 38 | 39 | This technique combines Reranking and Contextual Compression into a single pass by using a Reranker Model. Retrieved documents are broken down into smaller sub-sections, which are then used to both rerank documents by calculating an average score and compress them by statistically selecting only the most relevant sub-sections with regard to the user query. 40 | 41 | ## 🤝🏻 Acknowledgements 42 | 43 | Made with 100% love and 0% Vibe Coding by [Luca Strano](https://www.linkedin.com/in/strano-lucass/) ❤️. Acknowledgements for future contributions and collaborations will go here! 44 | 45 | ## 📝 License 46 | 47 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | #pdm.lock 116 | #pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | #pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .envrc 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | #.idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | 204 | # Marimo 205 | marimo/_static/ 206 | marimo/_lsp/ 207 | __marimo__/ 208 | 209 | # Streamlit 210 | .streamlit/secrets.toml 211 | 212 | # General 213 | .DS_Store 214 | .AppleDouble 215 | .LSOverride 216 | Icon[ ] 217 | 218 | # Thumbnails 219 | ._* 220 | 221 | # Files that might appear in the root of a volume 222 | .DocumentRevisions-V100 223 | .fseventsd 224 | .Spotlight-V100 225 | .TemporaryItems 226 | .Trashes 227 | .VolumeIcon.icns 228 | .com.apple.timemachine.donotpresent 229 | 230 | # Directories potentially created on remote AFP share 231 | .AppleDB 232 | .AppleDesktop 233 | Network Trash Folder 234 | Temporary Items 235 | .apdisk 236 | -------------------------------------------------------------------------------- /experimental_tech/2_compress_and_rerank.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "07dd9068", 6 | "metadata": {}, 7 | "source": [ 8 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LucaStrano/Experimental_RAG_Tech/blob/main/experimental_tech/2_compress_and_rerank.ipynb)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "2db5300c", 14 | "metadata": {}, 15 | "source": [ 16 | "## Single Pass Rerank and Contextual Compression using Recursive Reranking" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "dc8ff636", 22 | "metadata": {}, 23 | "source": [ 24 | "### Overview\n", 25 | "\n", 26 | "This notebook demonstrates how to use a **Reranker Model** to perform both _Reranking_ and _Contextual Compression_ in a single pass. We will go over the **Intuition** behind the technique, the **Implementation** details and, finally, a short **Conclusion**." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "15a734b8", 32 | "metadata": {}, 33 | "source": [ 34 | "### 1. Intuition\n", 35 | "\n", 36 | "The use of a Reranker Model has become essential in most modern RAG pipelines, especially when dealing with large or complex datasets. It significantly improves the **Precision** of retrieved results by re-evaluating and reordering initial retrieved documents based on deeper semantic understanding. A Reranker Model is tipically involved in the following steps:\n", 37 | "\n", 38 | "1. Use a fast **Embedding Model** to retrieve a set of $\\text{top}_K$ candidate documents of based on query similarity. Usually, $\\text{top}_K$ is set to a relatively high number to ensure high **Recall**;\n", 39 | "\n", 40 | "2. Use a **Reranker Model** to re-evaluate the $\\text{top}_K$ documents and select the $top_N$ most relevant ones. $\\text{top}_N$ is usually set to a much lower number to ensure high **Precision**.\n", 41 | "\n", 42 | "You might wonder why a reranker model is necessary at all: after all, the initial retrieval step already returns a set of seemingly relevant documents. This is because embedding models, while effective for initial retrieval, rely on the **Encoder** Architecture which compresses the semantic meaning of the documents into fixed-size vectors. Relevance is then estimated using a _similarity function_ (such as cosine similarity) over the calculated vectors. While this approach is efficient, it can miss subtle semantic nuances and contextual cues of the original texts.\n", 43 | "Reranker Models, on the other hand, use a **Cross-Encoder** architecture, which jointly processes the query and candidate documents at the _token level_, allowing for a more fine-grained understanding of their relationship. This process, while more computationally expensive, ensures a higher quality of the final results.\n", 44 | "\n", 45 | "To further enhance the quality of the results, we can also apply a **Contextual Compression** step. This step involves breaking down the retrieved documents into smaller, more manageable chunks. This allows us to not only select the most relevant documents but also to extract only the most relevant pieces from them, effectively compressing the context while retaining essential information.\n", 46 | "\n", 47 | "The problem with this pipeline is that it now requires three separate steps: An initial retrieval step, a reranking step, and a compression step. Using traditional methods, this can be inefficient and highly time-consuming. What if we could combine both Reranking and Compression into a single step? This is where the **Recursive Reranking** technique comes into play, which functions as follows:\n", 48 | "\n", 49 | "1. Use a fast Embedding Model to retrieve a set of $top_K$ candidate documents;\n", 50 | "\n", 51 | "2. Using a Reranker Model, calculate a relevance score for each sub-section of each document;\n", 52 | "\n", 53 | "3. Use calculated sub-section scores to both rerank documents and select only the most relevant sub-sections of each reranked document." 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "89054d33", 59 | "metadata": {}, 60 | "source": [ 61 | "### 2. Recursive Reranking\n", 62 | "\n", 63 | "This section focuses on the **Preliminaries** and the **Implementation** of the Recursive Reranking Technique.\n", 64 | "\n", 65 | "### 2.1 Preliminaries to Recursive Reranking\n", 66 | "\n", 67 | "Let's start by installing the necessary dependencies. We will use the `chromadb` library to handle our vector database, and the `sentence-transformers` library to use our Reranker Model." 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "7db373da", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "%conda install -c conda-forge sentence-transformers hf-xet chromadb\n", 78 | "# %pip install -U sentence-transformers hf-xet chromadb" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "666d71d2", 84 | "metadata": {}, 85 | "source": [ 86 | "Let's first define our example documents that we will use throughout this notebook:" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 1, 92 | "id": "48e0e125", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "docs = [\n", 97 | "\"\"\"\n", 98 | "Italy, officially the Italian Republic, is a country in Southern and Western Europe.\n", 99 | "It consists of a peninsula that extends into the Mediterranean Sea.\n", 100 | "The Alps mountain range forms its northern boundary, while the Apennine Mountains run down the length of the peninsula.\n", 101 | "The territory also includes well as nearly 800 islands, notably Sicily and Sardinia.\n", 102 | "It is a country in Southern Europe with a population of approximately 60 million people.\n", 103 | "\"\"\",\n", 104 | "\n", 105 | "\"\"\"\n", 106 | "The capital of Italy is Rome, which is also the largest city in the country.\n", 107 | "Rome is known for its nearly 3,000 years of globally influential art, architecture, and culture.\n", 108 | "The city is often referred to as the \"Eternal City\" and is famous for its ancient history, including landmarks such as the Colosseum and the Vatican.\n", 109 | "It is the capital city of Italy and has a population of almost 3 million people.\n", 110 | "\"\"\",\n", 111 | "\n", 112 | "\"\"\"\n", 113 | "Italy's history goes back to numerous Italic peoples—notably including the ancient Romans, \n", 114 | "who conquered the Mediterranean world during the Roman Republic and ruled it for centuries during the Roman Empire.\n", 115 | "The Roman Empire was among the largest in history, wielding great economical, cultural, political, and military power.\n", 116 | "\"\"\",\n", 117 | "\n", 118 | "\"\"\"\n", 119 | "France is a country in Western Europe, known for its rich history, culture, and influence.\n", 120 | "The food in France is renowned worldwide, with dishes like coq au vin and ratatouille.\n", 121 | "France has a world class cuisine and is famous for its wine, cheese, and pastries.\n", 122 | "Regions like Bordeaux and Champagne are particularly well-known in the culinary world.\n", 123 | "\"\"\"\n", 124 | "]" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "2b36c0b3", 130 | "metadata": {}, 131 | "source": [ 132 | "We will use Chroma's in-memory vector database to simulate the initial retrieval step:" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "7148df3d", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "import chromadb\n", 143 | "from uuid import uuid4\n", 144 | "\n", 145 | "client = chromadb.Client()\n", 146 | "collection = client.create_collection(name=\"italy\")\n", 147 | "collection.add(\n", 148 | " ids = [str(uuid4()) for _ in range(len(docs))],\n", 149 | " documents = docs,\n", 150 | ")" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "id": "3dbd7ec4", 156 | "metadata": {}, 157 | "source": [ 158 | "Chroma automatically handles the creation of embeddings for the documents we add to the collection. By default, it uses the [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model, which is lightweight and efficient. Let's try querying our collection to see if it works:" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 28, 164 | "id": "87caa818", 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "Document 1: \n", 172 | "The capital of Italy is Rome, which is also the largest city in the country.\n", 173 | "Rome is known for its nearly 3,000 years of globally influential art, architecture, and culture.\n", 174 | "The city is often referred to as the \"Eternal City\" and is famous for its ancient history, including landmarks such as the Colosseum and the Vatican.\n", 175 | "It is the capital city of Italy and has a population of almost 3 million people.\n", 176 | "\n", 177 | "ID: 9b23cca7-2749-4ed1-9404-4edd035a1b8a\n", 178 | "Distance: 0.6291064023971558\n", 179 | "--------------------------------------------------\n", 180 | "Document 2: \n", 181 | "Italy, officially the Italian Republic, is a country in Southern and Western Europe.\n", 182 | "It consists of a peninsula that extends into the Mediterranean Sea.\n", 183 | "The Alps mountain range forms its northern boundary, while the Apennine Mountains run down the length of the peninsula.\n", 184 | "The territory also includes well as nearly 800 islands, notably Sicily and Sardinia.\n", 185 | "It is a country in Southern Europe with a population of approximately 60 million people.\n", 186 | "\n", 187 | "ID: 05b94dfd-010c-426d-8add-bd427066cfff\n", 188 | "Distance: 0.7050184011459351\n", 189 | "--------------------------------------------------\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "query = \"How many people live in Italy and what is the capital?\"\n", 195 | "results = collection.query(query_texts=[query], n_results=2)\n", 196 | "\n", 197 | "for i, id in enumerate(results['ids'][0]):\n", 198 | " print(f\"Document {i+1}: {results['documents'][0][i]}\")\n", 199 | " print(f\"ID: {id}\\nDistance: {results['distances'][0][i]}\")\n", 200 | " print(\"-\" * 50)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "id": "ab38fa59", 206 | "metadata": {}, 207 | "source": [ 208 | "We get great results. Let's now introduce our Reranker Model. We will use the [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2) model, which is a lightweight Cross-Encoder model designed for reranking tasks." 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 11, 214 | "id": "461388ad", 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "Using device: mps\n", 222 | "✅ Reranker model loaded.\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "from sentence_transformers import CrossEncoder\n", 228 | "import torch # comes with sentence-transformers\n", 229 | "\n", 230 | "DEVICE = 'cuda' if torch.cuda.is_available() \\\n", 231 | " else 'mps' if torch.backends.mps.is_available() \\\n", 232 | " else 'cpu'\n", 233 | "print(f\"Using device: {DEVICE}\")\n", 234 | "rerank = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2', device=DEVICE)\n", 235 | "print(\"✅ Reranker model loaded.\")" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "id": "e6a78eb8", 241 | "metadata": {}, 242 | "source": [ 243 | "Let's do a quick check and see how the model ranks our documents with an example query:" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 32, 249 | "id": "3e18fd13", 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "text/plain": [ 255 | "array([ 1.5755582, 9.614823 , -4.1138754, -8.50889 ], dtype=float32)" 256 | ] 257 | }, 258 | "execution_count": 32, 259 | "metadata": {}, 260 | "output_type": "execute_result" 261 | } 262 | ], 263 | "source": [ 264 | "query = \"What is the capital of Italy?\"\n", 265 | "rerank_results = rerank.predict([[query, doc] for doc in docs])\n", 266 | "rerank_results" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "id": "7a06327b", 272 | "metadata": {}, 273 | "source": [ 274 | "Unsurprisingly, the highest ranked document is the one that talks about Rome, which is the capital of Italy. The lowest ranked document is the one that talks about France, which is not relevant to the query at all. This is to be expected. \n", 275 | "Let’s now take the highest-ranked document and predict a relevance score for each of its sentences individually. This allows us to analyze the alignment between the query and different parts of the document in a more fine-graned way, rather than treating the whole document as a single block." 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 17, 281 | "id": "1b804149", 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "name": "stdout", 286 | "output_type": "stream", 287 | "text": [ 288 | "Sentence 1: The capital of Italy is Rome, which is also the largest city in the country\n", 289 | "Score: 9.2809\n", 290 | "--------------------------------------------------\n", 291 | "Sentence 2: Rome is known for its nearly 3,000 years of globally influential art, architecture, and culture\n", 292 | "Score: -1.6989\n", 293 | "--------------------------------------------------\n", 294 | "Sentence 3: The city is often referred to as the \"Eternal City\" and is famous for its ancient history, including landmarks such as the Colosseum and the Vatican\n", 295 | "Score: -7.3384\n", 296 | "--------------------------------------------------\n", 297 | "Sentence 4: It is the capital city of Italy and has a population of almost 3 million people\n", 298 | "Score: 6.3446\n", 299 | "--------------------------------------------------\n", 300 | "Mean score of sentences: 1.6470724\n", 301 | "Standard deviation of scores: 6.5626807\n" 302 | ] 303 | } 304 | ], 305 | "source": [ 306 | "query = \"What is the capital of Italy?\"\n", 307 | "sents = [sent.strip() for sent in docs[1].split('.') if sent.strip()]\n", 308 | "sent_scores = rerank.predict([[query, sent] for sent in sents])\n", 309 | "\n", 310 | "for i, (sent, score) in enumerate(zip(sents, sent_scores)):\n", 311 | " print(f\"Sentence {i+1}: {sent}\")\n", 312 | " print(f\"Score: {score:.4f}\")\n", 313 | " print(\"-\" * 50)\n", 314 | "\n", 315 | "print(\"Mean score of sentences:\", sent_scores.mean())\n", 316 | "print(\"Standard deviation of scores:\", sent_scores.std())" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "id": "f783bb5c", 322 | "metadata": {}, 323 | "source": [ 324 | "We can see that the model assigns the highest score to the sentence that is more relevant to the query. The mean score of the whole document is quite low, given that the document contains multiple sentences that are not strictly relevant to the query. This is a common issue with long documents, where an high portion of sentences may not be relevant at all. The standard deviation of the scores is also quite high. We should control for these statistical measures when performing the selection of both document and sentences." 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "id": "f4a133ed", 330 | "metadata": {}, 331 | "source": [ 332 | "### 2.2 Implementing Recursive Reranking\n", 333 | "\n", 334 | "We are all set up! We can now implement the main logic of our Recursive Reranking Technique. The approach consists of the following steps:\n", 335 | "\n", 336 | "1. Split each document into separate sentences, then use the Reranker Model to calculate a relevance score for each sentence with respect to the query;\n", 337 | "\n", 338 | "2. Each document is then given a score based on the mean of the highest $\\text{score}_N$ scores of its sentences (This is done because we could have long chunks that contain multiple non-relevant sentences, which can drag down the overall score of the document);\n", 339 | "\n", 340 | "3. Select the $\\text{top}_N$ documents based on their scores;\n", 341 | "\n", 342 | "4. For each selected document, we perform Contextual Compression by selecting only the most relevant sentences using a simple **Static Filter**. Specifically, for document $d$, we select all sentences whose score satisfies:\n", 343 | "$$\\text{score} \\geq \\mu_d + \\alpha \\cdot \\sigma_d$$\n", 344 | "where $\\mu_d$ is the mean and $\\sigma_d$ is the standard deviation of sentence scores in document $d$, and $\\alpha$ is a tunable hyperparameter controlling the strictness of the filter. \n", 345 | "\n", 346 | "We start by defining a function that will perform the initial retrieval step using the Chroma collection we created earlier:" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 33, 352 | "id": "fcf4a88d", 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "def retrieve(query : str, k : int) -> list[str]:\n", 357 | " \"\"\"\n", 358 | " Retrieve the top-k documents from the Chroma collection based on the query.\n", 359 | " \"\"\"\n", 360 | " results = collection.query(query_texts=[query], n_results=k)\n", 361 | " return results['documents'][0]" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "id": "ebc28c20", 367 | "metadata": {}, 368 | "source": [ 369 | "Let's now define the main hyperparamters and implement the `recursive_rerank` function, which takes a `query` as a string and `docs` as a list of strings, and returns a list of reranked and compressed documents." 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": null, 375 | "id": "a454a1b6", 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "import numpy as np\n", 380 | "\n", 381 | "SCORE_N = 2 # Number of top sentences to consider for document scoring\n", 382 | "TOP_N = 2 # Number of top documents to select\n", 383 | "ALPHA = 0.2 # Strength of Contextual Compression\n", 384 | "\n", 385 | "def recursive_rerank(query: str, \n", 386 | " docs: list[str],\n", 387 | " score_n : int = SCORE_N,\n", 388 | " top_n : int = TOP_N,\n", 389 | " alpha : float = ALPHA) -> list[str]:\n", 390 | " \"\"\"\n", 391 | " Perform recursive reranking and contextual compression of documents in a single pass.\n", 392 | " \"\"\"\n", 393 | "\n", 394 | " # Step 1: Calculate sentence scores\n", 395 | " all_sents = []\n", 396 | " sent_scores = []\n", 397 | " for doc in docs:\n", 398 | " # Split using SpaCy for better sentence segmentation\n", 399 | " sents = [sent.strip() for sent in doc.split('.') if sent.strip()]\n", 400 | " all_sents.append(sents)\n", 401 | " scores = rerank.predict([[query, sent] for sent in sents])\n", 402 | " sent_scores.append(scores)\n", 403 | "\n", 404 | " # Step 2: Calculate document scores based on top score_N sentence scores\n", 405 | " doc_scores = []\n", 406 | " for scores in sent_scores:\n", 407 | " indx = min(score_n, len(scores))\n", 408 | " sorted_scores = sorted(scores, reverse=True)[:indx]\n", 409 | " doc_score = sum(sorted_scores) / indx\n", 410 | " doc_scores.append(doc_score)\n", 411 | "\n", 412 | " # Step 3: Select top N documents\n", 413 | " # We will use document indices to save space\n", 414 | " top_docs_indices = \\\n", 415 | " sorted(\n", 416 | " range(len(doc_scores)), \n", 417 | " key=lambda i: doc_scores[i], \n", 418 | " reverse=True\n", 419 | " )[:min(top_n, len(doc_scores))]\n", 420 | " \n", 421 | " # Step 4: rerank and compress documents whose indices are in top_docs_indices\n", 422 | " filtered_docs = []\n", 423 | " for i in top_docs_indices:\n", 424 | " mean = np.mean(sent_scores[i])\n", 425 | " std_dev = np.std(sent_scores[i])\n", 426 | " filtered_sents = [sent for sent, score in zip(all_sents[i], sent_scores[i]) \n", 427 | " if score >= mean + alpha * std_dev]\n", 428 | " if filtered_sents:\n", 429 | " filtered_docs.append('.\\n'.join(filtered_sents) + '.')\n", 430 | "\n", 431 | " return filtered_docs" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "id": "8aac4aea", 437 | "metadata": {}, 438 | "source": [ 439 | "Let's finally test our implementation with an example query and see how it performs:" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 57, 445 | "id": "0ffa9c00", 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "The capital of Italy is Rome, which is also the largest city in the country.\n", 453 | "It is the capital city of Italy and has a population of almost 3 million people.\n", 454 | "--------------------------------------------------\n", 455 | "Italy, officially the Italian Republic, is a country in Southern and Western Europe.\n", 456 | "It is a country in Southern Europe with a population of approximately 60 million people.\n", 457 | "--------------------------------------------------\n" 458 | ] 459 | } 460 | ], 461 | "source": [ 462 | "retrieve_query = \"How many people live in Italy and what is the capital?\"\n", 463 | "docs = retrieve(retrieve_query, k=4)\n", 464 | "reranked_docs = recursive_rerank(retrieve_query, docs)\n", 465 | "for doc in reranked_docs:\n", 466 | " print(doc)\n", 467 | " print(\"-\" * 50)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "id": "5a32fecb", 473 | "metadata": {}, 474 | "source": [ 475 | "We get exactly what we want! The two reranked documents returned are the ones discussing Rome and the population of Italy, which are both aligned with the query. We also retained only the most relevant sentences from each document, effectively performing Contextual Compression. Let's test it once again with a different query:" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 75, 481 | "id": "9e889b5f", 482 | "metadata": {}, 483 | "outputs": [ 484 | { 485 | "name": "stdout", 486 | "output_type": "stream", 487 | "text": [ 488 | "The food in France is renowned worldwide, with dishes like coq au vin and ratatouille.\n", 489 | "France has a world class cuisine and is famous for its wine, cheese, and pastries.\n", 490 | "--------------------------------------------------\n", 491 | "Italy, officially the Italian Republic, is a country in Southern and Western Europe.\n", 492 | "It is a country in Southern Europe with a population of approximately 60 million people.\n", 493 | "--------------------------------------------------\n", 494 | "The capital of Italy is Rome, which is also the largest city in the country.\n", 495 | "It is the capital city of Italy and has a population of almost 3 million people.\n", 496 | "--------------------------------------------------\n" 497 | ] 498 | } 499 | ], 500 | "source": [ 501 | "retrieve_query = \"Does france have good food?\"\n", 502 | "docs = retrieve(retrieve_query, k=4)\n", 503 | "reranked_docs = recursive_rerank(retrieve_query, docs, alpha=0.2, top_n=3)\n", 504 | "for doc in reranked_docs:\n", 505 | " print(doc)\n", 506 | " print(\"-\" * 50)" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "id": "51eb71a5", 512 | "metadata": {}, 513 | "source": [ 514 | "This time, the highest ranked document is the one that talks about France and its food. We can also control the strictness of the Contextual Compression step by adjusting the `alpha` parameter. Let's try the same query with an higher `alpha` value:" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 76, 520 | "id": "25e07369", 521 | "metadata": {}, 522 | "outputs": [ 523 | { 524 | "name": "stdout", 525 | "output_type": "stream", 526 | "text": [ 527 | "The food in France is renowned worldwide, with dishes like coq au vin and ratatouille.\n", 528 | "France has a world class cuisine and is famous for its wine, cheese, and pastries.\n", 529 | "--------------------------------------------------\n", 530 | "It is a country in Southern Europe with a population of approximately 60 million people.\n", 531 | "--------------------------------------------------\n", 532 | "The capital of Italy is Rome, which is also the largest city in the country.\n", 533 | "--------------------------------------------------\n" 534 | ] 535 | } 536 | ], 537 | "source": [ 538 | "retrieve_query = \"Does france have good food?\"\n", 539 | "docs = retrieve(retrieve_query, k=4)\n", 540 | "reranked_docs = recursive_rerank(retrieve_query, docs, alpha=0.9, top_n=3)\n", 541 | "for doc in reranked_docs:\n", 542 | " print(doc)\n", 543 | " print(\"-\" * 50)" 544 | ] 545 | }, 546 | { 547 | "cell_type": "markdown", 548 | "id": "853c7df2", 549 | "metadata": {}, 550 | "source": [ 551 | "As you can see, The highest ranked document has still retained every sentence, but the other documents have been compressed to only one sentence." 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "id": "787cc9a7", 557 | "metadata": {}, 558 | "source": [ 559 | "### 3. Conclusion\n", 560 | "\n", 561 | "You can find a diagram of the Recursive Reranking Technique [At this link](https://raw.githubusercontent.com/LucaStrano/Experimental_RAG_Tech/refs/heads/main/images/2_compress_and_rerank_diagram.png).\n", 562 | "\n", 563 | "The Recursive Reranking Technique offers a powerful way to combine both Reranking and Contextual Compression in a single pass. This technique is particularly useful when dealing with noisy chunks and high retrieval hyperparameters. \n", 564 | "\n", 565 | "The main advantages of this approach include:\n", 566 | "\n", 567 | "- High efficiency, since it combines both Reranking and Contextual Compression in a single pass;\n", 568 | "\n", 569 | "- Lower latency, since it avoids the need of perfoming multiple LLM calls to compress the retrieved documents.\n", 570 | "\n", 571 | "This technique works best when paired with other chunking techniques such as **Semantic Chunking** or **Proposition Chunking**. The Recursive Reranking function could also be enhanced by using a better (but more computationally intesive) Reranker Model or by considering windows of sentences instead of single sentences. This would allow for a wider understanding of the context, especially with ambiguous sentences where entities aren't directly mentioned (e.g., _The capital of Italy is rome. It is a city containing..._).\n", 572 | "\n", 573 | "Thank you for reading this notebook! I hope you found it useful. If you have any questions or suggestions, feel free to send me an email at **strano.lucass@gmail.com** or send me a message on [LinkedIn](https://www.linkedin.com/in/strano-lucass/)." 574 | ] 575 | } 576 | ], 577 | "metadata": { 578 | "kernelspec": { 579 | "display_name": ".exp_rag_tech", 580 | "language": "python", 581 | "name": "python3" 582 | }, 583 | "language_info": { 584 | "codemirror_mode": { 585 | "name": "ipython", 586 | "version": 3 587 | }, 588 | "file_extension": ".py", 589 | "mimetype": "text/x-python", 590 | "name": "python", 591 | "nbconvert_exporter": "python", 592 | "pygments_lexer": "ipython3", 593 | "version": "3.12.2" 594 | } 595 | }, 596 | "nbformat": 4, 597 | "nbformat_minor": 5 598 | } 599 | -------------------------------------------------------------------------------- /experimental_tech/1_estimating_k.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "43f93758", 6 | "metadata": {}, 7 | "source": [ 8 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LucaStrano/Experimental_RAG_Tech/blob/main/experimental_tech/1_estimating_k.ipynb)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "20250bd1", 14 | "metadata": {}, 15 | "source": [ 16 | "## Dynamically Estimating the K (Hyper)parameter using Query Complexity Score (QCS)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "8cf2797d", 22 | "metadata": {}, 23 | "source": [ 24 | "### Overview\n", 25 | "\n", 26 | "This notebook provides a thorough explanation on how to dynamically estimate the **K parameter** (number of docs to retrieve) during the retrieval phase using the **Query Complexity Score** (QCS). We will go over The **Intuition** behind the Query Complexity Score, **Alternative Solutions** to the problem, the **Implementation** of the QCS function and finally a short **Conclusion**." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "f8e6ffa8", 32 | "metadata": {}, 33 | "source": [ 34 | "### 1. Intuition\n", 35 | "\n", 36 | "When building a comprehensive RAG system, one of the most important hyperparameters to tune is **K**, which represents the number of documents (before reranking, if any) to retrieve for each query. The choice of this hyperparameter can significantly impact the performances of the system, especially in terms of **precision** and **answer feasibility**:\n", 37 | "\n", 38 | " - A **low K** will lead to _higher precision_ (since the number of retrieved documents is low), but may result in missing relevant information, which leads to incomplete answers;\n", 39 | " \n", 40 | " - A **high K** will lead lo _lower precision_, but may provide more relevant information, which leads to more complete answers.\n", 41 | "\n", 42 | "Unfortunately, **not all queries are created equal**: Some queries are more complex than others, and therefore may require more documents to be retrieved (higher K) in order to provide a complete answer. Other queries may be simpler and therefore require less documents (lower K) to achieve a satisfactory answer. For example, the query \"_What's the capital of Italy?_\" is trivial and can be answered by retrieving a single chunk; on the other hand, the query \"_What is the capital of Italy and what can i visit There? What are the main attractions?_\" is more complex and may require to fetch multiple chunks in order to provide an answer.\n", 43 | "\n", 44 | "Given the wide variety of queries a RAG system can receive, it’s intuitive to prefer a **dynamic K**, making it no longer an _hyperparameter_ of the system. What if we could assign each query a score that (roughly) reflects its complexity, and use this score to estimate how many documents to retrieve? Ideally, complex queries should lead higher scores, which in turn lead to higher K values, allowing the system to retrieve more documents from the Knowledge Base. This is the intution behind the **Query Complexity Score** (QCS), which we will explore in detail in the following chapters." 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "64b746c9", 50 | "metadata": {}, 51 | "source": [ 52 | "### 2. Alternative Solutions" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "5da88cbc", 58 | "metadata": {}, 59 | "source": [ 60 | "#### 2.1 High Static K\n", 61 | "\n", 62 | "The simplest approach to _\"solve\"_ this problem is setting an high, fixed K value. If we already know the expected range of query complexity, We could fix a K value that is high enough to provide a satisfactory answer for even the most complex queries.\n", 63 | "As already discussed, this approach has the drawback of leading to lower precision as well as higher costs and latency even for the simplest queries.\n", 64 | "\n", 65 | "> Please note that these problems could be somewhat mitigated by using a **reranker** model on the retrieved chunks, but this notebook focuses on another approach entirely." 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "080c7be1", 71 | "metadata": {}, 72 | "source": [ 73 | "#### 2.2 Training a Model to Estimate K for potential queries\n", 74 | "\n", 75 | "Another plausible approach is to train a model to predict the value of K for a given query. This model can be trained on a dataset of queries and their corresponding K values, which can be obtained by generating syntethic queries of varying complexity and manually annotating them with an appropiate value of K. This approach works well if the training dataset is highly representative of the queries that the system is going to receive. However, it has the drawback of requiring a good dataset (which is difficult to obtain, especially with syntethic queries) and a good model that is able to generalize well to unseen queries. This approach is also costly, time-consuming, and requires effort to maintain the model up-to-date in a system that is constantly evolving.\n", 76 | "\n", 77 | "> For a thorough implementation of this approach, you can read this [Medium Article](https://medium.com/@sauravjoshi23/optimizing-retrieval-augmentation-with-dynamic-top-k-tuning-for-efficient-question-answering-11961503d4ae) by Saurav Joshi." 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "c8da6c76", 83 | "metadata": {}, 84 | "source": [ 85 | "#### 2.3 Ask an LLM\n", 86 | "\n", 87 | "Of course, the **JALM** (Just Ask a Language Model) approach is always an option and, in most cases, the best one in terms of quality. a smart-enough LLM could be able to estimate the K value for a given query based on its complexity and, optionally, the context of the system.\n", 88 | "Another approach relies on the use of **Query Composition**: Starting from the original query, the LLM is (_kindly_) asked to generate a set of small, atomic sub-queries (addionally each with an associated K value) that reflects the decomposition of the original query into smaller, more manageable parts. The retrieval results for each sub-query is then merged using techniques like **Reciprocal Rank Fusion**. Alternatively, each sub-query could be answered separately and the sub-results merged by an LLM to achieve a final answer. This approach is paricularly useful when the original query is too complex to be answered in a single step, but it requires a quality LLM to achieve consistent results and adds complexity and latency to the system." 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "846471da", 94 | "metadata": {}, 95 | "source": [ 96 | "\n", 97 | "### 3. Query Complexity Score\n", 98 | "\n", 99 | "To estimate the **Complexity Score** for a given query, we will use different heuristics that capture a variety of aspects that reflect complexity. The following components will be considered:\n", 100 | "\n", 101 | "- **Length of the query**: The longer the query (in terms of _tokens_), the more complex it is likely to be;\n", 102 | "\n", 103 | "- **Number of unique entities in the query**: The more unique entities are mentioned in the query, the more complex it is likely to be;\n", 104 | "\n", 105 | "- **Number of different sentences and conjunctions in the query**: The more sentences and _relevant_ conjunctions (that connect two separate, meaningful clauses) are used in the query, the more complex it is likely to be;\n", 106 | "\n", 107 | "We will then normalize these components and combine them into a single score by using a **Weighted Mean** function, which will allow us to assign different weights to each heuristic based on their importance in the context of the system. The final score will be normalized to a range between 0 and 1, where a low value represents a trivial query and a high value reflects a highly complex query:\n", 108 | "\n", 109 | "\\begin{aligned}\n", 110 | "QCS(q) =\\ & w_{\\text{len}} \\cdot \\text{norm}(\\text{len}(q)) \\\\\n", 111 | " & +\\ w_{\\text{cc}} \\cdot \\text{norm}(\\#cc(q)) \\\\\n", 112 | " & +\\ w_{\\text{sent}} \\cdot \\text{norm}(\\#sent(q)) \\\\\n", 113 | " & +\\ w_{\\text{ent}} \\cdot \\text{norm}(\\#ent(q))\n", 114 | "\\end{aligned}\n", 115 | "\n", 116 | "With $QCS(q), w_{x} \\in (0,1)$, $\\sum_{x} w_x = 1$." 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "id": "24b8e20a", 122 | "metadata": {}, 123 | "source": [ 124 | "#### 3.1 Preliminaries to QCS\n", 125 | "\n", 126 | "To estimate the components discussed earlier, we will take advantage of the [**SpaCy** library](https://spacy.io/), which offers powerful and efficient NLP pipelines to perform operations such as _Dependency Parsing_ and _Named Entity Recognition_. Specifically, we'll use the **en_core_web_sm** pipeline, which is blazingly fast, lightweight (~12 MB of disk space) and provides a good balance between performance and accuracy for our use case.\n", 127 | "\n", 128 | "Here is what we'll need:" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "90d37fd4", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "%conda install --yes spacy" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "31700ef9", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Collecting en-core-web-sm==3.8.0\n", 152 | " Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)\n", 153 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.8/12.8 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", 154 | "\u001b[?25hInstalling collected packages: en-core-web-sm\n", 155 | "Successfully installed en-core-web-sm-3.8.0\n", 156 | "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n", 157 | "You can now load the package via spacy.load('en_core_web_sm')\n", 158 | "\u001b[38;5;3m⚠ Restart to reload dependencies\u001b[0m\n", 159 | "If you are in a Jupyter or Colab notebook, you may need to restart Python in\n", 160 | "order to load all the package's dependencies. You can do this by selecting the\n", 161 | "'Restart kernel' or 'Restart runtime' option.\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "import spacy\n", 167 | "\n", 168 | "try:\n", 169 | "\tnlp = spacy.load(\"en_core_web_sm\")\n", 170 | "except IOError:\n", 171 | "\tfrom spacy.cli.download import download\n", 172 | "\tdownload(\"en_core_web_sm\")\n", 173 | "\tnlp = spacy.load(\"en_core_web_sm\")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "e25e5a0d", 179 | "metadata": {}, 180 | "source": [ 181 | "Now that we've loaded the SpaCy models, we can try them out on a test query:" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "1b2e2cc4", 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "Query: Rome is the capital of Italy. Italy is known for its rich history, art (e.g. David by Michelangelo), and culture.\n", 195 | "********************\n", 196 | "Number of tokens: 26\n", 197 | "--------------------\n", 198 | "Number of sentences: 2\n", 199 | " Sentence 2 contains CONJUCTION ('and')\n", 200 | "--------------------\n", 201 | "Number of entities: 5\n", 202 | " 1. Entity: Rome, Label: GPE\n", 203 | " 2. Entity: Italy, Label: GPE\n", 204 | " 3. Entity: Italy, Label: GPE\n", 205 | " 4. Entity: David, Label: PERSON\n", 206 | " 5. Entity: Michelangelo, Label: PERSON\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "test = \"Rome is the capital of Italy. Italy is known for its rich history, art (e.g. David by Michelangelo), and culture.\"\n", 212 | "doc = nlp(test)\n", 213 | "\n", 214 | "print(f\"Query: {test}\")\n", 215 | "print(\"*\"*20)\n", 216 | "print(f\"Number of tokens: {len(doc)}\")\n", 217 | "print(\"-\"*20)\n", 218 | "print(f\"Number of sentences: {len(list(doc.sents))}\")\n", 219 | "for i, sent in enumerate(doc.sents):\n", 220 | " if \"and\" in sent.text:\n", 221 | " print(f\" Sentence {i+1} contains CONJUCTION ('and')\")\n", 222 | "print(\"-\"*20)\n", 223 | "print(f\"Number of entities: {len(doc.ents)}\")\n", 224 | "for i, ent in enumerate(doc.ents):\n", 225 | " print(f\" {i+1}. Entity: {ent.text}, Label: {ent.label_}\")" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "id": "40fe4df8", 231 | "metadata": {}, 232 | "source": [ 233 | "We have exactly what we need. Before we proceed with the implementation, let's go over some observations:\n", 234 | "\n", 235 | "- The **query length** is a great indicator of query complexity, But it is not enough on its own: for example, the query \"_What's the capital of the country next to France_?\" isn't long, but its ambiguous and requires some reasoning to provide an answer. Another point to consider is that we are using _number of tokens_ as a measure of length instead of _number of characters_. This is because the QCS is going to be a _weighted average_ of the different components, and considering number of tokens as a length measure allows us to have a slightly more consistent measure across queries;\n", 236 | " \n", 237 | "- We are using the number of **Distinct** entities in the query, not the total predicted number of entities. This is because we could have queries such as \"Where's Italy and what's the capital of Italy?\" that have the same entity mentioned multiple times, but since we are also considering the number of distinct clauses, this repetition shouldn't matter while calculating the QCS;\n", 238 | " \n", 239 | "- **Sentence Segmentation** is done smartly: In the example query, we have the \"_e.g. David by Michelangelo_\" part that, while containing dots, isn't considered as separate sentences;\n", 240 | " \n", 241 | "- Estimating the number of **Conjunctions** in a query is the trickiest part. We can't just count how many times the token \"_and_\" appears in the query. For example, the query \"What are the Q3 earnings of Johnson and Johnson?\" contains a conjunction, but it's part of a company name, so it shouldn't add complexity. What we ultimately want to count is the number of **Coordinating Conjunctions** (CCs) that connect main clauses in the query. We can use **Dependency Parsing** to achieve this.\n", 242 | "\n", 243 | "> Please note that to calculate the QCS we are only considering the \"and\" token, but the set of coordinating conjunctions for the english language consists of the _FANBOYS_ conjunctions: **For**, **And**, **Nor**, **But**, **Or**, **Yet**, **So**. This is a matter of preference that can be adjusted based on the specific use case." 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "id": "5020e9c8", 249 | "metadata": {}, 250 | "source": [ 251 | "#### 3.2 Implementing QCS\n", 252 | "\n", 253 | "Now that we have a clear understanding of the intuition behind QCS, we can begin implementing the main function and its components. We’ll start by building the function that calculates the number of **Relevant Coordinating Conjunctions** (CCs) in a query." 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "b66f649f", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "from spacy.tokens import Doc\n", 264 | "\n", 265 | "def count_ccs(doc : Doc) -> float:\n", 266 | " \"\"\"\n", 267 | " Count the number of relevant Coordinating Conjunctions (CCs) in the phrase.\n", 268 | " \"\"\"\n", 269 | " cc_count = 0.0\n", 270 | " for token in doc:\n", 271 | " if token.text.lower() == \"and\" and token.dep_ == \"cc\":\n", 272 | " head = token.head\n", 273 | "\n", 274 | " # CASE 1. Check if 'and' connects two verbal phrases\n", 275 | " # head verb could be AUX, so we check for root as well\n", 276 | " if head.pos_ == \"VERB\" or head.dep_ == \"ROOT\":\n", 277 | " if any(child.dep_ == \"conj\" and child.pos_ == \"VERB\"\\\n", 278 | " for child in head.children):\n", 279 | " cc_count += 1\n", 280 | " \n", 281 | " # CASE 2. Check if 'and' has a question as a conjunct\n", 282 | " # search for \"Wh-\" words in the CC subtree\n", 283 | " elif any(child.dep_ == \"conj\" and \\\n", 284 | " any(t.tag_ == \"WRB\" or t.tag_ == \"WP\" for t in child.subtree) \\\n", 285 | " for child in head.children):\n", 286 | " cc_count += 1\n", 287 | "\n", 288 | " return cc_count" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "id": "b1b07a03", 294 | "metadata": {}, 295 | "source": [ 296 | "The `count_ccs` function looks for two main patterns in the query to identify relevant CCs:\n", 297 | "\n", 298 | "1. If the head of the token is a verb or a root, it checks if there are any **verb conjucts**. These tipically appear in the second clause (after the \"and\") of the query;\n", 299 | "2. If the first condition is not met, it checks wether the clause contains \"**Wh-**\" words. Their presence usually indicates that the second part of the query is a separate, complete cause (most likely a separate question).\n", 300 | "\n", 301 | "Let's test the `count_ccs` function on some example queries:" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 121, 307 | "id": "63e0c334", 308 | "metadata": {}, 309 | "outputs": [ 310 | { 311 | "name": "stdout", 312 | "output_type": "stream", 313 | "text": [ 314 | "✅ All tests passed! :)\n" 315 | ] 316 | } 317 | ], 318 | "source": [ 319 | "def test_count_css():\n", 320 | " # This should be 0, as the second clause doesn't contribute to the question\n", 321 | " assert count_ccs(nlp(\"I would like to visit Rome and Italy and I don't know which one to choose?\")) == 0.0\n", 322 | " # This should be 1, as the second clause is a separate question\n", 323 | " assert count_ccs(nlp(\"What is the most important dish in Italy and how is it prepared traditionally?\")) == 1.0\n", 324 | " # This should be 2, the first \"and\" connects two nouns\n", 325 | " assert count_ccs(nlp(\"What are the Q3 earnings of Johnson and Johnson and how do they compare to the previous quarter? Also what is the highest quarterly earning and what increased sales the most?\")) == 2.0\n", 326 | "\n", 327 | "test_count_css()\n", 328 | "print(\"✅ All tests passed! :)\")" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "id": "eca36ee1", 334 | "metadata": {}, 335 | "source": [ 336 | "The `count_css` works as expected! Before we proceed with the implementation of the main function, let's define some **Normalization Constants**. These are mandatory since we are using arbitrary values without a set range to calculate the QCS, and we want to make sure that the final QCS is normalized between 0 and 1. These constants can be tuned based on the expected query statistics, but for this example we will use some reasonable values:" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 143, 342 | "id": "6b1b4e82", 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "MAX_LEN = 60 # max token length for a query\n", 347 | "MAX_CC = 2 # max relevant conjunctions expected\n", 348 | "MAX_SENT = 3 # max sentences expected in a query\n", 349 | "MAX_ENT = 4 # max distinct entities expected\n", 350 | "\n", 351 | "MIN_K = 1 # minimum value for K\n", 352 | "MAX_K = 8 # maximum value for K" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "b9a5d30d", 358 | "metadata": {}, 359 | "source": [ 360 | "Let's now implement the main `calculate_qcs` function, which takes a `query`string and four floats, `[len_w, cc_w, sent_w, ent_w]`, which represent the **weight** for each corresponding component (sentence length, number of CCs, number of separate sentences, number of entities). This function returns the Query Complexity Score as a `float` in the range (0,1)." 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "id": "2e070d25", 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "def calculate_qcs(query : str,\n", 371 | " len_w : float = .3,\n", 372 | " cc_w : float = .2, \n", 373 | " sent_w : float = .3, \n", 374 | " ent_w : float = .2,) -> float:\n", 375 | " \"\"\"\n", 376 | " Calculate the Query Complexity Score (QCS) for a given query.\n", 377 | " \"\"\"\n", 378 | " if len_w + cc_w + sent_w + ent_w != 1.0:\n", 379 | " raise ValueError(\"Weights must sum to 1.0\")\n", 380 | "\n", 381 | " doc = nlp(query)\n", 382 | "\n", 383 | " # Calculate each component\n", 384 | " len_count = len(doc)\n", 385 | " cc_count = count_ccs(doc)\n", 386 | " sentence_count = len(list(doc.sents))\n", 387 | " entity_count = len(set([ent.text for ent in doc.ents]))\n", 388 | "\n", 389 | " # Normalize each component\n", 390 | " norm_len = min(len_count / MAX_LEN, 1.0)\n", 391 | " norm_cc = min(cc_count / MAX_CC, 1.0)\n", 392 | " norm_sent = min(sentence_count / MAX_SENT, 1.0)\n", 393 | " norm_ent = min(entity_count / MAX_ENT, 1.0)\n", 394 | "\n", 395 | " # Return weighted sum\n", 396 | " return len_w * norm_len + \\\n", 397 | " cc_w * norm_cc + \\\n", 398 | " sent_w * norm_sent + \\\n", 399 | " ent_w * norm_ent" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "id": "8b68a5ea", 405 | "metadata": {}, 406 | "source": [ 407 | "Great! Let's now test this function with some example queries of increasing complexity:" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 144, 413 | "id": "af2a8e62", 414 | "metadata": {}, 415 | "outputs": [ 416 | { 417 | "name": "stdout", 418 | "output_type": "stream", 419 | "text": [ 420 | "0.16999999999999998\n", 421 | "0.20999999999999996\n", 422 | "0.315\n", 423 | "0.715\n" 424 | ] 425 | } 426 | ], 427 | "source": [ 428 | "print(calculate_qcs(\"Capital of Italy?\"))\n", 429 | "print(calculate_qcs(\"What's the capital of Italy and how big is it?\"))\n", 430 | "print(calculate_qcs(\"What's an important dish in Italy and how is it prepared?\"))\n", 431 | "print(calculate_qcs(\"I'm an exchange student and i just got here in Italy. What can you tell me about the italian culture, and what famous dishes can i eat in Rome?\"))" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "id": "677b5e25", 437 | "metadata": {}, 438 | "source": [ 439 | "As you can see, the `calculate_qcs` functions is correctly estimating a QCS for each query, with complex queries achieving higher scores. Now, estimating the K value is trivial: we can use a linear function that maps the QCS to a range of K values." 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 126, 445 | "id": "bd726aad", 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "from math import floor, ceil # the ceil function could alternatively be used here\n", 450 | "\n", 451 | "def estimate_k(query : str,\n", 452 | " min_k: int = MIN_K,\n", 453 | " max_k: int = MAX_K) -> int:\n", 454 | " \"\"\"\n", 455 | " Estimate the K value based on the QCS.\n", 456 | " \"\"\"\n", 457 | " return floor(min_k + (max_k - min_k) * calculate_qcs(query))" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "id": "b8143399", 463 | "metadata": {}, 464 | "source": [ 465 | "Let's finally test the `estimate_k` function with the same queries from before:" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 146, 471 | "id": "3d5d0ab8", 472 | "metadata": {}, 473 | "outputs": [ 474 | { 475 | "name": "stdout", 476 | "output_type": "stream", 477 | "text": [ 478 | "2\n", 479 | "2\n", 480 | "3\n", 481 | "6\n" 482 | ] 483 | } 484 | ], 485 | "source": [ 486 | "print(estimate_k(\"Capital of Italy?\"))\n", 487 | "print(estimate_k(\"What's the capital of Italy and how big is it?\"))\n", 488 | "print(estimate_k(\"What's an important dish in Italy and how is it prepared?\"))\n", 489 | "print(estimate_k(\"I'm an exchange student and i just got here in Italy. What can you tell me about the italian culture, and what famous dishes can i eat in Rome?\"))" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "id": "e2c7783b", 495 | "metadata": {}, 496 | "source": [ 497 | "We can observe that the `estimate_k` function is correctly estimating the K value for each query, with more complex queries having higher K values. This is exactly what we wanted to achieve!" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "id": "1dfc033a", 503 | "metadata": {}, 504 | "source": [ 505 | "### 4. Conclusion" 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "id": "dfe9349b", 511 | "metadata": {}, 512 | "source": [ 513 | "The following diagram presents a high level overview of the main steps of the QCS approach:" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "id": "70cfecd4", 519 | "metadata": {}, 520 | "source": [ 521 | "![QCS Diagram](https://raw.githubusercontent.com/LucaStrano/Experimental_RAG_Tech/refs/heads/main/images/1_estimating_k_diagram.png)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "id": "dd5cf5b5", 527 | "metadata": {}, 528 | "source": [ 529 | "Overall, the **Query Complexity Score** offers a fast, lightweight and pratical way to dynamically estimate the complexity of a query and, subsequently, the K value to optimize the retrieval phase. Key advantages of this approach include:\n", 530 | "\n", 531 | "- Not requiring any training or fine-tuning;\n", 532 | " \n", 533 | "- Easily adaptable to different contexts by adjusting normalization constants and component weights;\n", 534 | " \n", 535 | "- Minmal overhead, as it relies on simple heuristics and the SpaCy library, which is optimized for performance.\n", 536 | "\n", 537 | "This approach also comes with some limitations, such as:\n", 538 | "\n", 539 | "- The considered heuristics may not be sufficient to capture the full complexity of a query;\n", 540 | "\n", 541 | "- The QCS function may not generalize well to every context, especially if the queries contain domain-specific terms of formatting that is poorly recognized by the SpaCy models;\n", 542 | "\n", 543 | "- Each language might require a different set of heuristics to consider.\n", 544 | "\n", 545 | "To get the most out of this technique, it is recommended to use it in conjunction with other techniques such as **Query Decomposition** by following this approach:\n", 546 | "\n", 547 | "1. Decompose the original query into smaller, manageable sub-queries using an LLM;\n", 548 | " \n", 549 | "2. Estimate the QCS and K value for each sub-query;\n", 550 | " \n", 551 | "3. Retrieve the documents for each sub-query using the estimated K value;\n", 552 | " \n", 553 | "4. Merge the retrieved documents using techniques such as **Reciprocal Rank Fusion** or answer each sub-query separately and merge the sub-results.\n", 554 | "\n", 555 | "The QCS function could also be further improved by adding more components, such as presence of **quantifiers** (e.g. \"all\", \"some\", \"most\") or **negations** (e.g. \"not\", \"never\"), which can also affect the complexity of the query. Additionally, a small-scale empirical evaluation could be performed to compare the gains and losses of this approach with a fixed K value approach. \n", 556 | "\n", 557 | "Thank you for reading this notebook! I hope you found it useful. If you have any questions or suggestions, feel free to send me an email at **strano.lucass@gmail.com** or send me a message on [LinkedIn](https://www.linkedin.com/in/strano-lucass/)." 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "id": "476c6575", 563 | "metadata": {}, 564 | "source": [ 565 | "#" 566 | ] 567 | } 568 | ], 569 | "metadata": { 570 | "kernelspec": { 571 | "display_name": ".exp_rag_tech", 572 | "language": "python", 573 | "name": "python3" 574 | }, 575 | "language_info": { 576 | "codemirror_mode": { 577 | "name": "ipython", 578 | "version": 3 579 | }, 580 | "file_extension": ".py", 581 | "mimetype": "text/x-python", 582 | "name": "python", 583 | "nbconvert_exporter": "python", 584 | "pygments_lexer": "ipython3", 585 | "version": "3.12.11" 586 | } 587 | }, 588 | "nbformat": 4, 589 | "nbformat_minor": 5 590 | } 591 | --------------------------------------------------------------------------------