├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── RAGatouille.png ├── README.md ├── docs ├── api.md ├── index.md └── roadmap.md ├── examples ├── 01-basic_indexing_and_search.ipynb ├── 02-basic_training.ipynb ├── 03-finetuning_without_annotations_with_instructor_and_RAGatouille.ipynb ├── 04-reranking.ipynb ├── 05-llama_hub.ipynb ├── 06-index_free_use.ipynb └── data │ └── llama2.pdf ├── mkdocs.yml ├── pyproject.toml ├── ragatouille ├── RAGPretrainedModel.py ├── RAGTrainer.py ├── __init__.py ├── data │ ├── __init__.py │ ├── corpus_processor.py │ ├── preprocessors.py │ └── training_data_processor.py ├── integrations │ ├── __init__.py │ └── _langchain.py ├── models │ ├── __init__.py │ ├── base.py │ ├── colbert.py │ ├── index.py │ ├── torch_kmeans.py │ └── utils.py ├── negative_miners │ ├── __init__.py │ ├── base.py │ └── simpleminer.py └── utils.py ├── requirements-doc.txt └── tests ├── __init__.py ├── data ├── Studio_Ghibli_wikipedia.txt ├── Toei_Animation_wikipedia.txt └── miyazaki_wikipedia.txt ├── e2e └── test_e2e_indexing_searching.py ├── test_pretrained_loading.py ├── test_pretrained_optional_args.py ├── test_trainer_loading.py ├── test_training.py ├── test_training_data_loading.py └── test_training_data_processor.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: bclavie -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | .pdm.toml 87 | __pypackages__/ 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | .envrc 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | .ragatouille 103 | 104 | # mypy 105 | .mypy_cache/ 106 | .dmypy.json 107 | dmypy.json 108 | 109 | # Pyre type checker 110 | .pyre/ 111 | 112 | # pytype static type analyzer 113 | .pytype/ 114 | 115 | # Cython debug symbols 116 | cython_debug/ 117 | 118 | # data files 119 | *.tsv 120 | *.jsonl 121 | 122 | .mypy.ipynb_checkpoints 123 | .mkdocs.yml 124 | 125 | 126 | archive/ 127 | 128 | */.ragatouille 129 | 130 | local/ 131 | 132 | .vscode/ 133 | 134 | .devcontainer/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /RAGatouille.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/RAGatouille/e75b8a964a870dea886548f78da1900804749040/RAGatouille.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome to RAGatouille 2 | 3 | _Easily use and train state of the art retrieval methods in any RAG pipeline. Designed for modularity and ease-of-use, backed by research._ 4 | 5 | [![GitHub stars](https://img.shields.io/github/stars/bclavie/ragatouille.svg)](https://github.com/bclavie/ragatouille/stargazers) 6 | ![Python Versions](https://img.shields.io/badge/Python-3.9_3.10_3.11-blue) 7 | [![Downloads](https://static.pepy.tech/badge/ragatouille/month)](https://pepy.tech/project/ragatouille) 8 | [![Documentation](https://img.shields.io/badge/docs-available-brightgreen)](https://ben.clavie.eu/ragatouille/) 9 | [![Twitter Follow](https://img.shields.io/twitter/follow/bclavie?style=social)](https://twitter.com/bclavie) 10 | 11 |

The RAGatouille logo, it's a cheerful rat on his laptop (branded with a slightly eaten piece of cheese) and a pile of books he's looking for information in.

