├── .github
└── FUNDING.yml
├── .gitignore
├── LICENSE
├── RAGatouille.png
├── README.md
├── docs
├── api.md
├── index.md
└── roadmap.md
├── examples
├── 01-basic_indexing_and_search.ipynb
├── 02-basic_training.ipynb
├── 03-finetuning_without_annotations_with_instructor_and_RAGatouille.ipynb
├── 04-reranking.ipynb
├── 05-llama_hub.ipynb
├── 06-index_free_use.ipynb
└── data
│ └── llama2.pdf
├── mkdocs.yml
├── pyproject.toml
├── ragatouille
├── RAGPretrainedModel.py
├── RAGTrainer.py
├── __init__.py
├── data
│ ├── __init__.py
│ ├── corpus_processor.py
│ ├── preprocessors.py
│ └── training_data_processor.py
├── integrations
│ ├── __init__.py
│ └── _langchain.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── colbert.py
│ ├── index.py
│ ├── torch_kmeans.py
│ └── utils.py
├── negative_miners
│ ├── __init__.py
│ ├── base.py
│ └── simpleminer.py
└── utils.py
├── requirements-doc.txt
└── tests
├── __init__.py
├── data
├── Studio_Ghibli_wikipedia.txt
├── Toei_Animation_wikipedia.txt
└── miyazaki_wikipedia.txt
├── e2e
└── test_e2e_indexing_searching.py
├── test_pretrained_loading.py
├── test_pretrained_optional_args.py
├── test_trainer_loading.py
├── test_training.py
├── test_training_data_loading.py
└── test_training_data_processor.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: bclavie
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 | .pdm.toml
87 | __pypackages__/
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 | .envrc
98 |
99 | # mkdocs documentation
100 | /site
101 |
102 | .ragatouille
103 |
104 | # mypy
105 | .mypy_cache/
106 | .dmypy.json
107 | dmypy.json
108 |
109 | # Pyre type checker
110 | .pyre/
111 |
112 | # pytype static type analyzer
113 | .pytype/
114 |
115 | # Cython debug symbols
116 | cython_debug/
117 |
118 | # data files
119 | *.tsv
120 | *.jsonl
121 |
122 | .mypy.ipynb_checkpoints
123 | .mkdocs.yml
124 |
125 |
126 | archive/
127 |
128 | */.ragatouille
129 |
130 | local/
131 |
132 | .vscode/
133 |
134 | .devcontainer/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/RAGatouille.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnswerDotAI/RAGatouille/e75b8a964a870dea886548f78da1900804749040/RAGatouille.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Welcome to RAGatouille
2 |
3 | _Easily use and train state of the art retrieval methods in any RAG pipeline. Designed for modularity and ease-of-use, backed by research._
4 |
5 | [](https://github.com/bclavie/ragatouille/stargazers)
6 | 
7 | [](https://pepy.tech/project/ragatouille)
8 | [](https://ben.clavie.eu/ragatouille/)
9 | [](https://twitter.com/bclavie)
10 |
11 |

12 |
13 | ---
14 |
15 | The main motivation of RAGatouille is simple: bridging the gap between state-of-the-art research and alchemical RAG pipeline practices. RAG is complex, and there are many moving parts. To get the best performance, you need to optimise for many components: among them, a very important one is the models you use for retrieval.
16 |
17 | Dense retrieval, i.e. using embeddings such as OpenAI's `text-ada-002`, is a good baseline, but there's a lot of research [showing dense embeddings might not be the](https://arxiv.org/abs/2104.08663) [best fit for **your** usecase](https://arxiv.org/abs/2204.11447).
18 |
19 | The Information Retrieval research field has recently been booming, and models like ColBERT have been shown to [generalise better](https://arxiv.org/abs/2203.10053) [to new or complex domains](https://aclanthology.org/2022.findings-emnlp.78/) [than dense embeddings](https://arxiv.org/abs/2205.02870), are [ridiculously data-efficient](https://arxiv.org/abs/2309.06131) and are even [better suited to efficiently being trained](https://arxiv.org/abs/2312.09508) [on non-English languages with low amount of data](https://arxiv.org/abs/2312.16144)! Unfortunately, most of those new approaches aren't very well known, and are much harder to use than dense embeddings.
20 |
21 | This is where __RAGatouille__ comes in: RAGatouille's purpose is to bridge this gap: make it easy to use state-of-the-art methods in your RAG pipeline, without having to worry about the details or the years of literature! At the moment, RAGatouille focuses on making ColBERT simple to use. If you want to check out what's coming next, you can check out our [broad roadmap](https://ben.clavie.eu/ragatouille/roadmap)!
22 |
23 | _If you want to read more about the motivations, philosophy, and why the late-interaction approach used by ColBERT works so well, check out the [introduction in the docs](https://ben.clavie.eu/ragatouille/)._
24 |
25 |
26 |
27 |
28 | Want to give it a try? Nothing easier, just run `pip install ragatouille` and you're good to go!
29 |
30 | ⚠️ Running notes/requirements: ⚠️
31 |
32 | - If running inside a script, you must run it inside `if __name__ == "__main__"`
33 | - Windows is not supported. RAGatouille doesn't appear to work outside WSL and has issues with WSL1. Some users have had success running RAGatouille in WSL2.
34 |
35 | ## Get Started
36 |
37 | RAGatouille makes it as simple as can be to use ColBERT! We want the library to work on two levels:
38 |
39 | - Strong, but parameterizable defaults: you should be able to get started with just a few lines of code and still leverage the full power of ColBERT, and you should be able to tweak any relevant parameter if you need to!
40 | - Powerful yet simple re-usable components under-the-hood: any part of the library should be usable stand-alone. You can use our DataProcessor or our negative miners outside of `RAGPretrainedModel` and `RagTrainer`, and you can even write your own negative miner and use it in the pipeline if you want to!
41 |
42 |
43 | In this section, we'll quickly walk you through the three core aspects of RAGatouille:
44 |
45 | - [🚀 Training and Fine-Tuning ColBERT models](#-training-and-fine-tuning)
46 | - [🗄️ Embedding and Indexing Documents](#%EF%B8%8F-indexing)
47 | - [🔎 Retrieving documents](#-retrieving-documents)
48 |
49 | ➡️ If you want just want to see fully functional code examples, head over to the [examples](https://github.com/bclavie/RAGatouille/tree/main/examples)⬅️
50 |
51 | ### 🚀 Training and fine-tuning
52 |
53 | _If you're just prototyping, you don't need to train your own model! While finetuning can be useful, one of the strength of ColBERT is that the pretrained models are particularly good at generalisation, and [ColBERTv2](https://huggingface.co/colbert-ir/colbertv2.0) has [repeatedly been shown to be extremely strong](https://arxiv.org/abs/2303.00807) at zero-shot retrieval in new domains!_
54 |
55 | #### Data Processing
56 |
57 | RAGatouille's RAGTrainer has a built-in `TrainingDataProcessor`, which can take most forms of retrieval training data, and automatically convert it to training triplets, with data enhancements. The pipeline works as follows:
58 |
59 | - Accepts pairs, labelled pairs and various forms of triplets as inputs (strings or list of strings) -- transparently!
60 | - Automatically remove all duplicates and maps all positives/negatives to their respective query.
61 | - By default, mine hard negatives: this means generating negatives that are hard to distinguish from positives, and that are therefore more useful for training.
62 |
63 | This is all handled by `RAGTrainer.prepare_training_data()`, and is as easy as doing passing your data to it:
64 |
65 | ```python
66 | from ragatouille import RAGTrainer
67 |
68 | my_data = [
69 | ("What is the meaning of life ?", "The meaning of life is 42"),
70 | ("What is Neural Search?", "Neural Search is a terms referring to a family of ..."),
71 | ...
72 | ] # Unlabelled pairs here
73 | trainer = RAGTrainer()
74 | trainer.prepare_training_data(raw_data=my_data)
75 | ```
76 |
77 | ColBERT prefers to store processed training data on-file, which also makes easier to properly version training data via `wandb` or `dvc`. By default, it will write to `./data/`, but you can override this by passing a `data_out_path` argument to `prepare_training_data()`.
78 |
79 | Just like all things in RAGatouille, `prepare_training_data` uses strong defaults, but is also fully parameterizable.
80 |
81 |
82 | #### Running the Training/Fine-Tuning
83 |
84 | Training and Fine-Tuning follow the exact same process. When you instantiate `RAGTrainer`, you must pass it a `pretrained_model_name`. If this pretrained model is a ColBERT instance, the trainer will be in fine-tuning mode, if it's another kind of transformer, it will be in training mode to begin training a new ColBERT initialised from the model's weights!
85 |
86 |
87 | ```python
88 | from ragatouille import RAGTrainer
89 | from ragatouille.utils import get_wikipedia_page
90 |
91 | pairs = [
92 | ("What is the meaning of life ?", "The meaning of life is 42"),
93 | ("What is Neural Search?", "Neural Search is a terms referring to a family of ..."),
94 | # You need many more pairs to train! Check the examples for more details!
95 | ...
96 | ]
97 |
98 | my_full_corpus = [get_wikipedia_page("Hayao_Miyazaki"), get_wikipedia_page("Studio_Ghibli")]
99 |
100 |
101 | trainer = RAGTrainer(model_name = "MyFineTunedColBERT",
102 | pretrained_model_name = "colbert-ir/colbertv2.0") # In this example, we run fine-tuning
103 |
104 | # This step handles all the data processing, check the examples for more details!
105 | trainer.prepare_training_data(raw_data=pairs,
106 | data_out_path="./data/",
107 | all_documents=my_full_corpus)
108 |
109 | trainer.train(batch_size=32) # Train with the default hyperparams
110 | ```
111 |
112 | When you run `train()`, it'll by default inherit its parent ColBERT hyperparameters if fine-tuning, or use the default training parameters if training a new ColBERT. Feel free to modify them as you see fit (check the example and API reference for more details!)
113 |
114 |
115 | ### 🗄️ Indexing
116 |
117 | To create an index, you'll need to load a trained model, this can be one of your own or a pretrained one from the hub! Creating an index with the default configuration is just a few lines of code:
118 |
119 | ```python
120 | from ragatouille import RAGPretrainedModel
121 | from ragatouille.utils import get_wikipedia_page
122 |
123 | RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
124 | my_documents = [get_wikipedia_page("Hayao_Miyazaki"), get_wikipedia_page("Studio_Ghibli")]
125 | index_path = RAG.index(index_name="my_index", collection=my_documents)
126 | ```
127 | You can also optionally add document IDs or document metadata when creating the index:
128 |
129 | ```python
130 | document_ids = ["miyazaki", "ghibli"]
131 | document_metadatas = [
132 | {"entity": "person", "source": "wikipedia"},
133 | {"entity": "organisation", "source": "wikipedia"},
134 | ]
135 | index_path = RAG.index(
136 | index_name="my_index_with_ids_and_metadata",
137 | collection=my_documents,
138 | document_ids=document_ids,
139 | document_metadatas=document_metadatas,
140 | )
141 | ```
142 |
143 | Once this is done running, your index will be saved on-disk and ready to be queried! RAGatouille and ColBERT handle everything here:
144 | - Splitting your documents
145 | - Tokenizing your documents
146 | - Identifying the individual terms
147 | - Embedding the documents and generating the bags-of-embeddings
148 | - Compressing the vectors and storing them on disk
149 |
150 | Curious about how this works? Check out the [Late-Interaction & ColBERT concept explainer](https://ben.clavie.eu/ragatouille/#late-interaction)
151 |
152 |
153 | ### 🔎 Retrieving Documents
154 |
155 | Once an index is created, querying it is just as simple as creating it! You can either load the model you need directly from an index's configuration:
156 |
157 | ```python
158 | from ragatouille import RAGPretrainedModel
159 |
160 | query = "ColBERT my dear ColBERT, who is the fairest document of them all?"
161 | RAG = RAGPretrainedModel.from_index("path_to_your_index")
162 | results = RAG.search(query)
163 | ```
164 |
165 | This is the preferred way of doing things, since every index saves the full configuration of the model used to create it, and you can easily load it back up.
166 |
167 | `RAG.search` is a flexible method! You can set the `k` value to however many results you want (it defaults to `10`), and you can also use it to search for multiple queries at once:
168 |
169 | ```python
170 | RAG.search(["What manga did Hayao Miyazaki write?",
171 | "Who are the founders of Ghibli?"
172 | "Who is the director of Spirited Away?"],)
173 | ```
174 |
175 | `RAG.search` returns results in the form of a list of dictionaries, or a list of list of dictionaries if you used multiple queries:
176 |
177 | ```python
178 | # single-query result
179 | [
180 | {"content": "blablabla", "score": 42.424242, "rank": 1, "document_id": "x"},
181 | ...,
182 | {"content": "albalbalba", "score": 24.242424, "rank": k, "document_id": "y"},
183 | ]
184 | # multi-query result
185 | [
186 | [
187 | {"content": "blablabla", "score": 42.424242, "rank": 1, "document_id": "x"},
188 | ...,
189 | {"content": "albalbalba", "score": 24.242424, "rank": k, "document_id": "y"},
190 | ],
191 | [
192 | {"content": "blablabla", "score": 42.424242, "rank": 1, "document_id": "x"},
193 | ...,
194 | {"content": "albalbalba", "score": 24.242424, "rank": k, "document_id": "y"},
195 | ],
196 | ],
197 | ```
198 | If your index includes document metadata, it'll be returned as a dictionary in the `document_metadata` key of the result dictionary:
199 |
200 | ```python
201 | [
202 | {"content": "blablabla", "score": 42.424242, "rank": 1, "document_id": "x", "document_metadata": {"A": 1, "B": 2}},
203 | ...,
204 | {"content": "albalbalba", "score": 24.242424, "rank": k, "document_id": "y", "document_metadata": {"A": 3, "B": 4}},
205 | ]
206 | ```
207 |
208 | ## I'm sold, can I integrate late-interaction RAG into my project?
209 |
210 | To get started, RAGatouille bundles everything you need to build a ColBERT native index and query it. Just look at the docs! RAGatouille persists indices on disk in compressed format, and a very viable production deployment is to simply integrate the index you need into your project and query it directly. Don't just take our word for it, this is what Spotify does in production with their own vector search framework, serving dozens of millions of users:
211 |
212 | > Statelessness: Many of Spotify’s systems use nearest-neighbor search in memory, enabling stateless deployments (via Kubernetes) and almost entirely removing the maintenance and cost burden of maintaining a stateful database cluster. (_[Spotify, announcing Voyager](https://engineering.atspotify.com/2023/10/introducing-voyager-spotifys-new-nearest-neighbor-search-library/)_)
213 |
214 |
215 | ### Integrations
216 |
217 | If you'd like to use more than RAGatouille, ColBERT has a growing number of integrations, and they all fully support models trained or fine-tuned with RAGatouille!
218 |
219 | - The [official ColBERT implementation](https://github.com/stanford-futuredata/ColBERT) has a built-in query server (using Flask), which you can easily query via API requests and does support indexes generated with RAGatouille! This should be enough for most small applications, so long as you can persist the index on disk.
220 | - [Vespa](https://vespa.ai) offers a fully managed RAG engine with ColBERT support: it's essentially just like a vector DB, except with many more retrieval options! Full support for ColBERT models will be released in the next couple weeks, and using a RAGatouille-trained model will be as simple as loading it from the huggingface hub! **Vespa is a well-tested, widely used framework and is [fully-supported in LangChain](https://python.langchain.com/docs/integrations/providers/vespa), making it the ideal slot-in replacement to replace your current RAG pipeline with ColBERT!**
221 | - [Intel's FastRAG](https://github.com/IntelLabs/fastRAG) supports ColBERT models for RAG, and is fully compatible with RAGatouille-trained models.
222 | - [LlamaIndex](https://www.llamaindex.ai) is building ColBERT integrations and already [has early ColBERT support, with active development continuing](https://github.com/run-llama/llama_index/pull/9656).
223 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # API Reference
2 |
3 | ::: ragatouille.RAGTrainer
4 |
5 | ::: ragatouille.RAGPretrainedModel
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # RAGatouille
2 |
3 | _State-of-the-art document retrieval methods, in just a few lines of code._
4 |
5 | ---
6 |
7 | Welcome to the docs for RAGatouille. This page presents RAGatouille's philosophy. We also discuss late interaction retrievers, like ColBERT, and what makes their quality so high.
8 |
9 | While you're here, check out the [API reference](https://ben.clavie.eu/ragatouille/api) and the [evolving roadmap](https://ben.clavie.eu/ragatouille/roadmap). The docs will be actively updated in the next few weeks!
10 |
11 | ## Philosophy
12 |
13 | RAGatouille's philosophy is two-fold:
14 |
15 | ### Motivation
16 |
17 | The aim of RAGatouille is to close the growing gap between the Information Retrieval literature and everyday production retrieval uses. I'm not myself an IR researcher (my own background is in NLP), but I have arrived at late interaction retrievers through the many papers highlighting they're consistently better than dense embeddings on just about any zero-shot task, at least when tested apples-to-apples! Maybe more importantly, you don't need to use them zero-shot: they're very easy to adapt to new domains due to their **bag-of-embeddings** approach.
18 |
19 | However, it's been a consistently hard sell, and starting to use ColBERT on real projects wasn't particularly smooth. IR is a field that is outwardly more _stern_ than NLP, and the barrier-to-entry is higher. A lot of the IR frameworks, like Terrier or Anserini, are absolutely fantastic, but they just don't fit into the pythonic day-to-day workflows we're used to.
20 |
21 | For sentence transformers adoption, the aptly (re-)name SentenceTransformers library has been a boon. RAGatouille doesn't quite have the pretention to be that, but it aims to help democratise the easy training and use ColBERT and pals. To do so, we also take an approach of avoiding re-implementing whenever possible, to speed up iteration.
22 |
23 | If a paper has open-sourced its code, our goal is for RAGatouille to be a gateway to it, rather than a complete replacement, whenever possible! (_This might change in the future as the library grows!_) Moreover, this is a mutually beneficial parasitic relationships, as the early development of this lib has already resulted in a few upstreamed fixes on the main ColBERT repo!
24 |
25 | ### Code Philosophy
26 |
27 | The actual programming philosophy is fairly simple, as stated on the [github README](https://github.com/bclavie/RAGatouille):
28 |
29 | - Strong, but parameterable defaults: you should be able to get started with just a few lines of code and still leverage the full power of ColBERT, and you should be able to tweak any relevant parameter if you need to!
30 | - Powerful yet simple re-usable components under-the-hood: any part of the library should be usable stand-alone. You can use our DataProcessor or our negative miners outside of `RAGPretrainedModel` and `RAGTrainer`, and you can even write your own negative miner and use it in the pipeline if you want to!
31 |
32 | In practice, this manifests by the fact that RAGPretrainedModel and RAGTrainer should ultimately be _all you need_ to leverage the power of ColBERT in your pipelines. Everything that gets added to RAGatouille should be workable into these two classes, who are really just interfaces between the underlying models and processors.
33 |
34 | However, re-usable components is another very important aspect. RAGatouille aims to be built in such a way that every core component, such as our negative miners (`SimpleMiner` for dense retrieval at the moment) or data processors should be usable outside the main classes, if you so desire. If you're a seasoned ColBERT afficionado, nothing should stop you from importing `TrainingDataProcessor` to streamline processing and exporting triplets!
35 |
36 | Finally, there's a third point that's fairly important:
37 |
38 | - __Don't re-invent the wheel.__
39 |
40 | If a component needs to do something, we won't seek to do it our way, we'll seek to do it the way people already do it. This means using `LlamaIndex` to chunk documents, `instructor` and `pydantic` to constrain OpenAI calls, or `DSPy` whenever we need more complex LLM-based components!
41 |
42 | ## Late-Interaction
43 |
44 | ### tl;dr
45 |
46 | So, why is late-interaction so good? Why should you use RAGatouille/ColBERT?
47 |
48 | The underlying concept is simple. Quickly put, I like to explain ColBERT as a `bag-of-embeddings` approach, as this makes it immediately obvious how and why ColBERT works to NLP practitioners:
49 |
50 | - Just like bag-of-words, it works on small information units, and represents a document as the sum of them
51 | - Just like embeddings, it works on the semantic level: the actual way something is phrased doesn't matter, the model learns __meaning__.
52 |
53 | ### longer, might read
54 |
55 | A full blog post with more detail about this is coming (soon™), but for now, here's a quick explainer:
56 |
57 | Take it this way, the existing widely-used retrieval approaches, and a quick overview of their pros and cons:
58 |
59 | ##### BM25/Keyword-based Sparse Retrieval
60 |
61 | ➕ Fast
62 | ➕ Consistent performance
63 | ➕ No real training required
64 | ➕ Intuitive
65 | ➖ Requires exact matches
66 | ➖ Does not leverage any semantic information, and thus hits __a hard performance ceiling__
67 |
68 | ##### Cross-Encoders
69 |
70 | ➕ Very strong performance
71 | ➕ Leverages semantic information to a large extent ("understands" negative form so that "I love apples" and "I hate apples" are not similar, etc...)
72 | ➖ Major scalability issues: can only retrieve scores by running the model to compare a query to every single document in the corpus.
73 |
74 | ##### Dense Retrieval/Embeddings
75 |
76 | ➕ Fast
77 | ➕ Decent performance overall, once pre-trained
78 | ➕ Leverages semantic information...
79 | ➖ ... but not constrastive information (e.g. "I love apples" and "I hate apples" will have a high similarity score.)
80 | ➖ Fine-tuning can be finnicky
81 | ➖ Requires either billions of parameters (e5-mistral) or billions of pre-training examples to reach top performance
82 | ➖ __Often generalises poorly__
83 |
84 | ### Generalisation
85 |
86 | This last point is particularly important. __Generalisation__ is what you want, because the documents that you're trying to retrieve for your users, as well as the way that they phrase their queries, are __not the ones present in academic datasets__.
87 |
88 | Strong performance on academic benchmark is a solid signal to predict how well a model will perform, but it's far from the only ones. Single-vector embeddings approach, or _dense retrieval_ methods, often do well on benchmarks, as they're trained specifically for them. However, the IR litterature has shown many times that these models often genralise worse than other approaches.
89 |
90 | This is not a slight on them, it's actually very logical! If you think about it:
91 |
92 | - A single-vector embedding is **the representation of a sentence or document into a very small vector space, with at most 1024 dimensions**.
93 | - In retrieval settings, the **same model must also be able to create similar representations for very short query and long documents, to be able to retrieve them**.
94 | - Then, **these vectors must be able to represent your documents and your users' query, that it has never seen, in the same way that it has learned to represent its training data**.
95 | - And finally, **it must be able to encode all possible information contained in a document or in a query, so that it may be able to find a relevant document no matter how a question is phrased**
96 |
97 | The fact that dense embeddings perform well in these circumstances is very impressive! But sadly, embedding all this information into just a thousand dimensions isn't a problem that has been cracked yet.
98 |
99 | ### Bag-of-Embeddings: the Late Interaction trick
100 |
101 | Alleviating this is where late-interaction comes in.
102 |
103 | ColBERT does not represent documents into a single vector. In fact, ColBERT, at its core, is basically a **keyword-based approach**.
104 |
105 | But why does it perform so well, then, when we've established that keyword matchers have a hard ceiling?
106 |
107 | Because ColBERT is a **semantic keyword matcher**. It leverages the power of strong encoder models, like BERT, to break down each document into a bag of **contextualised units of information**.
108 |
109 | When a document is embedded by ColBERT, it isn't represented as a document, but as the sum of its parts.
110 |
111 | This fundamentally changes the nature of training our model, to a much easier task: it doesn't need to cram every possible meaning into a single vector, it just needs to capture the meaning of a few tokens at a time. When you do this, it doesn't really matter how you phrase something at retrieval time: the likelihood that the model is able to relate the way you mention a topic to the way a document discusses it is considerably higher. This is quite intuitively because the model has so much more space to store information on individual topics! Additionally, because this allows us to create smaller vectors for individual information units, they become very compressable, which means our indexes don't balloon up size.
112 |
--------------------------------------------------------------------------------
/docs/roadmap.md:
--------------------------------------------------------------------------------
1 | # Roadmap
2 |
3 | This page is incorrectly named: RAGatouille doesn't have a set-in-stone roadmap, but rather, a set of objectives.
4 |
5 | Below, you'll find things that we're hoping to integrate and/or support in upcoming versions (⛰️ denotes a major milestone):
6 |
7 | #### New Features
8 |
9 | ##### Synthetic Data Generation
10 |
11 | - Build upon our [tutorial 3](https://github.com/bclavie/RAGatouille/blob/main/examples/03-finetuning_without_annotations_with_instructor_and_RAGatouille.ipynb) and integrate OpenAI query generation into a built-in DataProcessor.
12 | - Leverage [DSPy](https://github.com/stanfordnlp/dspy) to perform data augmentation via LLM compiling, reducing the reliance on API providers by enabling locally-ran models to generate data.
13 | - ⛰️ Integrate [UDAPDR](https://arxiv.org/abs/2303.00807) - UDAPDR is an extremely impressive method to adapt retrievers to a target domain via entirely synthetic query: all you need to provide is your document collection. We're hoping to integrate this in an upcoming version of RAGatouille.
14 | - Provide a toolkit to generate synthetic passages for provided queries.
15 |
16 |
17 | #### Improvements
18 |
19 | - ⛰️ Full ColBERTv2 style training: transparently use an existing cross-encoder teacher model to generate distillation scores and improve model training.
20 | - Evaluation support: at the moment, RAGatouille doesn't roll out any evaluation metrics, as these are more commonly available already. Future versions of RAGatouille will include some form of evaluation for convenience!
21 | - Support for more "late-interaction" models, such as Google's [SparseEmbed](https://research.google/pubs/sparseembed-learning-sparse-lexical-representations-with-contextual-embeddings-for-retrieval/).
22 | - New negative miners, such as ColBERTMiner (not a huge priority as dense hard negative work well enough, but would be a nice feature for thoroughness)
23 | - Full LlamaIndex integration
24 |
25 | #### Library Upkeep
26 |
27 | - ⛰️ Improve the documentation to cover every component and concept of the library in-depth.
28 | - Comprehensive test coverage
--------------------------------------------------------------------------------
/examples/02-basic_training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "c71ca3ef",
6 | "metadata": {},
7 | "source": [
8 | "# Basic training and fine-tuning with RAGatouille\n",
9 | "\n",
10 | "In this quick example, we'll use the `RAGTrainer` magic class to demonstrate how to very easily fine-tune an existing ColBERT model, or train one from any BERT/RoBERTa-like model (to [train one for a previously unsupported language like Japanese](https://huggingface.co/bclavie/jacolbert), for example!)"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "d6f69fca",
16 | "metadata": {},
17 | "source": [
18 | "First, we'll create an instance of `RAGtrainer`. We need to give it two arguments: the `model_name` we want to give to the model we're training, and the `pretrained_model_name` of the base model. This can either be a local path, or the name of a model on the HuggingFace Hub. If you're training for a language other than English, you should also include a `language_code` two-letter argument (PLACEHOLDER ISO) so we can get the relevant processing utils!\n",
19 | "\n",
20 | "The trainer will auto-detect whether it's an existing ColBERT model or a BERT base model, and will set itself up accordingly!\n",
21 | "\n",
22 | "Please note: Training can currently only be ran on GPU, and will error out if using CPU/MPS! Training is also currently not functional on Google Colab and Windows 10.\n",
23 | "\n",
24 | "Whether we're training from scratch or fine-tuning doesn't matter, all the steps are the same. For this example, let's fine-tune ColBERTv2:"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 1,
30 | "id": "1e81ac64-d222-412c-925c-2a5262266c0e",
31 | "metadata": {
32 | "tags": []
33 | },
34 | "outputs": [],
35 | "source": [
36 | "from ragatouille import RAGTrainer\n",
37 | "trainer = RAGTrainer(model_name=\"GhibliColBERT\", pretrained_model_name=\"colbert-ir/colbertv2.0\", language_code=\"en\")"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "6fa8016a",
43 | "metadata": {},
44 | "source": [
45 | "To train retrieval models like colberts, we need training triplets: queries, positive passages, and negative passages for each query.\n",
46 | "\n",
47 | "In the next tutorial, we'll see [how to generate synthetic queries when we don't have any annotated passage]. For this tutorial, we'll assume that we have queries and relevant passages, but that we're lacking negative ones (because it's not an information we gather from our users).\n",
48 | "\n",
49 | "Let's assume our corpus is the same as the one we [used for our example about indexing an searching](PLACEHOLDER): Hayao Miyazaki's wikipedia page. Let's first fetch the content from Wikipedia:"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 2,
55 | "id": "3ca72a0d",
56 | "metadata": {
57 | "tags": []
58 | },
59 | "outputs": [],
60 | "source": [
61 | "import requests\n",
62 | "\n",
63 | "def get_wikipedia_page(title: str):\n",
64 | " \"\"\"\n",
65 | " Retrieve the full text content of a Wikipedia page.\n",
66 | " \n",
67 | " :param title: str - Title of the Wikipedia page.\n",
68 | " :return: str - Full text content of the page as raw string.\n",
69 | " \"\"\"\n",
70 | " # Wikipedia API endpoint\n",
71 | " URL = \"https://en.wikipedia.org/w/api.php\"\n",
72 | "\n",
73 | " # Parameters for the API request\n",
74 | " params = {\n",
75 | " \"action\": \"query\",\n",
76 | " \"format\": \"json\",\n",
77 | " \"titles\": title,\n",
78 | " \"prop\": \"extracts\",\n",
79 | " \"explaintext\": True,\n",
80 | " }\n",
81 | "\n",
82 | " # Custom User-Agent header to comply with Wikipedia's best practices\n",
83 | " headers = {\n",
84 | " \"User-Agent\": \"RAGatouille_tutorial/0.0.1 (ben@clavie.eu)\"\n",
85 | " }\n",
86 | "\n",
87 | " response = requests.get(URL, params=params, headers=headers)\n",
88 | " data = response.json()\n",
89 | "\n",
90 | " # Extracting page content\n",
91 | " page = next(iter(data['query']['pages'].values()))\n",
92 | " return page['extract'] if 'extract' in page else None"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 3,
98 | "id": "8d8ade7f",
99 | "metadata": {
100 | "tags": []
101 | },
102 | "outputs": [],
103 | "source": [
104 | "my_full_corpus = [get_wikipedia_page(\"Hayao_Miyazaki\")]\n",
105 | "my_full_corpus += [get_wikipedia_page(\"Studio_Ghibli\")]\n",
106 | "my_full_corpus += [get_wikipedia_page(\"Toei_Animation\")]"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "id": "31778aed",
112 | "metadata": {},
113 | "source": [
114 | "We're also some Toei Animation content -- it helps to have things in our corpus that aren't directly relevant to our queries but are likely to cover similar topics, so we can get some more adjacent negative examples.\n",
115 | "\n",
116 | "The documents are very long, so let's use a `CorpusProcessor` to split them into chunks of around 256 tokens:"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 4,
122 | "id": "a14fe476",
123 | "metadata": {
124 | "tags": []
125 | },
126 | "outputs": [],
127 | "source": [
128 | "from ragatouille.data import CorpusProcessor, llama_index_sentence_splitter\n",
129 | "\n",
130 | "corpus_processor = CorpusProcessor(document_splitter_fn=llama_index_sentence_splitter)\n",
131 | "documents = corpus_processor.process_corpus(my_full_corpus, chunk_size=256)"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "id": "a26eea4c",
137 | "metadata": {},
138 | "source": [
139 | "Now that we have a corpus of documents, let's generate fake query-relevant passage pair. Obviously, you wouldn't want that in the real world, but that's the topic of the next tutorial:"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 5,
145 | "id": "53b48c50",
146 | "metadata": {
147 | "tags": []
148 | },
149 | "outputs": [],
150 | "source": [
151 | "import random\n",
152 | "\n",
153 | "queries = [\"What manga did Hayao Miyazaki write?\",\n",
154 | " \"which film made ghibli famous internationally\",\n",
155 | " \"who directed Spirited Away?\",\n",
156 | " \"when was Hikotei Jidai published?\",\n",
157 | " \"where's studio ghibli based?\",\n",
158 | " \"where is the ghibli museum?\"\n",
159 | "] * 3\n",
160 | "pairs = []\n",
161 | "\n",
162 | "for query in queries:\n",
163 | " fake_relevant_docs = random.sample(documents, 10)\n",
164 | " for doc in fake_relevant_docs:\n",
165 | " pairs.append((query, doc))"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "id": "46f29b15",
171 | "metadata": {},
172 | "source": [
173 | "Here, we have created pairs.It's common for retrieval training data to be stored in a lot of different ways: pairs of [query, positive], pairs of [query, passage, label], triplets of [query, positive, negative], or triplets of [query, list_of_positives, list_of_negatives]. No matter which format your data's in, you don't need to worry about it: RAGatouille will generate ColBERT-friendly triplets for you, and export them to disk for easy `dvc` or `wandb` data tracking.\n",
174 | "\n",
175 | "Speaking of, let's process the data so it's ready for training. `RAGTrainer` has a `prepare_training_data` function, which will perform all the necessary steps. One of the steps it performs is called **hard negative mining**: that's searching the full collection of documents (even those not linked to a query) to retrieve passages that are semantically close to a query, but aren't actually relevant. Using those to train retrieval models has repeatedly been shown to greatly improve their ability to find actually relevant documents, so it's a very important step! \n",
176 | "\n",
177 | "RAGatouille handles all of this for you. By default, it'll fetch 10 negative examples per query, but you can customise this with `num_new_negatives`. You can also choose not to mine negatives and just sample random examples to speed up things, this might lower performance but will run done much quicker on large volumes of data, just set `mine_hard_negatives` to `False`. If you've already mined negatives yourself, you can set `num_new_negatives` to 0 to bypass this entirely."
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 6,
183 | "id": "1ddaf3fb",
184 | "metadata": {
185 | "tags": []
186 | },
187 | "outputs": [
188 | {
189 | "name": "stdout",
190 | "output_type": "stream",
191 | "text": [
192 | "Loading Hard Negative SimpleMiner dense embedding model BAAI/bge-small-en-v1.5...\n",
193 | "Building hard negative index for 93 documents...\n",
194 | "All documents embedded, now adding to index...\n",
195 | "save_index set to False, skipping saving hard negative index\n",
196 | "Hard negative index generated\n",
197 | "mining\n",
198 | "mining\n",
199 | "mining\n",
200 | "mining\n",
201 | "mining\n",
202 | "mining\n"
203 | ]
204 | },
205 | {
206 | "data": {
207 | "text/plain": [
208 | "'./data/'"
209 | ]
210 | },
211 | "execution_count": 6,
212 | "metadata": {},
213 | "output_type": "execute_result"
214 | }
215 | ],
216 | "source": [
217 | "trainer.prepare_training_data(raw_data=pairs, data_out_path=\"./data/\", all_documents=my_full_corpus, num_new_negatives=10, mine_hard_negatives=True)"
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "id": "b9468e25",
223 | "metadata": {},
224 | "source": [
225 | "Our training data's now fully processed and saved to disk in `data_out_path`! We're now ready to begin training our model with the `train` function. `train` takes many arguments, but the set of default is already fairly strong!\n",
226 | "\n",
227 | "Don't be surprised you don't see an `epochs` parameter here, ColBERT will train until it either reaches `maxsteps` or has seen the entire training data once (a full epoch), this is by design!"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 8,
233 | "id": "06773f75-c844-42a4-b786-e56ef899a96e",
234 | "metadata": {
235 | "tags": []
236 | },
237 | "outputs": [
238 | {
239 | "name": "stdout",
240 | "output_type": "stream",
241 | "text": [
242 | "#> Starting...\n",
243 | "nranks = 1 \t num_gpus = 1 \t device=0\n",
244 | "{\n",
245 | " \"query_token_id\": \"[unused0]\",\n",
246 | " \"doc_token_id\": \"[unused1]\",\n",
247 | " \"query_token\": \"[Q]\",\n",
248 | " \"doc_token\": \"[D]\",\n",
249 | " \"ncells\": null,\n",
250 | " \"centroid_score_threshold\": null,\n",
251 | " \"ndocs\": null,\n",
252 | " \"load_index_with_mmap\": false,\n",
253 | " \"index_path\": null,\n",
254 | " \"nbits\": 4,\n",
255 | " \"kmeans_niters\": 20,\n",
256 | " \"resume\": false,\n",
257 | " \"similarity\": \"cosine\",\n",
258 | " \"bsize\": 32,\n",
259 | " \"accumsteps\": 1,\n",
260 | " \"lr\": 5e-6,\n",
261 | " \"maxsteps\": 500000,\n",
262 | " \"save_every\": 0,\n",
263 | " \"warmup\": 0,\n",
264 | " \"warmup_bert\": null,\n",
265 | " \"relu\": false,\n",
266 | " \"nway\": 2,\n",
267 | " \"use_ib_negatives\": true,\n",
268 | " \"reranker\": false,\n",
269 | " \"distillation_alpha\": 1.0,\n",
270 | " \"ignore_scores\": false,\n",
271 | " \"model_name\": \"GhibliColBERT\",\n",
272 | " \"query_maxlen\": 32,\n",
273 | " \"attend_to_mask_tokens\": false,\n",
274 | " \"interaction\": \"colbert\",\n",
275 | " \"dim\": 128,\n",
276 | " \"doc_maxlen\": 256,\n",
277 | " \"mask_punctuation\": true,\n",
278 | " \"checkpoint\": \"colbert-ir\\/colbertv2.0\",\n",
279 | " \"triples\": \"data\\/triples.train.colbert.jsonl\",\n",
280 | " \"collection\": \"data\\/corpus.train.colbert.tsv\",\n",
281 | " \"queries\": \"data\\/queries.train.colbert.tsv\",\n",
282 | " \"index_name\": null,\n",
283 | " \"overwrite\": false,\n",
284 | " \"root\": \".ragatouille\\/\",\n",
285 | " \"experiment\": \"colbert\",\n",
286 | " \"index_root\": null,\n",
287 | " \"name\": \"2024-01\\/03\\/17.31.59\",\n",
288 | " \"rank\": 0,\n",
289 | " \"nranks\": 1,\n",
290 | " \"amp\": true,\n",
291 | " \"gpus\": 1\n",
292 | "}\n",
293 | "Using config.bsize = 32 (per process) and config.accumsteps = 1\n",
294 | "[Jan 03, 17:32:08] #> Loading the queries from data/queries.train.colbert.tsv ...\n",
295 | "[Jan 03, 17:32:08] #> Got 6 queries. All QIDs are unique.\n",
296 | "\n",
297 | "[Jan 03, 17:32:08] #> Loading collection...\n",
298 | "0M "
299 | ]
300 | },
301 | {
302 | "name": "stderr",
303 | "output_type": "stream",
304 | "text": [
305 | "/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:429: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
306 | " warnings.warn(\n",
307 | "/opt/conda/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:139: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
308 | " warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n"
309 | ]
310 | },
311 | {
312 | "name": "stdout",
313 | "output_type": "stream",
314 | "text": [
315 | "\n",
316 | "#> LR will use 0 warmup steps and linear decay over 500000 steps.\n",
317 | "\n",
318 | "#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==\n",
319 | "#> Input: . What manga did Hayao Miyazaki write?, \t\t True, \t\t None\n",
320 | "#> Output IDs: torch.Size([32]), tensor([ 101, 1, 2054, 8952, 2106, 10974, 7113, 2771, 3148, 18637,\n",
321 | " 4339, 1029, 102, 103, 103, 103, 103, 103, 103, 103,\n",
322 | " 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,\n",
323 | " 103, 103], device='cuda:0')\n",
324 | "#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
325 | " 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')\n",
326 | "\n",
327 | "\t\t\t\t 4.857693195343018 8.471616744995117\n",
328 | "#>>> 16.0 20.45 \t\t|\t\t -4.449999999999999\n",
329 | "[Jan 03, 17:32:12] 0 13.329309463500977\n",
330 | "\t\t\t\t 5.369233131408691 9.21822452545166\n",
331 | "#>>> 15.58 20.6 \t\t|\t\t -5.020000000000001\n",
332 | "[Jan 03, 17:32:12] 1 13.330567611694336\n",
333 | "\t\t\t\t 9.961652755737305 15.010833740234375\n",
334 | "#>>> 10.04 19.99 \t\t|\t\t -9.95\n",
335 | "[Jan 03, 17:32:13] 2 13.342209530578614\n",
336 | "\t\t\t\t 6.136704921722412 12.674407005310059\n",
337 | "#>>> 11.99 17.86 \t\t|\t\t -5.869999999999999\n",
338 | "[Jan 03, 17:32:14] 3 13.347678432498231\n",
339 | "\t\t\t\t 8.609644889831543 13.769636154174805\n",
340 | "#>>> 13.18 21.75 \t\t|\t\t -8.57\n",
341 | "[Jan 03, 17:32:15] 4 13.356710034156064\n",
342 | "[Jan 03, 17:32:15] #> Done with all triples!\n",
343 | "#> Saving a checkpoint to .ragatouille/colbert/none/2024-01/03/17.31.59/checkpoints/colbert ..\n",
344 | "#> Joined...\n"
345 | ]
346 | }
347 | ],
348 | "source": [
349 | "\n",
350 | "trainer.train(batch_size=32,\n",
351 | " nbits=4, # How many bits will the trained model use when compressing indexes\n",
352 | " maxsteps=500000, # Maximum steps hard stop\n",
353 | " use_ib_negatives=True, # Use in-batch negative to calculate loss\n",
354 | " dim=128, # How many dimensions per embedding. 128 is the default and works well.\n",
355 | " learning_rate=5e-6, # Learning rate, small values ([3e-6,3e-5] work best if the base model is BERT-like, 5e-6 is often the sweet spot)\n",
356 | " doc_maxlen=256, # Maximum document length. Because of how ColBERT works, smaller chunks (128-256) work very well.\n",
357 | " use_relu=False, # Disable ReLU -- doesn't improve performance\n",
358 | " warmup_steps=\"auto\", # Defaults to 10%\n",
359 | " )\n"
360 | ]
361 | },
362 | {
363 | "cell_type": "markdown",
364 | "id": "44ab6a51-a4bd-4fab-96d7-dd8e93dac462",
365 | "metadata": {},
366 | "source": [
367 | "And you're now done training! Your model is saved at the path it outputs, with the final checkpoint always being in the `.../checkpoints/colbert` path, and intermediate checkpoints saved at `.../checkpoints/colbert-{N_STEPS}`.\n",
368 | "\n",
369 | "You can now use your model by pointing at its local path, or upload it to the huggingface hub to share it with the world!"
370 | ]
371 | }
372 | ],
373 | "metadata": {
374 | "environment": {
375 | "kernel": "python3",
376 | "name": ".m114",
377 | "type": "gcloud",
378 | "uri": "gcr.io/deeplearning-platform-release/:m114"
379 | },
380 | "kernelspec": {
381 | "display_name": "Python 3",
382 | "language": "python",
383 | "name": "python3"
384 | },
385 | "language_info": {
386 | "codemirror_mode": {
387 | "name": "ipython",
388 | "version": 3
389 | },
390 | "file_extension": ".py",
391 | "mimetype": "text/x-python",
392 | "name": "python",
393 | "nbconvert_exporter": "python",
394 | "pygments_lexer": "ipython3",
395 | "version": "3.10.13"
396 | }
397 | },
398 | "nbformat": 4,
399 | "nbformat_minor": 5
400 | }
401 |
--------------------------------------------------------------------------------
/examples/06-index_free_use.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Using ColBERT in-memory: Index-Free Encodings & Search\n",
8 | "\n",
9 | "Sometimes, building an index doesn't make sense. Maybe you're working with a really small dataset, or one that is really fleeting nature, and will only be relevant to the lifetime of your current instance. In these cases, it can be more efficient to skip all the time-consuming index optimisation, and keep your encodings in-memory to perform ColBERT's magical MaxSim on-the-fly. This doesn't scale very well, but can be very useful in certain settings.\n",
10 | "\n",
11 | "In this quick example, we'll use the `RAGPretrainedModel` magic class to demonstrate how to **encode documents in-memory**, before **retrieving them with `search_encoded_docs`**.\n",
12 | "\n",
13 | "First, as usual, let's load up a pre-trained ColBERT model:"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 1,
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "name": "stderr",
23 | "output_type": "stream",
24 | "text": [
25 | "/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
26 | " from .autonotebook import tqdm as notebook_tqdm\n"
27 | ]
28 | },
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "[Jan 27, 16:56:39] Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n"
34 | ]
35 | },
36 | {
37 | "name": "stderr",
38 | "output_type": "stream",
39 | "text": [
40 | "/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py:125: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n",
41 | " warnings.warn(\n"
42 | ]
43 | }
44 | ],
45 | "source": [
46 | "from ragatouille import RAGPretrainedModel\n",
47 | "\n",
48 | "RAG = RAGPretrainedModel.from_pretrained(\"colbert-ir/colbertv2.0\")"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "Now that our model is loaded, we can load and preprocess some data, as in the previous tutorials:"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 2,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "from ragatouille.utils import get_wikipedia_page\n",
65 | "from ragatouille.data import CorpusProcessor\n",
66 | "\n",
67 | "corpus_processor = CorpusProcessor()\n",
68 | "\n",
69 | "documents = [get_wikipedia_page(\"Hayao Miyazaki\"), get_wikipedia_page(\"Studio Ghibli\"), get_wikipedia_page(\"Princess Mononoke\"), get_wikipedia_page(\"Shrek\")]\n",
70 | "documents = corpus_processor.process_corpus(documents, chunk_size=200)"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "Our documents are now fully ready to be encoded! \n",
78 | "\n",
79 | "One important note: `encode()` itself will not split your documents, you must pre-process them yourself (using corpus_processor or your preferred chunking approach). However, `encode()` will dynamically set the maximum token length, calculated based on the token length distribution in your corpus, up to the maximum length supported by the model you're using.\n",
80 | "\n",
81 | "Just like normal indexing, `encode()` also supports adding metadata to the encoded documents, which will be returned as part of query results:"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 3,
87 | "metadata": {},
88 | "outputs": [
89 | {
90 | "name": "stdout",
91 | "output_type": "stream",
92 | "text": [
93 | "Encoding 212 documents...\n"
94 | ]
95 | },
96 | {
97 | "name": "stderr",
98 | "output_type": "stream",
99 | "text": [
100 | " 0%| | 0/7 [00:00, ?it/s]/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
101 | " warnings.warn(\n",
102 | " 14%|█▍ | 1/7 [00:03<00:21, 3.62s/it]/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
103 | " warnings.warn(\n",
104 | "100%|██████████| 7/7 [00:20<00:00, 2.90s/it]"
105 | ]
106 | },
107 | {
108 | "name": "stdout",
109 | "output_type": "stream",
110 | "text": [
111 | "Shapes:\n",
112 | "encodings: torch.Size([212, 256, 128])\n",
113 | "doc_masks: torch.Size([212, 256])\n",
114 | "Documents encoded!\n"
115 | ]
116 | },
117 | {
118 | "name": "stderr",
119 | "output_type": "stream",
120 | "text": [
121 | "\n"
122 | ]
123 | }
124 | ],
125 | "source": [
126 | "RAG.encode([x['content'] for x in documents], document_metadatas=[{\"about\": \"ghibli\"} for _ in range(len(documents))])"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 4,
132 | "metadata": {},
133 | "outputs": [
134 | {
135 | "data": {
136 | "text/plain": [
137 | "[{'content': 'The studio is also known for its strict \"no-edits\" policy in licensing their films abroad due to Nausicaä of the Valley of the Wind being heavily edited for the film\\'s release in the United States as Warriors of the Wind.\\n\\n\\n=== Independent era ===\\nBetween 1999 and 2005, Studio Ghibli was a subsidiary brand of Tokuma Shoten; however, that partnership ended in April 2005, when Studio Ghibli was spun off from Tokuma Shoten and was re-established as an independent company with relocated headquarters.\\nOn February 1, 2008, Toshio Suzuki stepped down from the position of Studio Ghibli president, which he had held since 2005, and Koji Hoshino (former president of Walt Disney Japan) took over. Suzuki said he wanted to improve films with his own hands as a producer, rather than demanding this from his employees.',\n",
138 | " 'score': 15.333166122436523,\n",
139 | " 'rank': 0,\n",
140 | " 'result_index': 80,\n",
141 | " 'document_metadata': {'about': 'ghibli'}},\n",
142 | " {'content': 'Saeko Himuro\\'s novel Umi ga Kikoeru was serialised in the magazine and subsequently adapted into Ocean Waves, Studio Ghibli\\'s first animated feature-length film created for television. It was directed by Tomomi Mochizuki.In October 2001, the Ghibli Museum opened in Mitaka, Tokyo. It contains exhibits based on Studio Ghibli films and shows animations, including a number of short Studio Ghibli films not available elsewhere.\\nThe studio is also known for its strict \"no-edits\" policy in licensing their films abroad due to Nausicaä of the Valley of the Wind being heavily edited for the film\\'s release in the United States as Warriors of the Wind.',\n",
143 | " 'score': 14.356232643127441,\n",
144 | " 'rank': 1,\n",
145 | " 'result_index': 79,\n",
146 | " 'document_metadata': {'about': 'ghibli'}},\n",
147 | " {'content': \"Studio Ghibli, Inc. (Japanese: 株式会社スタジオジブリ, Hepburn: Kabushiki-gaisha Sutajio Jiburi) is a Japanese animation studio based in Koganei, Tokyo. It has a strong presence in the animation industry and has expanded its portfolio to include various media formats, such as short subjects, television commercials, and two television films. Their work has been well-received by audiences and recognized with numerous awards. Their mascot and most recognizable symbol, the character Totoro from the 1988 film My Neighbor Totoro, is a giant spirit inspired by raccoon dogs (tanuki) and cats (neko). Among the studio's highest-grossing films are Spirited Away (2001), Howl's Moving Castle (2004), and Ponyo (2008).\",\n",
148 | " 'score': 12.45059871673584,\n",
149 | " 'rank': 2,\n",
150 | " 'result_index': 71,\n",
151 | " 'document_metadata': {'about': 'ghibli'}}]"
152 | ]
153 | },
154 | "execution_count": 4,
155 | "metadata": {},
156 | "output_type": "execute_result"
157 | }
158 | ],
159 | "source": [
160 | "RAG.search_encoded_docs(query = \"What's Gihbli's famous policy?\", k=3)"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {},
166 | "source": [
167 | "And that's pretty much it for index-free encoding & querying!\n",
168 | "\n",
169 | "But wait, what if your application needs to update dynamically, and accept new documents? Well, that's easy too! A `RAGPretrainedModel` will keep its encoded docs in-memory, and further `encode()` calls will add to it:"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 5,
175 | "metadata": {},
176 | "outputs": [
177 | {
178 | "name": "stdout",
179 | "output_type": "stream",
180 | "text": [
181 | "Encoding 2 documents...\n"
182 | ]
183 | },
184 | {
185 | "name": "stderr",
186 | "output_type": "stream",
187 | "text": [
188 | "100%|██████████| 1/1 [00:00<00:00, 10.43it/s]"
189 | ]
190 | },
191 | {
192 | "name": "stdout",
193 | "output_type": "stream",
194 | "text": [
195 | "Shapes:\n",
196 | "encodings: torch.Size([2, 256, 128])\n",
197 | "doc_masks: torch.Size([2, 256])\n",
198 | "Documents encoded!\n"
199 | ]
200 | },
201 | {
202 | "name": "stderr",
203 | "output_type": "stream",
204 | "text": [
205 | "\n"
206 | ]
207 | },
208 | {
209 | "data": {
210 | "text/plain": [
211 | "[{'content': \"I'm a new document about the importance of Curry! I love curry, it's the best food! Do you like Curry too?\",\n",
212 | " 'score': 18.96149444580078,\n",
213 | " 'rank': 0,\n",
214 | " 'result_index': 212,\n",
215 | " 'document_metadata': {'about': 'new_document'}}]"
216 | ]
217 | },
218 | "execution_count": 5,
219 | "metadata": {},
220 | "output_type": "execute_result"
221 | }
222 | ],
223 | "source": [
224 | "my_new_document = [\n",
225 | " \"I'm a new document about the importance of Curry! I love curry, it's the best food! Do you like Curry too?\",\n",
226 | " \"I'm a second new document!\"\n",
227 | "]\n",
228 | "RAG.encode(my_new_document, document_metadatas=[{\"about\": \"new_document\"} for _ in range(len(my_new_document))])\n",
229 | "RAG.search_encoded_docs(query = \"What's the best food?\", k=1)"
230 | ]
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "metadata": {},
235 | "source": [
236 | "What if you want to keep your current `RAGPretrainedModel` loaded, but empty the in-memory encodings because the docs are expired and you need to encode new ones? You can do that easily too: just call `clear_encoded_docs()`. By default, this will wait for 10 seconds before deleting everything, but you can pass `force=True` to delete immediately:"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 6,
242 | "metadata": {},
243 | "outputs": [
244 | {
245 | "name": "stdout",
246 | "output_type": "stream",
247 | "text": [
248 | "All in-memory encodings will be deleted in 10 seconds, interrupt now if you want to keep them!\n",
249 | "...\n"
250 | ]
251 | }
252 | ],
253 | "source": [
254 | "RAG.clear_encoded_docs()"
255 | ]
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "metadata": {},
260 | "source": [
261 | "And we can now encode new documents and query them, with no trace of the previous encodings:"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 7,
267 | "metadata": {},
268 | "outputs": [
269 | {
270 | "name": "stdout",
271 | "output_type": "stream",
272 | "text": [
273 | "Encoding 2 documents...\n"
274 | ]
275 | },
276 | {
277 | "name": "stderr",
278 | "output_type": "stream",
279 | "text": [
280 | " 0%| | 0/1 [00:00, ?it/s]"
281 | ]
282 | },
283 | {
284 | "name": "stderr",
285 | "output_type": "stream",
286 | "text": [
287 | "100%|██████████| 1/1 [00:00<00:00, 4.49it/s]"
288 | ]
289 | },
290 | {
291 | "name": "stdout",
292 | "output_type": "stream",
293 | "text": [
294 | "Shapes:\n",
295 | "encodings: torch.Size([2, 256, 128])\n",
296 | "doc_masks: torch.Size([2, 256])\n",
297 | "Documents encoded!\n"
298 | ]
299 | },
300 | {
301 | "name": "stderr",
302 | "output_type": "stream",
303 | "text": [
304 | "\n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "RAG.encode(documents=[\"This a really good document about Ratatouille. Ratatouille is a French dish...\",\n",
310 | " \"This is a document that is absolutely and utterly relevant to anything\"])"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 8,
316 | "metadata": {},
317 | "outputs": [
318 | {
319 | "data": {
320 | "text/plain": [
321 | "[{'content': 'This a really good document about Ratatouille. Ratatouille is a French dish...',\n",
322 | " 'score': 8.764448165893555,\n",
323 | " 'rank': 0,\n",
324 | " 'result_index': 0}]"
325 | ]
326 | },
327 | "execution_count": 8,
328 | "metadata": {},
329 | "output_type": "execute_result"
330 | }
331 | ],
332 | "source": [
333 | "RAG.search_encoded_docs(query = \"What do you know about dishes? Curry maybe?\", k=1)"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "metadata": {},
339 | "source": [
340 | "Here it is! No trace of our previous, very important document about curry, but we can enjoy some Ratatouille facts instead."
341 | ]
342 | }
343 | ],
344 | "metadata": {
345 | "kernelspec": {
346 | "display_name": "ragatouille",
347 | "language": "python",
348 | "name": "python3"
349 | },
350 | "language_info": {
351 | "codemirror_mode": {
352 | "name": "ipython",
353 | "version": 3
354 | },
355 | "file_extension": ".py",
356 | "mimetype": "text/x-python",
357 | "name": "python",
358 | "nbconvert_exporter": "python",
359 | "pygments_lexer": "ipython3",
360 | "version": "3.9.18"
361 | }
362 | },
363 | "nbformat": 4,
364 | "nbformat_minor": 2
365 | }
366 |
--------------------------------------------------------------------------------
/examples/data/llama2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnswerDotAI/RAGatouille/e75b8a964a870dea886548f78da1900804749040/examples/data/llama2.pdf
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: RAGatouille
2 | site_author: Benjamin Clavié
3 | site_description: Bridging the gap between state-of-the-art Retrieval research and RAG applications.
4 | repo_name: RAGatouille
5 | repo_url: https://github.com/bclavie/RAGatouille/
6 | site_url: https://ragatouille.clavie.eu
7 | edit_uri: edit/main/docs/
8 | copyright: Copyright © 2024 Benjamin Clavié
9 | plugins:
10 | - search
11 | - mkdocstrings:
12 | handlers:
13 | python:
14 | options:
15 | members_order: alphabetical
16 | allow_inspection: true
17 | show_bases: true
18 | theme:
19 | name: material
20 | features:
21 | - content.code.annotate
22 | - content.code.copy
23 | - content.code.select
24 | - content.tabs.link
25 | - content.tooltips
26 | - header.autohide
27 | - navigation.expand
28 | - navigation.footer
29 | - navigation.indexes
30 | - navigation.instant
31 | - navigation.instant.prefetch
32 | - navigation.instant.progress
33 | - navigation.prune
34 | - navigation.sections
35 | - navigation.tabs
36 | - navigation.top
37 | - navigation.tracking
38 | - toc.follow
39 | palette:
40 | - scheme: default
41 | primary: green
42 | accent: teal
43 | toggle:
44 | icon: material/brightness-5
45 | name: Switch to dark mode
46 | - scheme: slate
47 | primary: black
48 | accent: green
49 | toggle:
50 | icon: material/brightness-6
51 | name: Switch to light mode
52 | font:
53 | text: Roboto
54 | code: Roboto Mono
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.setuptools]
6 | packages = [
7 | "ragatouille",
8 | "ragatouille.models",
9 | "ragatouille.data",
10 | "ragatouille.integrations",
11 | "ragatouille.negative_miners",
12 | ]
13 |
14 | [project]
15 | name = "RAGatouille"
16 | version = "0.0.9post2"
17 | description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
18 | keywords = ["reranking", "retrieval", "rag", "nlp"]
19 | authors = [
20 | {name = "Ben Clavié", email = "bc@answer.ai" }
21 | ]
22 | maintainers = [
23 | {name = "Ben Clavié", email = "bc@answer.ai" }
24 | ]
25 | license = {file = "LICENSE"}
26 | readme = "README.md"
27 |
28 | dependencies = [
29 | "llama-index",
30 | "faiss-cpu",
31 | "langchain_core",
32 | "colbert-ai>=0.2.19",
33 | "langchain",
34 | "onnx",
35 | "srsly",
36 | "voyager",
37 | "torch>=1.13",
38 | "fast-pytorch-kmeans",
39 | "sentence-transformers",
40 | ]
41 |
42 | [project.optional-dependencies]
43 | all = [
44 | "llama-index",
45 | "langchain",
46 | "rerankers",
47 | "voyager",
48 | ]
49 | llamaindex = ["llama-index"]
50 | langchain = ["langchain"]
51 | train = ["sentence-transformers", "pylate", "rerankers"]
52 |
53 | [project.urls]
54 | "Homepage" = "https://github.com/answerdotai/ragatouille"
55 |
56 | [tool.pytest.ini_options]
57 | filterwarnings = [
58 | "ignore::Warning"
59 | ]
60 |
61 | target-version = "py39"
--------------------------------------------------------------------------------
/ragatouille/RAGPretrainedModel.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union
3 | from uuid import uuid4
4 |
5 | from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
6 | from langchain_core.retrievers import BaseRetriever
7 |
8 | from ragatouille.data.corpus_processor import CorpusProcessor
9 | from ragatouille.data.preprocessors import llama_index_sentence_splitter
10 | from ragatouille.integrations import (
11 | RAGatouilleLangChainCompressor,
12 | RAGatouilleLangChainRetriever,
13 | )
14 | from ragatouille.models import ColBERT, LateInteractionModel
15 |
16 |
17 | class RAGPretrainedModel:
18 | """
19 | Wrapper class for a pretrained RAG late-interaction model, and all the associated utilities.
20 | Allows you to load a pretrained model from disk or from the hub, build or query an index.
21 |
22 | ## Usage
23 |
24 | Load a pre-trained checkpoint:
25 |
26 | ```python
27 | from ragatouille import RAGPretrainedModel
28 |
29 | RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
30 | ```
31 |
32 | Load checkpoint from an existing index:
33 |
34 | ```python
35 | from ragatouille import RAGPretrainedModel
36 |
37 | RAG = RAGPretrainedModel.from_index("path/to/my/index")
38 | ```
39 |
40 | Both methods will load a fully initialised instance of ColBERT, which you can use to build and query indexes.
41 |
42 | ```python
43 | RAG.search("How many people live in France?")
44 | ```
45 | """
46 |
47 | model_name: Union[str, None] = None
48 | model: Union[LateInteractionModel, None] = None
49 | corpus_processor: Optional[CorpusProcessor] = None
50 |
51 | @classmethod
52 | def from_pretrained(
53 | cls,
54 | pretrained_model_name_or_path: Union[str, Path],
55 | n_gpu: int = -1,
56 | verbose: int = 1,
57 | index_root: Optional[str] = None,
58 | ):
59 | """Load a ColBERT model from a pre-trained checkpoint.
60 |
61 | Parameters:
62 | pretrained_model_name_or_path (str): Local path or huggingface model name.
63 | n_gpu (int): Number of GPUs to use. By default, value is -1, which means use all available GPUs or none if no GPU is available.
64 | verbose (int): The level of ColBERT verbosity requested. By default, 1, which will filter out most internal logs.
65 | index_root (Optional[str]): The root directory where indexes will be stored. If None, will use the default directory, '.ragatouille/'.
66 |
67 | Returns:
68 | cls (RAGPretrainedModel): The current instance of RAGPretrainedModel, with the model initialised.
69 | """
70 | instance = cls()
71 | instance.model = ColBERT(
72 | pretrained_model_name_or_path, n_gpu, index_root=index_root, verbose=verbose
73 | )
74 | return instance
75 |
76 | @classmethod
77 | def from_index(
78 | cls, index_path: Union[str, Path], n_gpu: int = -1, verbose: int = 1
79 | ):
80 | """Load an Index and the associated ColBERT encoder from an existing document index.
81 |
82 | Parameters:
83 | index_path (Union[str, path]): Path to the index.
84 | n_gpu (int): Number of GPUs to use. By default, value is -1, which means use all available GPUs or none if no GPU is available.
85 | verbose (int): The level of ColBERT verbosity requested. By default, 1, which will filter out most internal logs.
86 |
87 | Returns:
88 | cls (RAGPretrainedModel): The current instance of RAGPretrainedModel, with the model and index initialised.
89 | """
90 | instance = cls()
91 | index_path = Path(index_path)
92 | instance.model = ColBERT(
93 | index_path, n_gpu, verbose=verbose, load_from_index=True
94 | )
95 |
96 | return instance
97 |
98 | def _process_metadata(
99 | self,
100 | document_ids: Optional[Union[TypeVar("T"), List[TypeVar("T")]]],
101 | document_metadatas: Optional[list[dict[Any, Any]]],
102 | collection_len: int,
103 | ) -> tuple[list[str], Optional[dict[Any, Any]]]:
104 | if document_ids is None:
105 | document_ids = [str(uuid4()) for i in range(collection_len)]
106 | else:
107 | if len(document_ids) != collection_len:
108 | raise ValueError("document_ids must be the same length as collection")
109 | if len(document_ids) != len(set(document_ids)):
110 | raise ValueError("document_ids must be unique")
111 | if any(not id.strip() for id in document_ids):
112 | raise ValueError("document_ids must not contain empty strings")
113 | if not all(isinstance(id, type(document_ids[0])) for id in document_ids):
114 | raise ValueError("All document_ids must be of the same type")
115 |
116 | if document_metadatas is not None:
117 | if len(document_metadatas) != collection_len:
118 | raise ValueError(
119 | "document_metadatas must be the same length as collection"
120 | )
121 | docid_metadata_map = {
122 | x: y for x, y in zip(document_ids, document_metadatas)
123 | }
124 | else:
125 | docid_metadata_map = None
126 |
127 | return document_ids, docid_metadata_map
128 |
129 | def _process_corpus(
130 | self,
131 | collection: List[str],
132 | document_ids: List[str],
133 | document_metadatas: List[Dict[Any, Any]],
134 | document_splitter_fn: Optional[Callable[[str], List[str]]],
135 | preprocessing_fn: Optional[Callable[[str], str]],
136 | max_document_length: int,
137 | ) -> Tuple[List[str], Dict[int, str], Dict[str, Dict[Any, Any]]]:
138 | """
139 | Processes a collection of documents by assigning unique IDs, splitting documents if necessary,
140 | applying preprocessing, and organizing metadata.
141 | """
142 | document_ids, docid_metadata_map = self._process_metadata(
143 | document_ids=document_ids,
144 | document_metadatas=document_metadatas,
145 | collection_len=len(collection),
146 | )
147 |
148 | if document_splitter_fn is not None or preprocessing_fn is not None:
149 | self.corpus_processor = CorpusProcessor(
150 | document_splitter_fn=document_splitter_fn,
151 | preprocessing_fn=preprocessing_fn,
152 | )
153 | collection_with_ids = self.corpus_processor.process_corpus(
154 | collection,
155 | document_ids,
156 | chunk_size=max_document_length,
157 | )
158 | else:
159 | collection_with_ids = [
160 | {"document_id": x, "content": y}
161 | for x, y in zip(document_ids, collection)
162 | ]
163 |
164 | pid_docid_map = {
165 | index: item["document_id"] for index, item in enumerate(collection_with_ids)
166 | }
167 | collection = [x["content"] for x in collection_with_ids]
168 |
169 | return collection, pid_docid_map, docid_metadata_map
170 |
171 | def index(
172 | self,
173 | collection: list[str],
174 | document_ids: Union[TypeVar("T"), List[TypeVar("T")]] = None,
175 | document_metadatas: Optional[list[dict]] = None,
176 | index_name: str = None,
177 | overwrite_index: Union[bool, str] = True,
178 | max_document_length: int = 256,
179 | split_documents: bool = True,
180 | document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
181 | preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
182 | bsize: int = 32,
183 | use_faiss: bool = False,
184 | ):
185 | """Build an index from a list of documents.
186 |
187 | Parameters:
188 | collection (list[str]): The collection of documents to index.
189 | document_ids (Optional[list[str]]): An optional list of document ids. Ids will be generated at index time if not supplied.
190 | index_name (str): The name of the index that will be built.
191 | overwrite_index (Union[bool, str]): Whether to overwrite an existing index with the same name.
192 | max_document_length (int): The maximum length of a document. Documents longer than this will be split into chunks.
193 | split_documents (bool): Whether to split documents into chunks.
194 | document_splitter_fn (Optional[Callable]): A function to split documents into chunks. If None and by default, will use the llama_index_sentence_splitter.
195 | preprocessing_fn (Optional[Union[Callable, list[Callable]]]): A function or list of functions to preprocess documents. If None and by default, will not preprocess documents.
196 | bsize (int): The batch size to use for encoding the passages.
197 |
198 | Returns:
199 | index (str): The path to the index that was built.
200 | """
201 | if not split_documents:
202 | document_splitter_fn = None
203 | collection, pid_docid_map, docid_metadata_map = self._process_corpus(
204 | collection,
205 | document_ids,
206 | document_metadatas,
207 | document_splitter_fn,
208 | preprocessing_fn,
209 | max_document_length,
210 | )
211 | return self.model.index(
212 | collection,
213 | pid_docid_map=pid_docid_map,
214 | docid_metadata_map=docid_metadata_map,
215 | index_name=index_name,
216 | max_document_length=max_document_length,
217 | overwrite=overwrite_index,
218 | bsize=bsize,
219 | use_faiss=use_faiss,
220 | )
221 |
222 | def add_to_index(
223 | self,
224 | new_collection: list[str],
225 | new_document_ids: Optional[Union[TypeVar("T"), List[TypeVar("T")]]] = None,
226 | new_document_metadatas: Optional[list[dict]] = None,
227 | index_name: Optional[str] = None,
228 | split_documents: bool = True,
229 | document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
230 | preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
231 | bsize: int = 32,
232 | use_faiss: bool = False,
233 | ):
234 | """Add documents to an existing index.
235 |
236 | Parameters:
237 | new_collection (list[str]): The documents to add to the index.
238 | new_document_metadatas (Optional[list[dict]]): An optional list of metadata dicts
239 | index_name (Optional[str]): The name of the index to add documents to. If None and by default, will add documents to the already initialised one.
240 | bsize (int): The batch size to use for encoding the passages.
241 | """
242 | if not split_documents:
243 | document_splitter_fn = None
244 |
245 | (
246 | new_collection,
247 | new_pid_docid_map,
248 | new_docid_metadata_map,
249 | ) = self._process_corpus(
250 | new_collection,
251 | new_document_ids,
252 | new_document_metadatas,
253 | document_splitter_fn,
254 | preprocessing_fn,
255 | self.model.config.doc_maxlen,
256 | )
257 |
258 | self.model.add_to_index(
259 | new_collection,
260 | new_pid_docid_map,
261 | new_docid_metadata_map=new_docid_metadata_map,
262 | index_name=index_name,
263 | bsize=bsize,
264 | use_faiss=use_faiss,
265 | )
266 |
267 | def delete_from_index(
268 | self,
269 | document_ids: Union[TypeVar("T"), List[TypeVar("T")]],
270 | index_name: Optional[str] = None,
271 | ):
272 | """Delete documents from an index by their IDs.
273 |
274 | Parameters:
275 | document_ids (Union[TypeVar("T"), List[TypeVar("T")]]): The IDs of the documents to delete.
276 | index_name (Optional[str]): The name of the index to delete documents from. If None and by default, will delete documents from the already initialised one.
277 | """
278 | self.model.delete_from_index(
279 | document_ids,
280 | index_name=index_name,
281 | )
282 |
283 | def search(
284 | self,
285 | query: Union[str, list[str]],
286 | index_name: Optional["str"] = None,
287 | k: int = 10,
288 | force_fast: bool = False,
289 | zero_index_ranks: bool = False,
290 | doc_ids: Optional[list[str]] = None,
291 | **kwargs,
292 | ):
293 | """Query an index.
294 |
295 | Parameters:
296 | query (Union[str, list[str]]): The query or list of queries to search for.
297 | index_name (Optional[str]): Provide the name of an index to query. If None and by default, will query an already initialised one.
298 | k (int): The number of results to return for each query.
299 | force_fast (bool): Whether to force the use of a faster but less accurate search method.
300 | zero_index_ranks (bool): Whether to zero the index ranks of the results. By default, result rank 1 is the highest ranked result
301 |
302 | Returns:
303 | results (Union[list[dict], list[list[dict]]]): A list of dict containing individual results for each query. If a list of queries is provided, returns a list of lists of dicts. Each result is a dict with keys `content`, `score`, `rank`, and 'document_id'. If metadata was indexed for the document, it will be returned under the "document_metadata" key.
304 |
305 | Individual results are always in the format:
306 | ```python3
307 | {"content": "text of the relevant passage", "score": 0.123456, "rank": 1, "document_id": "x"}
308 | ```
309 | or
310 | ```python3
311 | {"content": "text of the relevant passage", "score": 0.123456, "rank": 1, "document_id": "x", "document_metadata": {"metadata_key": "metadata_value", ...}}
312 | ```
313 |
314 | """
315 | return self.model.search(
316 | query=query,
317 | index_name=index_name,
318 | k=k,
319 | force_fast=force_fast,
320 | zero_index_ranks=zero_index_ranks,
321 | doc_ids=doc_ids,
322 | **kwargs,
323 | )
324 |
325 | def rerank(
326 | self,
327 | query: Union[str, list[str]],
328 | documents: list[str],
329 | k: int = 10,
330 | zero_index_ranks: bool = False,
331 | bsize: Union[Literal["auto"], int] = "auto",
332 | ):
333 | """Encode documents and rerank them in-memory. Performance degrades rapidly with more documents.
334 |
335 | Parameters:
336 | query (Union[str, list[str]]): The query or list of queries to search for.
337 | documents (list[str]): The documents to rerank.
338 | k (int): The number of results to return for each query.
339 | zero_index_ranks (bool): Whether to zero the index ranks of the results. By default, result rank 1 is the highest ranked result
340 | bsize (int): The batch size to use for re-ranking.
341 |
342 | Returns:
343 | results (Union[list[dict], list[list[dict]]]): A list of dict containing individual results for each query. If a list of queries is provided, returns a list of lists of dicts. Each result is a dict with keys `content`, `score` and `rank`.
344 |
345 | Individual results are always in the format:
346 | ```python3
347 | {"content": "text of the relevant passage", "score": 0.123456, "rank": 1}
348 | ```
349 | """
350 |
351 | return self.model.rank(
352 | query=query,
353 | documents=documents,
354 | k=k,
355 | zero_index_ranks=zero_index_ranks,
356 | bsize=bsize,
357 | )
358 |
359 | def encode(
360 | self,
361 | documents: list[str],
362 | bsize: Union[Literal["auto"], int] = "auto",
363 | document_metadatas: Optional[list[dict]] = None,
364 | verbose: bool = True,
365 | max_document_length: Union[Literal["auto"], int] = "auto",
366 | ):
367 | """Encode documents in memory to be searched through with no Index. Performance degrades rapidly with more documents.
368 |
369 | Parameters:
370 | documents (list[str]): The documents to encode.
371 | bsize (int): The batch size to use for encoding.
372 | document_metadatas (Optional[list[dict]]): An optional list of metadata dicts. Each entry must correspond to a document.
373 | """
374 | if verbose:
375 | print(f"Encoding {len(documents)} documents...")
376 | self.model.encode(
377 | documents=documents,
378 | bsize=bsize,
379 | document_metadatas=document_metadatas,
380 | verbose=verbose,
381 | max_tokens=max_document_length,
382 | )
383 | if verbose:
384 | print("Documents encoded!")
385 |
386 | def search_encoded_docs(
387 | self,
388 | query: Union[str, list[str]],
389 | k: int = 10,
390 | bsize: int = 32,
391 | ) -> list[dict[str, Any]]:
392 | """Search through documents encoded in-memory.
393 |
394 | Parameters:
395 | query (Union[str, list[str]]): The query or list of queries to search for.
396 | k (int): The number of results to return for each query.
397 | batch_size (int): The batch size to use for searching.
398 |
399 | Returns:
400 | results (list[dict[str, Any]]): A list of dict containing individual results for each query. If a list of queries is provided, returns a list of lists of dicts.
401 | """
402 | return self.model.search_encoded_docs(
403 | queries=query,
404 | k=k,
405 | bsize=bsize,
406 | )
407 |
408 | def clear_encoded_docs(self, force: bool = False):
409 | """Clear documents encoded in-memory.
410 |
411 | Parameters:
412 | force (bool): Whether to force the clearing of encoded documents without enforcing a 10s wait time.
413 | """
414 | self.model.clear_encoded_docs(force=force)
415 |
416 | def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
417 | return RAGatouilleLangChainRetriever(model=self, kwargs=kwargs)
418 |
419 | def as_langchain_document_compressor(
420 | self, k: int = 5, **kwargs: Any
421 | ) -> BaseDocumentCompressor:
422 | return RAGatouilleLangChainCompressor(model=self, k=k, kwargs=kwargs)
423 |
--------------------------------------------------------------------------------
/ragatouille/RAGTrainer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Literal, Optional, Union
3 |
4 | from colbert.infra import ColBERTConfig
5 |
6 | from ragatouille.data import TrainingDataProcessor
7 | from ragatouille.models import ColBERT, LateInteractionModel
8 | from ragatouille.negative_miners import HardNegativeMiner, SimpleMiner
9 | from ragatouille.utils import seeded_shuffle
10 |
11 |
12 | class RAGTrainer:
13 | """Main trainer to fine-tune/train ColBERT models with a few lines."""
14 |
15 | def __init__(
16 | self,
17 | model_name: str,
18 | pretrained_model_name: str,
19 | language_code: str = "en",
20 | n_usable_gpus: int = -1,
21 | ):
22 | """
23 | Initialise a RAGTrainer instance. This will load a base model: either an existing ColBERT model to fine-tune or a BERT/RoBERTa-like model to build a new ColBERT model from.
24 |
25 | Parameters:
26 | model_name: str - Name of the model to train. This will be used to name the checkpoints and the index.
27 | pretrained_model_name: str - Name of the pretrained model to use as a base. Can be a local path to a checkpoint or a huggingface model name.
28 | language_code: str - Language code of the model to train. This will be used to name the checkpoints and the index.
29 | n_usable_gpus: int - Number of GPUs to use. By default, value is -1, which means use all available GPUs or none if no GPU is available.
30 |
31 | Returns:
32 | self (RAGTrainer): The current instance of RAGTrainer, with the base model initialised.
33 | """
34 | self.model_name = model_name
35 | self.pretrained_model_name = pretrained_model_name
36 | self.language_code = language_code
37 | self.model: Union[LateInteractionModel, None] = ColBERT(
38 | pretrained_model_name_or_path=pretrained_model_name,
39 | n_gpu=n_usable_gpus,
40 | training_mode=True,
41 | )
42 | self.negative_miner: Union[HardNegativeMiner, None] = None
43 | self.collection: list[str] = []
44 | self.queries: Union[list[str], None] = None
45 | self.raw_data: Union[list[tuple], list[list], None] = None
46 | self.training_triplets: list[list[int]] = list()
47 |
48 | def add_documents(self, documents: list[str]):
49 | self.collection += documents
50 | seeded_shuffle(self.collection)
51 |
52 | def export_training_data(self, path: Union[str, Path]):
53 | """
54 | Manually export the training data processed by prepare_training_data to a given path.
55 |
56 | Parameters:
57 | path: Union[str, Path] - Path to the directory where the data will be exported."""
58 | self.data_processor.export_training_data(path)
59 |
60 | def _add_to_collection(self, content: Union[str, list, dict]):
61 | if isinstance(content, str):
62 | self.collection.append(content)
63 | elif isinstance(content, list):
64 | self.collection += [txt for txt in content if isinstance(txt, str)]
65 | elif isinstance(content, dict):
66 | self.collection += [content["content"]]
67 |
68 | def prepare_training_data(
69 | self,
70 | raw_data: Union[list[tuple], list[list]],
71 | all_documents: Optional[list[str]] = None,
72 | data_out_path: Union[str, Path] = "./data/",
73 | num_new_negatives: int = 10,
74 | hard_negative_minimum_rank: int = 10,
75 | mine_hard_negatives: bool = True,
76 | hard_negative_model_size: str = "small",
77 | pairs_with_labels: bool = False,
78 | positive_label: Union[int, str] = 1,
79 | negative_label: Union[int, str] = 0,
80 | ) -> str:
81 | """
82 | Fully pre-process input-data in various raw formats into ColBERT-ready files and triplets.
83 | Will accept a variety of formats, such as unannotated pairs, annotated pairs, triplets of strings and triplets of list of strings.
84 | Will process into a ColBERT-ready format and export to data_out_path.
85 | Will generate hard negatives if mine_hard_negatives is True.
86 | num_new_negatives decides how many negatives will be generated. if mine_hard_negatives is False and num_new_negatives is > 0, these negatives will be randomly sampled.
87 |
88 | Parameters:
89 | raw_data: Union[list[tuple], list[list]] - List of pairs, annotated pairs, or triplets of strings.
90 | all_documents: Optional[list[str]] - A corpus of documents to be used for sampling negatives.
91 | data_out_path: Union[str, Path] - Path to the directory where the data will be exported (can be a tmp directory).
92 | num_new_negatives: int - Number of new negatives to generate for each query.
93 | mine_hard_negatives: bool - Whether to use hard negatives mining or not.
94 | hard_negative_model_size: str - Size of the model to use for hard negatives mining.
95 | pairs_with_labels: bool - Whether the raw_data is a list of pairs with labels or not.
96 | positive_label: Union[int, str] - Label to use for positive pairs.
97 | negative_label: Union[int, str] - Label to use for negative pairs.
98 |
99 | Returns:
100 | data_out_path: Union[str, Path] - Path to the directory where the data has been exported.
101 | """
102 | if all_documents is not None:
103 | self.collection += [doc for doc in all_documents if isinstance(doc, str)]
104 |
105 | self.data_dir = Path(data_out_path)
106 | sample = raw_data[0]
107 | if len(sample) == 2:
108 | data_type = "pairs"
109 | elif len(sample) == 3:
110 | if pairs_with_labels:
111 | data_type = "labeled_pairs"
112 | if sample[-1] not in [positive_label, negative_label]:
113 | raise ValueError(f"Invalid value for label: {sample}")
114 | else:
115 | data_type = "triplets"
116 | else:
117 | raise ValueError("Raw data must be a list of pairs or triplets of strings.")
118 |
119 | self.queries = set()
120 | for x in raw_data:
121 | if isinstance(x[0], str):
122 | self.queries.add(x[0])
123 | else:
124 | raise ValueError("Queries must be a strings.")
125 |
126 | self._add_to_collection(x[1])
127 |
128 | if data_type == "triplets":
129 | self._add_to_collection(x[2])
130 |
131 | self.collection = list(set(self.collection))
132 | seeded_shuffle(self.collection)
133 |
134 | if mine_hard_negatives:
135 | self.negative_miner = SimpleMiner(
136 | language_code=self.language_code,
137 | model_size=hard_negative_model_size,
138 | )
139 | self.negative_miner.build_index(self.collection)
140 |
141 | self.data_processor = TrainingDataProcessor(
142 | collection=self.collection,
143 | queries=self.queries,
144 | negative_miner=self.negative_miner if mine_hard_negatives else None,
145 | )
146 |
147 | self.data_processor.process_raw_data(
148 | data_type=data_type,
149 | raw_data=raw_data,
150 | export=True,
151 | data_dir=data_out_path,
152 | num_new_negatives=num_new_negatives,
153 | positive_label=positive_label,
154 | negative_label=negative_label,
155 | mine_hard_negatives=mine_hard_negatives,
156 | hard_negative_minimum_rank=hard_negative_minimum_rank,
157 | )
158 | if len(self.data_processor.training_triplets) == 0:
159 | if mine_hard_negatives:
160 | print(
161 | "Warning: No training triplets were generated with setting mine_hard_negatives=='True'. This may be due to the data being too small or the hard negative miner not being able to find enough hard negatives."
162 | )
163 | self.data_processor.process_raw_data(
164 | data_type=data_type,
165 | raw_data=raw_data,
166 | export=True,
167 | data_dir=data_out_path,
168 | num_new_negatives=num_new_negatives,
169 | positive_label=positive_label,
170 | negative_label=negative_label,
171 | mine_hard_negatives=False,
172 | hard_negative_minimum_rank=hard_negative_minimum_rank,
173 | )
174 | else:
175 | raise ValueError("No training triplets were generated.")
176 |
177 | self.training_triplets = self.data_processor.training_triplets
178 |
179 | return data_out_path
180 |
181 | def train(
182 | self,
183 | batch_size: int = 32,
184 | nbits: int = 2,
185 | maxsteps: int = 500_000,
186 | use_ib_negatives: bool = True,
187 | learning_rate: float = 5e-6,
188 | dim: int = 128,
189 | doc_maxlen: int = 256,
190 | use_relu: bool = False,
191 | warmup_steps: Union[int, Literal["auto"]] = "auto",
192 | accumsteps: int = 1,
193 | ) -> str:
194 | """
195 | Launch training or fine-tuning of a ColBERT model.
196 |
197 | Parameters:
198 | batch_size: int - Total batch size -- divice by n_usable_gpus for per-GPU batch size.
199 | nbits: int - number of bits used for vector compression by the traiened model. 2 is usually ideal.
200 | maxsteps: int - End training early after maxsteps steps.
201 | use_ib_negatives: bool - Whether to use in-batch negatives to calculate loss or not.
202 | learning_rate: float - ColBERT litterature usually has this performing best between 3e-6 - 2e-5 depending on data size
203 | dim: int - Size of individual vector representations.
204 | doc_maxlen: int - The maximum length after which passages will be truncated
205 | warmup_steps: Union[int, Literal["auto"]] - How many warmup steps to use for the learning rate.
206 | Auto will default to 10% of total steps
207 | accumsteps: How many gradient accummulation steps to use to simulate higher batch sizes.
208 |
209 | Returns:
210 | model_path: str - Path to the trained model.
211 | """
212 | if not self.training_triplets:
213 | total_triplets = sum(
214 | 1 for _ in open(str(self.data_dir / "triples.train.colbert.jsonl"), "r")
215 | )
216 | else:
217 | total_triplets = len(self.training_triplets)
218 |
219 | training_config = ColBERTConfig(
220 | bsize=batch_size,
221 | model_name=self.model_name,
222 | name=self.model_name,
223 | checkpoint=self.pretrained_model_name,
224 | use_ib_negatives=use_ib_negatives,
225 | maxsteps=maxsteps,
226 | nbits=nbits,
227 | lr=learning_rate,
228 | dim=dim,
229 | doc_maxlen=doc_maxlen,
230 | relu=use_relu,
231 | accumsteps=accumsteps,
232 | warmup=int(total_triplets // batch_size * 0.1)
233 | if warmup_steps == "auto"
234 | else warmup_steps,
235 | save_every=int(total_triplets // batch_size // 10),
236 | )
237 |
238 | return self.model.train(data_dir=self.data_dir, training_config=training_config)
239 |
--------------------------------------------------------------------------------
/ragatouille/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | _FUTURE_MIGRATION_WARNING_MESSAGE = (
4 | "\n********************************************************************************\n"
5 | "RAGatouille WARNING: Future Release Notice\n"
6 | "--------------------------------------------\n"
7 | "RAGatouille version 0.0.10 will be migrating to a PyLate backend \n"
8 | "instead of the current Stanford ColBERT backend.\n"
9 | "PyLate is a fully mature, feature-equivalent backend, that greatly facilitates compatibility.\n"
10 | "However, please pin version <0.0.10 if you require the Stanford ColBERT backend.\n"
11 | "********************************************************************************"
12 | )
13 |
14 | warnings.warn(
15 | _FUTURE_MIGRATION_WARNING_MESSAGE,
16 | UserWarning,
17 | stacklevel=2 # Ensures the warning points to the user's import line
18 | )
19 |
20 | __version__ = "0.0.9post2"
21 | from .RAGPretrainedModel import RAGPretrainedModel
22 | from .RAGTrainer import RAGTrainer
23 |
24 | __all__ = ["RAGPretrainedModel", "RAGTrainer"]
--------------------------------------------------------------------------------
/ragatouille/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .corpus_processor import CorpusProcessor
2 | from .preprocessors import llama_index_sentence_splitter
3 | from .training_data_processor import TrainingDataProcessor
4 |
5 | __all__ = [
6 | "TrainingDataProcessor",
7 | "CorpusProcessor",
8 | "llama_index_sentence_splitter",
9 | ]
10 |
--------------------------------------------------------------------------------
/ragatouille/data/corpus_processor.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Union
2 | from uuid import uuid4
3 |
4 | from ragatouille.data.preprocessors import llama_index_sentence_splitter
5 |
6 |
7 | class CorpusProcessor:
8 | def __init__(
9 | self,
10 | document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
11 | preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
12 | ):
13 | self.document_splitter_fn = document_splitter_fn
14 | self.preprocessing_fn = preprocessing_fn
15 |
16 | def process_corpus(
17 | self,
18 | documents: list[str],
19 | document_ids: Optional[list[str]] = None,
20 | **splitter_kwargs,
21 | ) -> List[dict]:
22 | # TODO CHECK KWARGS
23 | document_ids = (
24 | [str(uuid4()) for _ in range(len(documents))]
25 | if document_ids is None
26 | else document_ids
27 | )
28 | if self.document_splitter_fn is not None:
29 | documents = self.document_splitter_fn(
30 | documents, document_ids, **splitter_kwargs
31 | )
32 | if self.preprocessing_fn is not None:
33 | if isinstance(self.preprocessing_fn, list):
34 | for fn in self.preprocessing_fn:
35 | documents = fn(documents, document_ids)
36 | return documents
37 | return self.preprocessing_fn(documents, document_ids)
38 | return documents
39 |
--------------------------------------------------------------------------------
/ragatouille/data/preprocessors.py:
--------------------------------------------------------------------------------
1 | try:
2 | from llama_index import Document
3 | from llama_index.text_splitter import SentenceSplitter
4 | except ImportError:
5 | from llama_index.core import Document
6 | from llama_index.core.text_splitter import SentenceSplitter
7 |
8 |
9 | def llama_index_sentence_splitter(
10 | documents: list[str], document_ids: list[str], chunk_size=256
11 | ):
12 | chunk_overlap = min(chunk_size / 4, min(chunk_size / 2, 64))
13 | chunks = []
14 | node_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
15 | docs = [[Document(text=doc)] for doc in documents]
16 | for doc_id, doc in zip(document_ids, docs):
17 | chunks += [
18 | {"document_id": doc_id, "content": node.text} for node in node_parser(doc)
19 | ]
20 | return chunks
21 |
--------------------------------------------------------------------------------
/ragatouille/data/training_data_processor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from collections import defaultdict
4 | from itertools import product
5 | from pathlib import Path
6 | from typing import Literal, Union
7 |
8 | import srsly
9 |
10 |
11 | class TrainingDataProcessor:
12 | def __init__(
13 | self,
14 | collection: list[str],
15 | queries: list[str],
16 | negative_miner=None,
17 | ):
18 | self.collection = collection
19 | self.queries = queries
20 | self.negative_miner = negative_miner
21 | self._make_data_map()
22 | self.training_triplets = []
23 |
24 | def process_raw_data(
25 | self,
26 | raw_data,
27 | data_type: Literal["pairs", "triplets", "labeled_pairs"],
28 | data_dir: Union[str, Path],
29 | export: bool = True,
30 | mine_hard_negatives: bool = True,
31 | num_new_negatives: int = 10,
32 | positive_label: int = 1,
33 | negative_label: int = 0,
34 | hard_negative_minimum_rank: int = 10,
35 | ):
36 | if self.negative_miner is None and mine_hard_negatives:
37 | raise ValueError(
38 | "mine_hard_negatives is True but no negative miner was provided!"
39 | )
40 | if self.negative_miner:
41 | self.negative_miner.min_rank = hard_negative_minimum_rank
42 | if data_type == "pairs":
43 | self._process_raw_pairs(
44 | raw_data=raw_data,
45 | mine_hard_negatives=mine_hard_negatives,
46 | n_new_negatives=num_new_negatives,
47 | )
48 | elif data_type == "labeled_pairs":
49 | self._process_raw_labeled_pairs(
50 | raw_data=raw_data,
51 | mine_hard_negatives=mine_hard_negatives,
52 | n_new_negatives=num_new_negatives,
53 | positive_label=positive_label,
54 | negative_label=negative_label,
55 | )
56 | elif data_type == "triplets":
57 | self._process_raw_triplets(
58 | raw_data=raw_data,
59 | mine_hard_negatives=mine_hard_negatives,
60 | n_new_negatives=num_new_negatives,
61 | )
62 |
63 | if export:
64 | self.export_training_data(data_dir)
65 |
66 | def _make_individual_triplets(self, query, positives, negatives):
67 | """Create the training data in ColBERT(v1) format from raw lists of triplets"""
68 | if len(positives) == 0 or len(negatives) == 0:
69 | return []
70 | triplets = []
71 | q = self.query_map[query]
72 | random.seed(42)
73 | if len(positives) > 1:
74 | all_pos_texts = [p for p in positives]
75 | max_triplets_per_query = 20
76 | negs_per_positive = max(1, max_triplets_per_query // len(all_pos_texts))
77 | initial_triplets_count = 0
78 | for pos in all_pos_texts:
79 | p = self.passage_map[pos]
80 | chosen_negs = random.sample(
81 | negatives, min(len(negatives), negs_per_positive)
82 | )
83 | for neg in chosen_negs:
84 | n = self.passage_map[neg]
85 | initial_triplets_count += 1
86 | triplets.append([q, p, n])
87 |
88 | extra_triplets_needed = max_triplets_per_query - initial_triplets_count
89 | if extra_triplets_needed > 0:
90 | all_combinations = list(product(all_pos_texts, negatives))
91 | random.seed(42)
92 | random.shuffle(all_combinations)
93 | for pos, neg in all_combinations:
94 | p = self.passage_map[pos]
95 | n = self.passage_map[neg]
96 | if [q, p, n] not in triplets:
97 | triplets.append([q, p, n])
98 | extra_triplets_needed -= 1
99 | if extra_triplets_needed <= 0:
100 | break
101 |
102 | else:
103 | p = self.passage_map[positives[0]]
104 | for n in negatives:
105 | triplets.append([q, p, self.passage_map[n]])
106 |
107 | return triplets
108 |
109 | def _get_new_negatives(self, query, passages, mine_hard_negatives, n_new_negatives):
110 | """Generate new negatives for each query, using either:
111 | - The assigned hard negative miner if mine_hard_negatives is True
112 | - Randomly sampling from the full collection otherwise
113 | """
114 | if mine_hard_negatives:
115 | hard_negatives = self.negative_miner.mine_hard_negatives(query)
116 | candidates = [
117 | x
118 | for x in hard_negatives
119 | if x not in passages["positives"] and x not in passages["negatives"]
120 | ]
121 | new_negatives = random.sample(
122 | candidates,
123 | min(n_new_negatives, len(candidates)),
124 | )
125 | else:
126 | new_negatives = [
127 | x
128 | for x in random.sample(
129 | self.collection, min(n_new_negatives, len(self.collection))
130 | )
131 | if x not in passages["positives"] and x not in passages["negatives"]
132 | ]
133 |
134 | return new_negatives
135 |
136 | def _process_raw_pairs(self, raw_data, mine_hard_negatives, n_new_negatives):
137 | """Convert unlabeled pairs into training triplets.
138 | It's assumed unlabeled pairs are always in the format (query, relevant_passage)"""
139 | training_triplets = []
140 | raw_grouped_triplets = defaultdict(lambda: defaultdict(list))
141 |
142 | for query, positive in raw_data:
143 | if isinstance(positive, str):
144 | positive = [positive]
145 | elif isinstance(positive, dict):
146 | positive = [positive["content"]]
147 | raw_grouped_triplets[query]["positives"] += positive
148 |
149 | for query, passages in raw_grouped_triplets.items():
150 | if n_new_negatives > 0:
151 | passages["negatives"] += self._get_new_negatives(
152 | query=query,
153 | passages=passages,
154 | mine_hard_negatives=mine_hard_negatives,
155 | n_new_negatives=n_new_negatives,
156 | )
157 | training_triplets += self._make_individual_triplets(
158 | query=query,
159 | positives=list(set(passages["positives"])),
160 | negatives=list(set(passages["negatives"])),
161 | )
162 | self.training_triplets = training_triplets
163 |
164 | def _process_raw_labeled_pairs(
165 | self,
166 | raw_data,
167 | mine_hard_negatives,
168 | n_new_negatives,
169 | positive_label,
170 | negative_label,
171 | ):
172 | """
173 | Convert labeled pairs intro training triplets.
174 | Labeled pairs are in the format (query, passage, label)
175 | """
176 | training_triplets = []
177 | raw_grouped_triplets = defaultdict(lambda: defaultdict(list))
178 |
179 | for query, passage, label in raw_data:
180 | if isinstance(passage, str):
181 | passage = [passage]
182 | if label == positive_label:
183 | label = "positives"
184 | elif label == negative_label:
185 | label = "negatives"
186 | else:
187 | raise ValueError(
188 | f"Label {label} must correspond to either positive_label or negative_label!"
189 | )
190 |
191 | raw_grouped_triplets[query][label] += passage
192 |
193 | for query, passages in raw_grouped_triplets.items():
194 | if n_new_negatives > 0:
195 | passages["negatives"] += self._get_new_negatives(
196 | query=query,
197 | passages=passages,
198 | mine_hard_negatives=mine_hard_negatives,
199 | n_new_negatives=n_new_negatives,
200 | )
201 |
202 | training_triplets += self._make_individual_triplets(
203 | query=query,
204 | positives=passages["positives"],
205 | negatives=passages["negatives"],
206 | )
207 | self.training_triplets = training_triplets
208 |
209 | def _process_raw_triplets(self, raw_data, mine_hard_negatives, n_new_negatives):
210 | """
211 | Convert raw triplets
212 | (query, positives : str | list[str], negatives: str | list[str])
213 | into training triplets.
214 | """
215 | training_triplets = []
216 | raw_grouped_triplets = defaultdict(lambda: defaultdict(list))
217 | for query, positive, negative in raw_data:
218 | if isinstance(positive, str):
219 | positive = [positive]
220 | if isinstance(negative, str):
221 | negative = [negative]
222 |
223 | raw_grouped_triplets[query]["positives"] += positive
224 | raw_grouped_triplets[query]["negatives"] += negative
225 |
226 | for query, passages in raw_grouped_triplets.items():
227 | if n_new_negatives > 0:
228 | passages["negatives"] += self._get_new_negatives(
229 | query=query,
230 | passages=passages,
231 | mine_hard_negatives=mine_hard_negatives,
232 | n_new_negatives=n_new_negatives,
233 | )
234 | training_triplets += self._make_individual_triplets(
235 | query=query,
236 | positives=passages["positives"],
237 | negatives=passages["negatives"],
238 | )
239 | self.training_triplets = training_triplets
240 |
241 | def _make_data_map(self):
242 | """
243 | Generate a query_text: query_id and passage_text: passage_id mapping
244 | To easily generate ColBERT-format training data.
245 | """
246 | self.query_map = {}
247 | self.passage_map = {}
248 |
249 | for i, query in enumerate(self.queries):
250 | self.query_map[query] = i
251 | for i, passage in enumerate(list(self.collection)):
252 | self.passage_map[passage] = i
253 |
254 | def export_training_data(self, path: Union[str, Path]):
255 | """
256 | Export training data for both training and versioning purposes.
257 | {path} should ideally be dvc versioned.
258 | """
259 |
260 | path = Path(path)
261 |
262 | # Create the directory if it does not exist
263 | os.makedirs(path, exist_ok=True)
264 |
265 | with open(path / "queries.train.colbert.tsv", "w") as f:
266 | for query, idx in self.query_map.items():
267 | query = query.replace("\t", " ").replace("\n", " ")
268 | f.write(f"{idx}\t{query}\n")
269 | with open(path / "corpus.train.colbert.tsv", "w") as f:
270 | for document, idx in self.passage_map.items():
271 | document = document.replace("\t", " ").replace("\n", " ")
272 | f.write(f"{idx}\t{document}\n")
273 |
274 | random.seed(42)
275 | random.shuffle(self.training_triplets)
276 | srsly.write_jsonl(path / "triples.train.colbert.jsonl", self.training_triplets)
277 |
--------------------------------------------------------------------------------
/ragatouille/integrations/__init__.py:
--------------------------------------------------------------------------------
1 | from ragatouille.integrations._langchain import (
2 | RAGatouilleLangChainCompressor,
3 | RAGatouilleLangChainRetriever,
4 | )
5 |
6 | __all__ = ["RAGatouilleLangChainRetriever", "RAGatouilleLangChainCompressor"]
7 |
--------------------------------------------------------------------------------
/ragatouille/integrations/_langchain.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional, Sequence
2 |
3 | from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
4 | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
5 | from langchain_core.documents import Document
6 | from langchain_core.retrievers import BaseRetriever
7 |
8 |
9 | class RAGatouilleLangChainRetriever(BaseRetriever):
10 | model: Any
11 | kwargs: dict = {}
12 |
13 | def _get_relevant_documents(
14 | self,
15 | query: str,
16 | *,
17 | run_manager: CallbackManagerForRetrieverRun, # noqa
18 | ) -> List[Document]:
19 | """Get documents relevant to a query."""
20 | docs = self.model.search(query, **self.kwargs)
21 | return [
22 | Document(
23 | page_content=doc["content"], metadata=doc.get("document_metadata", {})
24 | )
25 | for doc in docs
26 | ]
27 |
28 |
29 | class RAGatouilleLangChainCompressor(BaseDocumentCompressor):
30 | model: Any
31 | kwargs: dict = {}
32 | k: int = 5
33 |
34 | class Config:
35 | """Configuration for this pydantic object."""
36 |
37 | arbitrary_types_allowed = True
38 |
39 | def compress_documents(
40 | self,
41 | documents: Sequence[Document],
42 | query: str,
43 | callbacks: Optional[Callbacks] = None, # noqa
44 | **kwargs,
45 | ) -> Any:
46 | """Rerank a list of documents relevant to a query."""
47 | doc_list = list(documents)
48 | _docs = [d.page_content for d in doc_list]
49 | results = self.model.rerank(
50 | query=query,
51 | documents=_docs,
52 | k=kwargs.get("k", self.k),
53 | **self.kwargs,
54 | )
55 | final_results = []
56 | for r in results:
57 | doc = doc_list[r["result_index"]]
58 | doc.metadata["relevance_score"] = r["score"]
59 | final_results.append(doc)
60 | return final_results
61 |
--------------------------------------------------------------------------------
/ragatouille/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import LateInteractionModel
2 | from .colbert import ColBERT
3 |
4 | __all__ = ["LateInteractionModel", "ColBERT"]
5 |
--------------------------------------------------------------------------------
/ragatouille/models/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from typing import Union
4 |
5 |
6 | class LateInteractionModel(ABC):
7 | @abstractmethod
8 | def __init__(
9 | self,
10 | pretrained_model_name_or_path: Union[str, Path],
11 | n_gpu,
12 | ):
13 | ...
14 |
15 | @abstractmethod
16 | def train():
17 | ...
18 |
19 | @abstractmethod
20 | def index(self, name: str, collection: list[str]):
21 | ...
22 |
23 | @abstractmethod
24 | def add_to_index(self):
25 | ...
26 |
27 | @abstractmethod
28 | def search(self, name: str, query: Union[str, list[str]]):
29 | ...
30 |
31 | @abstractmethod
32 | def _search(self, name: str, query: str):
33 | ...
34 |
35 | @abstractmethod
36 | def _batch_search(self, name: str, queries: list[str]):
37 | ...
38 |
--------------------------------------------------------------------------------
/ragatouille/models/index.py:
--------------------------------------------------------------------------------
1 | import time
2 | from abc import ABC, abstractmethod
3 | from copy import deepcopy
4 | from pathlib import Path
5 | from typing import Any, List, Literal, Optional, TypeVar, Union
6 |
7 | import srsly
8 | import torch
9 | from colbert import Indexer, IndexUpdater, Searcher
10 | from colbert.indexing.collection_indexer import CollectionIndexer
11 | from colbert.infra import ColBERTConfig
12 |
13 | from ragatouille.models import torch_kmeans
14 |
15 | IndexType = Literal["FLAT", "HNSW", "PLAID"]
16 |
17 |
18 | class ModelIndex(ABC):
19 | index_type: IndexType
20 |
21 | def __init__(
22 | self,
23 | config: ColBERTConfig,
24 | ) -> None:
25 | self.config = config
26 |
27 | @staticmethod
28 | @abstractmethod
29 | def construct(
30 | config: ColBERTConfig,
31 | checkpoint: str,
32 | collection: List[str],
33 | index_name: Optional["str"] = None,
34 | overwrite: Union[bool, str] = "reuse",
35 | verbose: bool = True,
36 | **kwargs,
37 | ) -> "ModelIndex": ...
38 |
39 | @staticmethod
40 | @abstractmethod
41 | def load_from_file(
42 | index_path: str,
43 | index_name: Optional[str],
44 | index_config: dict[str, Any],
45 | config: ColBERTConfig,
46 | verbose: bool = True,
47 | ) -> "ModelIndex": ...
48 |
49 | @abstractmethod
50 | def build(
51 | self,
52 | checkpoint: Union[str, Path],
53 | collection: List[str],
54 | index_name: Optional["str"] = None,
55 | overwrite: Union[bool, str] = "reuse",
56 | verbose: bool = True,
57 | ) -> None: ...
58 |
59 | @abstractmethod
60 | def search(
61 | self,
62 | config: ColBERTConfig,
63 | checkpoint: Union[str, Path],
64 | collection: List[str],
65 | index_name: Optional[str],
66 | base_model_max_tokens: int,
67 | query: Union[str, list[str]],
68 | k: int = 10,
69 | pids: Optional[List[int]] = None,
70 | force_reload: bool = False,
71 | **kwargs,
72 | ) -> list[tuple[list, list, list]]: ...
73 |
74 | @abstractmethod
75 | def _search(self, query: str, k: int, pids: Optional[List[int]] = None): ...
76 |
77 | @abstractmethod
78 | def _batch_search(self, query: list[str], k: int): ...
79 |
80 | @abstractmethod
81 | def add(
82 | self,
83 | config: ColBERTConfig,
84 | checkpoint: Union[str, Path],
85 | collection: List[str],
86 | index_root: str,
87 | index_name: str,
88 | new_collection: List[str],
89 | verbose: bool = True,
90 | **kwargs,
91 | ) -> None: ...
92 |
93 | @abstractmethod
94 | def delete(
95 | self,
96 | config: ColBERTConfig,
97 | checkpoint: Union[str, Path],
98 | collection: List[str],
99 | index_name: str,
100 | pids_to_remove: Union[TypeVar("T"), List[TypeVar("T")]],
101 | verbose: bool = True,
102 | ) -> None: ...
103 |
104 | @abstractmethod
105 | def _export_config(self) -> dict[str, Any]: ...
106 |
107 | def export_metadata(self) -> dict[str, Any]:
108 | config = self._export_config()
109 | config["index_type"] = self.index_type
110 | return config
111 |
112 |
113 | class FLATModelIndex(ModelIndex):
114 | index_type = "FLAT"
115 |
116 |
117 | class HNSWModelIndex(ModelIndex):
118 | index_type = "HNSW"
119 |
120 |
121 | class PLAIDModelIndex(ModelIndex):
122 | _DEFAULT_INDEX_BSIZE = 32
123 | index_type = "PLAID"
124 | faiss_kmeans = staticmethod(deepcopy(CollectionIndexer._train_kmeans))
125 | pytorch_kmeans = staticmethod(torch_kmeans._train_kmeans)
126 |
127 | def __init__(self, config: ColBERTConfig) -> None:
128 | super().__init__(config)
129 | self.searcher: Optional[Searcher] = None
130 |
131 | @staticmethod
132 | def construct(
133 | config: ColBERTConfig,
134 | checkpoint: Union[str, Path],
135 | collection: List[str],
136 | index_name: Optional["str"] = None,
137 | overwrite: Union[bool, str] = "reuse",
138 | verbose: bool = True,
139 | **kwargs,
140 | ) -> "PLAIDModelIndex":
141 | return PLAIDModelIndex(config).build(
142 | checkpoint, collection, index_name, overwrite, verbose, **kwargs
143 | )
144 |
145 | @staticmethod
146 | def load_from_file(
147 | index_path: str,
148 | index_name: Optional[str],
149 | index_config: dict[str, Any],
150 | config: ColBERTConfig,
151 | verbose: bool = True,
152 | ) -> "PLAIDModelIndex":
153 | _, _, _, _ = index_path, index_name, index_config, verbose
154 | return PLAIDModelIndex(config)
155 |
156 | def build(
157 | self,
158 | checkpoint: Union[str, Path],
159 | collection: List[str],
160 | index_name: Optional["str"] = None,
161 | overwrite: Union[bool, str] = "reuse",
162 | verbose: bool = True,
163 | **kwargs,
164 | ) -> "PLAIDModelIndex":
165 | bsize = kwargs.get("bsize", PLAIDModelIndex._DEFAULT_INDEX_BSIZE)
166 | assert isinstance(bsize, int)
167 |
168 | nbits = 2
169 | if len(collection) < 10000:
170 | nbits = 4
171 | self.config = ColBERTConfig.from_existing(
172 | self.config, ColBERTConfig(nbits=nbits, index_bsize=bsize)
173 | )
174 |
175 | # Instruct colbert-ai to disable forking if nranks == 1
176 | self.config.avoid_fork_if_possible = True
177 |
178 | if len(collection) > 100000:
179 | self.config.kmeans_niters = 4
180 | elif len(collection) > 50000:
181 | self.config.kmeans_niters = 10
182 | else:
183 | self.config.kmeans_niters = 20
184 |
185 | # Monkey-patch colbert-ai to avoid using FAISS
186 | monkey_patching = (
187 | len(collection) < 75000 and kwargs.get("use_faiss", False) is False
188 | )
189 | if monkey_patching:
190 | print(
191 | "---- WARNING! You are using PLAID with an experimental replacement for FAISS for greater compatibility ----"
192 | )
193 | print("This is a behaviour change from RAGatouille 0.8.0 onwards.")
194 | print(
195 | "This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations."
196 | )
197 | print(
198 | "If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour."
199 | )
200 | print("--------------------")
201 | CollectionIndexer._train_kmeans = self.pytorch_kmeans
202 |
203 | # Try to keep runtime stable -- these are values that empirically didn't degrade performance at all on 3 benchmarks.
204 | # More tests required before warning can be removed.
205 | try:
206 | indexer = Indexer(
207 | checkpoint=checkpoint,
208 | config=self.config,
209 | verbose=verbose,
210 | )
211 | indexer.configure(avoid_fork_if_possible=True)
212 | indexer.index(
213 | name=index_name, collection=collection, overwrite=overwrite
214 | )
215 | except Exception as err:
216 | print(
217 | f"PyTorch-based indexing did not succeed with error: {err}",
218 | "! Reverting to using FAISS and attempting again...",
219 | )
220 | monkey_patching = False
221 | if monkey_patching is False:
222 | CollectionIndexer._train_kmeans = self.faiss_kmeans
223 | if torch.cuda.is_available():
224 | import faiss
225 |
226 | if not hasattr(faiss, "StandardGpuResources"):
227 | print(
228 | "________________________________________________________________________________\n"
229 | "WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
230 | "This means that indexing will be slow. To make use of your GPU.\n"
231 | "Please install `faiss-gpu` by running:\n"
232 | "pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
233 | "________________________________________________________________________________",
234 | )
235 | print("Will continue with CPU indexing in 5 seconds...")
236 | time.sleep(5)
237 | indexer = Indexer(
238 | checkpoint=checkpoint,
239 | config=self.config,
240 | verbose=verbose,
241 | )
242 | indexer.configure(avoid_fork_if_possible=True)
243 | indexer.index(name=index_name, collection=collection, overwrite=overwrite)
244 |
245 | return self
246 |
247 | def _load_searcher(
248 | self,
249 | checkpoint: Union[str, Path],
250 | collection: List[str],
251 | index_name: Optional[str],
252 | force_fast: bool = False,
253 | ):
254 | print(
255 | f"Loading searcher for index {index_name} for the first time...",
256 | "This may take a few seconds",
257 | )
258 | self.searcher = Searcher(
259 | checkpoint=checkpoint,
260 | config=None,
261 | collection=collection,
262 | index_root=self.config.root,
263 | index=index_name,
264 | )
265 |
266 | if not force_fast:
267 | self.searcher.configure(ndocs=1024)
268 | self.searcher.configure(ncells=16)
269 | if len(self.searcher.collection) < 10000:
270 | self.searcher.configure(ncells=8)
271 | self.searcher.configure(centroid_score_threshold=0.4)
272 | elif len(self.searcher.collection) < 100000:
273 | self.searcher.configure(ncells=4)
274 | self.searcher.configure(centroid_score_threshold=0.45)
275 | # Otherwise, use defaults for k
276 | else:
277 | # Use fast settingss
278 | self.searcher.configure(ncells=1)
279 | self.searcher.configure(centroid_score_threshold=0.5)
280 | self.searcher.configure(ndocs=256)
281 |
282 | print("Searcher loaded!")
283 |
284 | def _search(self, query: str, k: int, pids: Optional[List[int]] = None):
285 | assert self.searcher is not None
286 | return self.searcher.search(query, k=k, pids=pids)
287 |
288 | def _batch_search(self, query: list[str], k: int):
289 | assert self.searcher is not None
290 | queries = {i: x for i, x in enumerate(query)}
291 | results = self.searcher.search_all(queries, k=k)
292 | results = [
293 | [list(zip(*value))[i] for i in range(3)]
294 | for value in results.todict().values()
295 | ]
296 | return results
297 |
298 | def _upgrade_searcher_maxlen(self, maxlen: int, base_model_max_tokens: int):
299 | assert self.searcher is not None
300 | # Keep maxlen stable at 32 for short queries for easier visualisation
301 | maxlen = min(max(maxlen, 32), base_model_max_tokens)
302 | self.searcher.config.query_maxlen = maxlen
303 | self.searcher.checkpoint.query_tokenizer.query_maxlen = maxlen
304 |
305 | def search(
306 | self,
307 | config: ColBERTConfig,
308 | checkpoint: Union[str, Path],
309 | collection: List[str],
310 | index_name: Optional[str],
311 | base_model_max_tokens: int,
312 | query: Union[str, list[str]],
313 | k: int = 10,
314 | pids: Optional[List[int]] = None,
315 | force_reload: bool = False,
316 | **kwargs,
317 | ) -> list[tuple[list, list, list]]:
318 | self.config = config
319 |
320 | force_fast = kwargs.get("force_fast", False)
321 | assert isinstance(force_fast, bool)
322 |
323 | if self.searcher is None or force_reload:
324 | self._load_searcher(
325 | checkpoint,
326 | collection,
327 | index_name,
328 | force_fast,
329 | )
330 | assert self.searcher is not None
331 |
332 | base_ncells = self.searcher.config.ncells
333 | base_ndocs = self.searcher.config.ndocs
334 |
335 | if k > len(self.searcher.collection):
336 | print(
337 | "WARNING: k value is larger than the number of documents in the index!",
338 | f"Lowering k to {len(self.searcher.collection)}...",
339 | )
340 | k = len(self.searcher.collection)
341 |
342 | # For smaller collections, we need a higher ncells value to ensure we return enough results
343 | if k > (32 * self.searcher.config.ncells):
344 | self.searcher.configure(ncells=min((k // 32 + 2), base_ncells))
345 |
346 | self.searcher.configure(ndocs=max(k * 4, base_ndocs))
347 |
348 | if isinstance(query, str):
349 | query_length = int(len(query.split(" ")) * 1.35)
350 | self._upgrade_searcher_maxlen(query_length, base_model_max_tokens)
351 | results = [self._search(query, k, pids)]
352 | else:
353 | longest_query_length = max([int(len(x.split(" ")) * 1.35) for x in query])
354 | self._upgrade_searcher_maxlen(longest_query_length, base_model_max_tokens)
355 | results = self._batch_search(query, k)
356 |
357 | # Restore original ncells&ndocs if it had to be changed for large k values
358 | self.searcher.configure(ncells=base_ncells)
359 | self.searcher.configure(ndocs=base_ndocs)
360 |
361 | return results # type: ignore
362 |
363 | @staticmethod
364 | def _should_rebuild(current_len: int, new_doc_len: int) -> bool:
365 | """
366 | Heuristic to determine if it is more efficient to rebuild the index instead of updating it.
367 | """
368 | return current_len + new_doc_len < 5000 or new_doc_len > current_len * 0.05
369 |
370 | def add(
371 | self,
372 | config: ColBERTConfig,
373 | checkpoint: Union[str, Path],
374 | collection: List[str],
375 | index_root: str,
376 | index_name: str,
377 | new_collection: List[str],
378 | verbose: bool = True,
379 | **kwargs,
380 | ) -> None:
381 | self.config = config
382 |
383 | bsize = kwargs.get("bsize", PLAIDModelIndex._DEFAULT_INDEX_BSIZE)
384 | assert isinstance(bsize, int)
385 |
386 | searcher = Searcher(
387 | checkpoint=checkpoint,
388 | config=None,
389 | collection=collection,
390 | index=index_name,
391 | index_root=index_root,
392 | verbose=verbose,
393 | )
394 |
395 | if PLAIDModelIndex._should_rebuild(
396 | len(searcher.collection), len(new_collection)
397 | ):
398 | self.build(
399 | checkpoint=checkpoint,
400 | collection=collection + new_collection,
401 | index_name=index_name,
402 | overwrite="force_silent_overwrite",
403 | verbose=verbose,
404 | **kwargs,
405 | )
406 | else:
407 | if self.config.index_bsize != bsize: # Update bsize if it's different
408 | self.config.index_bsize = bsize
409 |
410 | updater = IndexUpdater(
411 | config=self.config, searcher=searcher, checkpoint=checkpoint
412 | )
413 | updater.add(new_collection)
414 | updater.persist_to_disk()
415 |
416 | def delete(
417 | self,
418 | config: ColBERTConfig,
419 | checkpoint: Union[str, Path],
420 | collection: List[str],
421 | index_name: str,
422 | pids_to_remove: Union[TypeVar("T"), List[TypeVar("T")]],
423 | verbose: bool = True,
424 | ) -> None:
425 | self.config = config
426 |
427 | # Initialize the searcher and updater
428 | searcher = Searcher(
429 | checkpoint=checkpoint,
430 | config=None,
431 | collection=collection,
432 | index=index_name,
433 | verbose=verbose,
434 | )
435 | updater = IndexUpdater(config=config, searcher=searcher, checkpoint=checkpoint)
436 |
437 | updater.remove(pids_to_remove)
438 | updater.persist_to_disk()
439 |
440 | def _export_config(self) -> dict[str, Any]:
441 | return {}
442 |
443 |
444 | class ModelIndexFactory:
445 | _MODEL_INDEX_BY_NAME = {
446 | "FLAT": FLATModelIndex,
447 | "HNSW": HNSWModelIndex,
448 | "PLAID": PLAIDModelIndex,
449 | }
450 |
451 | @staticmethod
452 | def _raise_if_invalid_index_type(index_type: str) -> IndexType:
453 | if index_type not in ["FLAT", "HNSW", "PLAID"]:
454 | raise ValueError(
455 | f"Unsupported index_type `{index_type}`; it must be one of 'FLAT', 'HNSW', OR 'PLAID'"
456 | )
457 | return index_type # type: ignore
458 |
459 | @staticmethod
460 | def construct(
461 | index_type: Union[Literal["auto"], IndexType],
462 | config: ColBERTConfig,
463 | checkpoint: str,
464 | collection: List[str],
465 | index_name: Optional["str"] = None,
466 | overwrite: Union[bool, str] = "reuse",
467 | verbose: bool = True,
468 | **kwargs,
469 | ) -> ModelIndex:
470 | # Automatically choose the appropriate index for the desired "workload".
471 | if index_type == "auto":
472 | # NOTE: For now only PLAID indexes are supported.
473 | index_type = "PLAID"
474 | return ModelIndexFactory._MODEL_INDEX_BY_NAME[
475 | ModelIndexFactory._raise_if_invalid_index_type(index_type)
476 | ].construct(
477 | config, checkpoint, collection, index_name, overwrite, verbose, **kwargs
478 | )
479 |
480 | @staticmethod
481 | def load_from_file(
482 | index_path: str,
483 | index_name: Optional[str],
484 | config: ColBERTConfig,
485 | verbose: bool = True,
486 | ) -> ModelIndex:
487 | metadata = srsly.read_json(index_path + "/metadata.json")
488 | try:
489 | index_config = metadata["RAGatouille"]["index_config"] # type: ignore
490 | except KeyError:
491 | if verbose:
492 | print(
493 | f"Constructing default index configuration for index `{index_name}` as it does not contain RAGatouille specific metadata."
494 | )
495 | index_config = {
496 | "index_type": "PLAID",
497 | "index_name": index_name,
498 | }
499 | index_name = (
500 | index_name if index_name is not None else index_config["index_name"] # type: ignore
501 | )
502 | return ModelIndexFactory._MODEL_INDEX_BY_NAME[
503 | ModelIndexFactory._raise_if_invalid_index_type(index_config["index_type"]) # type: ignore
504 | ].load_from_file(index_path, index_name, index_config, config, verbose)
505 |
--------------------------------------------------------------------------------
/ragatouille/models/torch_kmeans.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fast_pytorch_kmeans import KMeans
3 |
4 |
5 | def _train_kmeans(self, sample, shared_lists): # noqa: ARG001
6 | if self.use_gpu:
7 | torch.cuda.empty_cache()
8 | centroids = compute_pytorch_kmeans(
9 | sample,
10 | self.config.dim,
11 | self.num_partitions,
12 | self.config.kmeans_niters,
13 | self.use_gpu,
14 | )
15 | centroids = torch.nn.functional.normalize(centroids, dim=-1)
16 | if self.use_gpu:
17 | centroids = centroids.half()
18 | else:
19 | centroids = centroids.float()
20 | return centroids
21 |
22 |
23 | def compute_pytorch_kmeans(
24 | sample,
25 | dim, # noqa compatibility
26 | num_partitions,
27 | kmeans_niters,
28 | use_gpu,
29 | batch_size=16000,
30 | verbose=1,
31 | seed=123,
32 | max_points_per_centroid=256,
33 | min_points_per_centroid=10,
34 | ):
35 | device = torch.device("cuda" if use_gpu else "cpu")
36 | sample = sample.to(device)
37 | total_size = sample.shape[0]
38 |
39 | torch.manual_seed(seed)
40 |
41 | # Subsample the training set if too large
42 | if total_size > num_partitions * max_points_per_centroid:
43 | print("too many!")
44 | print("partitions:", num_partitions)
45 | print(total_size)
46 | perm = torch.randperm(total_size, device=device)[
47 | : num_partitions * max_points_per_centroid
48 | ]
49 | sample = sample[perm]
50 | total_size = sample.shape[0]
51 | print("reduced size:")
52 | print(total_size)
53 | elif total_size < num_partitions * min_points_per_centroid:
54 | if verbose:
55 | print(
56 | f"Warning: number of training points ({total_size}) is less than "
57 | f"the minimum recommended ({num_partitions * min_points_per_centroid})"
58 | )
59 |
60 | sample = sample.float()
61 | minibatch = None
62 | if num_partitions > 15000:
63 | minibatch = batch_size
64 | if num_partitions > 30000:
65 | minibatch = int(batch_size / 2)
66 |
67 | kmeans = KMeans(
68 | n_clusters=num_partitions,
69 | mode="euclidean",
70 | verbose=1,
71 | max_iter=kmeans_niters,
72 | minibatch=minibatch,
73 | )
74 | kmeans.fit(sample)
75 | return kmeans.centroids
76 |
77 |
78 | """ Archived homebrew quick implementation to be revisited. """
79 | # def compute_pytorch_kmeans(
80 | # sample,
81 | # dim,
82 | # num_partitions,
83 | # kmeans_niters,
84 | # use_gpu,
85 | # batch_size=8000,
86 | # tol=1e-4,
87 | # verbose=1,
88 | # seed=1234,
89 | # max_points_per_centroid=256,
90 | # min_points_per_centroid=10,
91 | # nredo=1,
92 | # ):
93 | # device = torch.device("cuda" if use_gpu else "cpu")
94 | # sample = sample.to(device)
95 | # total_size = sample.shape[0]
96 |
97 | # torch.manual_seed(seed)
98 |
99 | # # Subsample the training set if too large
100 | # if total_size > num_partitions * max_points_per_centroid:
101 | # print("too many!")
102 | # print("partitions:", num_partitions)
103 | # print(total_size)
104 | # perm = torch.randperm(total_size, device=device)[
105 | # : num_partitions * max_points_per_centroid
106 | # ]
107 | # sample = sample[perm]
108 | # total_size = sample.shape[0]
109 | # print("reduced size:")
110 | # print(total_size)
111 | # elif total_size < num_partitions * min_points_per_centroid:
112 | # if verbose:
113 | # print(
114 | # f"Warning: number of training points ({total_size}) is less than "
115 | # f"the minimum recommended ({num_partitions * min_points_per_centroid})"
116 | # )
117 |
118 | # sample = sample.float()
119 |
120 | # best_obj = float("inf")
121 | # best_centroids = None
122 |
123 | # for redo in range(nredo): # noqa
124 | # centroids = torch.randn(num_partitions, dim, dtype=sample.dtype, device=device)
125 | # centroids /= torch.norm(centroids, dim=1, keepdim=True)
126 |
127 | # with torch.no_grad():
128 | # for i in range(kmeans_niters):
129 | # if verbose >= 1:
130 | # print(f"KMEANS - Iteration {i+1} of {kmeans_niters}")
131 | # start_time = time.time()
132 | # obj = 0.0
133 | # counts = torch.zeros(num_partitions, dtype=torch.long, device=device)
134 |
135 | # # Process data in batches
136 | # if verbose >= 1:
137 | # _iterator = tqdm.tqdm(
138 | # range(0, total_size, batch_size),
139 | # total=math.ceil(total_size / batch_size),
140 | # desc="Batches for iteration",
141 | # )
142 | # else:
143 | # _iterator = range(0, total_size, batch_size)
144 | # for batch_start in _iterator:
145 | # batch_end = min(batch_start + batch_size, total_size)
146 | # batch = sample[batch_start:batch_end]
147 |
148 | # distances = torch.cdist(batch, centroids, p=2.0)
149 | # batch_assignments = torch.min(distances, dim=1)[1]
150 | # obj += torch.sum(
151 | # distances[torch.arange(batch.size(0)), batch_assignments]
152 | # )
153 |
154 | # counts.index_add_(
155 | # 0, batch_assignments, torch.ones_like(batch_assignments)
156 | # )
157 |
158 | # for j in range(num_partitions):
159 | # assigned_points = batch[batch_assignments == j]
160 | # if len(assigned_points) > 0:
161 | # centroids[j] += assigned_points.sum(dim=0)
162 |
163 | # # Handle empty clusters by assigning them a random data point from the largest cluster
164 | # empty_clusters = torch.where(counts == 0)[0]
165 | # if empty_clusters.numel() > 0:
166 | # for ec in empty_clusters:
167 | # largest_cluster = torch.argmax(counts)
168 | # idx = torch.randint(0, total_size, (1,), device=device)
169 | # counts[largest_cluster] -= 1
170 | # counts[ec] = 1
171 | # centroids[ec] = sample[idx]
172 |
173 | # centroids /= torch.norm(centroids, dim=1, keepdim=True)
174 |
175 | # if verbose >= 2:
176 | # print(
177 | # f"Iteration: {i+1}, Objective: {obj.item():.4f}, Time: {time.time() - start_time:.4f}s"
178 | # )
179 |
180 | # # Check for convergence
181 | # if obj < best_obj:
182 | # best_obj = obj
183 | # best_centroids = centroids.clone()
184 | # if obj <= tol:
185 | # break
186 |
187 | # torch.cuda.empty_cache() # Move outside the inner loop
188 |
189 | # if verbose >= 1:
190 | # print(f"Best objective: {best_obj.item():.4f}")
191 |
192 | # print(best_centroids)
193 |
194 | # return best_centroids
195 |
196 |
197 | """ Extremely slow implementation using voyager ANN below. CPU-only. Results ~= FAISS but would only be worth it if storing in int8, which might be for later."""
198 | # from voyager import Index, Space
199 |
200 | # def compute_pytorch_kmeans_via_voyager(
201 | # sample,
202 | # dim,
203 | # num_partitions,
204 | # kmeans_niters,
205 | # use_gpu,
206 | # batch_size=16000,
207 | # tol=1e-4,
208 | # verbose=3,
209 | # seed=1234,
210 | # max_points_per_centroid=256,
211 | # min_points_per_centroid=10,
212 | # nredo=1,
213 | # ):
214 | # device = torch.device("cuda" if use_gpu else "cpu")
215 | # total_size = sample.shape[0]
216 |
217 | # # Set random seed for reproducibility
218 | # torch.manual_seed(seed)
219 |
220 | # # Convert to float32 for better performance
221 | # sample = sample.float()
222 |
223 | # best_obj = float("inf")
224 | # best_centroids = None
225 |
226 | # for redo in range(nredo):
227 | # # Initialize centroids randomly
228 | # centroids = torch.randn(num_partitions, dim, dtype=sample.dtype, device=device)
229 | # centroids = centroids / torch.norm(centroids, dim=1, keepdim=True)
230 |
231 | # # Build Voyager index if the number of data points exceeds 128,000
232 | # # use_index = total_size > 128000
233 | # use_index = True
234 | # if use_index:
235 | # index = Index(Space.Euclidean, num_dimensions=dim)
236 | # index.add_items(centroids.cpu().numpy())
237 |
238 | # for i in range(kmeans_niters):
239 | # start_time = time.time()
240 | # obj = 0.0
241 | # counts = torch.zeros(num_partitions, dtype=torch.long, device=device)
242 |
243 | # # Process data in batches
244 | # for batch_start in range(0, total_size, batch_size):
245 | # batch_end = min(batch_start + batch_size, total_size)
246 | # batch = sample[batch_start:batch_end].to(device)
247 |
248 | # if use_index:
249 | # # Search for nearest centroids using Voyager index
250 | # batch_assignments, batch_distances = index.query(
251 | # batch.cpu().numpy(), k=1, num_threads=-1
252 | # )
253 | # batch_assignments = (
254 | # torch.from_numpy(batch_assignments.astype(np.int64))
255 | # .squeeze()
256 | # .to(device)
257 | # )
258 | # batch_assignments = batch_assignments.long()
259 | # batch_distances = (
260 | # torch.from_numpy(batch_distances.astype(np.float32))
261 | # .squeeze()
262 | # .to(device)
263 | # )
264 | # else:
265 | # # Compute distances using memory-efficient operations
266 | # distances = torch.cdist(batch, centroids, p=2.0)
267 | # batch_assignments = torch.min(distances, dim=1)[1]
268 | # batch_distances = distances[
269 | # torch.arange(batch.size(0)), batch_assignments
270 | # ]
271 |
272 | # # Update objective and counts
273 | # obj += torch.sum(batch_distances)
274 | # counts += torch.bincount(batch_assignments, minlength=num_partitions)
275 |
276 | # # Update centroids
277 | # for j in range(num_partitions):
278 | # assigned_points = batch[batch_assignments == j]
279 | # if len(assigned_points) > 0:
280 | # centroids[j] += assigned_points.sum(dim=0)
281 |
282 | # # Clear the batch from memory
283 | # del batch
284 | # torch.cuda.empty_cache()
285 |
286 | # # Handle empty clusters
287 | # empty_clusters = torch.where(counts == 0)[0]
288 | # if empty_clusters.numel() > 0:
289 | # for ec in empty_clusters:
290 | # # Find the largest cluster
291 | # largest_cluster = torch.argmax(counts)
292 | # # Assign the empty cluster to a random data point from the largest cluster
293 | # indexes = torch.where(counts == counts[largest_cluster])[0]
294 | # if indexes.numel() > 0:
295 | # idx = torch.randint(0, indexes.numel(), (1,), device=device)
296 | # counts[largest_cluster] -= 1
297 | # counts[ec] = 1
298 | # centroids[ec] = sample[batch_start + indexes[idx].item()]
299 | # # Normalize centroids
300 | # centroids = centroids / torch.norm(centroids, dim=1, keepdim=True)
301 |
302 | # if use_index:
303 | # # Update the Voyager index with the new centroids
304 | # index = Index(Space.Euclidean, num_dimensions=dim)
305 | # index.add_items(centroids.cpu().numpy())
306 |
307 | # if verbose >= 2:
308 | # print(
309 | # f"Iteration: {i+1}, Objective: {obj.item():.4f}, Time: {time.time() - start_time:.4f}s"
310 | # )
311 |
312 | # # Check for convergence
313 | # if obj < best_obj:
314 | # best_obj = obj
315 | # best_centroids = centroids.clone()
316 |
317 | # if verbose >= 1:
318 | # print(f"Best objective: {best_obj.item():.4f}")
319 |
320 | # return best_centroids
321 |
--------------------------------------------------------------------------------
/ragatouille/models/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import shutil
4 | from pathlib import Path
5 | from typing import Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from colbert.infra import ColBERTConfig
10 | from colbert.modeling.colbert import ColBERT
11 | from huggingface_hub import HfApi
12 | from huggingface_hub.utils import HfHubHTTPError
13 | from transformers import AutoModel, BertPreTrainedModel
14 |
15 |
16 | def seeded_shuffle(collection: list, seed: int = 42):
17 | random.seed(seed)
18 | random.shuffle(collection)
19 | return collection
20 |
21 |
22 | """ HUGGINGFACE """
23 |
24 |
25 | def export_to_huggingface_hub(
26 | colbert_path: Union[str, Path],
27 | huggingface_repo_name: str,
28 | export_vespa_onnx: bool = False,
29 | use_tmp_dir: bool = False,
30 | ):
31 | # ensure model contains a valid ColBERT config before exporting
32 | colbert_config = ColBERTConfig.load_from_checkpoint(colbert_path)
33 | try:
34 | assert colbert_config is not None
35 | except Exception:
36 | print(f"Path {colbert_path} does not contain a valid ColBERT config!")
37 |
38 | export_path = colbert_path
39 | if use_tmp_dir:
40 | export_path = ".tmp/hugging_face_export"
41 | print("Using tmp dir to store export files...")
42 | colbert_model = ColBERT(
43 | colbert_path,
44 | colbert_config=colbert_config,
45 | )
46 | print(f"Model loaded... saving export files to disk at {export_path}")
47 | try:
48 | save_model = colbert_model.save
49 | except Exception:
50 | save_model = colbert_model.module.save
51 | save_model(export_path)
52 |
53 | if export_vespa_onnx:
54 | rust_tokenizer_available = True
55 | if use_tmp_dir:
56 | try:
57 | colbert_model.raw_tokenizer.save_pretrained(
58 | export_path, legacy_format=False
59 | )
60 | except Exception:
61 | rust_tokenizer_available = False
62 | else:
63 | rust_tokenizer_available = os.path.exists(
64 | Path(colbert_path) / "tokenizer.json"
65 | )
66 | if not rust_tokenizer_available:
67 | print(
68 | "The tokenizer for your model does not seem to have a Fast Tokenizer implementation...\n",
69 | "This may cause problems when trying to use with Vespa!\n",
70 | "Proceeding anyway...",
71 | )
72 |
73 | export_to_vespa_onnx(colbert_path, out_path=export_path)
74 | try:
75 | api = HfApi()
76 | api.create_repo(repo_id=huggingface_repo_name, repo_type="model", exist_ok=True)
77 | api.upload_folder(
78 | folder_path=export_path,
79 | repo_id=huggingface_repo_name,
80 | repo_type="model",
81 | )
82 | print(f"Successfully uploaded model to {huggingface_repo_name}")
83 | except ValueError as e:
84 | print(
85 | f"Could not create repository on the huggingface hub.\n",
86 | f"Error: {e}\n",
87 | "Please make sure you are logged in (run huggingface-cli login)\n",
88 | "If the error persists, please open an issue on github. This is a beta feature!",
89 | )
90 | except HfHubHTTPError:
91 | print(
92 | "You don't seem to have the rights to create a repository with this name...\n",
93 | "Please make sure your repo name is in the format 'yourusername/your-repo-name'",
94 | )
95 | finally:
96 | if use_tmp_dir:
97 | shutil.rmtree(export_path)
98 |
99 |
100 | """ VESPA """
101 |
102 |
103 | class VespaColBERT(BertPreTrainedModel):
104 | def __init__(self, config, dim):
105 | super().__init__(config)
106 | self.bert = AutoModel.from_config(config)
107 | self.linear = nn.Linear(config.hidden_size, dim, bias=False)
108 | self.init_weights()
109 |
110 | def forward(self, input_ids, attention_mask):
111 | Q = self.bert(input_ids, attention_mask=attention_mask)[0]
112 | Q = self.linear(Q)
113 | return torch.nn.functional.normalize(Q, p=2, dim=2)
114 |
115 |
116 | def export_to_vespa_onnx(
117 | colbert_path: Union[str, Path],
118 | out_path: Union[str, Path],
119 | out_file_name: str = "vespa_colbert.onnx",
120 | ):
121 | print(f"Exporting model {colbert_path} to {out_path}/{out_file_name}")
122 | out_path = Path(out_path)
123 | vespa_colbert = VespaColBERT.from_pretrained(colbert_path, dim=128)
124 | print("Model loaded, converting to ONNX...")
125 | input_names = ["input_ids", "attention_mask"]
126 | output_names = ["contextual"]
127 | input_ids = torch.ones(1, 32, dtype=torch.int64)
128 | attention_mask = torch.ones(1, 32, dtype=torch.int64)
129 | args = (input_ids, attention_mask)
130 | torch.onnx.export(
131 | vespa_colbert,
132 | args=args,
133 | f=str(out_path / out_file_name),
134 | input_names=input_names,
135 | output_names=output_names,
136 | dynamic_axes={
137 | "input_ids": {0: "batch", 1: "batch"},
138 | "attention_mask": {0: "batch", 1: "batch"},
139 | "contextual": {0: "batch", 1: "batch"},
140 | },
141 | opset_version=17,
142 | )
143 | print("Vespa ONNX export complete!")
144 |
--------------------------------------------------------------------------------
/ragatouille/negative_miners/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import HardNegativeMiner
2 | from .simpleminer import SimpleMiner
3 |
4 | __all__ = ["HardNegativeMiner", "SimpleMiner"]
5 |
--------------------------------------------------------------------------------
/ragatouille/negative_miners/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from typing import Union
4 |
5 |
6 | class HardNegativeMiner(ABC):
7 | @abstractmethod
8 | def export_index(self, path: Union[str, Path]) -> bool:
9 | ...
10 |
11 | @abstractmethod
12 | def mine_hard_negatives(
13 | self,
14 | queries: list[str],
15 | collection: list[str],
16 | neg_k: int,
17 | ):
18 | ...
19 |
20 | @abstractmethod
21 | def _mine(
22 | self,
23 | queries: list[str],
24 | k: int,
25 | ):
26 | ...
27 |
--------------------------------------------------------------------------------
/ragatouille/negative_miners/simpleminer.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from pathlib import Path
3 | from typing import Literal, Optional, Union
4 |
5 | import torch
6 | from sentence_transformers import SentenceTransformer
7 | from tqdm import tqdm
8 | from voyager import Index, Space, StorageDataType
9 |
10 | from .base import HardNegativeMiner
11 |
12 |
13 | class DenseModels(Enum):
14 | en_small = "BAAI/bge-small-en-v1.5"
15 | zh_small = "thenlper/gte-small-zh"
16 | fr_small = "OrdalieTech/Solon-embeddings-base-0.1"
17 | other_small = "intfloat/multilingual-e5-small"
18 | en_base = "BAAI/bge-base-en-v1.5"
19 | zh_base = "thenlper/gte-base-zh"
20 | fr_base = "OrdalieTech/Solon-embeddings-base-0.1"
21 | other_base = "intfloat/multilingual-e5-base"
22 | en_large = "BAAI/bge-large-en-v1.5"
23 | zh_large = "thenlper/gte-large-zh"
24 | fr_large = "OrdalieTech/Solon-embeddings-large-0.1"
25 | other_large = "intfloat/multilingual-e5-large"
26 |
27 |
28 | class SimpleMiner(HardNegativeMiner):
29 | """The simplest approach to hard negatives mining.
30 | Select the most appropriate, small-sized embedding model for the target language.
31 | And retrieve random negatives in the top 10-100 results.
32 | Strong baseline for quick, low-engineering hard negative mining."""
33 |
34 | def __init__(
35 | self,
36 | language_code: str,
37 | model_size: Literal["small", "base", "large"] = "small",
38 | ) -> None:
39 | self.n_gpu = torch.cuda.device_count()
40 | self.target_language = language_code
41 | self.model_size = model_size
42 | if language_code not in ["en", "zh", "fr"]:
43 | language_code = "other"
44 | self.model_name = f"{language_code}_{model_size}"
45 | hub_model = DenseModels[self.model_name].value
46 | print(f"Loading Hard Negative SimpleMiner dense embedding model {hub_model}...")
47 | self.model = SentenceTransformer(hub_model)
48 | self.has_index = False
49 | self.min_rank = 10
50 |
51 | def build_index(
52 | self,
53 | collection,
54 | batch_size: int = 128,
55 | save_index: bool = False,
56 | save_path: Union[str, Path] = None,
57 | force_fp32: bool = True,
58 | ):
59 | print(f"Building hard negative index for {len(collection)} documents...")
60 | if len(collection) > 1000:
61 | pool = self.model.start_multi_process_pool()
62 | embeds = self.model.encode_multi_process(
63 | collection, pool, batch_size=batch_size
64 | )
65 | self.model.stop_multi_process_pool(pool)
66 | else:
67 | embeds = self.model.encode(collection, batch_size=batch_size)
68 |
69 | print("All documents embedded, now adding to index...")
70 |
71 | self.max_rank = min(110, int(len(collection) // 10))
72 | self.max_rank = min(self.max_rank, len(collection))
73 |
74 | storage_type = StorageDataType.Float32
75 | if len(collection) > 500000 and not force_fp32:
76 | storage_type = StorageDataType.E4M3
77 |
78 | self.voyager_index = Index(
79 | Space.Cosine,
80 | num_dimensions=self.model.get_sentence_embedding_dimension(),
81 | storage_data_type=storage_type,
82 | )
83 |
84 | self.corpus_map = {i: doc for i, doc in enumerate(collection)}
85 | id_to_vector = {}
86 | for i, emb in enumerate(embeds):
87 | id_to_vector[i] = emb
88 | self.corpus_map[i] = collection[i]
89 | del embeds
90 |
91 | self.voyager_index.add_items(
92 | vectors=[x for x in id_to_vector.values()],
93 | ids=[x for x in id_to_vector.keys()],
94 | num_threads=-1,
95 | )
96 |
97 | del id_to_vector
98 |
99 | if save_index:
100 | print(f"Saving index to {save_path}...")
101 | self.export_index(save_path)
102 | else:
103 | print("save_index set to False, skipping saving hard negative index")
104 | print("Hard negative index generated")
105 | self.has_index = True
106 |
107 | def query_index(self, query, top_k=110):
108 | results = self.voyager_index.query(
109 | query, k=min(top_k, self.voyager_index.__len__())
110 | )
111 | return results
112 |
113 | def mine_hard_negatives(
114 | self,
115 | queries: Union[list[str], str],
116 | collection: Optional[list[str]] = None,
117 | save_index: bool = False,
118 | save_path: Union[str, Path] = None,
119 | force_fp32: bool = True,
120 | ):
121 | if self.has_index is False and collection is not None:
122 | self.build_index(
123 | collection,
124 | save_index=save_index,
125 | save_path=save_path,
126 | force_fp32=force_fp32,
127 | )
128 | if isinstance(queries, str):
129 | return self._mine(queries)
130 | return self._batch_mine(queries)
131 |
132 | def _mine(
133 | self,
134 | query: str,
135 | ):
136 | q_emb = self.model.encode(query)
137 | query_results = self.query_index(q_emb, top_k=self.max_rank)
138 | if len(query_results) > self.min_rank:
139 | query_results = query_results[self.min_rank : self.max_rank]
140 | query_results = [self.corpus_map[x] for x in query_results[0]]
141 | return query_results
142 |
143 | def _batch_mine(
144 | self,
145 | queries: list[str],
146 | ):
147 | """Separate function to parallelise later on"""
148 | print(f"Retrieving hard negatives for {len(queries)} queries...")
149 | results = []
150 | print("Embedding queries...")
151 | query_embeddings = self.model.encode(queries, show_progress_bar=True)
152 | print("Retrieving hard negatives...")
153 | for q_emb in tqdm(query_embeddings):
154 | query_results = self.query_index(q_emb, top_k=self.max_rank)
155 | query_results = query_results[self.min_rank : self.max_rank]
156 | query_results = [self.corpus_map[x.id] for x in query_results]
157 | results.append(query_results)
158 | print(f"""Done generating hard negatives.""")
159 | return results
160 |
161 | def export_index(self, path: Union[str, Path]) -> bool:
162 | self.voyager_index.save(path)
163 | return True
164 |
--------------------------------------------------------------------------------
/ragatouille/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import requests
4 |
5 |
6 | def seeded_shuffle(collection: list, seed: int = 42):
7 | random.seed(seed)
8 | random.shuffle(collection)
9 | return collection
10 |
11 |
12 | def get_wikipedia_page(title: str):
13 | """
14 | Retrieve the full text content of a Wikipedia page.
15 |
16 | :param title: str - Title of the Wikipedia page.
17 | :return: str - Full text content of the page as raw string.
18 | """
19 | # Wikipedia API endpoint
20 | URL = "https://en.wikipedia.org/w/api.php"
21 |
22 | # Parameters for the API request
23 | params = {
24 | "action": "query",
25 | "format": "json",
26 | "titles": title,
27 | "prop": "extracts",
28 | "explaintext": True,
29 | }
30 |
31 | # Custom User-Agent header to comply with Wikipedia's best practices
32 | headers = {"User-Agent": "RAGatouille_tutorial/0.0.1 (ben@clavie.eu)"}
33 |
34 | response = requests.get(URL, params=params, headers=headers)
35 | data = response.json()
36 |
37 | # Extracting page content
38 | page = next(iter(data["query"]["pages"].values()))
39 | return page["extract"] if "extract" in page else None
40 |
--------------------------------------------------------------------------------
/requirements-doc.txt:
--------------------------------------------------------------------------------
1 | mkdocs
2 | mkdocs-material
3 | yarl
4 | cairosvg
5 | mkdocstrings
6 | mkdocstrings-python
7 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnswerDotAI/RAGatouille/e75b8a964a870dea886548f78da1900804749040/tests/__init__.py
--------------------------------------------------------------------------------
/tests/data/Toei_Animation_wikipedia.txt:
--------------------------------------------------------------------------------
1 | Toei Animation Co., Ltd. (Japanese: 東映アニメーション株式会社, Hepburn: Tōei Animēshon Kabushiki-gaisha, ) is a Japanese animation studio primarily controlled by its namesake Toei Company. It has produced numerous series, including Sally the Witch, GeGeGe no Kitarō, Mazinger Z, Galaxy Express 999, Cutie Honey, Dr. Slump, Dragon Ball, Saint Seiya, Sailor Moon, Slam Dunk, Digimon, One Piece, Toriko, World Trigger, The Transformers (between 1984 and 1990, including several Japanese exclusive productions), and the Pretty Cure series.
2 |
3 |
4 | == History ==
5 | The studio was founded by animators Kenzō Masaoka and Zenjirō Yamamoto in 1948 as Japan Animated Films (日本動画映画, Nihon Dōga Eiga, often shortened to 日動映画 (Nichidō Eiga)). In 1956, Toei purchased the studio and it was renamed Toei Doga Co., Ltd. (東映動画株式会社, Tōei Dōga Kabushiki-gaisha, "dōga" is Japanese for "animation"), doing business as Toei Animation Co., Ltd. outside Japan. In 1998, the Japanese name was renamed to Toei Animation. It has created a number of TV series and movies and adapted Japanese comics as animated series, many popular worldwide. Hayao Miyazaki, Isao Takahata, Yasuji Mori, Leiji Matsumoto and Yōichi Kotabe have worked with the company. Toei is a shareholder in the Japanese anime satellite television network Animax with other anime studios and production companies, such as Sunrise, TMS Entertainment and Nihon Ad Systems Inc. The company is headquartered in the Ohizumi Studio in Nerima, Tokyo.Their mascot is the cat Pero, from the company's 1969 film adaptation of Puss in Boots.
6 | Toei Animation produced anime versions of works from manga series by manga artists, including Go Nagai (Mazinger Z), Eiichiro Oda (One Piece), Shotaro Ishinomori (Cyborg 009), Mitsutoshi Shimabukuro (Toriko), Takehiko Inoue (Slam Dunk), Mitsuteru Yokoyama (Sally the Witch), Masami Kurumada (Saint Seiya), Akira Toriyama (Dragon Ball and Dr. Slump), Leiji Matsumoto (Galaxy Express 999), and Naoko Takeuchi (Sailor Moon). The studio helped propel the popularity of the Magical Girl and Super Robot genres of anime; Toei's TV series include the first magical-girl anime series, Mahoutsukai Sally (an adaptation of Mitsuteru Yokoyama's manga of the same name), and Go Nagai's Mazinger Z, an adaptation of his manga which set the standard for Super Robot anime. Although the Toei Company usually contracts Toei Animation to handle its animation internally, they occasionally hire other companies to provide animation; although the Toei Company produced the Robot Romance Trilogy, Sunrise (then known as Nippon Sunrise) provided the animation. Toei Company would also enlist the help of other studios such as hiring Academy Productions to produce the animation for Space Emperor God Sigma, rather than use their own studio.
7 | Toei Animation's anime which have won the Animage Anime Grand Prix award are Galaxy Express 999 in 1981, Saint Seiya in 1987 and Sailor Moon in 1992. In addition to producing anime for release in Japan, Toei Animation began providing animation for American films and television series during the 1960s and particularly during the 1980s.
8 | In October 2021, Toei Animation announced that they had signed a strategic partnership with the South Korean entertainment conglomerate CJ ENM.
9 |
10 |
11 | === 2022 ransomware attack ===
12 | On March 6, 2022, an incident occurred in which an unauthorized third party attempted to hack Toei Animation's network, which resulted in the company's online store and internal systems becoming temporarily suspended. The company investigated the incident and stated that the hack would affect the broadcast schedules of several anime series, including One Piece. In addition, Dragon Ball Super: Super Hero was also rescheduled to June 11, 2022, due to the hack. On April 6, 2022, Toei Animation announced that it would resume broadcasting the anime series, including One Piece. The following day, the Japanese public broadcaster NHK reported that the hack was caused by a targeted ransomware attack.
13 |
14 |
15 | == Subsidiaries ==
16 |
17 |
18 | == Currently in production ==
19 |
20 |
21 | == TV animation ==
22 |
23 |
24 | === 1960–69 ===
25 |
26 |
27 | === 1970–79 ===
28 |
29 |
30 | === 1980–89 ===
31 |
32 |
33 | === 1990–99 ===
34 |
35 |
36 | === 2000–09 ===
37 |
38 |
39 | === 2010–19 ===
40 |
41 |
42 | === 2020–present ===
43 |
44 |
45 | == Television films and specials ==
46 |
47 |
48 | == Theatrical films ==
49 |
50 |
51 | == CGI films ==
52 |
53 |
54 | == Original video animation and original net animation ==
55 |
56 |
57 | == Video game animation ==
58 |
59 |
60 | == Video game development ==
61 |
62 |
63 | == Dubbing ==
64 | Animated productions by foreign studios dubbed in Japanese by Toei are The Mystery of the Third Planet (1981 Russian film, dubbed in 2008); Les Maîtres du temps (1982 French-Hungarian film, dubbed in 2014), Alice's Birthday (2009 Russian film, dubbed in 2013) and Becca's Bunch (2018 television series, dubbed in 2021 to 2022).
65 |
66 |
67 | == Foreign Production History ==
68 | Toei has been commissioned to provide animation by Japanese and American studios such as Sunbow Entertainment, Marvel Productions, Hanna-Barbera, DIC Entertainment, Rankin/Bass Productions and World Events Productions (DreamWorks Animation). In the 60's, they primarily worked with Rankin/Bass, but beginning in the 80's, they worked with Marvel Productions and their list of clients grew, until the end of the decade. Toei didn't provide much outsourced animation work in the 90's and since the 2000s has only rarely worked with other companies outside Japan.
69 |
70 |
71 | == Controversies ==
72 |
73 |
74 | === Fair use disputes ===
75 | Between 2008 and 2018, Toei Animation had copyright claimed TeamFourStar's parody series, DragonBall Z Abridged. TFS stated that the parody series is protected under fair use.On December 7, 2021, Toei Animation copyright claimed over 150 videos by YouTuber Totally Not Mark, real name Mark Fitzpatrick. He uploaded a video addressing the issue, claiming that they were protected under fair use, and that nine of the videos do not include any Toei footage. He also outlined the appeal process on YouTube, and estimated having the videos reinstated could take over 37 years. He then goes on to announce that he would not be supporting new Toei releases until the issue had been resolved, and also called for a boycott on the upcoming Dragon Ball Super: Super Hero film. The dispute sparked discussion on YouTube on the vulnerability of creators against the copyright system and lack of fair use laws in Japan, with YouTubers such as PewDiePie and The Anime Man speaking out on the issue.On January 26, 2022, Fitzpatrick had his videos reinstated after negotiations with YouTube.
76 |
77 |
78 | === Treatment of employees ===
79 | On January 20, 2021, two employees have accused Toei Animation of overworking their employees and discrimination towards sexual minorities. The company had inappropriately referred to employees who identifies as X-gender (a non-binary identity in Japan).
80 |
81 |
82 | == See also ==
83 | SynergySP, Studio Junio and Hal Film Maker/Yumeta Company, animation studios founded by former Toei animators.
84 | Topcraft, an animation studio founded by former Toei Animation producer Toru Hara.
85 | Studio Ghibli, an animation studio founded by former Toei animators Hayao Miyazaki and Isao Takahata.
86 | Mushi Production, an animation studio founded by Osamu Tezuka and former Toei animators.
87 | Shin-Ei Animation, formally A Production, an animation studio founded by former Toei animator Daikichirō Kusube.
88 | Yamamura Animation, an animation studio founded by former Toei animator Kōji Yamamura.
89 | Doga Kobo, an animation studio formed by former Toei animators Hideo Furusawa and Megumu Ishiguro.
90 |
91 |
92 | == References ==
93 |
94 |
95 | == External links ==
96 |
97 | Official website (in English)
98 | Toei Animation at Anime News Network's encyclopedia
99 | Toei Animation at IMDb
--------------------------------------------------------------------------------
/tests/e2e/test_e2e_indexing_searching.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import srsly
3 |
4 | from ragatouille import RAGPretrainedModel
5 | from ragatouille.utils import get_wikipedia_page
6 |
7 |
8 | def test_indexing():
9 | RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
10 | with open("tests/data/miyazaki_wikipedia.txt", "r") as f:
11 | full_document = f.read()
12 | RAG.index(
13 | collection=[full_document],
14 | index_name="Miyazaki",
15 | max_document_length=180,
16 | split_documents=True,
17 | )
18 | # ensure collection is stored to disk
19 | collection = srsly.read_json(
20 | ".ragatouille/colbert/indexes/Miyazaki/collection.json"
21 | )
22 | assert len(collection) > 1
23 |
24 |
25 | def test_search():
26 | RAG = RAGPretrainedModel.from_index(".ragatouille/colbert/indexes/Miyazaki/")
27 | k = 3 # How many documents you want to retrieve, defaults to 10, we set it to 3 here for readability
28 | results = RAG.search(query="What animation studio did Miyazaki found?", k=k)
29 | assert len(results) == k
30 | assert (
31 | "In April 1984, Miyazaki opened his own office in Suginami Ward"
32 | in results[0]["content"]
33 | )
34 | assert (
35 | "Hayao Miyazaki (宮崎 駿 or 宮﨑 駿, Miyazaki Hayao, [mijaꜜzaki hajao]; born January 5, 1941)" # noqa
36 | in results[1]["content"]
37 | )
38 | assert (
39 | 'Glen Keane said Miyazaki is a "huge influence" on Walt Disney Animation Studios and has been' # noqa
40 | in results[2]["content"]
41 | )
42 |
43 | all_results = RAG.search(
44 | query=["What animation studio did Miyazaki found?", "Miyazaki son name"], k=k
45 | )
46 | assert (
47 | "In April 1984, Miyazaki opened his own office in Suginami Ward"
48 | in all_results[0][0]["content"]
49 | )
50 | assert (
51 | "Hayao Miyazaki (宮崎 駿 or 宮﨑 駿, Miyazaki Hayao, [mijaꜜzaki hajao]; born January 5, 1941)" # noqa
52 | in all_results[0][1]["content"]
53 | )
54 | assert (
55 | 'Glen Keane said Miyazaki is a "huge influence" on Walt Disney Animation Studios and has been' # noqa
56 | in all_results[0][2]["content"]
57 | )
58 | assert (
59 | "== Early life ==\nHayao Miyazaki was born on January 5, 1941"
60 | in all_results[1][0]["content"] # noqa
61 | )
62 | assert (
63 | "Directed by Isao Takahata, with whom Miyazaki would continue to collaborate for the remainder of his career" # noqa
64 | in all_results[1][1]["content"]
65 | )
66 | actual = all_results[1][2]["content"]
67 | assert (
68 | "Specific works that have influenced Miyazaki include Animal Farm (1945)"
69 | in actual
70 | or "She met with Suzuki" in actual
71 | )
72 | print(all_results)
73 |
74 |
75 | @pytest.mark.skip(reason="experimental feature.")
76 | def test_basic_CRUD_addition():
77 | old_collection = srsly.read_json(
78 | ".ragatouille/colbert/indexes/Miyazaki/collection.json"
79 | )
80 | old_collection_len = len(old_collection)
81 | path_to_index = ".ragatouille/colbert/indexes/Miyazaki/"
82 | RAG = RAGPretrainedModel.from_index(path_to_index)
83 |
84 | new_documents = get_wikipedia_page("Studio_Ghibli")
85 |
86 | RAG.add_to_index([new_documents])
87 | new_collection = srsly.read_json(
88 | ".ragatouille/colbert/indexes/Miyazaki/collection.json"
89 | )
90 | assert len(new_collection) > old_collection_len
91 | assert len(new_collection) == 140
92 |
--------------------------------------------------------------------------------
/tests/test_pretrained_loading.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | @pytest.mark.skip(reason="NotImplemented")
5 | def test_from_checkpoint():
6 | pass
7 |
8 |
9 | @pytest.mark.skip(reason="NotImplemented")
10 | def test_from_index():
11 | pass
12 |
13 |
14 | @pytest.mark.skip(reason="NotImplemented")
15 | def test_searcher():
16 | pass
17 |
--------------------------------------------------------------------------------
/tests/test_pretrained_optional_args.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | import srsly
5 |
6 | from ragatouille import RAGPretrainedModel
7 |
8 | collection = [
9 | "Hayao Miyazaki (宮崎 駿 or 宮﨑 駿, Miyazaki Hayao, [mijaꜜzaki hajao]; born January 5, 1941) is a Japanese animator, filmmaker, and manga artist. A co-founder of Studio Ghibli, he has attained international acclaim as a masterful storyteller and creator of Japanese animated feature films, and is widely regarded as one of the most accomplished filmmakers in the history of animation.\nBorn in Tokyo City in the Empire of Japan, Miyazaki expressed interest in manga and animation from an early age, and he joined Toei Animation in 1963. During his early years at Toei Animation he worked as an in-between artist and later collaborated with director Isao Takahata. Notable films to which Miyazaki contributed at Toei include Doggie March and Gulliver's Travels Beyond the Moon. He provided key animation to other films at Toei, such as Puss in Boots and Animal Treasure Island, before moving to A-Pro in 1971, where he co-directed Lupin the Third Part I alongside Takahata. After moving to Zuiyō Eizō (later known as Nippon Animation) in 1973, Miyazaki worked as an animator on World Masterpiece Theater, and directed the television series Future Boy Conan (1978). He joined Tokyo Movie Shinsha in 1979 to direct his first feature film The Castle of Cagliostro as well as the television series Sherlock Hound. In the same period, he also began writing and illustrating the manga Nausicaä of the Valley of the Wind (1982–1994), and he also directed the 1984 film adaptation produced by Topcraft.\nMiyazaki co-founded Studio Ghibli in 1985. He directed numerous films with Ghibli, including Laputa: Castle in the Sky (1986), My Neighbor Totoro (1988), Kiki's Delivery Service (1989), and Porco Rosso (1992). The films were met with critical and commercial success in Japan. Miyazaki's film Princess Mononoke was the first animated film ever to win the Japan Academy Prize for Picture of the Year, and briefly became the highest-grossing film in Japan following its release in 1997; its distribution to the Western world greatly increased Ghibli's popularity and influence outside Japan. His 2001 film Spirited Away became the highest-grossing film in Japanese history, winning the Academy Award for Best Animated Feature, and is frequently ranked among the greatest films of the 21st century. Miyazaki's later films—Howl's Moving Castle (2004), Ponyo (2008), and The Wind Rises (2013)—also enjoyed critical and commercial success.",
10 | "Studio Ghibli, Inc. (Japanese: 株式会社スタジオジブリ, Hepburn: Kabushiki gaisha Sutajio Jiburi) is a Japanese animation studio based in Koganei, Tokyo. It has a strong presence in the animation industry and has expanded its portfolio to include various media formats, such as short subjects, television commercials, and two television films. Their work has been well-received by audiences and recognized with numerous awards. Their mascot and most recognizable symbol, the character Totoro from the 1988 film My Neighbor Totoro, is a giant spirit inspired by raccoon dogs (tanuki) and cats (neko). Among the studio's highest-grossing films are Spirited Away (2001), Howl's Moving Castle (2004), and Ponyo (2008). Studio Ghibli was founded on June 15, 1985, by the directors Hayao Miyazaki and Isao Takahata and producer Toshio Suzuki, after acquiring Topcraft's assets. The studio has also collaborated with video game studios on the visual development of several games.Five of the studio's films are among the ten highest-grossing anime feature films made in Japan. Spirited Away is second, grossing 31.68 billion yen in Japan and over US$380 million worldwide, and Princess Mononoke is fourth, grossing 20.18 billion yen. Three of their films have won the Animage Grand Prix award, four have won the Japan Academy Prize for Animation of the Year, and five have received Academy Award nominations. Spirited Away won the 2002 Golden Bear and the 2003 Academy Award for Best Animated Feature.On August 3, 2014, Studio Ghibli temporarily suspended production following Miyazaki's retirement.",
11 | ]
12 |
13 | document_ids = ["miyazaki", "ghibli"]
14 |
15 | document_metadatas = [
16 | {"entity": "person", "source": "wikipedia"},
17 | {"entity": "organisation", "source": "wikipedia"},
18 | ]
19 |
20 |
21 | @pytest.fixture(scope="session")
22 | def persistent_temp_index_root(tmp_path_factory):
23 | return tmp_path_factory.mktemp("temp_test_indexes")
24 |
25 |
26 | @pytest.fixture(scope="session")
27 | def RAG_from_pretrained_model(persistent_temp_index_root):
28 | return RAGPretrainedModel.from_pretrained(
29 | "colbert-ir/colbertv2.0", index_root=str(persistent_temp_index_root)
30 | )
31 |
32 |
33 | @pytest.fixture(scope="session")
34 | def index_path_fixture(persistent_temp_index_root, index_creation_inputs):
35 | index_path = os.path.join(
36 | str(persistent_temp_index_root),
37 | "colbert",
38 | "indexes",
39 | index_creation_inputs["index_name"],
40 | )
41 | return str(index_path)
42 |
43 |
44 | @pytest.fixture(scope="session")
45 | def collection_path_fixture(index_path_fixture):
46 | collection_path = os.path.join(index_path_fixture, "collection.json")
47 | return str(collection_path)
48 |
49 |
50 | @pytest.fixture(scope="session")
51 | def document_metadata_path_fixture(index_path_fixture):
52 | document_metadata_path = os.path.join(index_path_fixture, "docid_metadata_map.json")
53 | return str(document_metadata_path)
54 |
55 |
56 | @pytest.fixture(scope="session")
57 | def pid_docid_map_path_fixture(index_path_fixture):
58 | pid_docid_map_path = os.path.join(index_path_fixture, "pid_docid_map.json")
59 | return str(pid_docid_map_path)
60 |
61 |
62 | @pytest.fixture(
63 | scope="session",
64 | params=[
65 | {
66 | "collection": collection,
67 | "index_name": "no_optional_args",
68 | "split_documents": False,
69 | },
70 | {
71 | "collection": collection,
72 | "document_ids": document_ids,
73 | "index_name": "with_docid",
74 | "split_documents": False,
75 | },
76 | {
77 | "collection": collection,
78 | "document_metadatas": document_metadatas,
79 | "index_name": "with_metadata",
80 | "split_documents": False,
81 | },
82 | {
83 | "collection": collection,
84 | "index_name": "with_split",
85 | "split_documents": True,
86 | },
87 | {
88 | "collection": collection,
89 | "document_ids": document_ids,
90 | "document_metadatas": document_metadatas,
91 | "index_name": "with_docid_metadata",
92 | "split_documents": False,
93 | },
94 | {
95 | "collection": collection,
96 | "document_ids": document_ids,
97 | "index_name": "with_docid_split",
98 | "split_documents": True,
99 | },
100 | {
101 | "collection": collection,
102 | "document_metadatas": document_metadatas,
103 | "index_name": "with_metadata_split",
104 | "split_documents": True,
105 | },
106 | {
107 | "collection": collection,
108 | "document_ids": document_ids,
109 | "document_metadatas": document_metadatas,
110 | "index_name": "with_docid_metadata_split",
111 | "split_documents": True,
112 | },
113 | ],
114 | ids=[
115 | "No optional arguments",
116 | "With document IDs",
117 | "With metadata",
118 | "With document splitting",
119 | "With document IDs and metadata",
120 | "With document IDs and splitting",
121 | "With metadata and splitting",
122 | "With document IDs, metadata, and splitting",
123 | ],
124 | )
125 | def index_creation_inputs(request):
126 | params = request.param
127 | return params
128 |
129 |
130 | @pytest.fixture(scope="session")
131 | def create_index(RAG_from_pretrained_model, index_creation_inputs):
132 | index_path = RAG_from_pretrained_model.index(**index_creation_inputs)
133 | return index_path
134 |
135 |
136 | def test_index_creation(create_index):
137 | assert os.path.exists(create_index) == True
138 |
139 |
140 | @pytest.fixture(scope="session", autouse=True)
141 | def add_docids_to_index_inputs(
142 | create_index, # noqa: ARG001
143 | index_creation_inputs,
144 | pid_docid_map_path_fixture,
145 | ):
146 | if "document_ids" not in index_creation_inputs:
147 | pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture)
148 | seen_ids = set()
149 | index_creation_inputs["document_ids"] = [
150 | x
151 | for x in list(pid_docid_map_data.values())
152 | if not (x in seen_ids or seen_ids.add(x))
153 | ]
154 |
155 |
156 | def test_collection_creation(collection_path_fixture):
157 | assert os.path.exists(collection_path_fixture) == True
158 | collection_data = srsly.read_json(collection_path_fixture)
159 | assert isinstance(
160 | collection_data, list
161 | ), "The collection.json file should contain a list."
162 |
163 |
164 | def test_pid_docid_map_creation(pid_docid_map_path_fixture):
165 | assert os.path.exists(pid_docid_map_path_fixture) == True
166 | # TODO check pid_docid_map_data
167 | pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture)
168 | assert isinstance(
169 | pid_docid_map_data, dict
170 | ), "The pid_docid_map.json file should contain a dictionary."
171 |
172 |
173 | def test_document_metadata_creation(
174 | index_creation_inputs, document_metadata_path_fixture
175 | ):
176 | if "document_metadatas" in index_creation_inputs:
177 | assert os.path.exists(document_metadata_path_fixture) == True
178 | document_metadata_dict = srsly.read_json(document_metadata_path_fixture)
179 | assert (
180 | set(document_metadata_dict.keys())
181 | == set(index_creation_inputs["document_ids"])
182 | ), "The keys in document_metadata.json should match the document_ids provided for index creation."
183 | for doc_id, metadata in document_metadata_dict.items():
184 | assert (
185 | metadata
186 | == index_creation_inputs["document_metadatas"][
187 | index_creation_inputs["document_ids"].index(doc_id)
188 | ]
189 | ), f"The metadata for document_id {doc_id} should match the provided metadata."
190 | else:
191 | assert os.path.exists(document_metadata_path_fixture) == False
192 |
193 |
194 | def test_document_metadata_returned_in_search_results(
195 | index_creation_inputs, index_path_fixture
196 | ):
197 | RAG = RAGPretrainedModel.from_index(index_path_fixture)
198 | results = RAG.search(
199 | "when was miyazaki born", index_name=index_creation_inputs["index_name"]
200 | )
201 |
202 | if "document_metadatas" in index_creation_inputs:
203 | for result in results:
204 | assert (
205 | "document_metadata" in result
206 | ), "The metadata should be returned in the results."
207 | doc_id = result["document_id"]
208 | expected_metadata = index_creation_inputs["document_metadatas"][
209 | index_creation_inputs["document_ids"].index(doc_id)
210 | ]
211 | assert (
212 | result["document_metadata"] == expected_metadata
213 | ), f"The metadata for document_id {doc_id} should match the provided metadata."
214 |
215 | else:
216 | for result in results:
217 | assert (
218 | "metadata" not in result
219 | ), "The metadata should not be returned in the results."
220 |
221 |
222 | # TODO: move this to a separate CRUD test file
223 | # TODO: add checks for metadata and doc content
224 | def test_add_to_existing_index(
225 | index_creation_inputs,
226 | document_metadata_path_fixture,
227 | pid_docid_map_path_fixture,
228 | index_path_fixture,
229 | ):
230 | RAG = RAGPretrainedModel.from_index(index_path_fixture)
231 | existing_doc_ids = index_creation_inputs["document_ids"]
232 | new_doc_ids = ["mononoke", "sabaku_no_tami"]
233 | new_docs = [
234 | "Princess Mononoke (Japanese: もののけ姫, Hepburn: Mononoke-hime) is a 1997 Japanese animated epic historical fantasy film written and directed by Hayao Miyazaki and animated by Studio Ghibli for Tokuma Shoten, Nippon Television Network and Dentsu. The film stars the voices of Yōji Matsuda, Yuriko Ishida, Yūko Tanaka, Kaoru Kobayashi, Masahiko Nishimura, Tsunehiko Kamijo, Akihiro Miwa, Mitsuko Mori, and Hisaya Morishige.\nPrincess Mononoke is set in the late Muromachi period of Japan (approximately 1336 to 1573 AD) and includes fantasy elements. The story follows a young Emishi prince named Ashitaka, and his involvement in a struggle between the gods (kami) of a forest and the humans who consume its resources. The film deals with themes of Shinto and environmentalism.\nThe film was released in Japan on July 12, 1997, by Toho, and in the United States on October 29, 1999. This was the first Studio Ghibli film in the United States to be rated PG-13 by the MPA. It was a critical and commercial blockbuster, becoming the highest-grossing film in Japan of 1997, and also held Japan's box office record for domestic films until 2001's Spirited Away, another Miyazaki film. It was dubbed into English with a script by Neil Gaiman and initially distributed in North America by Miramax, where it sold well on home media despite not performing strongly at the box office. The film greatly increased Ghibli's popularity and influence outside Japan.",
235 | "People of the Desert (砂漠の民, Sabaku no Tami, translated on the cover as The People of Desert), or The Desert Tribe, is a comic strip written and illustrated by Hayao Miyazaki. It was serialized, under the pseudonym Akitsu Saburō (秋津三朗), and ran in Boys and Girls Newspaper (少年少女新聞, Shōnen Shōjo Shinbun) between September 12, 1969, and March 15, 1970.\n\n\n== Story ==\nThe story is set in the distant past, on the fictionalised desert plains of Central Asia. Part of the story takes place in the fortified city named Pejite (ペジテ). The story follows the exploits of the main character, Tem (テム, Temu), a shepherd boy of the fictional Sokut (ソクート, Sokūto) tribe, as he tries to evade the mounted militia of the nomadic Kittāru (キッタール) tribe. In order to restore peace to the realm, Tem rallies his remaining compatriots and rebels against the Kittāru's attempts to gain control of the Sokut territory and enslave its inhabitants through military force.\n\n\n== Background, publication and influences ==\nMiyazaki initially wanted to become a manga artist but started his professional career as an animator for Toei Animation in 1963. Here he worked on animated television series and animated feature-length films for theatrical release. He never abandoned his childhood dream of becoming a manga artist completely, however, and his professional debut as a manga creator came in 1969 with the publication of his manga interpretation of Puss 'n Boots, which was serialized in 12 weekly instalments in the Sunday edition of Tokyo Shimbun, from January to March 1969. Printed in colour and created for promotional purposes in conjunction with his work on Toei's animated film of the same title, directed by Kimio Yabuki.\nIn 1969 pseudonymous serialization also started of Miyazaki's original manga People of the Desert (砂漠の民, Sabaku no Tami). This strip was created in the style of illustrated stories (絵物語, emonogatari) he read in boys' magazines and tankōbon volumes while growing up, such as Soji Yamakawa's Shōnen Ōja (少年王者) and in particular Tetsuji Fukushima's Evil Lord of the Desert (沙漠の魔王, Sabaku no Maō). Miyazaki's People of the Desert is a continuation of that tradition. In People of the Desert expository text is presented separately from the monochrome artwork but Miyazaki progressively used additional text balloons inside the panels for dialogue.\nPeople of the Desert was serialized in 26 weekly instalments which were printed in Boys and Girls Newspaper (少年少女新聞, Shōnen shōjo shinbun), a publication of the Japanese Communist Party, between September 12, 1969 (issue 28) and March 15, 1970 (issue 53). The strip was published under the pseudonym Akitsu Saburō (秋津三朗).\nThe strip has been identified as a precursor for Miyazaki's manga Nausicaä of the Valley of the Wind (1982–1995) and the one-off graphic novel Shuna's Journey (1983), published by Tokuma Shoten.",
236 | ]
237 | new_doc_metadata = [
238 | {"entity": "film", "source": "wikipedia"},
239 | {"entity": "manga", "source": "wikipedia"},
240 | ]
241 | RAG.add_to_index(
242 | new_collection=new_docs,
243 | new_document_ids=new_doc_ids,
244 | new_document_metadatas=new_doc_metadata,
245 | index_name=index_creation_inputs["index_name"],
246 | )
247 | pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture)
248 | document_ids = set(list(pid_docid_map_data.values()))
249 |
250 | document_metadata_dict = srsly.read_json(document_metadata_path_fixture)
251 | # check for new docs
252 | for new_doc_id in new_doc_ids:
253 | assert (
254 | new_doc_id in document_ids
255 | ), f"New document ID '{new_doc_id}' should be in the pid_docid_map's document_ids:{document_ids}."
256 |
257 | assert (
258 | new_doc_id in document_metadata_dict
259 | ), f"New document ID '{new_doc_id}' should be in the document metadata keys:{document_metadata_dict.keys}."
260 |
261 | for existing_doc_id in existing_doc_ids:
262 | assert (
263 | existing_doc_id in document_ids
264 | ), f"Old document ID '{existing_doc_id}' should be in the pid_docid_map's document_ids:{document_ids}."
265 |
266 | if "document_metadatas" in index_creation_inputs:
267 | assert (
268 | existing_doc_id in document_metadata_dict
269 | ), f"Old document ID '{existing_doc_id}' should be in the document metadata keys:{document_metadata_dict.keys}."
270 |
271 |
272 | # TODO: move this to a separate CRUD test file
273 | def test_delete_from_index(
274 | index_creation_inputs,
275 | pid_docid_map_path_fixture,
276 | document_metadata_path_fixture,
277 | index_path_fixture,
278 | ):
279 | RAG = RAGPretrainedModel.from_index(index_path_fixture)
280 | deleted_doc_id = index_creation_inputs["document_ids"][0]
281 | original_doc_ids = set(index_creation_inputs["document_ids"])
282 | RAG.delete_from_index(
283 | index_name=index_creation_inputs["index_name"],
284 | document_ids=[deleted_doc_id],
285 | )
286 | pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture)
287 | updated_document_ids = set(list(pid_docid_map_data.values()))
288 |
289 | assert (
290 | deleted_doc_id not in updated_document_ids
291 | ), f"Deleted document ID '{deleted_doc_id}' should not be in the pid_docid_map's document_ids: {updated_document_ids}."
292 |
293 | assert (
294 | original_doc_ids - updated_document_ids == {deleted_doc_id}
295 | ), f"Only the deleted document ID '{deleted_doc_id}' should be missing from the pid_docid_map's document_ids: {updated_document_ids}."
296 |
297 | if "document_metadatas" in index_creation_inputs:
298 | document_metadata_dict = srsly.read_json(document_metadata_path_fixture)
299 | assert (
300 | deleted_doc_id not in document_metadata_dict
301 | ), f"Deleted document ID '{deleted_doc_id}' should not be in the document metadata: {document_metadata_dict.keys}."
302 | assert (
303 | original_doc_ids - set(document_metadata_dict.keys()) == {deleted_doc_id}
304 | ), f"Only the deleted document ID '{deleted_doc_id}' should be missing from the document metadata: {document_metadata_dict.keys}."
305 |
--------------------------------------------------------------------------------
/tests/test_trainer_loading.py:
--------------------------------------------------------------------------------
1 | from colbert.infra import ColBERTConfig
2 |
3 | from ragatouille import RAGTrainer
4 |
5 |
6 | def test_finetune():
7 | """Ensure that the initially loaded config is the one from the pretrained model."""
8 | trainer = RAGTrainer(
9 | model_name="test",
10 | pretrained_model_name="colbert-ir/colbertv2.0",
11 | language_code="en",
12 | )
13 | trainer_config = trainer.model.config
14 |
15 | assert ColBERTConfig() != trainer_config
16 | assert trainer_config.query_token == "[Q]"
17 | assert trainer_config.doc_token == "[D]"
18 | assert trainer_config.nbits == 1
19 | assert trainer_config.kmeans_niters == 20
20 | assert trainer_config.lr == 1e-05
21 | assert trainer_config.relu is False
22 | assert trainer_config.nway == 64
23 | assert trainer_config.doc_maxlen == 180
24 | assert trainer_config.use_ib_negatives is True
25 | assert trainer_config.name == "kldR2.nway64.ib"
26 |
27 |
28 | def test_raw_model():
29 | """Ensure that the default ColBERT configuration is properly loaded when initialising from a BERT-like model""" # noqa: E501
30 | trainer = RAGTrainer(
31 | model_name="test",
32 | pretrained_model_name="bert-base-uncased",
33 | language_code="en",
34 | )
35 | trainer_config = trainer.model.config
36 |
37 | default_config = ColBERTConfig()
38 |
39 | assert trainer_config.query_token == default_config.query_token
40 | assert trainer_config.doc_token == default_config.doc_token
41 | assert trainer_config.nway == default_config.nway
42 | assert trainer_config.doc_maxlen == default_config.doc_maxlen
43 | assert trainer_config.bsize == default_config.bsize
44 | assert trainer_config.use_ib_negatives == default_config.use_ib_negatives
45 |
--------------------------------------------------------------------------------
/tests/test_training.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import random
3 | import signal
4 | from contextlib import contextmanager
5 |
6 | import pytest
7 | import torch
8 |
9 | from ragatouille import RAGTrainer
10 | from ragatouille.data import CorpusProcessor, llama_index_sentence_splitter
11 |
12 | DATA_DIR = pathlib.Path(__file__).parent / "data"
13 |
14 |
15 | skip_if_no_cuda = pytest.mark.skipif(
16 | not torch.cuda.is_available(),
17 | reason="Skip test. Training currently only works when CUDA is available",
18 | )
19 |
20 |
21 | class TimeoutException(Exception):
22 | pass
23 |
24 |
25 | @contextmanager
26 | def time_limit(seconds):
27 | """Time limit context manager as in https://stackoverflow.com/a/601168"""
28 |
29 | def signal_handler(_, __):
30 | raise TimeoutException("Timed out!")
31 |
32 | signal.signal(signal.SIGALRM, signal_handler)
33 | signal.alarm(seconds)
34 | try:
35 | yield
36 | finally:
37 | signal.alarm(0)
38 |
39 |
40 | @skip_if_no_cuda
41 | @pytest.mark.slow
42 | def test_training(tmp_path):
43 | """This test is based on the content of examples/02-basic_training.ipynb
44 | and mainly tests that there are no exceptions which can happen e.g. due
45 | to bugs in data processing.
46 | """
47 | trainer = RAGTrainer(
48 | model_name="GhibliColBERT",
49 | pretrained_model_name="colbert-ir/colbertv2.0",
50 | language_code="en",
51 | )
52 | pages = ["miyazaki", "Studio_Ghibli", "Toei_Animation"]
53 | my_full_corpus = [(DATA_DIR / f"{p}_wikipedia.txt").open().read() for p in pages]
54 |
55 | corpus_processor = CorpusProcessor(
56 | document_splitter_fn=llama_index_sentence_splitter
57 | )
58 | documents = corpus_processor.process_corpus(my_full_corpus, chunk_size=256)
59 |
60 | queries = [
61 | "What manga did Hayao Miyazaki write?",
62 | "which film made ghibli famous internationally",
63 | "who directed Spirited Away?",
64 | "when was Hikotei Jidai published?",
65 | "where's studio ghibli based?",
66 | "where is the ghibli museum?",
67 | ]
68 | pairs = []
69 |
70 | for query in queries:
71 | fake_relevant_docs = random.sample(documents, 10)
72 | for doc in fake_relevant_docs:
73 | pairs.append((query, doc))
74 | trainer.prepare_training_data(
75 | raw_data=pairs,
76 | data_out_path=str(tmp_path),
77 | all_documents=my_full_corpus,
78 | num_new_negatives=10,
79 | mine_hard_negatives=True,
80 | )
81 | try:
82 | with time_limit(30):
83 | trainer.train(
84 | batch_size=32,
85 | nbits=4, # How many bits will the trained model use when compressing indexes
86 | maxsteps=1, # Maximum steps hard stop
87 | use_ib_negatives=True, # Use in-batch negative to calculate loss
88 | dim=128, # How many dimensions per embedding. 128 is the default and works well.
89 | learning_rate=5e-6,
90 | # Learning rate, small values ([3e-6,3e-5] work best if the base model is BERT-like, 5e-6 is often the sweet spot)
91 | doc_maxlen=256,
92 | # Maximum document length. Because of how ColBERT works, smaller chunks (128-256) work very well.
93 | use_relu=False, # Disable ReLU -- doesn't improve performance
94 | warmup_steps="auto", # Defaults to 10%
95 | )
96 | # Simply test that some of the files generated have really been made.
97 | assert (tmp_path / "corpus.train.colbert.tsv").exists()
98 | except TimeoutException as e:
99 | print("Timed out!")
100 | raise AssertionError("Timout in training") from None
101 |
102 |
103 | if __name__ == "__main__":
104 | test_training()
105 |
--------------------------------------------------------------------------------
/tests/test_training_data_loading.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from ragatouille import RAGTrainer
4 |
5 |
6 | @pytest.fixture
7 | def rag_trainer():
8 | # Setup for a RAGTrainer instance
9 | instance = RAGTrainer(
10 | model_name="test_model", pretrained_model_name="bert-base-uncased"
11 | )
12 | return instance
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "input_data,pairs_with_labels,expected_queries,expected_collection",
17 | [
18 | # Unlabeled pairs
19 | (
20 | [("Query1", "Document1"), ("Query2", "Document2")],
21 | False,
22 | {"Query1", "Query2"},
23 | ["Document1", "Document2"],
24 | ),
25 | # Labeled pairs
26 | (
27 | [
28 | ("Query1", "Document1Pos", 1),
29 | ("Query1", "Document1Neg", 0),
30 | ("Query2", "Document2", 0),
31 | ],
32 | True,
33 | {"Query1", "Query2"},
34 | ["Document1Pos", "Document1Neg", "Document2"],
35 | ),
36 | # Triplets
37 | (
38 | [("Query1", "Positive Doc", "Negative Doc")],
39 | False,
40 | {"Query1"},
41 | ["Positive Doc", "Negative Doc"],
42 | ),
43 | ],
44 | )
45 | def test_prepare_training_data(
46 | rag_trainer, input_data, pairs_with_labels, expected_queries, expected_collection
47 | ):
48 | rag_trainer.prepare_training_data(
49 | raw_data=input_data, pairs_with_labels=pairs_with_labels
50 | )
51 |
52 | assert rag_trainer.queries == expected_queries
53 |
54 | assert len(rag_trainer.collection) == len(expected_collection)
55 | assert set(rag_trainer.collection) == set(expected_collection)
56 |
57 |
58 | def test_prepare_training_data_with_all_documents(rag_trainer):
59 | input_data = [("Query1", "Document1")]
60 | all_documents = ["Document2", "Document3"]
61 |
62 | rag_trainer.prepare_training_data(raw_data=input_data, all_documents=all_documents)
63 |
64 | assert rag_trainer.queries == {"Query1"}
65 |
66 | assert len(rag_trainer.collection) == 3
67 | assert set(rag_trainer.collection) == {"Document1", "Document2", "Document3"}
68 |
69 |
70 | def test_prepare_training_data_invalid_input(rag_trainer):
71 | # Providing an invalid input format
72 | with pytest.raises(ValueError):
73 | rag_trainer.prepare_training_data(raw_data=[("Query1")]) # Missing document
74 |
--------------------------------------------------------------------------------
/tests/test_training_data_processor.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock
2 |
3 | import pytest
4 |
5 | from ragatouille.data import TrainingDataProcessor
6 |
7 |
8 | @pytest.fixture
9 | def collection():
10 | return ["doc1", "doc2", "doc3"]
11 |
12 |
13 | @pytest.fixture
14 | def queries():
15 | return ["query1", "query2"]
16 |
17 |
18 | def test_process_raw_data_without_miner(collection, queries):
19 | processor = TrainingDataProcessor(collection, queries, None)
20 | processor._process_raw_pairs = MagicMock(return_value=None)
21 |
22 | processor.process_raw_data(
23 | raw_data=[], data_type="pairs", data_dir="./", mine_hard_negatives=False
24 | )
25 |
26 | processor._process_raw_pairs.assert_called_once()
27 |
28 |
29 | def test_process_raw_data_with_miner(collection, queries):
30 | negative_miner = MagicMock()
31 | processor = TrainingDataProcessor(collection, queries, negative_miner)
32 | processor._process_raw_pairs = MagicMock(return_value=None)
33 |
34 | processor.process_raw_data(raw_data=[], data_type="pairs", data_dir="./")
35 |
36 | processor._process_raw_pairs.assert_called_once()
37 |
--------------------------------------------------------------------------------