12 | 13 | --- 14 | 15 | The main motivation of RAGatouille is simple: bridging the gap between state-of-the-art research and alchemical RAG pipeline practices. RAG is complex, and there are many moving parts. To get the best performance, you need to optimise for many components: among them, a very important one is the models you use for retrieval. 16 | 17 | Dense retrieval, i.e. using embeddings such as OpenAI's `text-ada-002`, is a good baseline, but there's a lot of research [showing dense embeddings might not be the](https://arxiv.org/abs/2104.08663) [best fit for **your** usecase](https://arxiv.org/abs/2204.11447). 18 | 19 | The Information Retrieval research field has recently been booming, and models like ColBERT have been shown to [generalise better](https://arxiv.org/abs/2203.10053) [to new or complex domains](https://aclanthology.org/2022.findings-emnlp.78/) [than dense embeddings](https://arxiv.org/abs/2205.02870), are [ridiculously data-efficient](https://arxiv.org/abs/2309.06131) and are even [better suited to efficiently being trained](https://arxiv.org/abs/2312.09508) [on non-English languages with low amount of data](https://arxiv.org/abs/2312.16144)! Unfortunately, most of those new approaches aren't very well known, and are much harder to use than dense embeddings. 20 | 21 | This is where __RAGatouille__ comes in: RAGatouille's purpose is to bridge this gap: make it easy to use state-of-the-art methods in your RAG pipeline, without having to worry about the details or the years of literature! At the moment, RAGatouille focuses on making ColBERT simple to use. If you want to check out what's coming next, you can check out our [broad roadmap](https://ben.clavie.eu/ragatouille/roadmap)! 22 | 23 | _If you want to read more about the motivations, philosophy, and why the late-interaction approach used by ColBERT works so well, check out the [introduction in the docs](https://ben.clavie.eu/ragatouille/)._ 24 | 25 | 26 | 27 | 28 | Want to give it a try? Nothing easier, just run `pip install ragatouille` and you're good to go! 29 | 30 | ⚠️ Running notes/requirements: ⚠️ 31 | 32 | - If running inside a script, you must run it inside `if __name__ == "__main__"` 33 | - Windows is not supported. RAGatouille doesn't appear to work outside WSL and has issues with WSL1. Some users have had success running RAGatouille in WSL2. 34 | 35 | ## Get Started 36 | 37 | RAGatouille makes it as simple as can be to use ColBERT! We want the library to work on two levels: 38 | 39 | - Strong, but parameterizable defaults: you should be able to get started with just a few lines of code and still leverage the full power of ColBERT, and you should be able to tweak any relevant parameter if you need to! 40 | - Powerful yet simple re-usable components under-the-hood: any part of the library should be usable stand-alone. You can use our DataProcessor or our negative miners outside of `RAGPretrainedModel` and `RagTrainer`, and you can even write your own negative miner and use it in the pipeline if you want to! 41 | 42 | 43 | In this section, we'll quickly walk you through the three core aspects of RAGatouille: 44 | 45 | - [🚀 Training and Fine-Tuning ColBERT models](#-training-and-fine-tuning) 46 | - [🗄️ Embedding and Indexing Documents](#%EF%B8%8F-indexing) 47 | - [🔎 Retrieving documents](#-retrieving-documents) 48 | 49 | ➡️ If you want just want to see fully functional code examples, head over to the [examples](https://github.com/bclavie/RAGatouille/tree/main/examples)⬅️ 50 | 51 | ### 🚀 Training and fine-tuning 52 | 53 | _If you're just prototyping, you don't need to train your own model! While finetuning can be useful, one of the strength of ColBERT is that the pretrained models are particularly good at generalisation, and [ColBERTv2](https://huggingface.co/colbert-ir/colbertv2.0) has [repeatedly been shown to be extremely strong](https://arxiv.org/abs/2303.00807) at zero-shot retrieval in new domains!_ 54 | 55 | #### Data Processing 56 | 57 | RAGatouille's RAGTrainer has a built-in `TrainingDataProcessor`, which can take most forms of retrieval training data, and automatically convert it to training triplets, with data enhancements. The pipeline works as follows: 58 | 59 | - Accepts pairs, labelled pairs and various forms of triplets as inputs (strings or list of strings) -- transparently! 60 | - Automatically remove all duplicates and maps all positives/negatives to their respective query. 61 | - By default, mine hard negatives: this means generating negatives that are hard to distinguish from positives, and that are therefore more useful for training. 62 | 63 | This is all handled by `RAGTrainer.prepare_training_data()`, and is as easy as doing passing your data to it: 64 | 65 | ```python 66 | from ragatouille import RAGTrainer 67 | 68 | my_data = [ 69 | ("What is the meaning of life ?", "The meaning of life is 42"), 70 | ("What is Neural Search?", "Neural Search is a terms referring to a family of ..."), 71 | ... 72 | ] # Unlabelled pairs here 73 | trainer = RAGTrainer() 74 | trainer.prepare_training_data(raw_data=my_data) 75 | ``` 76 | 77 | ColBERT prefers to store processed training data on-file, which also makes easier to properly version training data via `wandb` or `dvc`. By default, it will write to `./data/`, but you can override this by passing a `data_out_path` argument to `prepare_training_data()`. 78 | 79 | Just like all things in RAGatouille, `prepare_training_data` uses strong defaults, but is also fully parameterizable. 80 | 81 | 82 | #### Running the Training/Fine-Tuning 83 | 84 | Training and Fine-Tuning follow the exact same process. When you instantiate `RAGTrainer`, you must pass it a `pretrained_model_name`. If this pretrained model is a ColBERT instance, the trainer will be in fine-tuning mode, if it's another kind of transformer, it will be in training mode to begin training a new ColBERT initialised from the model's weights! 85 | 86 | 87 | ```python 88 | from ragatouille import RAGTrainer 89 | from ragatouille.utils import get_wikipedia_page 90 | 91 | pairs = [ 92 | ("What is the meaning of life ?", "The meaning of life is 42"), 93 | ("What is Neural Search?", "Neural Search is a terms referring to a family of ..."), 94 | # You need many more pairs to train! Check the examples for more details! 95 | ... 96 | ] 97 | 98 | my_full_corpus = [get_wikipedia_page("Hayao_Miyazaki"), get_wikipedia_page("Studio_Ghibli")] 99 | 100 | 101 | trainer = RAGTrainer(model_name = "MyFineTunedColBERT", 102 | pretrained_model_name = "colbert-ir/colbertv2.0") # In this example, we run fine-tuning 103 | 104 | # This step handles all the data processing, check the examples for more details! 105 | trainer.prepare_training_data(raw_data=pairs, 106 | data_out_path="./data/", 107 | all_documents=my_full_corpus) 108 | 109 | trainer.train(batch_size=32) # Train with the default hyperparams 110 | ``` 111 | 112 | When you run `train()`, it'll by default inherit its parent ColBERT hyperparameters if fine-tuning, or use the default training parameters if training a new ColBERT. Feel free to modify them as you see fit (check the example and API reference for more details!) 113 | 114 | 115 | ### 🗄️ Indexing 116 | 117 | To create an index, you'll need to load a trained model, this can be one of your own or a pretrained one from the hub! Creating an index with the default configuration is just a few lines of code: 118 | 119 | ```python 120 | from ragatouille import RAGPretrainedModel 121 | from ragatouille.utils import get_wikipedia_page 122 | 123 | RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") 124 | my_documents = [get_wikipedia_page("Hayao_Miyazaki"), get_wikipedia_page("Studio_Ghibli")] 125 | index_path = RAG.index(index_name="my_index", collection=my_documents) 126 | ``` 127 | You can also optionally add document IDs or document metadata when creating the index: 128 | 129 | ```python 130 | document_ids = ["miyazaki", "ghibli"] 131 | document_metadatas = [ 132 | {"entity": "person", "source": "wikipedia"}, 133 | {"entity": "organisation", "source": "wikipedia"}, 134 | ] 135 | index_path = RAG.index( 136 | index_name="my_index_with_ids_and_metadata", 137 | collection=my_documents, 138 | document_ids=document_ids, 139 | document_metadatas=document_metadatas, 140 | ) 141 | ``` 142 | 143 | Once this is done running, your index will be saved on-disk and ready to be queried! RAGatouille and ColBERT handle everything here: 144 | - Splitting your documents 145 | - Tokenizing your documents 146 | - Identifying the individual terms 147 | - Embedding the documents and generating the bags-of-embeddings 148 | - Compressing the vectors and storing them on disk 149 | 150 | Curious about how this works? Check out the [Late-Interaction & ColBERT concept explainer](https://ben.clavie.eu/ragatouille/#late-interaction) 151 | 152 | 153 | ### 🔎 Retrieving Documents 154 | 155 | Once an index is created, querying it is just as simple as creating it! You can either load the model you need directly from an index's configuration: 156 | 157 | ```python 158 | from ragatouille import RAGPretrainedModel 159 | 160 | query = "ColBERT my dear ColBERT, who is the fairest document of them all?" 161 | RAG = RAGPretrainedModel.from_index("path_to_your_index") 162 | results = RAG.search(query) 163 | ``` 164 | 165 | This is the preferred way of doing things, since every index saves the full configuration of the model used to create it, and you can easily load it back up. 166 | 167 | `RAG.search` is a flexible method! You can set the `k` value to however many results you want (it defaults to `10`), and you can also use it to search for multiple queries at once: 168 | 169 | ```python 170 | RAG.search(["What manga did Hayao Miyazaki write?", 171 | "Who are the founders of Ghibli?" 172 | "Who is the director of Spirited Away?"],) 173 | ``` 174 | 175 | `RAG.search` returns results in the form of a list of dictionaries, or a list of list of dictionaries if you used multiple queries: 176 | 177 | ```python 178 | # single-query result 179 | [ 180 | {"content": "blablabla", "score": 42.424242, "rank": 1, "document_id": "x"}, 181 | ..., 182 | {"content": "albalbalba", "score": 24.242424, "rank": k, "document_id": "y"}, 183 | ] 184 | # multi-query result 185 | [ 186 | [ 187 | {"content": "blablabla", "score": 42.424242, "rank": 1, "document_id": "x"}, 188 | ..., 189 | {"content": "albalbalba", "score": 24.242424, "rank": k, "document_id": "y"}, 190 | ], 191 | [ 192 | {"content": "blablabla", "score": 42.424242, "rank": 1, "document_id": "x"}, 193 | ..., 194 | {"content": "albalbalba", "score": 24.242424, "rank": k, "document_id": "y"}, 195 | ], 196 | ], 197 | ``` 198 | If your index includes document metadata, it'll be returned as a dictionary in the `document_metadata` key of the result dictionary: 199 | 200 | ```python 201 | [ 202 | {"content": "blablabla", "score": 42.424242, "rank": 1, "document_id": "x", "document_metadata": {"A": 1, "B": 2}}, 203 | ..., 204 | {"content": "albalbalba", "score": 24.242424, "rank": k, "document_id": "y", "document_metadata": {"A": 3, "B": 4}}, 205 | ] 206 | ``` 207 | 208 | ## I'm sold, can I integrate late-interaction RAG into my project? 209 | 210 | To get started, RAGatouille bundles everything you need to build a ColBERT native index and query it. Just look at the docs! RAGatouille persists indices on disk in compressed format, and a very viable production deployment is to simply integrate the index you need into your project and query it directly. Don't just take our word for it, this is what Spotify does in production with their own vector search framework, serving dozens of millions of users: 211 | 212 | > Statelessness: Many of Spotify’s systems use nearest-neighbor search in memory, enabling stateless deployments (via Kubernetes) and almost entirely removing the maintenance and cost burden of maintaining a stateful database cluster. (_[Spotify, announcing Voyager](https://engineering.atspotify.com/2023/10/introducing-voyager-spotifys-new-nearest-neighbor-search-library/)_) 213 | 214 | 215 | ### Integrations 216 | 217 | If you'd like to use more than RAGatouille, ColBERT has a growing number of integrations, and they all fully support models trained or fine-tuned with RAGatouille! 218 | 219 | - The [official ColBERT implementation](https://github.com/stanford-futuredata/ColBERT) has a built-in query server (using Flask), which you can easily query via API requests and does support indexes generated with RAGatouille! This should be enough for most small applications, so long as you can persist the index on disk. 220 | - [Vespa](https://vespa.ai) offers a fully managed RAG engine with ColBERT support: it's essentially just like a vector DB, except with many more retrieval options! Full support for ColBERT models will be released in the next couple weeks, and using a RAGatouille-trained model will be as simple as loading it from the huggingface hub! **Vespa is a well-tested, widely used framework and is [fully-supported in LangChain](https://python.langchain.com/docs/integrations/providers/vespa), making it the ideal slot-in replacement to replace your current RAG pipeline with ColBERT!** 221 | - [Intel's FastRAG](https://github.com/IntelLabs/fastRAG) supports ColBERT models for RAG, and is fully compatible with RAGatouille-trained models. 222 | - [LlamaIndex](https://www.llamaindex.ai) is building ColBERT integrations and already [has early ColBERT support, with active development continuing](https://github.com/run-llama/llama_index/pull/9656). 223 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | ::: ragatouille.RAGTrainer 4 | 5 | ::: ragatouille.RAGPretrainedModel -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # RAGatouille 2 | 3 | _State-of-the-art document retrieval methods, in just a few lines of code._ 4 | 5 | --- 6 | 7 | Welcome to the docs for RAGatouille. This page presents RAGatouille's philosophy. We also discuss late interaction retrievers, like ColBERT, and what makes their quality so high. 8 | 9 | While you're here, check out the [API reference](https://ben.clavie.eu/ragatouille/api) and the [evolving roadmap](https://ben.clavie.eu/ragatouille/roadmap). The docs will be actively updated in the next few weeks! 10 | 11 | ## Philosophy 12 | 13 | RAGatouille's philosophy is two-fold: 14 | 15 | ### Motivation 16 | 17 | The aim of RAGatouille is to close the growing gap between the Information Retrieval literature and everyday production retrieval uses. I'm not myself an IR researcher (my own background is in NLP), but I have arrived at late interaction retrievers through the many papers highlighting they're consistently better than dense embeddings on just about any zero-shot task, at least when tested apples-to-apples! Maybe more importantly, you don't need to use them zero-shot: they're very easy to adapt to new domains due to their **bag-of-embeddings** approach. 18 | 19 | However, it's been a consistently hard sell, and starting to use ColBERT on real projects wasn't particularly smooth. IR is a field that is outwardly more _stern_ than NLP, and the barrier-to-entry is higher. A lot of the IR frameworks, like Terrier or Anserini, are absolutely fantastic, but they just don't fit into the pythonic day-to-day workflows we're used to. 20 | 21 | For sentence transformers adoption, the aptly (re-)name SentenceTransformers library has been a boon. RAGatouille doesn't quite have the pretention to be that, but it aims to help democratise the easy training and use ColBERT and pals. To do so, we also take an approach of avoiding re-implementing whenever possible, to speed up iteration. 22 | 23 | If a paper has open-sourced its code, our goal is for RAGatouille to be a gateway to it, rather than a complete replacement, whenever possible! (_This might change in the future as the library grows!_) Moreover, this is a mutually beneficial parasitic relationships, as the early development of this lib has already resulted in a few upstreamed fixes on the main ColBERT repo! 24 | 25 | ### Code Philosophy 26 | 27 | The actual programming philosophy is fairly simple, as stated on the [github README](https://github.com/bclavie/RAGatouille): 28 | 29 | - Strong, but parameterable defaults: you should be able to get started with just a few lines of code and still leverage the full power of ColBERT, and you should be able to tweak any relevant parameter if you need to! 30 | - Powerful yet simple re-usable components under-the-hood: any part of the library should be usable stand-alone. You can use our DataProcessor or our negative miners outside of `RAGPretrainedModel` and `RAGTrainer`, and you can even write your own negative miner and use it in the pipeline if you want to! 31 | 32 | In practice, this manifests by the fact that RAGPretrainedModel and RAGTrainer should ultimately be _all you need_ to leverage the power of ColBERT in your pipelines. Everything that gets added to RAGatouille should be workable into these two classes, who are really just interfaces between the underlying models and processors. 33 | 34 | However, re-usable components is another very important aspect. RAGatouille aims to be built in such a way that every core component, such as our negative miners (`SimpleMiner` for dense retrieval at the moment) or data processors should be usable outside the main classes, if you so desire. If you're a seasoned ColBERT afficionado, nothing should stop you from importing `TrainingDataProcessor` to streamline processing and exporting triplets! 35 | 36 | Finally, there's a third point that's fairly important: 37 | 38 | - __Don't re-invent the wheel.__ 39 | 40 | If a component needs to do something, we won't seek to do it our way, we'll seek to do it the way people already do it. This means using `LlamaIndex` to chunk documents, `instructor` and `pydantic` to constrain OpenAI calls, or `DSPy` whenever we need more complex LLM-based components! 41 | 42 | ## Late-Interaction 43 | 44 | ### tl;dr 45 | 46 | So, why is late-interaction so good? Why should you use RAGatouille/ColBERT? 47 | 48 | The underlying concept is simple. Quickly put, I like to explain ColBERT as a `bag-of-embeddings` approach, as this makes it immediately obvious how and why ColBERT works to NLP practitioners: 49 | 50 | - Just like bag-of-words, it works on small information units, and represents a document as the sum of them 51 | - Just like embeddings, it works on the semantic level: the actual way something is phrased doesn't matter, the model learns __meaning__. 52 | 53 | ### longer, might read 54 | 55 | A full blog post with more detail about this is coming (soon™), but for now, here's a quick explainer: 56 | 57 | Take it this way, the existing widely-used retrieval approaches, and a quick overview of their pros and cons: 58 | 59 | ##### BM25/Keyword-based Sparse Retrieval 60 | 61 | ➕ Fast 62 | ➕ Consistent performance 63 | ➕ No real training required 64 | ➕ Intuitive 65 | ➖ Requires exact matches 66 | ➖ Does not leverage any semantic information, and thus hits __a hard performance ceiling__ 67 | 68 | ##### Cross-Encoders 69 | 70 | ➕ Very strong performance 71 | ➕ Leverages semantic information to a large extent ("understands" negative form so that "I love apples" and "I hate apples" are not similar, etc...) 72 | ➖ Major scalability issues: can only retrieve scores by running the model to compare a query to every single document in the corpus. 73 | 74 | ##### Dense Retrieval/Embeddings 75 | 76 | ➕ Fast 77 | ➕ Decent performance overall, once pre-trained 78 | ➕ Leverages semantic information... 79 | ➖ ... but not constrastive information (e.g. "I love apples" and "I hate apples" will have a high similarity score.) 80 | ➖ Fine-tuning can be finnicky 81 | ➖ Requires either billions of parameters (e5-mistral) or billions of pre-training examples to reach top performance 82 | ➖ __Often generalises poorly__ 83 | 84 | ### Generalisation 85 | 86 | This last point is particularly important. __Generalisation__ is what you want, because the documents that you're trying to retrieve for your users, as well as the way that they phrase their queries, are __not the ones present in academic datasets__. 87 | 88 | Strong performance on academic benchmark is a solid signal to predict how well a model will perform, but it's far from the only ones. Single-vector embeddings approach, or _dense retrieval_ methods, often do well on benchmarks, as they're trained specifically for them. However, the IR litterature has shown many times that these models often genralise worse than other approaches. 89 | 90 | This is not a slight on them, it's actually very logical! If you think about it: 91 | 92 | - A single-vector embedding is **the representation of a sentence or document into a very small vector space, with at most 1024 dimensions**. 93 | - In retrieval settings, the **same model must also be able to create similar representations for very short query and long documents, to be able to retrieve them**. 94 | - Then, **these vectors must be able to represent your documents and your users' query, that it has never seen, in the same way that it has learned to represent its training data**. 95 | - And finally, **it must be able to encode all possible information contained in a document or in a query, so that it may be able to find a relevant document no matter how a question is phrased** 96 | 97 | The fact that dense embeddings perform well in these circumstances is very impressive! But sadly, embedding all this information into just a thousand dimensions isn't a problem that has been cracked yet. 98 | 99 | ### Bag-of-Embeddings: the Late Interaction trick 100 | 101 | Alleviating this is where late-interaction comes in. 102 | 103 | ColBERT does not represent documents into a single vector. In fact, ColBERT, at its core, is basically a **keyword-based approach**. 104 | 105 | But why does it perform so well, then, when we've established that keyword matchers have a hard ceiling? 106 | 107 | Because ColBERT is a **semantic keyword matcher**. It leverages the power of strong encoder models, like BERT, to break down each document into a bag of **contextualised units of information**. 108 | 109 | When a document is embedded by ColBERT, it isn't represented as a document, but as the sum of its parts. 110 | 111 | This fundamentally changes the nature of training our model, to a much easier task: it doesn't need to cram every possible meaning into a single vector, it just needs to capture the meaning of a few tokens at a time. When you do this, it doesn't really matter how you phrase something at retrieval time: the likelihood that the model is able to relate the way you mention a topic to the way a document discusses it is considerably higher. This is quite intuitively because the model has so much more space to store information on individual topics! Additionally, because this allows us to create smaller vectors for individual information units, they become very compressable, which means our indexes don't balloon up size. 112 | -------------------------------------------------------------------------------- /docs/roadmap.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | 3 | This page is incorrectly named: RAGatouille doesn't have a set-in-stone roadmap, but rather, a set of objectives. 4 | 5 | Below, you'll find things that we're hoping to integrate and/or support in upcoming versions (⛰️ denotes a major milestone): 6 | 7 | #### New Features 8 | 9 | ##### Synthetic Data Generation 10 | 11 | - Build upon our [tutorial 3](https://github.com/bclavie/RAGatouille/blob/main/examples/03-finetuning_without_annotations_with_instructor_and_RAGatouille.ipynb) and integrate OpenAI query generation into a built-in DataProcessor. 12 | - Leverage [DSPy](https://github.com/stanfordnlp/dspy) to perform data augmentation via LLM compiling, reducing the reliance on API providers by enabling locally-ran models to generate data. 13 | - ⛰️ Integrate [UDAPDR](https://arxiv.org/abs/2303.00807) - UDAPDR is an extremely impressive method to adapt retrievers to a target domain via entirely synthetic query: all you need to provide is your document collection. We're hoping to integrate this in an upcoming version of RAGatouille. 14 | - Provide a toolkit to generate synthetic passages for provided queries. 15 | 16 | 17 | #### Improvements 18 | 19 | - ⛰️ Full ColBERTv2 style training: transparently use an existing cross-encoder teacher model to generate distillation scores and improve model training. 20 | - Evaluation support: at the moment, RAGatouille doesn't roll out any evaluation metrics, as these are more commonly available already. Future versions of RAGatouille will include some form of evaluation for convenience! 21 | - Support for more "late-interaction" models, such as Google's [SparseEmbed](https://research.google/pubs/sparseembed-learning-sparse-lexical-representations-with-contextual-embeddings-for-retrieval/). 22 | - New negative miners, such as ColBERTMiner (not a huge priority as dense hard negative work well enough, but would be a nice feature for thoroughness) 23 | - Full LlamaIndex integration 24 | 25 | #### Library Upkeep 26 | 27 | - ⛰️ Improve the documentation to cover every component and concept of the library in-depth. 28 | - Comprehensive test coverage -------------------------------------------------------------------------------- /examples/02-basic_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c71ca3ef", 6 | "metadata": {}, 7 | "source": [ 8 | "# Basic training and fine-tuning with RAGatouille\n", 9 | "\n", 10 | "In this quick example, we'll use the `RAGTrainer` magic class to demonstrate how to very easily fine-tune an existing ColBERT model, or train one from any BERT/RoBERTa-like model (to [train one for a previously unsupported language like Japanese](https://huggingface.co/bclavie/jacolbert), for example!)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "d6f69fca", 16 | "metadata": {}, 17 | "source": [ 18 | "First, we'll create an instance of `RAGtrainer`. We need to give it two arguments: the `model_name` we want to give to the model we're training, and the `pretrained_model_name` of the base model. This can either be a local path, or the name of a model on the HuggingFace Hub. If you're training for a language other than English, you should also include a `language_code` two-letter argument (PLACEHOLDER ISO) so we can get the relevant processing utils!\n", 19 | "\n", 20 | "The trainer will auto-detect whether it's an existing ColBERT model or a BERT base model, and will set itself up accordingly!\n", 21 | "\n", 22 | "Please note: Training can currently only be ran on GPU, and will error out if using CPU/MPS! Training is also currently not functional on Google Colab and Windows 10.\n", 23 | "\n", 24 | "Whether we're training from scratch or fine-tuning doesn't matter, all the steps are the same. For this example, let's fine-tune ColBERTv2:" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "id": "1e81ac64-d222-412c-925c-2a5262266c0e", 31 | "metadata": { 32 | "tags": [] 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "from ragatouille import RAGTrainer\n", 37 | "trainer = RAGTrainer(model_name=\"GhibliColBERT\", pretrained_model_name=\"colbert-ir/colbertv2.0\", language_code=\"en\")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "6fa8016a", 43 | "metadata": {}, 44 | "source": [ 45 | "To train retrieval models like colberts, we need training triplets: queries, positive passages, and negative passages for each query.\n", 46 | "\n", 47 | "In the next tutorial, we'll see [how to generate synthetic queries when we don't have any annotated passage]. For this tutorial, we'll assume that we have queries and relevant passages, but that we're lacking negative ones (because it's not an information we gather from our users).\n", 48 | "\n", 49 | "Let's assume our corpus is the same as the one we [used for our example about indexing an searching](PLACEHOLDER): Hayao Miyazaki's wikipedia page. Let's first fetch the content from Wikipedia:" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "id": "3ca72a0d", 56 | "metadata": { 57 | "tags": [] 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "import requests\n", 62 | "\n", 63 | "def get_wikipedia_page(title: str):\n", 64 | " \"\"\"\n", 65 | " Retrieve the full text content of a Wikipedia page.\n", 66 | " \n", 67 | " :param title: str - Title of the Wikipedia page.\n", 68 | " :return: str - Full text content of the page as raw string.\n", 69 | " \"\"\"\n", 70 | " # Wikipedia API endpoint\n", 71 | " URL = \"https://en.wikipedia.org/w/api.php\"\n", 72 | "\n", 73 | " # Parameters for the API request\n", 74 | " params = {\n", 75 | " \"action\": \"query\",\n", 76 | " \"format\": \"json\",\n", 77 | " \"titles\": title,\n", 78 | " \"prop\": \"extracts\",\n", 79 | " \"explaintext\": True,\n", 80 | " }\n", 81 | "\n", 82 | " # Custom User-Agent header to comply with Wikipedia's best practices\n", 83 | " headers = {\n", 84 | " \"User-Agent\": \"RAGatouille_tutorial/0.0.1 (ben@clavie.eu)\"\n", 85 | " }\n", 86 | "\n", 87 | " response = requests.get(URL, params=params, headers=headers)\n", 88 | " data = response.json()\n", 89 | "\n", 90 | " # Extracting page content\n", 91 | " page = next(iter(data['query']['pages'].values()))\n", 92 | " return page['extract'] if 'extract' in page else None" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "id": "8d8ade7f", 99 | "metadata": { 100 | "tags": [] 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "my_full_corpus = [get_wikipedia_page(\"Hayao_Miyazaki\")]\n", 105 | "my_full_corpus += [get_wikipedia_page(\"Studio_Ghibli\")]\n", 106 | "my_full_corpus += [get_wikipedia_page(\"Toei_Animation\")]" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "id": "31778aed", 112 | "metadata": {}, 113 | "source": [ 114 | "We're also some Toei Animation content -- it helps to have things in our corpus that aren't directly relevant to our queries but are likely to cover similar topics, so we can get some more adjacent negative examples.\n", 115 | "\n", 116 | "The documents are very long, so let's use a `CorpusProcessor` to split them into chunks of around 256 tokens:" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "id": "a14fe476", 123 | "metadata": { 124 | "tags": [] 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "from ragatouille.data import CorpusProcessor, llama_index_sentence_splitter\n", 129 | "\n", 130 | "corpus_processor = CorpusProcessor(document_splitter_fn=llama_index_sentence_splitter)\n", 131 | "documents = corpus_processor.process_corpus(my_full_corpus, chunk_size=256)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "id": "a26eea4c", 137 | "metadata": {}, 138 | "source": [ 139 | "Now that we have a corpus of documents, let's generate fake query-relevant passage pair. Obviously, you wouldn't want that in the real world, but that's the topic of the next tutorial:" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "id": "53b48c50", 146 | "metadata": { 147 | "tags": [] 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "import random\n", 152 | "\n", 153 | "queries = [\"What manga did Hayao Miyazaki write?\",\n", 154 | " \"which film made ghibli famous internationally\",\n", 155 | " \"who directed Spirited Away?\",\n", 156 | " \"when was Hikotei Jidai published?\",\n", 157 | " \"where's studio ghibli based?\",\n", 158 | " \"where is the ghibli museum?\"\n", 159 | "] * 3\n", 160 | "pairs = []\n", 161 | "\n", 162 | "for query in queries:\n", 163 | " fake_relevant_docs = random.sample(documents, 10)\n", 164 | " for doc in fake_relevant_docs:\n", 165 | " pairs.append((query, doc))" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "46f29b15", 171 | "metadata": {}, 172 | "source": [ 173 | "Here, we have created pairs.It's common for retrieval training data to be stored in a lot of different ways: pairs of [query, positive], pairs of [query, passage, label], triplets of [query, positive, negative], or triplets of [query, list_of_positives, list_of_negatives]. No matter which format your data's in, you don't need to worry about it: RAGatouille will generate ColBERT-friendly triplets for you, and export them to disk for easy `dvc` or `wandb` data tracking.\n", 174 | "\n", 175 | "Speaking of, let's process the data so it's ready for training. `RAGTrainer` has a `prepare_training_data` function, which will perform all the necessary steps. One of the steps it performs is called **hard negative mining**: that's searching the full collection of documents (even those not linked to a query) to retrieve passages that are semantically close to a query, but aren't actually relevant. Using those to train retrieval models has repeatedly been shown to greatly improve their ability to find actually relevant documents, so it's a very important step! \n", 176 | "\n", 177 | "RAGatouille handles all of this for you. By default, it'll fetch 10 negative examples per query, but you can customise this with `num_new_negatives`. You can also choose not to mine negatives and just sample random examples to speed up things, this might lower performance but will run done much quicker on large volumes of data, just set `mine_hard_negatives` to `False`. If you've already mined negatives yourself, you can set `num_new_negatives` to 0 to bypass this entirely." 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 6, 183 | "id": "1ddaf3fb", 184 | "metadata": { 185 | "tags": [] 186 | }, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "Loading Hard Negative SimpleMiner dense embedding model BAAI/bge-small-en-v1.5...\n", 193 | "Building hard negative index for 93 documents...\n", 194 | "All documents embedded, now adding to index...\n", 195 | "save_index set to False, skipping saving hard negative index\n", 196 | "Hard negative index generated\n", 197 | "mining\n", 198 | "mining\n", 199 | "mining\n", 200 | "mining\n", 201 | "mining\n", 202 | "mining\n" 203 | ] 204 | }, 205 | { 206 | "data": { 207 | "text/plain": [ 208 | "'./data/'" 209 | ] 210 | }, 211 | "execution_count": 6, 212 | "metadata": {}, 213 | "output_type": "execute_result" 214 | } 215 | ], 216 | "source": [ 217 | "trainer.prepare_training_data(raw_data=pairs, data_out_path=\"./data/\", all_documents=my_full_corpus, num_new_negatives=10, mine_hard_negatives=True)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "id": "b9468e25", 223 | "metadata": {}, 224 | "source": [ 225 | "Our training data's now fully processed and saved to disk in `data_out_path`! We're now ready to begin training our model with the `train` function. `train` takes many arguments, but the set of default is already fairly strong!\n", 226 | "\n", 227 | "Don't be surprised you don't see an `epochs` parameter here, ColBERT will train until it either reaches `maxsteps` or has seen the entire training data once (a full epoch), this is by design!" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 8, 233 | "id": "06773f75-c844-42a4-b786-e56ef899a96e", 234 | "metadata": { 235 | "tags": [] 236 | }, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "#> Starting...\n", 243 | "nranks = 1 \t num_gpus = 1 \t device=0\n", 244 | "{\n", 245 | " \"query_token_id\": \"[unused0]\",\n", 246 | " \"doc_token_id\": \"[unused1]\",\n", 247 | " \"query_token\": \"[Q]\",\n", 248 | " \"doc_token\": \"[D]\",\n", 249 | " \"ncells\": null,\n", 250 | " \"centroid_score_threshold\": null,\n", 251 | " \"ndocs\": null,\n", 252 | " \"load_index_with_mmap\": false,\n", 253 | " \"index_path\": null,\n", 254 | " \"nbits\": 4,\n", 255 | " \"kmeans_niters\": 20,\n", 256 | " \"resume\": false,\n", 257 | " \"similarity\": \"cosine\",\n", 258 | " \"bsize\": 32,\n", 259 | " \"accumsteps\": 1,\n", 260 | " \"lr\": 5e-6,\n", 261 | " \"maxsteps\": 500000,\n", 262 | " \"save_every\": 0,\n", 263 | " \"warmup\": 0,\n", 264 | " \"warmup_bert\": null,\n", 265 | " \"relu\": false,\n", 266 | " \"nway\": 2,\n", 267 | " \"use_ib_negatives\": true,\n", 268 | " \"reranker\": false,\n", 269 | " \"distillation_alpha\": 1.0,\n", 270 | " \"ignore_scores\": false,\n", 271 | " \"model_name\": \"GhibliColBERT\",\n", 272 | " \"query_maxlen\": 32,\n", 273 | " \"attend_to_mask_tokens\": false,\n", 274 | " \"interaction\": \"colbert\",\n", 275 | " \"dim\": 128,\n", 276 | " \"doc_maxlen\": 256,\n", 277 | " \"mask_punctuation\": true,\n", 278 | " \"checkpoint\": \"colbert-ir\\/colbertv2.0\",\n", 279 | " \"triples\": \"data\\/triples.train.colbert.jsonl\",\n", 280 | " \"collection\": \"data\\/corpus.train.colbert.tsv\",\n", 281 | " \"queries\": \"data\\/queries.train.colbert.tsv\",\n", 282 | " \"index_name\": null,\n", 283 | " \"overwrite\": false,\n", 284 | " \"root\": \".ragatouille\\/\",\n", 285 | " \"experiment\": \"colbert\",\n", 286 | " \"index_root\": null,\n", 287 | " \"name\": \"2024-01\\/03\\/17.31.59\",\n", 288 | " \"rank\": 0,\n", 289 | " \"nranks\": 1,\n", 290 | " \"amp\": true,\n", 291 | " \"gpus\": 1\n", 292 | "}\n", 293 | "Using config.bsize = 32 (per process) and config.accumsteps = 1\n", 294 | "[Jan 03, 17:32:08] #> Loading the queries from data/queries.train.colbert.tsv ...\n", 295 | "[Jan 03, 17:32:08] #> Got 6 queries. All QIDs are unique.\n", 296 | "\n", 297 | "[Jan 03, 17:32:08] #> Loading collection...\n", 298 | "0M " 299 | ] 300 | }, 301 | { 302 | "name": "stderr", 303 | "output_type": "stream", 304 | "text": [ 305 | "/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:429: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", 306 | " warnings.warn(\n", 307 | "/opt/conda/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:139: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n", 308 | " warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n" 309 | ] 310 | }, 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "\n", 316 | "#> LR will use 0 warmup steps and linear decay over 500000 steps.\n", 317 | "\n", 318 | "#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==\n", 319 | "#> Input: . What manga did Hayao Miyazaki write?, \t\t True, \t\t None\n", 320 | "#> Output IDs: torch.Size([32]), tensor([ 101, 1, 2054, 8952, 2106, 10974, 7113, 2771, 3148, 18637,\n", 321 | " 4339, 1029, 102, 103, 103, 103, 103, 103, 103, 103,\n", 322 | " 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,\n", 323 | " 103, 103], device='cuda:0')\n", 324 | "#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 325 | " 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')\n", 326 | "\n", 327 | "\t\t\t\t 4.857693195343018 8.471616744995117\n", 328 | "#>>> 16.0 20.45 \t\t|\t\t -4.449999999999999\n", 329 | "[Jan 03, 17:32:12] 0 13.329309463500977\n", 330 | "\t\t\t\t 5.369233131408691 9.21822452545166\n", 331 | "#>>> 15.58 20.6 \t\t|\t\t -5.020000000000001\n", 332 | "[Jan 03, 17:32:12] 1 13.330567611694336\n", 333 | "\t\t\t\t 9.961652755737305 15.010833740234375\n", 334 | "#>>> 10.04 19.99 \t\t|\t\t -9.95\n", 335 | "[Jan 03, 17:32:13] 2 13.342209530578614\n", 336 | "\t\t\t\t 6.136704921722412 12.674407005310059\n", 337 | "#>>> 11.99 17.86 \t\t|\t\t -5.869999999999999\n", 338 | "[Jan 03, 17:32:14] 3 13.347678432498231\n", 339 | "\t\t\t\t 8.609644889831543 13.769636154174805\n", 340 | "#>>> 13.18 21.75 \t\t|\t\t -8.57\n", 341 | "[Jan 03, 17:32:15] 4 13.356710034156064\n", 342 | "[Jan 03, 17:32:15] #> Done with all triples!\n", 343 | "#> Saving a checkpoint to .ragatouille/colbert/none/2024-01/03/17.31.59/checkpoints/colbert ..\n", 344 | "#> Joined...\n" 345 | ] 346 | } 347 | ], 348 | "source": [ 349 | "\n", 350 | "trainer.train(batch_size=32,\n", 351 | " nbits=4, # How many bits will the trained model use when compressing indexes\n", 352 | " maxsteps=500000, # Maximum steps hard stop\n", 353 | " use_ib_negatives=True, # Use in-batch negative to calculate loss\n", 354 | " dim=128, # How many dimensions per embedding. 128 is the default and works well.\n", 355 | " learning_rate=5e-6, # Learning rate, small values ([3e-6,3e-5] work best if the base model is BERT-like, 5e-6 is often the sweet spot)\n", 356 | " doc_maxlen=256, # Maximum document length. Because of how ColBERT works, smaller chunks (128-256) work very well.\n", 357 | " use_relu=False, # Disable ReLU -- doesn't improve performance\n", 358 | " warmup_steps=\"auto\", # Defaults to 10%\n", 359 | " )\n" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "id": "44ab6a51-a4bd-4fab-96d7-dd8e93dac462", 365 | "metadata": {}, 366 | "source": [ 367 | "And you're now done training! Your model is saved at the path it outputs, with the final checkpoint always being in the `.../checkpoints/colbert` path, and intermediate checkpoints saved at `.../checkpoints/colbert-{N_STEPS}`.\n", 368 | "\n", 369 | "You can now use your model by pointing at its local path, or upload it to the huggingface hub to share it with the world!" 370 | ] 371 | } 372 | ], 373 | "metadata": { 374 | "environment": { 375 | "kernel": "python3", 376 | "name": ".m114", 377 | "type": "gcloud", 378 | "uri": "gcr.io/deeplearning-platform-release/:m114" 379 | }, 380 | "kernelspec": { 381 | "display_name": "Python 3", 382 | "language": "python", 383 | "name": "python3" 384 | }, 385 | "language_info": { 386 | "codemirror_mode": { 387 | "name": "ipython", 388 | "version": 3 389 | }, 390 | "file_extension": ".py", 391 | "mimetype": "text/x-python", 392 | "name": "python", 393 | "nbconvert_exporter": "python", 394 | "pygments_lexer": "ipython3", 395 | "version": "3.10.13" 396 | } 397 | }, 398 | "nbformat": 4, 399 | "nbformat_minor": 5 400 | } 401 | -------------------------------------------------------------------------------- /examples/06-index_free_use.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using ColBERT in-memory: Index-Free Encodings & Search\n", 8 | "\n", 9 | "Sometimes, building an index doesn't make sense. Maybe you're working with a really small dataset, or one that is really fleeting nature, and will only be relevant to the lifetime of your current instance. In these cases, it can be more efficient to skip all the time-consuming index optimisation, and keep your encodings in-memory to perform ColBERT's magical MaxSim on-the-fly. This doesn't scale very well, but can be very useful in certain settings.\n", 10 | "\n", 11 | "In this quick example, we'll use the `RAGPretrainedModel` magic class to demonstrate how to **encode documents in-memory**, before **retrieving them with `search_encoded_docs`**.\n", 12 | "\n", 13 | "First, as usual, let's load up a pre-trained ColBERT model:" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stderr", 23 | "output_type": "stream", 24 | "text": [ 25 | "/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 26 | " from .autonotebook import tqdm as notebook_tqdm\n" 27 | ] 28 | }, 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "[Jan 27, 16:56:39] Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" 34 | ] 35 | }, 36 | { 37 | "name": "stderr", 38 | "output_type": "stream", 39 | "text": [ 40 | "/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py:125: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n", 41 | " warnings.warn(\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "from ragatouille import RAGPretrainedModel\n", 47 | "\n", 48 | "RAG = RAGPretrainedModel.from_pretrained(\"colbert-ir/colbertv2.0\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "Now that our model is loaded, we can load and preprocess some data, as in the previous tutorials:" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "from ragatouille.utils import get_wikipedia_page\n", 65 | "from ragatouille.data import CorpusProcessor\n", 66 | "\n", 67 | "corpus_processor = CorpusProcessor()\n", 68 | "\n", 69 | "documents = [get_wikipedia_page(\"Hayao Miyazaki\"), get_wikipedia_page(\"Studio Ghibli\"), get_wikipedia_page(\"Princess Mononoke\"), get_wikipedia_page(\"Shrek\")]\n", 70 | "documents = corpus_processor.process_corpus(documents, chunk_size=200)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "Our documents are now fully ready to be encoded! \n", 78 | "\n", 79 | "One important note: `encode()` itself will not split your documents, you must pre-process them yourself (using corpus_processor or your preferred chunking approach). However, `encode()` will dynamically set the maximum token length, calculated based on the token length distribution in your corpus, up to the maximum length supported by the model you're using.\n", 80 | "\n", 81 | "Just like normal indexing, `encode()` also supports adding metadata to the encoded documents, which will be returned as part of query results:" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "Encoding 212 documents...\n" 94 | ] 95 | }, 96 | { 97 | "name": "stderr", 98 | "output_type": "stream", 99 | "text": [ 100 | " 0%| | 0/7 [00:00=0.2.19", 33 | "langchain", 34 | "onnx", 35 | "srsly", 36 | "voyager", 37 | "torch>=1.13", 38 | "fast-pytorch-kmeans", 39 | "sentence-transformers", 40 | ] 41 | 42 | [project.optional-dependencies] 43 | all = [ 44 | "llama-index", 45 | "langchain", 46 | "rerankers", 47 | "voyager", 48 | ] 49 | llamaindex = ["llama-index"] 50 | langchain = ["langchain"] 51 | train = ["sentence-transformers", "pylate", "rerankers"] 52 | 53 | [project.urls] 54 | "Homepage" = "https://github.com/answerdotai/ragatouille" 55 | 56 | [tool.pytest.ini_options] 57 | filterwarnings = [ 58 | "ignore::Warning" 59 | ] 60 | 61 | target-version = "py39" -------------------------------------------------------------------------------- /ragatouille/RAGPretrainedModel.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union 3 | from uuid import uuid4 4 | 5 | from langchain.retrievers.document_compressors.base import BaseDocumentCompressor 6 | from langchain_core.retrievers import BaseRetriever 7 | 8 | from ragatouille.data.corpus_processor import CorpusProcessor 9 | from ragatouille.data.preprocessors import llama_index_sentence_splitter 10 | from ragatouille.integrations import ( 11 | RAGatouilleLangChainCompressor, 12 | RAGatouilleLangChainRetriever, 13 | ) 14 | from ragatouille.models import ColBERT, LateInteractionModel 15 | 16 | 17 | class RAGPretrainedModel: 18 | """ 19 | Wrapper class for a pretrained RAG late-interaction model, and all the associated utilities. 20 | Allows you to load a pretrained model from disk or from the hub, build or query an index. 21 | 22 | ## Usage 23 | 24 | Load a pre-trained checkpoint: 25 | 26 | ```python 27 | from ragatouille import RAGPretrainedModel 28 | 29 | RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") 30 | ``` 31 | 32 | Load checkpoint from an existing index: 33 | 34 | ```python 35 | from ragatouille import RAGPretrainedModel 36 | 37 | RAG = RAGPretrainedModel.from_index("path/to/my/index") 38 | ``` 39 | 40 | Both methods will load a fully initialised instance of ColBERT, which you can use to build and query indexes. 41 | 42 | ```python 43 | RAG.search("How many people live in France?") 44 | ``` 45 | """ 46 | 47 | model_name: Union[str, None] = None 48 | model: Union[LateInteractionModel, None] = None 49 | corpus_processor: Optional[CorpusProcessor] = None 50 | 51 | @classmethod 52 | def from_pretrained( 53 | cls, 54 | pretrained_model_name_or_path: Union[str, Path], 55 | n_gpu: int = -1, 56 | verbose: int = 1, 57 | index_root: Optional[str] = None, 58 | ): 59 | """Load a ColBERT model from a pre-trained checkpoint. 60 | 61 | Parameters: 62 | pretrained_model_name_or_path (str): Local path or huggingface model name. 63 | n_gpu (int): Number of GPUs to use. By default, value is -1, which means use all available GPUs or none if no GPU is available. 64 | verbose (int): The level of ColBERT verbosity requested. By default, 1, which will filter out most internal logs. 65 | index_root (Optional[str]): The root directory where indexes will be stored. If None, will use the default directory, '.ragatouille/'. 66 | 67 | Returns: 68 | cls (RAGPretrainedModel): The current instance of RAGPretrainedModel, with the model initialised. 69 | """ 70 | instance = cls() 71 | instance.model = ColBERT( 72 | pretrained_model_name_or_path, n_gpu, index_root=index_root, verbose=verbose 73 | ) 74 | return instance 75 | 76 | @classmethod 77 | def from_index( 78 | cls, index_path: Union[str, Path], n_gpu: int = -1, verbose: int = 1 79 | ): 80 | """Load an Index and the associated ColBERT encoder from an existing document index. 81 | 82 | Parameters: 83 | index_path (Union[str, path]): Path to the index. 84 | n_gpu (int): Number of GPUs to use. By default, value is -1, which means use all available GPUs or none if no GPU is available. 85 | verbose (int): The level of ColBERT verbosity requested. By default, 1, which will filter out most internal logs. 86 | 87 | Returns: 88 | cls (RAGPretrainedModel): The current instance of RAGPretrainedModel, with the model and index initialised. 89 | """ 90 | instance = cls() 91 | index_path = Path(index_path) 92 | instance.model = ColBERT( 93 | index_path, n_gpu, verbose=verbose, load_from_index=True 94 | ) 95 | 96 | return instance 97 | 98 | def _process_metadata( 99 | self, 100 | document_ids: Optional[Union[TypeVar("T"), List[TypeVar("T")]]], 101 | document_metadatas: Optional[list[dict[Any, Any]]], 102 | collection_len: int, 103 | ) -> tuple[list[str], Optional[dict[Any, Any]]]: 104 | if document_ids is None: 105 | document_ids = [str(uuid4()) for i in range(collection_len)] 106 | else: 107 | if len(document_ids) != collection_len: 108 | raise ValueError("document_ids must be the same length as collection") 109 | if len(document_ids) != len(set(document_ids)): 110 | raise ValueError("document_ids must be unique") 111 | if any(not id.strip() for id in document_ids): 112 | raise ValueError("document_ids must not contain empty strings") 113 | if not all(isinstance(id, type(document_ids[0])) for id in document_ids): 114 | raise ValueError("All document_ids must be of the same type") 115 | 116 | if document_metadatas is not None: 117 | if len(document_metadatas) != collection_len: 118 | raise ValueError( 119 | "document_metadatas must be the same length as collection" 120 | ) 121 | docid_metadata_map = { 122 | x: y for x, y in zip(document_ids, document_metadatas) 123 | } 124 | else: 125 | docid_metadata_map = None 126 | 127 | return document_ids, docid_metadata_map 128 | 129 | def _process_corpus( 130 | self, 131 | collection: List[str], 132 | document_ids: List[str], 133 | document_metadatas: List[Dict[Any, Any]], 134 | document_splitter_fn: Optional[Callable[[str], List[str]]], 135 | preprocessing_fn: Optional[Callable[[str], str]], 136 | max_document_length: int, 137 | ) -> Tuple[List[str], Dict[int, str], Dict[str, Dict[Any, Any]]]: 138 | """ 139 | Processes a collection of documents by assigning unique IDs, splitting documents if necessary, 140 | applying preprocessing, and organizing metadata. 141 | """ 142 | document_ids, docid_metadata_map = self._process_metadata( 143 | document_ids=document_ids, 144 | document_metadatas=document_metadatas, 145 | collection_len=len(collection), 146 | ) 147 | 148 | if document_splitter_fn is not None or preprocessing_fn is not None: 149 | self.corpus_processor = CorpusProcessor( 150 | document_splitter_fn=document_splitter_fn, 151 | preprocessing_fn=preprocessing_fn, 152 | ) 153 | collection_with_ids = self.corpus_processor.process_corpus( 154 | collection, 155 | document_ids, 156 | chunk_size=max_document_length, 157 | ) 158 | else: 159 | collection_with_ids = [ 160 | {"document_id": x, "content": y} 161 | for x, y in zip(document_ids, collection) 162 | ] 163 | 164 | pid_docid_map = { 165 | index: item["document_id"] for index, item in enumerate(collection_with_ids) 166 | } 167 | collection = [x["content"] for x in collection_with_ids] 168 | 169 | return collection, pid_docid_map, docid_metadata_map 170 | 171 | def index( 172 | self, 173 | collection: list[str], 174 | document_ids: Union[TypeVar("T"), List[TypeVar("T")]] = None, 175 | document_metadatas: Optional[list[dict]] = None, 176 | index_name: str = None, 177 | overwrite_index: Union[bool, str] = True, 178 | max_document_length: int = 256, 179 | split_documents: bool = True, 180 | document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter, 181 | preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None, 182 | bsize: int = 32, 183 | use_faiss: bool = False, 184 | ): 185 | """Build an index from a list of documents. 186 | 187 | Parameters: 188 | collection (list[str]): The collection of documents to index. 189 | document_ids (Optional[list[str]]): An optional list of document ids. Ids will be generated at index time if not supplied. 190 | index_name (str): The name of the index that will be built. 191 | overwrite_index (Union[bool, str]): Whether to overwrite an existing index with the same name. 192 | max_document_length (int): The maximum length of a document. Documents longer than this will be split into chunks. 193 | split_documents (bool): Whether to split documents into chunks. 194 | document_splitter_fn (Optional[Callable]): A function to split documents into chunks. If None and by default, will use the llama_index_sentence_splitter. 195 | preprocessing_fn (Optional[Union[Callable, list[Callable]]]): A function or list of functions to preprocess documents. If None and by default, will not preprocess documents. 196 | bsize (int): The batch size to use for encoding the passages. 197 | 198 | Returns: 199 | index (str): The path to the index that was built. 200 | """ 201 | if not split_documents: 202 | document_splitter_fn = None 203 | collection, pid_docid_map, docid_metadata_map = self._process_corpus( 204 | collection, 205 | document_ids, 206 | document_metadatas, 207 | document_splitter_fn, 208 | preprocessing_fn, 209 | max_document_length, 210 | ) 211 | return self.model.index( 212 | collection, 213 | pid_docid_map=pid_docid_map, 214 | docid_metadata_map=docid_metadata_map, 215 | index_name=index_name, 216 | max_document_length=max_document_length, 217 | overwrite=overwrite_index, 218 | bsize=bsize, 219 | use_faiss=use_faiss, 220 | ) 221 | 222 | def add_to_index( 223 | self, 224 | new_collection: list[str], 225 | new_document_ids: Optional[Union[TypeVar("T"), List[TypeVar("T")]]] = None, 226 | new_document_metadatas: Optional[list[dict]] = None, 227 | index_name: Optional[str] = None, 228 | split_documents: bool = True, 229 | document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter, 230 | preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None, 231 | bsize: int = 32, 232 | use_faiss: bool = False, 233 | ): 234 | """Add documents to an existing index. 235 | 236 | Parameters: 237 | new_collection (list[str]): The documents to add to the index. 238 | new_document_metadatas (Optional[list[dict]]): An optional list of metadata dicts 239 | index_name (Optional[str]): The name of the index to add documents to. If None and by default, will add documents to the already initialised one. 240 | bsize (int): The batch size to use for encoding the passages. 241 | """ 242 | if not split_documents: 243 | document_splitter_fn = None 244 | 245 | ( 246 | new_collection, 247 | new_pid_docid_map, 248 | new_docid_metadata_map, 249 | ) = self._process_corpus( 250 | new_collection, 251 | new_document_ids, 252 | new_document_metadatas, 253 | document_splitter_fn, 254 | preprocessing_fn, 255 | self.model.config.doc_maxlen, 256 | ) 257 | 258 | self.model.add_to_index( 259 | new_collection, 260 | new_pid_docid_map, 261 | new_docid_metadata_map=new_docid_metadata_map, 262 | index_name=index_name, 263 | bsize=bsize, 264 | use_faiss=use_faiss, 265 | ) 266 | 267 | def delete_from_index( 268 | self, 269 | document_ids: Union[TypeVar("T"), List[TypeVar("T")]], 270 | index_name: Optional[str] = None, 271 | ): 272 | """Delete documents from an index by their IDs. 273 | 274 | Parameters: 275 | document_ids (Union[TypeVar("T"), List[TypeVar("T")]]): The IDs of the documents to delete. 276 | index_name (Optional[str]): The name of the index to delete documents from. If None and by default, will delete documents from the already initialised one. 277 | """ 278 | self.model.delete_from_index( 279 | document_ids, 280 | index_name=index_name, 281 | ) 282 | 283 | def search( 284 | self, 285 | query: Union[str, list[str]], 286 | index_name: Optional["str"] = None, 287 | k: int = 10, 288 | force_fast: bool = False, 289 | zero_index_ranks: bool = False, 290 | doc_ids: Optional[list[str]] = None, 291 | **kwargs, 292 | ): 293 | """Query an index. 294 | 295 | Parameters: 296 | query (Union[str, list[str]]): The query or list of queries to search for. 297 | index_name (Optional[str]): Provide the name of an index to query. If None and by default, will query an already initialised one. 298 | k (int): The number of results to return for each query. 299 | force_fast (bool): Whether to force the use of a faster but less accurate search method. 300 | zero_index_ranks (bool): Whether to zero the index ranks of the results. By default, result rank 1 is the highest ranked result 301 | 302 | Returns: 303 | results (Union[list[dict], list[list[dict]]]): A list of dict containing individual results for each query. If a list of queries is provided, returns a list of lists of dicts. Each result is a dict with keys `content`, `score`, `rank`, and 'document_id'. If metadata was indexed for the document, it will be returned under the "document_metadata" key. 304 | 305 | Individual results are always in the format: 306 | ```python3 307 | {"content": "text of the relevant passage", "score": 0.123456, "rank": 1, "document_id": "x"} 308 | ``` 309 | or 310 | ```python3 311 | {"content": "text of the relevant passage", "score": 0.123456, "rank": 1, "document_id": "x", "document_metadata": {"metadata_key": "metadata_value", ...}} 312 | ``` 313 | 314 | """ 315 | return self.model.search( 316 | query=query, 317 | index_name=index_name, 318 | k=k, 319 | force_fast=force_fast, 320 | zero_index_ranks=zero_index_ranks, 321 | doc_ids=doc_ids, 322 | **kwargs, 323 | ) 324 | 325 | def rerank( 326 | self, 327 | query: Union[str, list[str]], 328 | documents: list[str], 329 | k: int = 10, 330 | zero_index_ranks: bool = False, 331 | bsize: Union[Literal["auto"], int] = "auto", 332 | ): 333 | """Encode documents and rerank them in-memory. Performance degrades rapidly with more documents. 334 | 335 | Parameters: 336 | query (Union[str, list[str]]): The query or list of queries to search for. 337 | documents (list[str]): The documents to rerank. 338 | k (int): The number of results to return for each query. 339 | zero_index_ranks (bool): Whether to zero the index ranks of the results. By default, result rank 1 is the highest ranked result 340 | bsize (int): The batch size to use for re-ranking. 341 | 342 | Returns: 343 | results (Union[list[dict], list[list[dict]]]): A list of dict containing individual results for each query. If a list of queries is provided, returns a list of lists of dicts. Each result is a dict with keys `content`, `score` and `rank`. 344 | 345 | Individual results are always in the format: 346 | ```python3 347 | {"content": "text of the relevant passage", "score": 0.123456, "rank": 1} 348 | ``` 349 | """ 350 | 351 | return self.model.rank( 352 | query=query, 353 | documents=documents, 354 | k=k, 355 | zero_index_ranks=zero_index_ranks, 356 | bsize=bsize, 357 | ) 358 | 359 | def encode( 360 | self, 361 | documents: list[str], 362 | bsize: Union[Literal["auto"], int] = "auto", 363 | document_metadatas: Optional[list[dict]] = None, 364 | verbose: bool = True, 365 | max_document_length: Union[Literal["auto"], int] = "auto", 366 | ): 367 | """Encode documents in memory to be searched through with no Index. Performance degrades rapidly with more documents. 368 | 369 | Parameters: 370 | documents (list[str]): The documents to encode. 371 | bsize (int): The batch size to use for encoding. 372 | document_metadatas (Optional[list[dict]]): An optional list of metadata dicts. Each entry must correspond to a document. 373 | """ 374 | if verbose: 375 | print(f"Encoding {len(documents)} documents...") 376 | self.model.encode( 377 | documents=documents, 378 | bsize=bsize, 379 | document_metadatas=document_metadatas, 380 | verbose=verbose, 381 | max_tokens=max_document_length, 382 | ) 383 | if verbose: 384 | print("Documents encoded!") 385 | 386 | def search_encoded_docs( 387 | self, 388 | query: Union[str, list[str]], 389 | k: int = 10, 390 | bsize: int = 32, 391 | ) -> list[dict[str, Any]]: 392 | """Search through documents encoded in-memory. 393 | 394 | Parameters: 395 | query (Union[str, list[str]]): The query or list of queries to search for. 396 | k (int): The number of results to return for each query. 397 | batch_size (int): The batch size to use for searching. 398 | 399 | Returns: 400 | results (list[dict[str, Any]]): A list of dict containing individual results for each query. If a list of queries is provided, returns a list of lists of dicts. 401 | """ 402 | return self.model.search_encoded_docs( 403 | queries=query, 404 | k=k, 405 | bsize=bsize, 406 | ) 407 | 408 | def clear_encoded_docs(self, force: bool = False): 409 | """Clear documents encoded in-memory. 410 | 411 | Parameters: 412 | force (bool): Whether to force the clearing of encoded documents without enforcing a 10s wait time. 413 | """ 414 | self.model.clear_encoded_docs(force=force) 415 | 416 | def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever: 417 | return RAGatouilleLangChainRetriever(model=self, kwargs=kwargs) 418 | 419 | def as_langchain_document_compressor( 420 | self, k: int = 5, **kwargs: Any 421 | ) -> BaseDocumentCompressor: 422 | return RAGatouilleLangChainCompressor(model=self, k=k, kwargs=kwargs) 423 | -------------------------------------------------------------------------------- /ragatouille/RAGTrainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal, Optional, Union 3 | 4 | from colbert.infra import ColBERTConfig 5 | 6 | from ragatouille.data import TrainingDataProcessor 7 | from ragatouille.models import ColBERT, LateInteractionModel 8 | from ragatouille.negative_miners import HardNegativeMiner, SimpleMiner 9 | from ragatouille.utils import seeded_shuffle 10 | 11 | 12 | class RAGTrainer: 13 | """Main trainer to fine-tune/train ColBERT models with a few lines.""" 14 | 15 | def __init__( 16 | self, 17 | model_name: str, 18 | pretrained_model_name: str, 19 | language_code: str = "en", 20 | n_usable_gpus: int = -1, 21 | ): 22 | """ 23 | Initialise a RAGTrainer instance. This will load a base model: either an existing ColBERT model to fine-tune or a BERT/RoBERTa-like model to build a new ColBERT model from. 24 | 25 | Parameters: 26 | model_name: str - Name of the model to train. This will be used to name the checkpoints and the index. 27 | pretrained_model_name: str - Name of the pretrained model to use as a base. Can be a local path to a checkpoint or a huggingface model name. 28 | language_code: str - Language code of the model to train. This will be used to name the checkpoints and the index. 29 | n_usable_gpus: int - Number of GPUs to use. By default, value is -1, which means use all available GPUs or none if no GPU is available. 30 | 31 | Returns: 32 | self (RAGTrainer): The current instance of RAGTrainer, with the base model initialised. 33 | """ 34 | self.model_name = model_name 35 | self.pretrained_model_name = pretrained_model_name 36 | self.language_code = language_code 37 | self.model: Union[LateInteractionModel, None] = ColBERT( 38 | pretrained_model_name_or_path=pretrained_model_name, 39 | n_gpu=n_usable_gpus, 40 | training_mode=True, 41 | ) 42 | self.negative_miner: Union[HardNegativeMiner, None] = None 43 | self.collection: list[str] = [] 44 | self.queries: Union[list[str], None] = None 45 | self.raw_data: Union[list[tuple], list[list], None] = None 46 | self.training_triplets: list[list[int]] = list() 47 | 48 | def add_documents(self, documents: list[str]): 49 | self.collection += documents 50 | seeded_shuffle(self.collection) 51 | 52 | def export_training_data(self, path: Union[str, Path]): 53 | """ 54 | Manually export the training data processed by prepare_training_data to a given path. 55 | 56 | Parameters: 57 | path: Union[str, Path] - Path to the directory where the data will be exported.""" 58 | self.data_processor.export_training_data(path) 59 | 60 | def _add_to_collection(self, content: Union[str, list, dict]): 61 | if isinstance(content, str): 62 | self.collection.append(content) 63 | elif isinstance(content, list): 64 | self.collection += [txt for txt in content if isinstance(txt, str)] 65 | elif isinstance(content, dict): 66 | self.collection += [content["content"]] 67 | 68 | def prepare_training_data( 69 | self, 70 | raw_data: Union[list[tuple], list[list]], 71 | all_documents: Optional[list[str]] = None, 72 | data_out_path: Union[str, Path] = "./data/", 73 | num_new_negatives: int = 10, 74 | hard_negative_minimum_rank: int = 10, 75 | mine_hard_negatives: bool = True, 76 | hard_negative_model_size: str = "small", 77 | pairs_with_labels: bool = False, 78 | positive_label: Union[int, str] = 1, 79 | negative_label: Union[int, str] = 0, 80 | ) -> str: 81 | """ 82 | Fully pre-process input-data in various raw formats into ColBERT-ready files and triplets. 83 | Will accept a variety of formats, such as unannotated pairs, annotated pairs, triplets of strings and triplets of list of strings. 84 | Will process into a ColBERT-ready format and export to data_out_path. 85 | Will generate hard negatives if mine_hard_negatives is True. 86 | num_new_negatives decides how many negatives will be generated. if mine_hard_negatives is False and num_new_negatives is > 0, these negatives will be randomly sampled. 87 | 88 | Parameters: 89 | raw_data: Union[list[tuple], list[list]] - List of pairs, annotated pairs, or triplets of strings. 90 | all_documents: Optional[list[str]] - A corpus of documents to be used for sampling negatives. 91 | data_out_path: Union[str, Path] - Path to the directory where the data will be exported (can be a tmp directory). 92 | num_new_negatives: int - Number of new negatives to generate for each query. 93 | mine_hard_negatives: bool - Whether to use hard negatives mining or not. 94 | hard_negative_model_size: str - Size of the model to use for hard negatives mining. 95 | pairs_with_labels: bool - Whether the raw_data is a list of pairs with labels or not. 96 | positive_label: Union[int, str] - Label to use for positive pairs. 97 | negative_label: Union[int, str] - Label to use for negative pairs. 98 | 99 | Returns: 100 | data_out_path: Union[str, Path] - Path to the directory where the data has been exported. 101 | """ 102 | if all_documents is not None: 103 | self.collection += [doc for doc in all_documents if isinstance(doc, str)] 104 | 105 | self.data_dir = Path(data_out_path) 106 | sample = raw_data[0] 107 | if len(sample) == 2: 108 | data_type = "pairs" 109 | elif len(sample) == 3: 110 | if pairs_with_labels: 111 | data_type = "labeled_pairs" 112 | if sample[-1] not in [positive_label, negative_label]: 113 | raise ValueError(f"Invalid value for label: {sample}") 114 | else: 115 | data_type = "triplets" 116 | else: 117 | raise ValueError("Raw data must be a list of pairs or triplets of strings.") 118 | 119 | self.queries = set() 120 | for x in raw_data: 121 | if isinstance(x[0], str): 122 | self.queries.add(x[0]) 123 | else: 124 | raise ValueError("Queries must be a strings.") 125 | 126 | self._add_to_collection(x[1]) 127 | 128 | if data_type == "triplets": 129 | self._add_to_collection(x[2]) 130 | 131 | self.collection = list(set(self.collection)) 132 | seeded_shuffle(self.collection) 133 | 134 | if mine_hard_negatives: 135 | self.negative_miner = SimpleMiner( 136 | language_code=self.language_code, 137 | model_size=hard_negative_model_size, 138 | ) 139 | self.negative_miner.build_index(self.collection) 140 | 141 | self.data_processor = TrainingDataProcessor( 142 | collection=self.collection, 143 | queries=self.queries, 144 | negative_miner=self.negative_miner if mine_hard_negatives else None, 145 | ) 146 | 147 | self.data_processor.process_raw_data( 148 | data_type=data_type, 149 | raw_data=raw_data, 150 | export=True, 151 | data_dir=data_out_path, 152 | num_new_negatives=num_new_negatives, 153 | positive_label=positive_label, 154 | negative_label=negative_label, 155 | mine_hard_negatives=mine_hard_negatives, 156 | hard_negative_minimum_rank=hard_negative_minimum_rank, 157 | ) 158 | if len(self.data_processor.training_triplets) == 0: 159 | if mine_hard_negatives: 160 | print( 161 | "Warning: No training triplets were generated with setting mine_hard_negatives=='True'. This may be due to the data being too small or the hard negative miner not being able to find enough hard negatives." 162 | ) 163 | self.data_processor.process_raw_data( 164 | data_type=data_type, 165 | raw_data=raw_data, 166 | export=True, 167 | data_dir=data_out_path, 168 | num_new_negatives=num_new_negatives, 169 | positive_label=positive_label, 170 | negative_label=negative_label, 171 | mine_hard_negatives=False, 172 | hard_negative_minimum_rank=hard_negative_minimum_rank, 173 | ) 174 | else: 175 | raise ValueError("No training triplets were generated.") 176 | 177 | self.training_triplets = self.data_processor.training_triplets 178 | 179 | return data_out_path 180 | 181 | def train( 182 | self, 183 | batch_size: int = 32, 184 | nbits: int = 2, 185 | maxsteps: int = 500_000, 186 | use_ib_negatives: bool = True, 187 | learning_rate: float = 5e-6, 188 | dim: int = 128, 189 | doc_maxlen: int = 256, 190 | use_relu: bool = False, 191 | warmup_steps: Union[int, Literal["auto"]] = "auto", 192 | accumsteps: int = 1, 193 | ) -> str: 194 | """ 195 | Launch training or fine-tuning of a ColBERT model. 196 | 197 | Parameters: 198 | batch_size: int - Total batch size -- divice by n_usable_gpus for per-GPU batch size. 199 | nbits: int - number of bits used for vector compression by the traiened model. 2 is usually ideal. 200 | maxsteps: int - End training early after maxsteps steps. 201 | use_ib_negatives: bool - Whether to use in-batch negatives to calculate loss or not. 202 | learning_rate: float - ColBERT litterature usually has this performing best between 3e-6 - 2e-5 depending on data size 203 | dim: int - Size of individual vector representations. 204 | doc_maxlen: int - The maximum length after which passages will be truncated 205 | warmup_steps: Union[int, Literal["auto"]] - How many warmup steps to use for the learning rate. 206 | Auto will default to 10% of total steps 207 | accumsteps: How many gradient accummulation steps to use to simulate higher batch sizes. 208 | 209 | Returns: 210 | model_path: str - Path to the trained model. 211 | """ 212 | if not self.training_triplets: 213 | total_triplets = sum( 214 | 1 for _ in open(str(self.data_dir / "triples.train.colbert.jsonl"), "r") 215 | ) 216 | else: 217 | total_triplets = len(self.training_triplets) 218 | 219 | training_config = ColBERTConfig( 220 | bsize=batch_size, 221 | model_name=self.model_name, 222 | name=self.model_name, 223 | checkpoint=self.pretrained_model_name, 224 | use_ib_negatives=use_ib_negatives, 225 | maxsteps=maxsteps, 226 | nbits=nbits, 227 | lr=learning_rate, 228 | dim=dim, 229 | doc_maxlen=doc_maxlen, 230 | relu=use_relu, 231 | accumsteps=accumsteps, 232 | warmup=int(total_triplets // batch_size * 0.1) 233 | if warmup_steps == "auto" 234 | else warmup_steps, 235 | save_every=int(total_triplets // batch_size // 10), 236 | ) 237 | 238 | return self.model.train(data_dir=self.data_dir, training_config=training_config) 239 | -------------------------------------------------------------------------------- /ragatouille/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | _FUTURE_MIGRATION_WARNING_MESSAGE = ( 4 | "\n********************************************************************************\n" 5 | "RAGatouille WARNING: Future Release Notice\n" 6 | "--------------------------------------------\n" 7 | "RAGatouille version 0.0.10 will be migrating to a PyLate backend \n" 8 | "instead of the current Stanford ColBERT backend.\n" 9 | "PyLate is a fully mature, feature-equivalent backend, that greatly facilitates compatibility.\n" 10 | "However, please pin version <0.0.10 if you require the Stanford ColBERT backend.\n" 11 | "********************************************************************************" 12 | ) 13 | 14 | warnings.warn( 15 | _FUTURE_MIGRATION_WARNING_MESSAGE, 16 | UserWarning, 17 | stacklevel=2 # Ensures the warning points to the user's import line 18 | ) 19 | 20 | __version__ = "0.0.9post2" 21 | from .RAGPretrainedModel import RAGPretrainedModel 22 | from .RAGTrainer import RAGTrainer 23 | 24 | __all__ = ["RAGPretrainedModel", "RAGTrainer"] -------------------------------------------------------------------------------- /ragatouille/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .corpus_processor import CorpusProcessor 2 | from .preprocessors import llama_index_sentence_splitter 3 | from .training_data_processor import TrainingDataProcessor 4 | 5 | __all__ = [ 6 | "TrainingDataProcessor", 7 | "CorpusProcessor", 8 | "llama_index_sentence_splitter", 9 | ] 10 | -------------------------------------------------------------------------------- /ragatouille/data/corpus_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Union 2 | from uuid import uuid4 3 | 4 | from ragatouille.data.preprocessors import llama_index_sentence_splitter 5 | 6 | 7 | class CorpusProcessor: 8 | def __init__( 9 | self, 10 | document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter, 11 | preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None, 12 | ): 13 | self.document_splitter_fn = document_splitter_fn 14 | self.preprocessing_fn = preprocessing_fn 15 | 16 | def process_corpus( 17 | self, 18 | documents: list[str], 19 | document_ids: Optional[list[str]] = None, 20 | **splitter_kwargs, 21 | ) -> List[dict]: 22 | # TODO CHECK KWARGS 23 | document_ids = ( 24 | [str(uuid4()) for _ in range(len(documents))] 25 | if document_ids is None 26 | else document_ids 27 | ) 28 | if self.document_splitter_fn is not None: 29 | documents = self.document_splitter_fn( 30 | documents, document_ids, **splitter_kwargs 31 | ) 32 | if self.preprocessing_fn is not None: 33 | if isinstance(self.preprocessing_fn, list): 34 | for fn in self.preprocessing_fn: 35 | documents = fn(documents, document_ids) 36 | return documents 37 | return self.preprocessing_fn(documents, document_ids) 38 | return documents 39 | -------------------------------------------------------------------------------- /ragatouille/data/preprocessors.py: -------------------------------------------------------------------------------- 1 | try: 2 | from llama_index import Document 3 | from llama_index.text_splitter import SentenceSplitter 4 | except ImportError: 5 | from llama_index.core import Document 6 | from llama_index.core.text_splitter import SentenceSplitter 7 | 8 | 9 | def llama_index_sentence_splitter( 10 | documents: list[str], document_ids: list[str], chunk_size=256 11 | ): 12 | chunk_overlap = min(chunk_size / 4, min(chunk_size / 2, 64)) 13 | chunks = [] 14 | node_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) 15 | docs = [[Document(text=doc)] for doc in documents] 16 | for doc_id, doc in zip(document_ids, docs): 17 | chunks += [ 18 | {"document_id": doc_id, "content": node.text} for node in node_parser(doc) 19 | ] 20 | return chunks 21 | -------------------------------------------------------------------------------- /ragatouille/data/training_data_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import defaultdict 4 | from itertools import product 5 | from pathlib import Path 6 | from typing import Literal, Union 7 | 8 | import srsly 9 | 10 | 11 | class TrainingDataProcessor: 12 | def __init__( 13 | self, 14 | collection: list[str], 15 | queries: list[str], 16 | negative_miner=None, 17 | ): 18 | self.collection = collection 19 | self.queries = queries 20 | self.negative_miner = negative_miner 21 | self._make_data_map() 22 | self.training_triplets = [] 23 | 24 | def process_raw_data( 25 | self, 26 | raw_data, 27 | data_type: Literal["pairs", "triplets", "labeled_pairs"], 28 | data_dir: Union[str, Path], 29 | export: bool = True, 30 | mine_hard_negatives: bool = True, 31 | num_new_negatives: int = 10, 32 | positive_label: int = 1, 33 | negative_label: int = 0, 34 | hard_negative_minimum_rank: int = 10, 35 | ): 36 | if self.negative_miner is None and mine_hard_negatives: 37 | raise ValueError( 38 | "mine_hard_negatives is True but no negative miner was provided!" 39 | ) 40 | if self.negative_miner: 41 | self.negative_miner.min_rank = hard_negative_minimum_rank 42 | if data_type == "pairs": 43 | self._process_raw_pairs( 44 | raw_data=raw_data, 45 | mine_hard_negatives=mine_hard_negatives, 46 | n_new_negatives=num_new_negatives, 47 | ) 48 | elif data_type == "labeled_pairs": 49 | self._process_raw_labeled_pairs( 50 | raw_data=raw_data, 51 | mine_hard_negatives=mine_hard_negatives, 52 | n_new_negatives=num_new_negatives, 53 | positive_label=positive_label, 54 | negative_label=negative_label, 55 | ) 56 | elif data_type == "triplets": 57 | self._process_raw_triplets( 58 | raw_data=raw_data, 59 | mine_hard_negatives=mine_hard_negatives, 60 | n_new_negatives=num_new_negatives, 61 | ) 62 | 63 | if export: 64 | self.export_training_data(data_dir) 65 | 66 | def _make_individual_triplets(self, query, positives, negatives): 67 | """Create the training data in ColBERT(v1) format from raw lists of triplets""" 68 | if len(positives) == 0 or len(negatives) == 0: 69 | return [] 70 | triplets = [] 71 | q = self.query_map[query] 72 | random.seed(42) 73 | if len(positives) > 1: 74 | all_pos_texts = [p for p in positives] 75 | max_triplets_per_query = 20 76 | negs_per_positive = max(1, max_triplets_per_query // len(all_pos_texts)) 77 | initial_triplets_count = 0 78 | for pos in all_pos_texts: 79 | p = self.passage_map[pos] 80 | chosen_negs = random.sample( 81 | negatives, min(len(negatives), negs_per_positive) 82 | ) 83 | for neg in chosen_negs: 84 | n = self.passage_map[neg] 85 | initial_triplets_count += 1 86 | triplets.append([q, p, n]) 87 | 88 | extra_triplets_needed = max_triplets_per_query - initial_triplets_count 89 | if extra_triplets_needed > 0: 90 | all_combinations = list(product(all_pos_texts, negatives)) 91 | random.seed(42) 92 | random.shuffle(all_combinations) 93 | for pos, neg in all_combinations: 94 | p = self.passage_map[pos] 95 | n = self.passage_map[neg] 96 | if [q, p, n] not in triplets: 97 | triplets.append([q, p, n]) 98 | extra_triplets_needed -= 1 99 | if extra_triplets_needed <= 0: 100 | break 101 | 102 | else: 103 | p = self.passage_map[positives[0]] 104 | for n in negatives: 105 | triplets.append([q, p, self.passage_map[n]]) 106 | 107 | return triplets 108 | 109 | def _get_new_negatives(self, query, passages, mine_hard_negatives, n_new_negatives): 110 | """Generate new negatives for each query, using either: 111 | - The assigned hard negative miner if mine_hard_negatives is True 112 | - Randomly sampling from the full collection otherwise 113 | """ 114 | if mine_hard_negatives: 115 | hard_negatives = self.negative_miner.mine_hard_negatives(query) 116 | candidates = [ 117 | x 118 | for x in hard_negatives 119 | if x not in passages["positives"] and x not in passages["negatives"] 120 | ] 121 | new_negatives = random.sample( 122 | candidates, 123 | min(n_new_negatives, len(candidates)), 124 | ) 125 | else: 126 | new_negatives = [ 127 | x 128 | for x in random.sample( 129 | self.collection, min(n_new_negatives, len(self.collection)) 130 | ) 131 | if x not in passages["positives"] and x not in passages["negatives"] 132 | ] 133 | 134 | return new_negatives 135 | 136 | def _process_raw_pairs(self, raw_data, mine_hard_negatives, n_new_negatives): 137 | """Convert unlabeled pairs into training triplets. 138 | It's assumed unlabeled pairs are always in the format (query, relevant_passage)""" 139 | training_triplets = [] 140 | raw_grouped_triplets = defaultdict(lambda: defaultdict(list)) 141 | 142 | for query, positive in raw_data: 143 | if isinstance(positive, str): 144 | positive = [positive] 145 | elif isinstance(positive, dict): 146 | positive = [positive["content"]] 147 | raw_grouped_triplets[query]["positives"] += positive 148 | 149 | for query, passages in raw_grouped_triplets.items(): 150 | if n_new_negatives > 0: 151 | passages["negatives"] += self._get_new_negatives( 152 | query=query, 153 | passages=passages, 154 | mine_hard_negatives=mine_hard_negatives, 155 | n_new_negatives=n_new_negatives, 156 | ) 157 | training_triplets += self._make_individual_triplets( 158 | query=query, 159 | positives=list(set(passages["positives"])), 160 | negatives=list(set(passages["negatives"])), 161 | ) 162 | self.training_triplets = training_triplets 163 | 164 | def _process_raw_labeled_pairs( 165 | self, 166 | raw_data, 167 | mine_hard_negatives, 168 | n_new_negatives, 169 | positive_label, 170 | negative_label, 171 | ): 172 | """ 173 | Convert labeled pairs intro training triplets. 174 | Labeled pairs are in the format (query, passage, label) 175 | """ 176 | training_triplets = [] 177 | raw_grouped_triplets = defaultdict(lambda: defaultdict(list)) 178 | 179 | for query, passage, label in raw_data: 180 | if isinstance(passage, str): 181 | passage = [passage] 182 | if label == positive_label: 183 | label = "positives" 184 | elif label == negative_label: 185 | label = "negatives" 186 | else: 187 | raise ValueError( 188 | f"Label {label} must correspond to either positive_label or negative_label!" 189 | ) 190 | 191 | raw_grouped_triplets[query][label] += passage 192 | 193 | for query, passages in raw_grouped_triplets.items(): 194 | if n_new_negatives > 0: 195 | passages["negatives"] += self._get_new_negatives( 196 | query=query, 197 | passages=passages, 198 | mine_hard_negatives=mine_hard_negatives, 199 | n_new_negatives=n_new_negatives, 200 | ) 201 | 202 | training_triplets += self._make_individual_triplets( 203 | query=query, 204 | positives=passages["positives"], 205 | negatives=passages["negatives"], 206 | ) 207 | self.training_triplets = training_triplets 208 | 209 | def _process_raw_triplets(self, raw_data, mine_hard_negatives, n_new_negatives): 210 | """ 211 | Convert raw triplets 212 | (query, positives : str | list[str], negatives: str | list[str]) 213 | into training triplets. 214 | """ 215 | training_triplets = [] 216 | raw_grouped_triplets = defaultdict(lambda: defaultdict(list)) 217 | for query, positive, negative in raw_data: 218 | if isinstance(positive, str): 219 | positive = [positive] 220 | if isinstance(negative, str): 221 | negative = [negative] 222 | 223 | raw_grouped_triplets[query]["positives"] += positive 224 | raw_grouped_triplets[query]["negatives"] += negative 225 | 226 | for query, passages in raw_grouped_triplets.items(): 227 | if n_new_negatives > 0: 228 | passages["negatives"] += self._get_new_negatives( 229 | query=query, 230 | passages=passages, 231 | mine_hard_negatives=mine_hard_negatives, 232 | n_new_negatives=n_new_negatives, 233 | ) 234 | training_triplets += self._make_individual_triplets( 235 | query=query, 236 | positives=passages["positives"], 237 | negatives=passages["negatives"], 238 | ) 239 | self.training_triplets = training_triplets 240 | 241 | def _make_data_map(self): 242 | """ 243 | Generate a query_text: query_id and passage_text: passage_id mapping 244 | To easily generate ColBERT-format training data. 245 | """ 246 | self.query_map = {} 247 | self.passage_map = {} 248 | 249 | for i, query in enumerate(self.queries): 250 | self.query_map[query] = i 251 | for i, passage in enumerate(list(self.collection)): 252 | self.passage_map[passage] = i 253 | 254 | def export_training_data(self, path: Union[str, Path]): 255 | """ 256 | Export training data for both training and versioning purposes. 257 | {path} should ideally be dvc versioned. 258 | """ 259 | 260 | path = Path(path) 261 | 262 | # Create the directory if it does not exist 263 | os.makedirs(path, exist_ok=True) 264 | 265 | with open(path / "queries.train.colbert.tsv", "w") as f: 266 | for query, idx in self.query_map.items(): 267 | query = query.replace("\t", " ").replace("\n", " ") 268 | f.write(f"{idx}\t{query}\n") 269 | with open(path / "corpus.train.colbert.tsv", "w") as f: 270 | for document, idx in self.passage_map.items(): 271 | document = document.replace("\t", " ").replace("\n", " ") 272 | f.write(f"{idx}\t{document}\n") 273 | 274 | random.seed(42) 275 | random.shuffle(self.training_triplets) 276 | srsly.write_jsonl(path / "triples.train.colbert.jsonl", self.training_triplets) 277 | -------------------------------------------------------------------------------- /ragatouille/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | from ragatouille.integrations._langchain import ( 2 | RAGatouilleLangChainCompressor, 3 | RAGatouilleLangChainRetriever, 4 | ) 5 | 6 | __all__ = ["RAGatouilleLangChainRetriever", "RAGatouilleLangChainCompressor"] 7 | -------------------------------------------------------------------------------- /ragatouille/integrations/_langchain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Sequence 2 | 3 | from langchain.retrievers.document_compressors.base import BaseDocumentCompressor 4 | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks 5 | from langchain_core.documents import Document 6 | from langchain_core.retrievers import BaseRetriever 7 | 8 | 9 | class RAGatouilleLangChainRetriever(BaseRetriever): 10 | model: Any 11 | kwargs: dict = {} 12 | 13 | def _get_relevant_documents( 14 | self, 15 | query: str, 16 | *, 17 | run_manager: CallbackManagerForRetrieverRun, # noqa 18 | ) -> List[Document]: 19 | """Get documents relevant to a query.""" 20 | docs = self.model.search(query, **self.kwargs) 21 | return [ 22 | Document( 23 | page_content=doc["content"], metadata=doc.get("document_metadata", {}) 24 | ) 25 | for doc in docs 26 | ] 27 | 28 | 29 | class RAGatouilleLangChainCompressor(BaseDocumentCompressor): 30 | model: Any 31 | kwargs: dict = {} 32 | k: int = 5 33 | 34 | class Config: 35 | """Configuration for this pydantic object.""" 36 | 37 | arbitrary_types_allowed = True 38 | 39 | def compress_documents( 40 | self, 41 | documents: Sequence[Document], 42 | query: str, 43 | callbacks: Optional[Callbacks] = None, # noqa 44 | **kwargs, 45 | ) -> Any: 46 | """Rerank a list of documents relevant to a query.""" 47 | doc_list = list(documents) 48 | _docs = [d.page_content for d in doc_list] 49 | results = self.model.rerank( 50 | query=query, 51 | documents=_docs, 52 | k=kwargs.get("k", self.k), 53 | **self.kwargs, 54 | ) 55 | final_results = [] 56 | for r in results: 57 | doc = doc_list[r["result_index"]] 58 | doc.metadata["relevance_score"] = r["score"] 59 | final_results.append(doc) 60 | return final_results 61 | -------------------------------------------------------------------------------- /ragatouille/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import LateInteractionModel 2 | from .colbert import ColBERT 3 | 4 | __all__ = ["LateInteractionModel", "ColBERT"] 5 | -------------------------------------------------------------------------------- /ragatouille/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | 6 | class LateInteractionModel(ABC): 7 | @abstractmethod 8 | def __init__( 9 | self, 10 | pretrained_model_name_or_path: Union[str, Path], 11 | n_gpu, 12 | ): 13 | ... 14 | 15 | @abstractmethod 16 | def train(): 17 | ... 18 | 19 | @abstractmethod 20 | def index(self, name: str, collection: list[str]): 21 | ... 22 | 23 | @abstractmethod 24 | def add_to_index(self): 25 | ... 26 | 27 | @abstractmethod 28 | def search(self, name: str, query: Union[str, list[str]]): 29 | ... 30 | 31 | @abstractmethod 32 | def _search(self, name: str, query: str): 33 | ... 34 | 35 | @abstractmethod 36 | def _batch_search(self, name: str, queries: list[str]): 37 | ... 38 | -------------------------------------------------------------------------------- /ragatouille/models/index.py: -------------------------------------------------------------------------------- 1 | import time 2 | from abc import ABC, abstractmethod 3 | from copy import deepcopy 4 | from pathlib import Path 5 | from typing import Any, List, Literal, Optional, TypeVar, Union 6 | 7 | import srsly 8 | import torch 9 | from colbert import Indexer, IndexUpdater, Searcher 10 | from colbert.indexing.collection_indexer import CollectionIndexer 11 | from colbert.infra import ColBERTConfig 12 | 13 | from ragatouille.models import torch_kmeans 14 | 15 | IndexType = Literal["FLAT", "HNSW", "PLAID"] 16 | 17 | 18 | class ModelIndex(ABC): 19 | index_type: IndexType 20 | 21 | def __init__( 22 | self, 23 | config: ColBERTConfig, 24 | ) -> None: 25 | self.config = config 26 | 27 | @staticmethod 28 | @abstractmethod 29 | def construct( 30 | config: ColBERTConfig, 31 | checkpoint: str, 32 | collection: List[str], 33 | index_name: Optional["str"] = None, 34 | overwrite: Union[bool, str] = "reuse", 35 | verbose: bool = True, 36 | **kwargs, 37 | ) -> "ModelIndex": ... 38 | 39 | @staticmethod 40 | @abstractmethod 41 | def load_from_file( 42 | index_path: str, 43 | index_name: Optional[str], 44 | index_config: dict[str, Any], 45 | config: ColBERTConfig, 46 | verbose: bool = True, 47 | ) -> "ModelIndex": ... 48 | 49 | @abstractmethod 50 | def build( 51 | self, 52 | checkpoint: Union[str, Path], 53 | collection: List[str], 54 | index_name: Optional["str"] = None, 55 | overwrite: Union[bool, str] = "reuse", 56 | verbose: bool = True, 57 | ) -> None: ... 58 | 59 | @abstractmethod 60 | def search( 61 | self, 62 | config: ColBERTConfig, 63 | checkpoint: Union[str, Path], 64 | collection: List[str], 65 | index_name: Optional[str], 66 | base_model_max_tokens: int, 67 | query: Union[str, list[str]], 68 | k: int = 10, 69 | pids: Optional[List[int]] = None, 70 | force_reload: bool = False, 71 | **kwargs, 72 | ) -> list[tuple[list, list, list]]: ... 73 | 74 | @abstractmethod 75 | def _search(self, query: str, k: int, pids: Optional[List[int]] = None): ... 76 | 77 | @abstractmethod 78 | def _batch_search(self, query: list[str], k: int): ... 79 | 80 | @abstractmethod 81 | def add( 82 | self, 83 | config: ColBERTConfig, 84 | checkpoint: Union[str, Path], 85 | collection: List[str], 86 | index_root: str, 87 | index_name: str, 88 | new_collection: List[str], 89 | verbose: bool = True, 90 | **kwargs, 91 | ) -> None: ... 92 | 93 | @abstractmethod 94 | def delete( 95 | self, 96 | config: ColBERTConfig, 97 | checkpoint: Union[str, Path], 98 | collection: List[str], 99 | index_name: str, 100 | pids_to_remove: Union[TypeVar("T"), List[TypeVar("T")]], 101 | verbose: bool = True, 102 | ) -> None: ... 103 | 104 | @abstractmethod 105 | def _export_config(self) -> dict[str, Any]: ... 106 | 107 | def export_metadata(self) -> dict[str, Any]: 108 | config = self._export_config() 109 | config["index_type"] = self.index_type 110 | return config 111 | 112 | 113 | class FLATModelIndex(ModelIndex): 114 | index_type = "FLAT" 115 | 116 | 117 | class HNSWModelIndex(ModelIndex): 118 | index_type = "HNSW" 119 | 120 | 121 | class PLAIDModelIndex(ModelIndex): 122 | _DEFAULT_INDEX_BSIZE = 32 123 | index_type = "PLAID" 124 | faiss_kmeans = staticmethod(deepcopy(CollectionIndexer._train_kmeans)) 125 | pytorch_kmeans = staticmethod(torch_kmeans._train_kmeans) 126 | 127 | def __init__(self, config: ColBERTConfig) -> None: 128 | super().__init__(config) 129 | self.searcher: Optional[Searcher] = None 130 | 131 | @staticmethod 132 | def construct( 133 | config: ColBERTConfig, 134 | checkpoint: Union[str, Path], 135 | collection: List[str], 136 | index_name: Optional["str"] = None, 137 | overwrite: Union[bool, str] = "reuse", 138 | verbose: bool = True, 139 | **kwargs, 140 | ) -> "PLAIDModelIndex": 141 | return PLAIDModelIndex(config).build( 142 | checkpoint, collection, index_name, overwrite, verbose, **kwargs 143 | ) 144 | 145 | @staticmethod 146 | def load_from_file( 147 | index_path: str, 148 | index_name: Optional[str], 149 | index_config: dict[str, Any], 150 | config: ColBERTConfig, 151 | verbose: bool = True, 152 | ) -> "PLAIDModelIndex": 153 | _, _, _, _ = index_path, index_name, index_config, verbose 154 | return PLAIDModelIndex(config) 155 | 156 | def build( 157 | self, 158 | checkpoint: Union[str, Path], 159 | collection: List[str], 160 | index_name: Optional["str"] = None, 161 | overwrite: Union[bool, str] = "reuse", 162 | verbose: bool = True, 163 | **kwargs, 164 | ) -> "PLAIDModelIndex": 165 | bsize = kwargs.get("bsize", PLAIDModelIndex._DEFAULT_INDEX_BSIZE) 166 | assert isinstance(bsize, int) 167 | 168 | nbits = 2 169 | if len(collection) < 10000: 170 | nbits = 4 171 | self.config = ColBERTConfig.from_existing( 172 | self.config, ColBERTConfig(nbits=nbits, index_bsize=bsize) 173 | ) 174 | 175 | # Instruct colbert-ai to disable forking if nranks == 1 176 | self.config.avoid_fork_if_possible = True 177 | 178 | if len(collection) > 100000: 179 | self.config.kmeans_niters = 4 180 | elif len(collection) > 50000: 181 | self.config.kmeans_niters = 10 182 | else: 183 | self.config.kmeans_niters = 20 184 | 185 | # Monkey-patch colbert-ai to avoid using FAISS 186 | monkey_patching = ( 187 | len(collection) < 75000 and kwargs.get("use_faiss", False) is False 188 | ) 189 | if monkey_patching: 190 | print( 191 | "---- WARNING! You are using PLAID with an experimental replacement for FAISS for greater compatibility ----" 192 | ) 193 | print("This is a behaviour change from RAGatouille 0.8.0 onwards.") 194 | print( 195 | "This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations." 196 | ) 197 | print( 198 | "If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour." 199 | ) 200 | print("--------------------") 201 | CollectionIndexer._train_kmeans = self.pytorch_kmeans 202 | 203 | # Try to keep runtime stable -- these are values that empirically didn't degrade performance at all on 3 benchmarks. 204 | # More tests required before warning can be removed. 205 | try: 206 | indexer = Indexer( 207 | checkpoint=checkpoint, 208 | config=self.config, 209 | verbose=verbose, 210 | ) 211 | indexer.configure(avoid_fork_if_possible=True) 212 | indexer.index( 213 | name=index_name, collection=collection, overwrite=overwrite 214 | ) 215 | except Exception as err: 216 | print( 217 | f"PyTorch-based indexing did not succeed with error: {err}", 218 | "! Reverting to using FAISS and attempting again...", 219 | ) 220 | monkey_patching = False 221 | if monkey_patching is False: 222 | CollectionIndexer._train_kmeans = self.faiss_kmeans 223 | if torch.cuda.is_available(): 224 | import faiss 225 | 226 | if not hasattr(faiss, "StandardGpuResources"): 227 | print( 228 | "________________________________________________________________________________\n" 229 | "WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n", 230 | "This means that indexing will be slow. To make use of your GPU.\n" 231 | "Please install `faiss-gpu` by running:\n" 232 | "pip uninstall --y faiss-cpu & pip install faiss-gpu\n", 233 | "________________________________________________________________________________", 234 | ) 235 | print("Will continue with CPU indexing in 5 seconds...") 236 | time.sleep(5) 237 | indexer = Indexer( 238 | checkpoint=checkpoint, 239 | config=self.config, 240 | verbose=verbose, 241 | ) 242 | indexer.configure(avoid_fork_if_possible=True) 243 | indexer.index(name=index_name, collection=collection, overwrite=overwrite) 244 | 245 | return self 246 | 247 | def _load_searcher( 248 | self, 249 | checkpoint: Union[str, Path], 250 | collection: List[str], 251 | index_name: Optional[str], 252 | force_fast: bool = False, 253 | ): 254 | print( 255 | f"Loading searcher for index {index_name} for the first time...", 256 | "This may take a few seconds", 257 | ) 258 | self.searcher = Searcher( 259 | checkpoint=checkpoint, 260 | config=None, 261 | collection=collection, 262 | index_root=self.config.root, 263 | index=index_name, 264 | ) 265 | 266 | if not force_fast: 267 | self.searcher.configure(ndocs=1024) 268 | self.searcher.configure(ncells=16) 269 | if len(self.searcher.collection) < 10000: 270 | self.searcher.configure(ncells=8) 271 | self.searcher.configure(centroid_score_threshold=0.4) 272 | elif len(self.searcher.collection) < 100000: 273 | self.searcher.configure(ncells=4) 274 | self.searcher.configure(centroid_score_threshold=0.45) 275 | # Otherwise, use defaults for k 276 | else: 277 | # Use fast settingss 278 | self.searcher.configure(ncells=1) 279 | self.searcher.configure(centroid_score_threshold=0.5) 280 | self.searcher.configure(ndocs=256) 281 | 282 | print("Searcher loaded!") 283 | 284 | def _search(self, query: str, k: int, pids: Optional[List[int]] = None): 285 | assert self.searcher is not None 286 | return self.searcher.search(query, k=k, pids=pids) 287 | 288 | def _batch_search(self, query: list[str], k: int): 289 | assert self.searcher is not None 290 | queries = {i: x for i, x in enumerate(query)} 291 | results = self.searcher.search_all(queries, k=k) 292 | results = [ 293 | [list(zip(*value))[i] for i in range(3)] 294 | for value in results.todict().values() 295 | ] 296 | return results 297 | 298 | def _upgrade_searcher_maxlen(self, maxlen: int, base_model_max_tokens: int): 299 | assert self.searcher is not None 300 | # Keep maxlen stable at 32 for short queries for easier visualisation 301 | maxlen = min(max(maxlen, 32), base_model_max_tokens) 302 | self.searcher.config.query_maxlen = maxlen 303 | self.searcher.checkpoint.query_tokenizer.query_maxlen = maxlen 304 | 305 | def search( 306 | self, 307 | config: ColBERTConfig, 308 | checkpoint: Union[str, Path], 309 | collection: List[str], 310 | index_name: Optional[str], 311 | base_model_max_tokens: int, 312 | query: Union[str, list[str]], 313 | k: int = 10, 314 | pids: Optional[List[int]] = None, 315 | force_reload: bool = False, 316 | **kwargs, 317 | ) -> list[tuple[list, list, list]]: 318 | self.config = config 319 | 320 | force_fast = kwargs.get("force_fast", False) 321 | assert isinstance(force_fast, bool) 322 | 323 | if self.searcher is None or force_reload: 324 | self._load_searcher( 325 | checkpoint, 326 | collection, 327 | index_name, 328 | force_fast, 329 | ) 330 | assert self.searcher is not None 331 | 332 | base_ncells = self.searcher.config.ncells 333 | base_ndocs = self.searcher.config.ndocs 334 | 335 | if k > len(self.searcher.collection): 336 | print( 337 | "WARNING: k value is larger than the number of documents in the index!", 338 | f"Lowering k to {len(self.searcher.collection)}...", 339 | ) 340 | k = len(self.searcher.collection) 341 | 342 | # For smaller collections, we need a higher ncells value to ensure we return enough results 343 | if k > (32 * self.searcher.config.ncells): 344 | self.searcher.configure(ncells=min((k // 32 + 2), base_ncells)) 345 | 346 | self.searcher.configure(ndocs=max(k * 4, base_ndocs)) 347 | 348 | if isinstance(query, str): 349 | query_length = int(len(query.split(" ")) * 1.35) 350 | self._upgrade_searcher_maxlen(query_length, base_model_max_tokens) 351 | results = [self._search(query, k, pids)] 352 | else: 353 | longest_query_length = max([int(len(x.split(" ")) * 1.35) for x in query]) 354 | self._upgrade_searcher_maxlen(longest_query_length, base_model_max_tokens) 355 | results = self._batch_search(query, k) 356 | 357 | # Restore original ncells&ndocs if it had to be changed for large k values 358 | self.searcher.configure(ncells=base_ncells) 359 | self.searcher.configure(ndocs=base_ndocs) 360 | 361 | return results # type: ignore 362 | 363 | @staticmethod 364 | def _should_rebuild(current_len: int, new_doc_len: int) -> bool: 365 | """ 366 | Heuristic to determine if it is more efficient to rebuild the index instead of updating it. 367 | """ 368 | return current_len + new_doc_len < 5000 or new_doc_len > current_len * 0.05 369 | 370 | def add( 371 | self, 372 | config: ColBERTConfig, 373 | checkpoint: Union[str, Path], 374 | collection: List[str], 375 | index_root: str, 376 | index_name: str, 377 | new_collection: List[str], 378 | verbose: bool = True, 379 | **kwargs, 380 | ) -> None: 381 | self.config = config 382 | 383 | bsize = kwargs.get("bsize", PLAIDModelIndex._DEFAULT_INDEX_BSIZE) 384 | assert isinstance(bsize, int) 385 | 386 | searcher = Searcher( 387 | checkpoint=checkpoint, 388 | config=None, 389 | collection=collection, 390 | index=index_name, 391 | index_root=index_root, 392 | verbose=verbose, 393 | ) 394 | 395 | if PLAIDModelIndex._should_rebuild( 396 | len(searcher.collection), len(new_collection) 397 | ): 398 | self.build( 399 | checkpoint=checkpoint, 400 | collection=collection + new_collection, 401 | index_name=index_name, 402 | overwrite="force_silent_overwrite", 403 | verbose=verbose, 404 | **kwargs, 405 | ) 406 | else: 407 | if self.config.index_bsize != bsize: # Update bsize if it's different 408 | self.config.index_bsize = bsize 409 | 410 | updater = IndexUpdater( 411 | config=self.config, searcher=searcher, checkpoint=checkpoint 412 | ) 413 | updater.add(new_collection) 414 | updater.persist_to_disk() 415 | 416 | def delete( 417 | self, 418 | config: ColBERTConfig, 419 | checkpoint: Union[str, Path], 420 | collection: List[str], 421 | index_name: str, 422 | pids_to_remove: Union[TypeVar("T"), List[TypeVar("T")]], 423 | verbose: bool = True, 424 | ) -> None: 425 | self.config = config 426 | 427 | # Initialize the searcher and updater 428 | searcher = Searcher( 429 | checkpoint=checkpoint, 430 | config=None, 431 | collection=collection, 432 | index=index_name, 433 | verbose=verbose, 434 | ) 435 | updater = IndexUpdater(config=config, searcher=searcher, checkpoint=checkpoint) 436 | 437 | updater.remove(pids_to_remove) 438 | updater.persist_to_disk() 439 | 440 | def _export_config(self) -> dict[str, Any]: 441 | return {} 442 | 443 | 444 | class ModelIndexFactory: 445 | _MODEL_INDEX_BY_NAME = { 446 | "FLAT": FLATModelIndex, 447 | "HNSW": HNSWModelIndex, 448 | "PLAID": PLAIDModelIndex, 449 | } 450 | 451 | @staticmethod 452 | def _raise_if_invalid_index_type(index_type: str) -> IndexType: 453 | if index_type not in ["FLAT", "HNSW", "PLAID"]: 454 | raise ValueError( 455 | f"Unsupported index_type `{index_type}`; it must be one of 'FLAT', 'HNSW', OR 'PLAID'" 456 | ) 457 | return index_type # type: ignore 458 | 459 | @staticmethod 460 | def construct( 461 | index_type: Union[Literal["auto"], IndexType], 462 | config: ColBERTConfig, 463 | checkpoint: str, 464 | collection: List[str], 465 | index_name: Optional["str"] = None, 466 | overwrite: Union[bool, str] = "reuse", 467 | verbose: bool = True, 468 | **kwargs, 469 | ) -> ModelIndex: 470 | # Automatically choose the appropriate index for the desired "workload". 471 | if index_type == "auto": 472 | # NOTE: For now only PLAID indexes are supported. 473 | index_type = "PLAID" 474 | return ModelIndexFactory._MODEL_INDEX_BY_NAME[ 475 | ModelIndexFactory._raise_if_invalid_index_type(index_type) 476 | ].construct( 477 | config, checkpoint, collection, index_name, overwrite, verbose, **kwargs 478 | ) 479 | 480 | @staticmethod 481 | def load_from_file( 482 | index_path: str, 483 | index_name: Optional[str], 484 | config: ColBERTConfig, 485 | verbose: bool = True, 486 | ) -> ModelIndex: 487 | metadata = srsly.read_json(index_path + "/metadata.json") 488 | try: 489 | index_config = metadata["RAGatouille"]["index_config"] # type: ignore 490 | except KeyError: 491 | if verbose: 492 | print( 493 | f"Constructing default index configuration for index `{index_name}` as it does not contain RAGatouille specific metadata." 494 | ) 495 | index_config = { 496 | "index_type": "PLAID", 497 | "index_name": index_name, 498 | } 499 | index_name = ( 500 | index_name if index_name is not None else index_config["index_name"] # type: ignore 501 | ) 502 | return ModelIndexFactory._MODEL_INDEX_BY_NAME[ 503 | ModelIndexFactory._raise_if_invalid_index_type(index_config["index_type"]) # type: ignore 504 | ].load_from_file(index_path, index_name, index_config, config, verbose) 505 | -------------------------------------------------------------------------------- /ragatouille/models/torch_kmeans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fast_pytorch_kmeans import KMeans 3 | 4 | 5 | def _train_kmeans(self, sample, shared_lists): # noqa: ARG001 6 | if self.use_gpu: 7 | torch.cuda.empty_cache() 8 | centroids = compute_pytorch_kmeans( 9 | sample, 10 | self.config.dim, 11 | self.num_partitions, 12 | self.config.kmeans_niters, 13 | self.use_gpu, 14 | ) 15 | centroids = torch.nn.functional.normalize(centroids, dim=-1) 16 | if self.use_gpu: 17 | centroids = centroids.half() 18 | else: 19 | centroids = centroids.float() 20 | return centroids 21 | 22 | 23 | def compute_pytorch_kmeans( 24 | sample, 25 | dim, # noqa compatibility 26 | num_partitions, 27 | kmeans_niters, 28 | use_gpu, 29 | batch_size=16000, 30 | verbose=1, 31 | seed=123, 32 | max_points_per_centroid=256, 33 | min_points_per_centroid=10, 34 | ): 35 | device = torch.device("cuda" if use_gpu else "cpu") 36 | sample = sample.to(device) 37 | total_size = sample.shape[0] 38 | 39 | torch.manual_seed(seed) 40 | 41 | # Subsample the training set if too large 42 | if total_size > num_partitions * max_points_per_centroid: 43 | print("too many!") 44 | print("partitions:", num_partitions) 45 | print(total_size) 46 | perm = torch.randperm(total_size, device=device)[ 47 | : num_partitions * max_points_per_centroid 48 | ] 49 | sample = sample[perm] 50 | total_size = sample.shape[0] 51 | print("reduced size:") 52 | print(total_size) 53 | elif total_size < num_partitions * min_points_per_centroid: 54 | if verbose: 55 | print( 56 | f"Warning: number of training points ({total_size}) is less than " 57 | f"the minimum recommended ({num_partitions * min_points_per_centroid})" 58 | ) 59 | 60 | sample = sample.float() 61 | minibatch = None 62 | if num_partitions > 15000: 63 | minibatch = batch_size 64 | if num_partitions > 30000: 65 | minibatch = int(batch_size / 2) 66 | 67 | kmeans = KMeans( 68 | n_clusters=num_partitions, 69 | mode="euclidean", 70 | verbose=1, 71 | max_iter=kmeans_niters, 72 | minibatch=minibatch, 73 | ) 74 | kmeans.fit(sample) 75 | return kmeans.centroids 76 | 77 | 78 | """ Archived homebrew quick implementation to be revisited. """ 79 | # def compute_pytorch_kmeans( 80 | # sample, 81 | # dim, 82 | # num_partitions, 83 | # kmeans_niters, 84 | # use_gpu, 85 | # batch_size=8000, 86 | # tol=1e-4, 87 | # verbose=1, 88 | # seed=1234, 89 | # max_points_per_centroid=256, 90 | # min_points_per_centroid=10, 91 | # nredo=1, 92 | # ): 93 | # device = torch.device("cuda" if use_gpu else "cpu") 94 | # sample = sample.to(device) 95 | # total_size = sample.shape[0] 96 | 97 | # torch.manual_seed(seed) 98 | 99 | # # Subsample the training set if too large 100 | # if total_size > num_partitions * max_points_per_centroid: 101 | # print("too many!") 102 | # print("partitions:", num_partitions) 103 | # print(total_size) 104 | # perm = torch.randperm(total_size, device=device)[ 105 | # : num_partitions * max_points_per_centroid 106 | # ] 107 | # sample = sample[perm] 108 | # total_size = sample.shape[0] 109 | # print("reduced size:") 110 | # print(total_size) 111 | # elif total_size < num_partitions * min_points_per_centroid: 112 | # if verbose: 113 | # print( 114 | # f"Warning: number of training points ({total_size}) is less than " 115 | # f"the minimum recommended ({num_partitions * min_points_per_centroid})" 116 | # ) 117 | 118 | # sample = sample.float() 119 | 120 | # best_obj = float("inf") 121 | # best_centroids = None 122 | 123 | # for redo in range(nredo): # noqa 124 | # centroids = torch.randn(num_partitions, dim, dtype=sample.dtype, device=device) 125 | # centroids /= torch.norm(centroids, dim=1, keepdim=True) 126 | 127 | # with torch.no_grad(): 128 | # for i in range(kmeans_niters): 129 | # if verbose >= 1: 130 | # print(f"KMEANS - Iteration {i+1} of {kmeans_niters}") 131 | # start_time = time.time() 132 | # obj = 0.0 133 | # counts = torch.zeros(num_partitions, dtype=torch.long, device=device) 134 | 135 | # # Process data in batches 136 | # if verbose >= 1: 137 | # _iterator = tqdm.tqdm( 138 | # range(0, total_size, batch_size), 139 | # total=math.ceil(total_size / batch_size), 140 | # desc="Batches for iteration", 141 | # ) 142 | # else: 143 | # _iterator = range(0, total_size, batch_size) 144 | # for batch_start in _iterator: 145 | # batch_end = min(batch_start + batch_size, total_size) 146 | # batch = sample[batch_start:batch_end] 147 | 148 | # distances = torch.cdist(batch, centroids, p=2.0) 149 | # batch_assignments = torch.min(distances, dim=1)[1] 150 | # obj += torch.sum( 151 | # distances[torch.arange(batch.size(0)), batch_assignments] 152 | # ) 153 | 154 | # counts.index_add_( 155 | # 0, batch_assignments, torch.ones_like(batch_assignments) 156 | # ) 157 | 158 | # for j in range(num_partitions): 159 | # assigned_points = batch[batch_assignments == j] 160 | # if len(assigned_points) > 0: 161 | # centroids[j] += assigned_points.sum(dim=0) 162 | 163 | # # Handle empty clusters by assigning them a random data point from the largest cluster 164 | # empty_clusters = torch.where(counts == 0)[0] 165 | # if empty_clusters.numel() > 0: 166 | # for ec in empty_clusters: 167 | # largest_cluster = torch.argmax(counts) 168 | # idx = torch.randint(0, total_size, (1,), device=device) 169 | # counts[largest_cluster] -= 1 170 | # counts[ec] = 1 171 | # centroids[ec] = sample[idx] 172 | 173 | # centroids /= torch.norm(centroids, dim=1, keepdim=True) 174 | 175 | # if verbose >= 2: 176 | # print( 177 | # f"Iteration: {i+1}, Objective: {obj.item():.4f}, Time: {time.time() - start_time:.4f}s" 178 | # ) 179 | 180 | # # Check for convergence 181 | # if obj < best_obj: 182 | # best_obj = obj 183 | # best_centroids = centroids.clone() 184 | # if obj <= tol: 185 | # break 186 | 187 | # torch.cuda.empty_cache() # Move outside the inner loop 188 | 189 | # if verbose >= 1: 190 | # print(f"Best objective: {best_obj.item():.4f}") 191 | 192 | # print(best_centroids) 193 | 194 | # return best_centroids 195 | 196 | 197 | """ Extremely slow implementation using voyager ANN below. CPU-only. Results ~= FAISS but would only be worth it if storing in int8, which might be for later.""" 198 | # from voyager import Index, Space 199 | 200 | # def compute_pytorch_kmeans_via_voyager( 201 | # sample, 202 | # dim, 203 | # num_partitions, 204 | # kmeans_niters, 205 | # use_gpu, 206 | # batch_size=16000, 207 | # tol=1e-4, 208 | # verbose=3, 209 | # seed=1234, 210 | # max_points_per_centroid=256, 211 | # min_points_per_centroid=10, 212 | # nredo=1, 213 | # ): 214 | # device = torch.device("cuda" if use_gpu else "cpu") 215 | # total_size = sample.shape[0] 216 | 217 | # # Set random seed for reproducibility 218 | # torch.manual_seed(seed) 219 | 220 | # # Convert to float32 for better performance 221 | # sample = sample.float() 222 | 223 | # best_obj = float("inf") 224 | # best_centroids = None 225 | 226 | # for redo in range(nredo): 227 | # # Initialize centroids randomly 228 | # centroids = torch.randn(num_partitions, dim, dtype=sample.dtype, device=device) 229 | # centroids = centroids / torch.norm(centroids, dim=1, keepdim=True) 230 | 231 | # # Build Voyager index if the number of data points exceeds 128,000 232 | # # use_index = total_size > 128000 233 | # use_index = True 234 | # if use_index: 235 | # index = Index(Space.Euclidean, num_dimensions=dim) 236 | # index.add_items(centroids.cpu().numpy()) 237 | 238 | # for i in range(kmeans_niters): 239 | # start_time = time.time() 240 | # obj = 0.0 241 | # counts = torch.zeros(num_partitions, dtype=torch.long, device=device) 242 | 243 | # # Process data in batches 244 | # for batch_start in range(0, total_size, batch_size): 245 | # batch_end = min(batch_start + batch_size, total_size) 246 | # batch = sample[batch_start:batch_end].to(device) 247 | 248 | # if use_index: 249 | # # Search for nearest centroids using Voyager index 250 | # batch_assignments, batch_distances = index.query( 251 | # batch.cpu().numpy(), k=1, num_threads=-1 252 | # ) 253 | # batch_assignments = ( 254 | # torch.from_numpy(batch_assignments.astype(np.int64)) 255 | # .squeeze() 256 | # .to(device) 257 | # ) 258 | # batch_assignments = batch_assignments.long() 259 | # batch_distances = ( 260 | # torch.from_numpy(batch_distances.astype(np.float32)) 261 | # .squeeze() 262 | # .to(device) 263 | # ) 264 | # else: 265 | # # Compute distances using memory-efficient operations 266 | # distances = torch.cdist(batch, centroids, p=2.0) 267 | # batch_assignments = torch.min(distances, dim=1)[1] 268 | # batch_distances = distances[ 269 | # torch.arange(batch.size(0)), batch_assignments 270 | # ] 271 | 272 | # # Update objective and counts 273 | # obj += torch.sum(batch_distances) 274 | # counts += torch.bincount(batch_assignments, minlength=num_partitions) 275 | 276 | # # Update centroids 277 | # for j in range(num_partitions): 278 | # assigned_points = batch[batch_assignments == j] 279 | # if len(assigned_points) > 0: 280 | # centroids[j] += assigned_points.sum(dim=0) 281 | 282 | # # Clear the batch from memory 283 | # del batch 284 | # torch.cuda.empty_cache() 285 | 286 | # # Handle empty clusters 287 | # empty_clusters = torch.where(counts == 0)[0] 288 | # if empty_clusters.numel() > 0: 289 | # for ec in empty_clusters: 290 | # # Find the largest cluster 291 | # largest_cluster = torch.argmax(counts) 292 | # # Assign the empty cluster to a random data point from the largest cluster 293 | # indexes = torch.where(counts == counts[largest_cluster])[0] 294 | # if indexes.numel() > 0: 295 | # idx = torch.randint(0, indexes.numel(), (1,), device=device) 296 | # counts[largest_cluster] -= 1 297 | # counts[ec] = 1 298 | # centroids[ec] = sample[batch_start + indexes[idx].item()] 299 | # # Normalize centroids 300 | # centroids = centroids / torch.norm(centroids, dim=1, keepdim=True) 301 | 302 | # if use_index: 303 | # # Update the Voyager index with the new centroids 304 | # index = Index(Space.Euclidean, num_dimensions=dim) 305 | # index.add_items(centroids.cpu().numpy()) 306 | 307 | # if verbose >= 2: 308 | # print( 309 | # f"Iteration: {i+1}, Objective: {obj.item():.4f}, Time: {time.time() - start_time:.4f}s" 310 | # ) 311 | 312 | # # Check for convergence 313 | # if obj < best_obj: 314 | # best_obj = obj 315 | # best_centroids = centroids.clone() 316 | 317 | # if verbose >= 1: 318 | # print(f"Best objective: {best_obj.item():.4f}") 319 | 320 | # return best_centroids 321 | -------------------------------------------------------------------------------- /ragatouille/models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from colbert.infra import ColBERTConfig 10 | from colbert.modeling.colbert import ColBERT 11 | from huggingface_hub import HfApi 12 | from huggingface_hub.utils import HfHubHTTPError 13 | from transformers import AutoModel, BertPreTrainedModel 14 | 15 | 16 | def seeded_shuffle(collection: list, seed: int = 42): 17 | random.seed(seed) 18 | random.shuffle(collection) 19 | return collection 20 | 21 | 22 | """ HUGGINGFACE """ 23 | 24 | 25 | def export_to_huggingface_hub( 26 | colbert_path: Union[str, Path], 27 | huggingface_repo_name: str, 28 | export_vespa_onnx: bool = False, 29 | use_tmp_dir: bool = False, 30 | ): 31 | # ensure model contains a valid ColBERT config before exporting 32 | colbert_config = ColBERTConfig.load_from_checkpoint(colbert_path) 33 | try: 34 | assert colbert_config is not None 35 | except Exception: 36 | print(f"Path {colbert_path} does not contain a valid ColBERT config!") 37 | 38 | export_path = colbert_path 39 | if use_tmp_dir: 40 | export_path = ".tmp/hugging_face_export" 41 | print("Using tmp dir to store export files...") 42 | colbert_model = ColBERT( 43 | colbert_path, 44 | colbert_config=colbert_config, 45 | ) 46 | print(f"Model loaded... saving export files to disk at {export_path}") 47 | try: 48 | save_model = colbert_model.save 49 | except Exception: 50 | save_model = colbert_model.module.save 51 | save_model(export_path) 52 | 53 | if export_vespa_onnx: 54 | rust_tokenizer_available = True 55 | if use_tmp_dir: 56 | try: 57 | colbert_model.raw_tokenizer.save_pretrained( 58 | export_path, legacy_format=False 59 | ) 60 | except Exception: 61 | rust_tokenizer_available = False 62 | else: 63 | rust_tokenizer_available = os.path.exists( 64 | Path(colbert_path) / "tokenizer.json" 65 | ) 66 | if not rust_tokenizer_available: 67 | print( 68 | "The tokenizer for your model does not seem to have a Fast Tokenizer implementation...\n", 69 | "This may cause problems when trying to use with Vespa!\n", 70 | "Proceeding anyway...", 71 | ) 72 | 73 | export_to_vespa_onnx(colbert_path, out_path=export_path) 74 | try: 75 | api = HfApi() 76 | api.create_repo(repo_id=huggingface_repo_name, repo_type="model", exist_ok=True) 77 | api.upload_folder( 78 | folder_path=export_path, 79 | repo_id=huggingface_repo_name, 80 | repo_type="model", 81 | ) 82 | print(f"Successfully uploaded model to {huggingface_repo_name}") 83 | except ValueError as e: 84 | print( 85 | f"Could not create repository on the huggingface hub.\n", 86 | f"Error: {e}\n", 87 | "Please make sure you are logged in (run huggingface-cli login)\n", 88 | "If the error persists, please open an issue on github. This is a beta feature!", 89 | ) 90 | except HfHubHTTPError: 91 | print( 92 | "You don't seem to have the rights to create a repository with this name...\n", 93 | "Please make sure your repo name is in the format 'yourusername/your-repo-name'", 94 | ) 95 | finally: 96 | if use_tmp_dir: 97 | shutil.rmtree(export_path) 98 | 99 | 100 | """ VESPA """ 101 | 102 | 103 | class VespaColBERT(BertPreTrainedModel): 104 | def __init__(self, config, dim): 105 | super().__init__(config) 106 | self.bert = AutoModel.from_config(config) 107 | self.linear = nn.Linear(config.hidden_size, dim, bias=False) 108 | self.init_weights() 109 | 110 | def forward(self, input_ids, attention_mask): 111 | Q = self.bert(input_ids, attention_mask=attention_mask)[0] 112 | Q = self.linear(Q) 113 | return torch.nn.functional.normalize(Q, p=2, dim=2) 114 | 115 | 116 | def export_to_vespa_onnx( 117 | colbert_path: Union[str, Path], 118 | out_path: Union[str, Path], 119 | out_file_name: str = "vespa_colbert.onnx", 120 | ): 121 | print(f"Exporting model {colbert_path} to {out_path}/{out_file_name}") 122 | out_path = Path(out_path) 123 | vespa_colbert = VespaColBERT.from_pretrained(colbert_path, dim=128) 124 | print("Model loaded, converting to ONNX...") 125 | input_names = ["input_ids", "attention_mask"] 126 | output_names = ["contextual"] 127 | input_ids = torch.ones(1, 32, dtype=torch.int64) 128 | attention_mask = torch.ones(1, 32, dtype=torch.int64) 129 | args = (input_ids, attention_mask) 130 | torch.onnx.export( 131 | vespa_colbert, 132 | args=args, 133 | f=str(out_path / out_file_name), 134 | input_names=input_names, 135 | output_names=output_names, 136 | dynamic_axes={ 137 | "input_ids": {0: "batch", 1: "batch"}, 138 | "attention_mask": {0: "batch", 1: "batch"}, 139 | "contextual": {0: "batch", 1: "batch"}, 140 | }, 141 | opset_version=17, 142 | ) 143 | print("Vespa ONNX export complete!") 144 | -------------------------------------------------------------------------------- /ragatouille/negative_miners/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import HardNegativeMiner 2 | from .simpleminer import SimpleMiner 3 | 4 | __all__ = ["HardNegativeMiner", "SimpleMiner"] 5 | -------------------------------------------------------------------------------- /ragatouille/negative_miners/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | 6 | class HardNegativeMiner(ABC): 7 | @abstractmethod 8 | def export_index(self, path: Union[str, Path]) -> bool: 9 | ... 10 | 11 | @abstractmethod 12 | def mine_hard_negatives( 13 | self, 14 | queries: list[str], 15 | collection: list[str], 16 | neg_k: int, 17 | ): 18 | ... 19 | 20 | @abstractmethod 21 | def _mine( 22 | self, 23 | queries: list[str], 24 | k: int, 25 | ): 26 | ... 27 | -------------------------------------------------------------------------------- /ragatouille/negative_miners/simpleminer.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from pathlib import Path 3 | from typing import Literal, Optional, Union 4 | 5 | import torch 6 | from sentence_transformers import SentenceTransformer 7 | from tqdm import tqdm 8 | from voyager import Index, Space, StorageDataType 9 | 10 | from .base import HardNegativeMiner 11 | 12 | 13 | class DenseModels(Enum): 14 | en_small = "BAAI/bge-small-en-v1.5" 15 | zh_small = "thenlper/gte-small-zh" 16 | fr_small = "OrdalieTech/Solon-embeddings-base-0.1" 17 | other_small = "intfloat/multilingual-e5-small" 18 | en_base = "BAAI/bge-base-en-v1.5" 19 | zh_base = "thenlper/gte-base-zh" 20 | fr_base = "OrdalieTech/Solon-embeddings-base-0.1" 21 | other_base = "intfloat/multilingual-e5-base" 22 | en_large = "BAAI/bge-large-en-v1.5" 23 | zh_large = "thenlper/gte-large-zh" 24 | fr_large = "OrdalieTech/Solon-embeddings-large-0.1" 25 | other_large = "intfloat/multilingual-e5-large" 26 | 27 | 28 | class SimpleMiner(HardNegativeMiner): 29 | """The simplest approach to hard negatives mining. 30 | Select the most appropriate, small-sized embedding model for the target language. 31 | And retrieve random negatives in the top 10-100 results. 32 | Strong baseline for quick, low-engineering hard negative mining.""" 33 | 34 | def __init__( 35 | self, 36 | language_code: str, 37 | model_size: Literal["small", "base", "large"] = "small", 38 | ) -> None: 39 | self.n_gpu = torch.cuda.device_count() 40 | self.target_language = language_code 41 | self.model_size = model_size 42 | if language_code not in ["en", "zh", "fr"]: 43 | language_code = "other" 44 | self.model_name = f"{language_code}_{model_size}" 45 | hub_model = DenseModels[self.model_name].value 46 | print(f"Loading Hard Negative SimpleMiner dense embedding model {hub_model}...") 47 | self.model = SentenceTransformer(hub_model) 48 | self.has_index = False 49 | self.min_rank = 10 50 | 51 | def build_index( 52 | self, 53 | collection, 54 | batch_size: int = 128, 55 | save_index: bool = False, 56 | save_path: Union[str, Path] = None, 57 | force_fp32: bool = True, 58 | ): 59 | print(f"Building hard negative index for {len(collection)} documents...") 60 | if len(collection) > 1000: 61 | pool = self.model.start_multi_process_pool() 62 | embeds = self.model.encode_multi_process( 63 | collection, pool, batch_size=batch_size 64 | ) 65 | self.model.stop_multi_process_pool(pool) 66 | else: 67 | embeds = self.model.encode(collection, batch_size=batch_size) 68 | 69 | print("All documents embedded, now adding to index...") 70 | 71 | self.max_rank = min(110, int(len(collection) // 10)) 72 | self.max_rank = min(self.max_rank, len(collection)) 73 | 74 | storage_type = StorageDataType.Float32 75 | if len(collection) > 500000 and not force_fp32: 76 | storage_type = StorageDataType.E4M3 77 | 78 | self.voyager_index = Index( 79 | Space.Cosine, 80 | num_dimensions=self.model.get_sentence_embedding_dimension(), 81 | storage_data_type=storage_type, 82 | ) 83 | 84 | self.corpus_map = {i: doc for i, doc in enumerate(collection)} 85 | id_to_vector = {} 86 | for i, emb in enumerate(embeds): 87 | id_to_vector[i] = emb 88 | self.corpus_map[i] = collection[i] 89 | del embeds 90 | 91 | self.voyager_index.add_items( 92 | vectors=[x for x in id_to_vector.values()], 93 | ids=[x for x in id_to_vector.keys()], 94 | num_threads=-1, 95 | ) 96 | 97 | del id_to_vector 98 | 99 | if save_index: 100 | print(f"Saving index to {save_path}...") 101 | self.export_index(save_path) 102 | else: 103 | print("save_index set to False, skipping saving hard negative index") 104 | print("Hard negative index generated") 105 | self.has_index = True 106 | 107 | def query_index(self, query, top_k=110): 108 | results = self.voyager_index.query( 109 | query, k=min(top_k, self.voyager_index.__len__()) 110 | ) 111 | return results 112 | 113 | def mine_hard_negatives( 114 | self, 115 | queries: Union[list[str], str], 116 | collection: Optional[list[str]] = None, 117 | save_index: bool = False, 118 | save_path: Union[str, Path] = None, 119 | force_fp32: bool = True, 120 | ): 121 | if self.has_index is False and collection is not None: 122 | self.build_index( 123 | collection, 124 | save_index=save_index, 125 | save_path=save_path, 126 | force_fp32=force_fp32, 127 | ) 128 | if isinstance(queries, str): 129 | return self._mine(queries) 130 | return self._batch_mine(queries) 131 | 132 | def _mine( 133 | self, 134 | query: str, 135 | ): 136 | q_emb = self.model.encode(query) 137 | query_results = self.query_index(q_emb, top_k=self.max_rank) 138 | if len(query_results) > self.min_rank: 139 | query_results = query_results[self.min_rank : self.max_rank] 140 | query_results = [self.corpus_map[x] for x in query_results[0]] 141 | return query_results 142 | 143 | def _batch_mine( 144 | self, 145 | queries: list[str], 146 | ): 147 | """Separate function to parallelise later on""" 148 | print(f"Retrieving hard negatives for {len(queries)} queries...") 149 | results = [] 150 | print("Embedding queries...") 151 | query_embeddings = self.model.encode(queries, show_progress_bar=True) 152 | print("Retrieving hard negatives...") 153 | for q_emb in tqdm(query_embeddings): 154 | query_results = self.query_index(q_emb, top_k=self.max_rank) 155 | query_results = query_results[self.min_rank : self.max_rank] 156 | query_results = [self.corpus_map[x.id] for x in query_results] 157 | results.append(query_results) 158 | print(f"""Done generating hard negatives.""") 159 | return results 160 | 161 | def export_index(self, path: Union[str, Path]) -> bool: 162 | self.voyager_index.save(path) 163 | return True 164 | -------------------------------------------------------------------------------- /ragatouille/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import requests 4 | 5 | 6 | def seeded_shuffle(collection: list, seed: int = 42): 7 | random.seed(seed) 8 | random.shuffle(collection) 9 | return collection 10 | 11 | 12 | def get_wikipedia_page(title: str): 13 | """ 14 | Retrieve the full text content of a Wikipedia page. 15 | 16 | :param title: str - Title of the Wikipedia page. 17 | :return: str - Full text content of the page as raw string. 18 | """ 19 | # Wikipedia API endpoint 20 | URL = "https://en.wikipedia.org/w/api.php" 21 | 22 | # Parameters for the API request 23 | params = { 24 | "action": "query", 25 | "format": "json", 26 | "titles": title, 27 | "prop": "extracts", 28 | "explaintext": True, 29 | } 30 | 31 | # Custom User-Agent header to comply with Wikipedia's best practices 32 | headers = {"User-Agent": "RAGatouille_tutorial/0.0.1 (ben@clavie.eu)"} 33 | 34 | response = requests.get(URL, params=params, headers=headers) 35 | data = response.json() 36 | 37 | # Extracting page content 38 | page = next(iter(data["query"]["pages"].values())) 39 | return page["extract"] if "extract" in page else None 40 | -------------------------------------------------------------------------------- /requirements-doc.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material 3 | yarl 4 | cairosvg 5 | mkdocstrings 6 | mkdocstrings-python 7 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/RAGatouille/e75b8a964a870dea886548f78da1900804749040/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/Toei_Animation_wikipedia.txt: -------------------------------------------------------------------------------- 1 | Toei Animation Co., Ltd. (Japanese: 東映アニメーション株式会社, Hepburn: Tōei Animēshon Kabushiki-gaisha, ) is a Japanese animation studio primarily controlled by its namesake Toei Company. It has produced numerous series, including Sally the Witch, GeGeGe no Kitarō, Mazinger Z, Galaxy Express 999, Cutie Honey, Dr. Slump, Dragon Ball, Saint Seiya, Sailor Moon, Slam Dunk, Digimon, One Piece, Toriko, World Trigger, The Transformers (between 1984 and 1990, including several Japanese exclusive productions), and the Pretty Cure series. 2 | 3 | 4 | == History == 5 | The studio was founded by animators Kenzō Masaoka and Zenjirō Yamamoto in 1948 as Japan Animated Films (日本動画映画, Nihon Dōga Eiga, often shortened to 日動映画 (Nichidō Eiga)). In 1956, Toei purchased the studio and it was renamed Toei Doga Co., Ltd. (東映動画株式会社, Tōei Dōga Kabushiki-gaisha, "dōga" is Japanese for "animation"), doing business as Toei Animation Co., Ltd. outside Japan. In 1998, the Japanese name was renamed to Toei Animation. It has created a number of TV series and movies and adapted Japanese comics as animated series, many popular worldwide. Hayao Miyazaki, Isao Takahata, Yasuji Mori, Leiji Matsumoto and Yōichi Kotabe have worked with the company. Toei is a shareholder in the Japanese anime satellite television network Animax with other anime studios and production companies, such as Sunrise, TMS Entertainment and Nihon Ad Systems Inc. The company is headquartered in the Ohizumi Studio in Nerima, Tokyo.Their mascot is the cat Pero, from the company's 1969 film adaptation of Puss in Boots. 6 | Toei Animation produced anime versions of works from manga series by manga artists, including Go Nagai (Mazinger Z), Eiichiro Oda (One Piece), Shotaro Ishinomori (Cyborg 009), Mitsutoshi Shimabukuro (Toriko), Takehiko Inoue (Slam Dunk), Mitsuteru Yokoyama (Sally the Witch), Masami Kurumada (Saint Seiya), Akira Toriyama (Dragon Ball and Dr. Slump), Leiji Matsumoto (Galaxy Express 999), and Naoko Takeuchi (Sailor Moon). The studio helped propel the popularity of the Magical Girl and Super Robot genres of anime; Toei's TV series include the first magical-girl anime series, Mahoutsukai Sally (an adaptation of Mitsuteru Yokoyama's manga of the same name), and Go Nagai's Mazinger Z, an adaptation of his manga which set the standard for Super Robot anime. Although the Toei Company usually contracts Toei Animation to handle its animation internally, they occasionally hire other companies to provide animation; although the Toei Company produced the Robot Romance Trilogy, Sunrise (then known as Nippon Sunrise) provided the animation. Toei Company would also enlist the help of other studios such as hiring Academy Productions to produce the animation for Space Emperor God Sigma, rather than use their own studio. 7 | Toei Animation's anime which have won the Animage Anime Grand Prix award are Galaxy Express 999 in 1981, Saint Seiya in 1987 and Sailor Moon in 1992. In addition to producing anime for release in Japan, Toei Animation began providing animation for American films and television series during the 1960s and particularly during the 1980s. 8 | In October 2021, Toei Animation announced that they had signed a strategic partnership with the South Korean entertainment conglomerate CJ ENM. 9 | 10 | 11 | === 2022 ransomware attack === 12 | On March 6, 2022, an incident occurred in which an unauthorized third party attempted to hack Toei Animation's network, which resulted in the company's online store and internal systems becoming temporarily suspended. The company investigated the incident and stated that the hack would affect the broadcast schedules of several anime series, including One Piece. In addition, Dragon Ball Super: Super Hero was also rescheduled to June 11, 2022, due to the hack. On April 6, 2022, Toei Animation announced that it would resume broadcasting the anime series, including One Piece. The following day, the Japanese public broadcaster NHK reported that the hack was caused by a targeted ransomware attack. 13 | 14 | 15 | == Subsidiaries == 16 | 17 | 18 | == Currently in production == 19 | 20 | 21 | == TV animation == 22 | 23 | 24 | === 1960–69 === 25 | 26 | 27 | === 1970–79 === 28 | 29 | 30 | === 1980–89 === 31 | 32 | 33 | === 1990–99 === 34 | 35 | 36 | === 2000–09 === 37 | 38 | 39 | === 2010–19 === 40 | 41 | 42 | === 2020–present === 43 | 44 | 45 | == Television films and specials == 46 | 47 | 48 | == Theatrical films == 49 | 50 | 51 | == CGI films == 52 | 53 | 54 | == Original video animation and original net animation == 55 | 56 | 57 | == Video game animation == 58 | 59 | 60 | == Video game development == 61 | 62 | 63 | == Dubbing == 64 | Animated productions by foreign studios dubbed in Japanese by Toei are The Mystery of the Third Planet (1981 Russian film, dubbed in 2008); Les Maîtres du temps (1982 French-Hungarian film, dubbed in 2014), Alice's Birthday (2009 Russian film, dubbed in 2013) and Becca's Bunch (2018 television series, dubbed in 2021 to 2022). 65 | 66 | 67 | == Foreign Production History == 68 | Toei has been commissioned to provide animation by Japanese and American studios such as Sunbow Entertainment, Marvel Productions, Hanna-Barbera, DIC Entertainment, Rankin/Bass Productions and World Events Productions (DreamWorks Animation). In the 60's, they primarily worked with Rankin/Bass, but beginning in the 80's, they worked with Marvel Productions and their list of clients grew, until the end of the decade. Toei didn't provide much outsourced animation work in the 90's and since the 2000s has only rarely worked with other companies outside Japan. 69 | 70 | 71 | == Controversies == 72 | 73 | 74 | === Fair use disputes === 75 | Between 2008 and 2018, Toei Animation had copyright claimed TeamFourStar's parody series, DragonBall Z Abridged. TFS stated that the parody series is protected under fair use.On December 7, 2021, Toei Animation copyright claimed over 150 videos by YouTuber Totally Not Mark, real name Mark Fitzpatrick. He uploaded a video addressing the issue, claiming that they were protected under fair use, and that nine of the videos do not include any Toei footage. He also outlined the appeal process on YouTube, and estimated having the videos reinstated could take over 37 years. He then goes on to announce that he would not be supporting new Toei releases until the issue had been resolved, and also called for a boycott on the upcoming Dragon Ball Super: Super Hero film. The dispute sparked discussion on YouTube on the vulnerability of creators against the copyright system and lack of fair use laws in Japan, with YouTubers such as PewDiePie and The Anime Man speaking out on the issue.On January 26, 2022, Fitzpatrick had his videos reinstated after negotiations with YouTube. 76 | 77 | 78 | === Treatment of employees === 79 | On January 20, 2021, two employees have accused Toei Animation of overworking their employees and discrimination towards sexual minorities. The company had inappropriately referred to employees who identifies as X-gender (a non-binary identity in Japan). 80 | 81 | 82 | == See also == 83 | SynergySP, Studio Junio and Hal Film Maker/Yumeta Company, animation studios founded by former Toei animators. 84 | Topcraft, an animation studio founded by former Toei Animation producer Toru Hara. 85 | Studio Ghibli, an animation studio founded by former Toei animators Hayao Miyazaki and Isao Takahata. 86 | Mushi Production, an animation studio founded by Osamu Tezuka and former Toei animators. 87 | Shin-Ei Animation, formally A Production, an animation studio founded by former Toei animator Daikichirō Kusube. 88 | Yamamura Animation, an animation studio founded by former Toei animator Kōji Yamamura. 89 | Doga Kobo, an animation studio formed by former Toei animators Hideo Furusawa and Megumu Ishiguro. 90 | 91 | 92 | == References == 93 | 94 | 95 | == External links == 96 | 97 | Official website (in English) 98 | Toei Animation at Anime News Network's encyclopedia 99 | Toei Animation at IMDb -------------------------------------------------------------------------------- /tests/e2e/test_e2e_indexing_searching.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import srsly 3 | 4 | from ragatouille import RAGPretrainedModel 5 | from ragatouille.utils import get_wikipedia_page 6 | 7 | 8 | def test_indexing(): 9 | RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") 10 | with open("tests/data/miyazaki_wikipedia.txt", "r") as f: 11 | full_document = f.read() 12 | RAG.index( 13 | collection=[full_document], 14 | index_name="Miyazaki", 15 | max_document_length=180, 16 | split_documents=True, 17 | ) 18 | # ensure collection is stored to disk 19 | collection = srsly.read_json( 20 | ".ragatouille/colbert/indexes/Miyazaki/collection.json" 21 | ) 22 | assert len(collection) > 1 23 | 24 | 25 | def test_search(): 26 | RAG = RAGPretrainedModel.from_index(".ragatouille/colbert/indexes/Miyazaki/") 27 | k = 3 # How many documents you want to retrieve, defaults to 10, we set it to 3 here for readability 28 | results = RAG.search(query="What animation studio did Miyazaki found?", k=k) 29 | assert len(results) == k 30 | assert ( 31 | "In April 1984, Miyazaki opened his own office in Suginami Ward" 32 | in results[0]["content"] 33 | ) 34 | assert ( 35 | "Hayao Miyazaki (宮崎 駿 or 宮﨑 駿, Miyazaki Hayao, [mijaꜜzaki hajao]; born January 5, 1941)" # noqa 36 | in results[1]["content"] 37 | ) 38 | assert ( 39 | 'Glen Keane said Miyazaki is a "huge influence" on Walt Disney Animation Studios and has been' # noqa 40 | in results[2]["content"] 41 | ) 42 | 43 | all_results = RAG.search( 44 | query=["What animation studio did Miyazaki found?", "Miyazaki son name"], k=k 45 | ) 46 | assert ( 47 | "In April 1984, Miyazaki opened his own office in Suginami Ward" 48 | in all_results[0][0]["content"] 49 | ) 50 | assert ( 51 | "Hayao Miyazaki (宮崎 駿 or 宮﨑 駿, Miyazaki Hayao, [mijaꜜzaki hajao]; born January 5, 1941)" # noqa 52 | in all_results[0][1]["content"] 53 | ) 54 | assert ( 55 | 'Glen Keane said Miyazaki is a "huge influence" on Walt Disney Animation Studios and has been' # noqa 56 | in all_results[0][2]["content"] 57 | ) 58 | assert ( 59 | "== Early life ==\nHayao Miyazaki was born on January 5, 1941" 60 | in all_results[1][0]["content"] # noqa 61 | ) 62 | assert ( 63 | "Directed by Isao Takahata, with whom Miyazaki would continue to collaborate for the remainder of his career" # noqa 64 | in all_results[1][1]["content"] 65 | ) 66 | actual = all_results[1][2]["content"] 67 | assert ( 68 | "Specific works that have influenced Miyazaki include Animal Farm (1945)" 69 | in actual 70 | or "She met with Suzuki" in actual 71 | ) 72 | print(all_results) 73 | 74 | 75 | @pytest.mark.skip(reason="experimental feature.") 76 | def test_basic_CRUD_addition(): 77 | old_collection = srsly.read_json( 78 | ".ragatouille/colbert/indexes/Miyazaki/collection.json" 79 | ) 80 | old_collection_len = len(old_collection) 81 | path_to_index = ".ragatouille/colbert/indexes/Miyazaki/" 82 | RAG = RAGPretrainedModel.from_index(path_to_index) 83 | 84 | new_documents = get_wikipedia_page("Studio_Ghibli") 85 | 86 | RAG.add_to_index([new_documents]) 87 | new_collection = srsly.read_json( 88 | ".ragatouille/colbert/indexes/Miyazaki/collection.json" 89 | ) 90 | assert len(new_collection) > old_collection_len 91 | assert len(new_collection) == 140 92 | -------------------------------------------------------------------------------- /tests/test_pretrained_loading.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.skip(reason="NotImplemented") 5 | def test_from_checkpoint(): 6 | pass 7 | 8 | 9 | @pytest.mark.skip(reason="NotImplemented") 10 | def test_from_index(): 11 | pass 12 | 13 | 14 | @pytest.mark.skip(reason="NotImplemented") 15 | def test_searcher(): 16 | pass 17 | -------------------------------------------------------------------------------- /tests/test_pretrained_optional_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import srsly 5 | 6 | from ragatouille import RAGPretrainedModel 7 | 8 | collection = [ 9 | "Hayao Miyazaki (宮崎 駿 or 宮﨑 駿, Miyazaki Hayao, [mijaꜜzaki hajao]; born January 5, 1941) is a Japanese animator, filmmaker, and manga artist. A co-founder of Studio Ghibli, he has attained international acclaim as a masterful storyteller and creator of Japanese animated feature films, and is widely regarded as one of the most accomplished filmmakers in the history of animation.\nBorn in Tokyo City in the Empire of Japan, Miyazaki expressed interest in manga and animation from an early age, and he joined Toei Animation in 1963. During his early years at Toei Animation he worked as an in-between artist and later collaborated with director Isao Takahata. Notable films to which Miyazaki contributed at Toei include Doggie March and Gulliver's Travels Beyond the Moon. He provided key animation to other films at Toei, such as Puss in Boots and Animal Treasure Island, before moving to A-Pro in 1971, where he co-directed Lupin the Third Part I alongside Takahata. After moving to Zuiyō Eizō (later known as Nippon Animation) in 1973, Miyazaki worked as an animator on World Masterpiece Theater, and directed the television series Future Boy Conan (1978). He joined Tokyo Movie Shinsha in 1979 to direct his first feature film The Castle of Cagliostro as well as the television series Sherlock Hound. In the same period, he also began writing and illustrating the manga Nausicaä of the Valley of the Wind (1982–1994), and he also directed the 1984 film adaptation produced by Topcraft.\nMiyazaki co-founded Studio Ghibli in 1985. He directed numerous films with Ghibli, including Laputa: Castle in the Sky (1986), My Neighbor Totoro (1988), Kiki's Delivery Service (1989), and Porco Rosso (1992). The films were met with critical and commercial success in Japan. Miyazaki's film Princess Mononoke was the first animated film ever to win the Japan Academy Prize for Picture of the Year, and briefly became the highest-grossing film in Japan following its release in 1997; its distribution to the Western world greatly increased Ghibli's popularity and influence outside Japan. His 2001 film Spirited Away became the highest-grossing film in Japanese history, winning the Academy Award for Best Animated Feature, and is frequently ranked among the greatest films of the 21st century. Miyazaki's later films—Howl's Moving Castle (2004), Ponyo (2008), and The Wind Rises (2013)—also enjoyed critical and commercial success.", 10 | "Studio Ghibli, Inc. (Japanese: 株式会社スタジオジブリ, Hepburn: Kabushiki gaisha Sutajio Jiburi) is a Japanese animation studio based in Koganei, Tokyo. It has a strong presence in the animation industry and has expanded its portfolio to include various media formats, such as short subjects, television commercials, and two television films. Their work has been well-received by audiences and recognized with numerous awards. Their mascot and most recognizable symbol, the character Totoro from the 1988 film My Neighbor Totoro, is a giant spirit inspired by raccoon dogs (tanuki) and cats (neko). Among the studio's highest-grossing films are Spirited Away (2001), Howl's Moving Castle (2004), and Ponyo (2008). Studio Ghibli was founded on June 15, 1985, by the directors Hayao Miyazaki and Isao Takahata and producer Toshio Suzuki, after acquiring Topcraft's assets. The studio has also collaborated with video game studios on the visual development of several games.Five of the studio's films are among the ten highest-grossing anime feature films made in Japan. Spirited Away is second, grossing 31.68 billion yen in Japan and over US$380 million worldwide, and Princess Mononoke is fourth, grossing 20.18 billion yen. Three of their films have won the Animage Grand Prix award, four have won the Japan Academy Prize for Animation of the Year, and five have received Academy Award nominations. Spirited Away won the 2002 Golden Bear and the 2003 Academy Award for Best Animated Feature.On August 3, 2014, Studio Ghibli temporarily suspended production following Miyazaki's retirement.", 11 | ] 12 | 13 | document_ids = ["miyazaki", "ghibli"] 14 | 15 | document_metadatas = [ 16 | {"entity": "person", "source": "wikipedia"}, 17 | {"entity": "organisation", "source": "wikipedia"}, 18 | ] 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def persistent_temp_index_root(tmp_path_factory): 23 | return tmp_path_factory.mktemp("temp_test_indexes") 24 | 25 | 26 | @pytest.fixture(scope="session") 27 | def RAG_from_pretrained_model(persistent_temp_index_root): 28 | return RAGPretrainedModel.from_pretrained( 29 | "colbert-ir/colbertv2.0", index_root=str(persistent_temp_index_root) 30 | ) 31 | 32 | 33 | @pytest.fixture(scope="session") 34 | def index_path_fixture(persistent_temp_index_root, index_creation_inputs): 35 | index_path = os.path.join( 36 | str(persistent_temp_index_root), 37 | "colbert", 38 | "indexes", 39 | index_creation_inputs["index_name"], 40 | ) 41 | return str(index_path) 42 | 43 | 44 | @pytest.fixture(scope="session") 45 | def collection_path_fixture(index_path_fixture): 46 | collection_path = os.path.join(index_path_fixture, "collection.json") 47 | return str(collection_path) 48 | 49 | 50 | @pytest.fixture(scope="session") 51 | def document_metadata_path_fixture(index_path_fixture): 52 | document_metadata_path = os.path.join(index_path_fixture, "docid_metadata_map.json") 53 | return str(document_metadata_path) 54 | 55 | 56 | @pytest.fixture(scope="session") 57 | def pid_docid_map_path_fixture(index_path_fixture): 58 | pid_docid_map_path = os.path.join(index_path_fixture, "pid_docid_map.json") 59 | return str(pid_docid_map_path) 60 | 61 | 62 | @pytest.fixture( 63 | scope="session", 64 | params=[ 65 | { 66 | "collection": collection, 67 | "index_name": "no_optional_args", 68 | "split_documents": False, 69 | }, 70 | { 71 | "collection": collection, 72 | "document_ids": document_ids, 73 | "index_name": "with_docid", 74 | "split_documents": False, 75 | }, 76 | { 77 | "collection": collection, 78 | "document_metadatas": document_metadatas, 79 | "index_name": "with_metadata", 80 | "split_documents": False, 81 | }, 82 | { 83 | "collection": collection, 84 | "index_name": "with_split", 85 | "split_documents": True, 86 | }, 87 | { 88 | "collection": collection, 89 | "document_ids": document_ids, 90 | "document_metadatas": document_metadatas, 91 | "index_name": "with_docid_metadata", 92 | "split_documents": False, 93 | }, 94 | { 95 | "collection": collection, 96 | "document_ids": document_ids, 97 | "index_name": "with_docid_split", 98 | "split_documents": True, 99 | }, 100 | { 101 | "collection": collection, 102 | "document_metadatas": document_metadatas, 103 | "index_name": "with_metadata_split", 104 | "split_documents": True, 105 | }, 106 | { 107 | "collection": collection, 108 | "document_ids": document_ids, 109 | "document_metadatas": document_metadatas, 110 | "index_name": "with_docid_metadata_split", 111 | "split_documents": True, 112 | }, 113 | ], 114 | ids=[ 115 | "No optional arguments", 116 | "With document IDs", 117 | "With metadata", 118 | "With document splitting", 119 | "With document IDs and metadata", 120 | "With document IDs and splitting", 121 | "With metadata and splitting", 122 | "With document IDs, metadata, and splitting", 123 | ], 124 | ) 125 | def index_creation_inputs(request): 126 | params = request.param 127 | return params 128 | 129 | 130 | @pytest.fixture(scope="session") 131 | def create_index(RAG_from_pretrained_model, index_creation_inputs): 132 | index_path = RAG_from_pretrained_model.index(**index_creation_inputs) 133 | return index_path 134 | 135 | 136 | def test_index_creation(create_index): 137 | assert os.path.exists(create_index) == True 138 | 139 | 140 | @pytest.fixture(scope="session", autouse=True) 141 | def add_docids_to_index_inputs( 142 | create_index, # noqa: ARG001 143 | index_creation_inputs, 144 | pid_docid_map_path_fixture, 145 | ): 146 | if "document_ids" not in index_creation_inputs: 147 | pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture) 148 | seen_ids = set() 149 | index_creation_inputs["document_ids"] = [ 150 | x 151 | for x in list(pid_docid_map_data.values()) 152 | if not (x in seen_ids or seen_ids.add(x)) 153 | ] 154 | 155 | 156 | def test_collection_creation(collection_path_fixture): 157 | assert os.path.exists(collection_path_fixture) == True 158 | collection_data = srsly.read_json(collection_path_fixture) 159 | assert isinstance( 160 | collection_data, list 161 | ), "The collection.json file should contain a list." 162 | 163 | 164 | def test_pid_docid_map_creation(pid_docid_map_path_fixture): 165 | assert os.path.exists(pid_docid_map_path_fixture) == True 166 | # TODO check pid_docid_map_data 167 | pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture) 168 | assert isinstance( 169 | pid_docid_map_data, dict 170 | ), "The pid_docid_map.json file should contain a dictionary." 171 | 172 | 173 | def test_document_metadata_creation( 174 | index_creation_inputs, document_metadata_path_fixture 175 | ): 176 | if "document_metadatas" in index_creation_inputs: 177 | assert os.path.exists(document_metadata_path_fixture) == True 178 | document_metadata_dict = srsly.read_json(document_metadata_path_fixture) 179 | assert ( 180 | set(document_metadata_dict.keys()) 181 | == set(index_creation_inputs["document_ids"]) 182 | ), "The keys in document_metadata.json should match the document_ids provided for index creation." 183 | for doc_id, metadata in document_metadata_dict.items(): 184 | assert ( 185 | metadata 186 | == index_creation_inputs["document_metadatas"][ 187 | index_creation_inputs["document_ids"].index(doc_id) 188 | ] 189 | ), f"The metadata for document_id {doc_id} should match the provided metadata." 190 | else: 191 | assert os.path.exists(document_metadata_path_fixture) == False 192 | 193 | 194 | def test_document_metadata_returned_in_search_results( 195 | index_creation_inputs, index_path_fixture 196 | ): 197 | RAG = RAGPretrainedModel.from_index(index_path_fixture) 198 | results = RAG.search( 199 | "when was miyazaki born", index_name=index_creation_inputs["index_name"] 200 | ) 201 | 202 | if "document_metadatas" in index_creation_inputs: 203 | for result in results: 204 | assert ( 205 | "document_metadata" in result 206 | ), "The metadata should be returned in the results." 207 | doc_id = result["document_id"] 208 | expected_metadata = index_creation_inputs["document_metadatas"][ 209 | index_creation_inputs["document_ids"].index(doc_id) 210 | ] 211 | assert ( 212 | result["document_metadata"] == expected_metadata 213 | ), f"The metadata for document_id {doc_id} should match the provided metadata." 214 | 215 | else: 216 | for result in results: 217 | assert ( 218 | "metadata" not in result 219 | ), "The metadata should not be returned in the results." 220 | 221 | 222 | # TODO: move this to a separate CRUD test file 223 | # TODO: add checks for metadata and doc content 224 | def test_add_to_existing_index( 225 | index_creation_inputs, 226 | document_metadata_path_fixture, 227 | pid_docid_map_path_fixture, 228 | index_path_fixture, 229 | ): 230 | RAG = RAGPretrainedModel.from_index(index_path_fixture) 231 | existing_doc_ids = index_creation_inputs["document_ids"] 232 | new_doc_ids = ["mononoke", "sabaku_no_tami"] 233 | new_docs = [ 234 | "Princess Mononoke (Japanese: もののけ姫, Hepburn: Mononoke-hime) is a 1997 Japanese animated epic historical fantasy film written and directed by Hayao Miyazaki and animated by Studio Ghibli for Tokuma Shoten, Nippon Television Network and Dentsu. The film stars the voices of Yōji Matsuda, Yuriko Ishida, Yūko Tanaka, Kaoru Kobayashi, Masahiko Nishimura, Tsunehiko Kamijo, Akihiro Miwa, Mitsuko Mori, and Hisaya Morishige.\nPrincess Mononoke is set in the late Muromachi period of Japan (approximately 1336 to 1573 AD) and includes fantasy elements. The story follows a young Emishi prince named Ashitaka, and his involvement in a struggle between the gods (kami) of a forest and the humans who consume its resources. The film deals with themes of Shinto and environmentalism.\nThe film was released in Japan on July 12, 1997, by Toho, and in the United States on October 29, 1999. This was the first Studio Ghibli film in the United States to be rated PG-13 by the MPA. It was a critical and commercial blockbuster, becoming the highest-grossing film in Japan of 1997, and also held Japan's box office record for domestic films until 2001's Spirited Away, another Miyazaki film. It was dubbed into English with a script by Neil Gaiman and initially distributed in North America by Miramax, where it sold well on home media despite not performing strongly at the box office. The film greatly increased Ghibli's popularity and influence outside Japan.", 235 | "People of the Desert (砂漠の民, Sabaku no Tami, translated on the cover as The People of Desert), or The Desert Tribe, is a comic strip written and illustrated by Hayao Miyazaki. It was serialized, under the pseudonym Akitsu Saburō (秋津三朗), and ran in Boys and Girls Newspaper (少年少女新聞, Shōnen Shōjo Shinbun) between September 12, 1969, and March 15, 1970.\n\n\n== Story ==\nThe story is set in the distant past, on the fictionalised desert plains of Central Asia. Part of the story takes place in the fortified city named Pejite (ペジテ). The story follows the exploits of the main character, Tem (テム, Temu), a shepherd boy of the fictional Sokut (ソクート, Sokūto) tribe, as he tries to evade the mounted militia of the nomadic Kittāru (キッタール) tribe. In order to restore peace to the realm, Tem rallies his remaining compatriots and rebels against the Kittāru's attempts to gain control of the Sokut territory and enslave its inhabitants through military force.\n\n\n== Background, publication and influences ==\nMiyazaki initially wanted to become a manga artist but started his professional career as an animator for Toei Animation in 1963. Here he worked on animated television series and animated feature-length films for theatrical release. He never abandoned his childhood dream of becoming a manga artist completely, however, and his professional debut as a manga creator came in 1969 with the publication of his manga interpretation of Puss 'n Boots, which was serialized in 12 weekly instalments in the Sunday edition of Tokyo Shimbun, from January to March 1969. Printed in colour and created for promotional purposes in conjunction with his work on Toei's animated film of the same title, directed by Kimio Yabuki.\nIn 1969 pseudonymous serialization also started of Miyazaki's original manga People of the Desert (砂漠の民, Sabaku no Tami). This strip was created in the style of illustrated stories (絵物語, emonogatari) he read in boys' magazines and tankōbon volumes while growing up, such as Soji Yamakawa's Shōnen Ōja (少年王者) and in particular Tetsuji Fukushima's Evil Lord of the Desert (沙漠の魔王, Sabaku no Maō). Miyazaki's People of the Desert is a continuation of that tradition. In People of the Desert expository text is presented separately from the monochrome artwork but Miyazaki progressively used additional text balloons inside the panels for dialogue.\nPeople of the Desert was serialized in 26 weekly instalments which were printed in Boys and Girls Newspaper (少年少女新聞, Shōnen shōjo shinbun), a publication of the Japanese Communist Party, between September 12, 1969 (issue 28) and March 15, 1970 (issue 53). The strip was published under the pseudonym Akitsu Saburō (秋津三朗).\nThe strip has been identified as a precursor for Miyazaki's manga Nausicaä of the Valley of the Wind (1982–1995) and the one-off graphic novel Shuna's Journey (1983), published by Tokuma Shoten.", 236 | ] 237 | new_doc_metadata = [ 238 | {"entity": "film", "source": "wikipedia"}, 239 | {"entity": "manga", "source": "wikipedia"}, 240 | ] 241 | RAG.add_to_index( 242 | new_collection=new_docs, 243 | new_document_ids=new_doc_ids, 244 | new_document_metadatas=new_doc_metadata, 245 | index_name=index_creation_inputs["index_name"], 246 | ) 247 | pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture) 248 | document_ids = set(list(pid_docid_map_data.values())) 249 | 250 | document_metadata_dict = srsly.read_json(document_metadata_path_fixture) 251 | # check for new docs 252 | for new_doc_id in new_doc_ids: 253 | assert ( 254 | new_doc_id in document_ids 255 | ), f"New document ID '{new_doc_id}' should be in the pid_docid_map's document_ids:{document_ids}." 256 | 257 | assert ( 258 | new_doc_id in document_metadata_dict 259 | ), f"New document ID '{new_doc_id}' should be in the document metadata keys:{document_metadata_dict.keys}." 260 | 261 | for existing_doc_id in existing_doc_ids: 262 | assert ( 263 | existing_doc_id in document_ids 264 | ), f"Old document ID '{existing_doc_id}' should be in the pid_docid_map's document_ids:{document_ids}." 265 | 266 | if "document_metadatas" in index_creation_inputs: 267 | assert ( 268 | existing_doc_id in document_metadata_dict 269 | ), f"Old document ID '{existing_doc_id}' should be in the document metadata keys:{document_metadata_dict.keys}." 270 | 271 | 272 | # TODO: move this to a separate CRUD test file 273 | def test_delete_from_index( 274 | index_creation_inputs, 275 | pid_docid_map_path_fixture, 276 | document_metadata_path_fixture, 277 | index_path_fixture, 278 | ): 279 | RAG = RAGPretrainedModel.from_index(index_path_fixture) 280 | deleted_doc_id = index_creation_inputs["document_ids"][0] 281 | original_doc_ids = set(index_creation_inputs["document_ids"]) 282 | RAG.delete_from_index( 283 | index_name=index_creation_inputs["index_name"], 284 | document_ids=[deleted_doc_id], 285 | ) 286 | pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture) 287 | updated_document_ids = set(list(pid_docid_map_data.values())) 288 | 289 | assert ( 290 | deleted_doc_id not in updated_document_ids 291 | ), f"Deleted document ID '{deleted_doc_id}' should not be in the pid_docid_map's document_ids: {updated_document_ids}." 292 | 293 | assert ( 294 | original_doc_ids - updated_document_ids == {deleted_doc_id} 295 | ), f"Only the deleted document ID '{deleted_doc_id}' should be missing from the pid_docid_map's document_ids: {updated_document_ids}." 296 | 297 | if "document_metadatas" in index_creation_inputs: 298 | document_metadata_dict = srsly.read_json(document_metadata_path_fixture) 299 | assert ( 300 | deleted_doc_id not in document_metadata_dict 301 | ), f"Deleted document ID '{deleted_doc_id}' should not be in the document metadata: {document_metadata_dict.keys}." 302 | assert ( 303 | original_doc_ids - set(document_metadata_dict.keys()) == {deleted_doc_id} 304 | ), f"Only the deleted document ID '{deleted_doc_id}' should be missing from the document metadata: {document_metadata_dict.keys}." 305 | -------------------------------------------------------------------------------- /tests/test_trainer_loading.py: -------------------------------------------------------------------------------- 1 | from colbert.infra import ColBERTConfig 2 | 3 | from ragatouille import RAGTrainer 4 | 5 | 6 | def test_finetune(): 7 | """Ensure that the initially loaded config is the one from the pretrained model.""" 8 | trainer = RAGTrainer( 9 | model_name="test", 10 | pretrained_model_name="colbert-ir/colbertv2.0", 11 | language_code="en", 12 | ) 13 | trainer_config = trainer.model.config 14 | 15 | assert ColBERTConfig() != trainer_config 16 | assert trainer_config.query_token == "[Q]" 17 | assert trainer_config.doc_token == "[D]" 18 | assert trainer_config.nbits == 1 19 | assert trainer_config.kmeans_niters == 20 20 | assert trainer_config.lr == 1e-05 21 | assert trainer_config.relu is False 22 | assert trainer_config.nway == 64 23 | assert trainer_config.doc_maxlen == 180 24 | assert trainer_config.use_ib_negatives is True 25 | assert trainer_config.name == "kldR2.nway64.ib" 26 | 27 | 28 | def test_raw_model(): 29 | """Ensure that the default ColBERT configuration is properly loaded when initialising from a BERT-like model""" # noqa: E501 30 | trainer = RAGTrainer( 31 | model_name="test", 32 | pretrained_model_name="bert-base-uncased", 33 | language_code="en", 34 | ) 35 | trainer_config = trainer.model.config 36 | 37 | default_config = ColBERTConfig() 38 | 39 | assert trainer_config.query_token == default_config.query_token 40 | assert trainer_config.doc_token == default_config.doc_token 41 | assert trainer_config.nway == default_config.nway 42 | assert trainer_config.doc_maxlen == default_config.doc_maxlen 43 | assert trainer_config.bsize == default_config.bsize 44 | assert trainer_config.use_ib_negatives == default_config.use_ib_negatives 45 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | import signal 4 | from contextlib import contextmanager 5 | 6 | import pytest 7 | import torch 8 | 9 | from ragatouille import RAGTrainer 10 | from ragatouille.data import CorpusProcessor, llama_index_sentence_splitter 11 | 12 | DATA_DIR = pathlib.Path(__file__).parent / "data" 13 | 14 | 15 | skip_if_no_cuda = pytest.mark.skipif( 16 | not torch.cuda.is_available(), 17 | reason="Skip test. Training currently only works when CUDA is available", 18 | ) 19 | 20 | 21 | class TimeoutException(Exception): 22 | pass 23 | 24 | 25 | @contextmanager 26 | def time_limit(seconds): 27 | """Time limit context manager as in https://stackoverflow.com/a/601168""" 28 | 29 | def signal_handler(_, __): 30 | raise TimeoutException("Timed out!") 31 | 32 | signal.signal(signal.SIGALRM, signal_handler) 33 | signal.alarm(seconds) 34 | try: 35 | yield 36 | finally: 37 | signal.alarm(0) 38 | 39 | 40 | @skip_if_no_cuda 41 | @pytest.mark.slow 42 | def test_training(tmp_path): 43 | """This test is based on the content of examples/02-basic_training.ipynb 44 | and mainly tests that there are no exceptions which can happen e.g. due 45 | to bugs in data processing. 46 | """ 47 | trainer = RAGTrainer( 48 | model_name="GhibliColBERT", 49 | pretrained_model_name="colbert-ir/colbertv2.0", 50 | language_code="en", 51 | ) 52 | pages = ["miyazaki", "Studio_Ghibli", "Toei_Animation"] 53 | my_full_corpus = [(DATA_DIR / f"{p}_wikipedia.txt").open().read() for p in pages] 54 | 55 | corpus_processor = CorpusProcessor( 56 | document_splitter_fn=llama_index_sentence_splitter 57 | ) 58 | documents = corpus_processor.process_corpus(my_full_corpus, chunk_size=256) 59 | 60 | queries = [ 61 | "What manga did Hayao Miyazaki write?", 62 | "which film made ghibli famous internationally", 63 | "who directed Spirited Away?", 64 | "when was Hikotei Jidai published?", 65 | "where's studio ghibli based?", 66 | "where is the ghibli museum?", 67 | ] 68 | pairs = [] 69 | 70 | for query in queries: 71 | fake_relevant_docs = random.sample(documents, 10) 72 | for doc in fake_relevant_docs: 73 | pairs.append((query, doc)) 74 | trainer.prepare_training_data( 75 | raw_data=pairs, 76 | data_out_path=str(tmp_path), 77 | all_documents=my_full_corpus, 78 | num_new_negatives=10, 79 | mine_hard_negatives=True, 80 | ) 81 | try: 82 | with time_limit(30): 83 | trainer.train( 84 | batch_size=32, 85 | nbits=4, # How many bits will the trained model use when compressing indexes 86 | maxsteps=1, # Maximum steps hard stop 87 | use_ib_negatives=True, # Use in-batch negative to calculate loss 88 | dim=128, # How many dimensions per embedding. 128 is the default and works well. 89 | learning_rate=5e-6, 90 | # Learning rate, small values ([3e-6,3e-5] work best if the base model is BERT-like, 5e-6 is often the sweet spot) 91 | doc_maxlen=256, 92 | # Maximum document length. Because of how ColBERT works, smaller chunks (128-256) work very well. 93 | use_relu=False, # Disable ReLU -- doesn't improve performance 94 | warmup_steps="auto", # Defaults to 10% 95 | ) 96 | # Simply test that some of the files generated have really been made. 97 | assert (tmp_path / "corpus.train.colbert.tsv").exists() 98 | except TimeoutException as e: 99 | print("Timed out!") 100 | raise AssertionError("Timout in training") from None 101 | 102 | 103 | if __name__ == "__main__": 104 | test_training() 105 | -------------------------------------------------------------------------------- /tests/test_training_data_loading.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ragatouille import RAGTrainer 4 | 5 | 6 | @pytest.fixture 7 | def rag_trainer(): 8 | # Setup for a RAGTrainer instance 9 | instance = RAGTrainer( 10 | model_name="test_model", pretrained_model_name="bert-base-uncased" 11 | ) 12 | return instance 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "input_data,pairs_with_labels,expected_queries,expected_collection", 17 | [ 18 | # Unlabeled pairs 19 | ( 20 | [("Query1", "Document1"), ("Query2", "Document2")], 21 | False, 22 | {"Query1", "Query2"}, 23 | ["Document1", "Document2"], 24 | ), 25 | # Labeled pairs 26 | ( 27 | [ 28 | ("Query1", "Document1Pos", 1), 29 | ("Query1", "Document1Neg", 0), 30 | ("Query2", "Document2", 0), 31 | ], 32 | True, 33 | {"Query1", "Query2"}, 34 | ["Document1Pos", "Document1Neg", "Document2"], 35 | ), 36 | # Triplets 37 | ( 38 | [("Query1", "Positive Doc", "Negative Doc")], 39 | False, 40 | {"Query1"}, 41 | ["Positive Doc", "Negative Doc"], 42 | ), 43 | ], 44 | ) 45 | def test_prepare_training_data( 46 | rag_trainer, input_data, pairs_with_labels, expected_queries, expected_collection 47 | ): 48 | rag_trainer.prepare_training_data( 49 | raw_data=input_data, pairs_with_labels=pairs_with_labels 50 | ) 51 | 52 | assert rag_trainer.queries == expected_queries 53 | 54 | assert len(rag_trainer.collection) == len(expected_collection) 55 | assert set(rag_trainer.collection) == set(expected_collection) 56 | 57 | 58 | def test_prepare_training_data_with_all_documents(rag_trainer): 59 | input_data = [("Query1", "Document1")] 60 | all_documents = ["Document2", "Document3"] 61 | 62 | rag_trainer.prepare_training_data(raw_data=input_data, all_documents=all_documents) 63 | 64 | assert rag_trainer.queries == {"Query1"} 65 | 66 | assert len(rag_trainer.collection) == 3 67 | assert set(rag_trainer.collection) == {"Document1", "Document2", "Document3"} 68 | 69 | 70 | def test_prepare_training_data_invalid_input(rag_trainer): 71 | # Providing an invalid input format 72 | with pytest.raises(ValueError): 73 | rag_trainer.prepare_training_data(raw_data=[("Query1")]) # Missing document 74 | -------------------------------------------------------------------------------- /tests/test_training_data_processor.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import pytest 4 | 5 | from ragatouille.data import TrainingDataProcessor 6 | 7 | 8 | @pytest.fixture 9 | def collection(): 10 | return ["doc1", "doc2", "doc3"] 11 | 12 | 13 | @pytest.fixture 14 | def queries(): 15 | return ["query1", "query2"] 16 | 17 | 18 | def test_process_raw_data_without_miner(collection, queries): 19 | processor = TrainingDataProcessor(collection, queries, None) 20 | processor._process_raw_pairs = MagicMock(return_value=None) 21 | 22 | processor.process_raw_data( 23 | raw_data=[], data_type="pairs", data_dir="./", mine_hard_negatives=False 24 | ) 25 | 26 | processor._process_raw_pairs.assert_called_once() 27 | 28 | 29 | def test_process_raw_data_with_miner(collection, queries): 30 | negative_miner = MagicMock() 31 | processor = TrainingDataProcessor(collection, queries, negative_miner) 32 | processor._process_raw_pairs = MagicMock(return_value=None) 33 | 34 | processor.process_raw_data(raw_data=[], data_type="pairs", data_dir="./") 35 | 36 | processor._process_raw_pairs.assert_called_once() 37 | --------------------------------------------------------------------------------