├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── docs
├── .pages
├── CNAME
├── api
│ ├── .pages
│ ├── clustering
│ │ ├── .pages
│ │ ├── KMeans.md
│ │ ├── average.md
│ │ ├── get-mapping-nodes-documents.md
│ │ └── optimize-leafs.md
│ ├── datasets
│ │ ├── .pages
│ │ ├── load-beir-test.md
│ │ ├── load-beir-train.md
│ │ └── load-beir.md
│ ├── leafs
│ │ ├── .pages
│ │ └── Leaf.md
│ ├── nodes
│ │ ├── .pages
│ │ └── Node.md
│ ├── overview.md
│ ├── retrievers
│ │ ├── .pages
│ │ ├── ColBERT.md
│ │ ├── SentenceTransformer.md
│ │ └── TfIdf.md
│ ├── scoring
│ │ ├── .pages
│ │ ├── BaseScore.md
│ │ ├── ColBERT.md
│ │ ├── SentenceTransformer.md
│ │ └── TfIdf.md
│ ├── trees
│ │ ├── .pages
│ │ ├── ColBERT.md
│ │ ├── SentenceTransformer.md
│ │ ├── TfIdf.md
│ │ └── Tree.md
│ └── utils
│ │ ├── .pages
│ │ ├── batchify.md
│ │ ├── evaluate.md
│ │ ├── iter.md
│ │ ├── leafs-precision.md
│ │ ├── sanity-check.md
│ │ └── set-env.md
├── css
│ └── version-select.css
├── evaluate
│ ├── .pages
│ └── evaluate.md
├── existing_tree
│ ├── .pages
│ └── existing_tree.md
├── img
│ ├── logo.png
│ └── neural_tree.png
├── index.md
├── javascripts
│ └── config.js
├── js
│ └── version-select.js
├── scripts
│ └── index.py
├── stylesheets
│ └── extra.css
└── trees
│ ├── .pages
│ ├── colbert.md
│ ├── sentence_transformer.md
│ └── tfidf.md
├── mkdocs.yml
├── neural_tree
├── __init__.py
├── __version__.py
├── clustering
│ ├── __init__.py
│ ├── average.py
│ ├── kmeans.py
│ └── optimize.py
├── datasets
│ ├── __init__.py
│ └── beir.py
├── leafs
│ ├── __init__.py
│ └── leaf.py
├── nodes
│ ├── __init__.py
│ └── node.py
├── retrievers
│ ├── __init__.py
│ ├── colbert.py
│ ├── sentence_transformer.py
│ └── tfidf.py
├── scoring
│ ├── __init__.py
│ ├── base.py
│ ├── colbert.py
│ ├── sentence_transformer.py
│ └── tfidf.py
├── trees
│ ├── __init__.py
│ ├── colbert.py
│ ├── sentence_transformer.py
│ ├── tfidf.py
│ └── tree.py
├── utils
│ ├── __init__.py
│ ├── batchify.py
│ ├── evaluate.py
│ ├── iter.py
│ ├── sanity_check.py
│ └── set_env.py
└── version.txt
├── pytest.ini
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Raphael Sourty
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | livedoc:
2 | mkdocs build --clean
3 | mkdocs serve --dirtyreload
4 |
5 | deploydoc:
6 | mkdocs gh-deploy --force
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Neural-Tree
4 |
Neural Search
5 |
6 |
7 | 
8 |
9 |
10 |
11 |

12 |
13 |

14 |
15 |
16 |
17 |
18 | Are tree-based indexes the counterpart of standard ANN algorithms for token-level embeddings IR models? Neural-Tree replicate the SIGIR 2023 publication [Constructing Tree-based Index for Efficient and Effective Dense Retrieval](https://dl.acm.org/doi/10.1145/3539618.3591651) in order to accelerate ColBERT. Neural-Tree is compatible with Sentence Transformers and TfIdf models as in the original paper.
19 |
20 | Neural-Tree creates a tree using hierarchical clustering of documents and then learn embeddings in each node of the tree using paired queries and documents. Additionally, there is the flexibility to input an existing tree structure in JSON format to build the index.
21 |
22 | The optimization of the index by Neural-Tree is geared towards maintaining the performance level of the original model while significantly speeding up the search process. It is important to note that Neural-Tree does not modify the underlying model; therefore, it is advisable to initiate tree creation with a model that has already been fine-tuned. Given that Neural-Tree does not alter the model, the index training process is relatively quick.
23 |
24 | ## Installation
25 |
26 | We can install neural-tree using:
27 |
28 | ```
29 | pip install neural-tree
30 | ```
31 |
32 | If we plan to evaluate our model while training install:
33 |
34 | ```
35 | pip install "neural-tree[eval]"
36 | ```
37 |
38 | ## Documentation
39 |
40 | The complete documentation is available [here](https://raphaelsty.github.io/neural-tree/).
41 |
42 |
43 | ## Quick Start
44 |
45 | The following code shows how to train a tree index. Let's start by creating a fictional dataset:
46 |
47 | ```python
48 | documents = [
49 | {"id": 0, "content": "paris"},
50 | {"id": 1, "content": "london"},
51 | {"id": 2, "content": "berlin"},
52 | {"id": 3, "content": "rome"},
53 | {"id": 4, "content": "bordeaux"},
54 | {"id": 5, "content": "milan"},
55 | ]
56 |
57 | train_queries = [
58 | "paris is the capital of france",
59 | "london is the capital of england",
60 | "berlin is the capital of germany",
61 | "rome is the capital of italy",
62 | ]
63 |
64 | train_documents = [
65 | {"id": 0, "content": "paris"},
66 | {"id": 1, "content": "london"},
67 | {"id": 2, "content": "berlin"},
68 | {"id": 3, "content": "rome"},
69 | ]
70 |
71 | test_queries = [
72 | "bordeaux is the capital of france",
73 | "milan is the capital of italy",
74 | ]
75 | ```
76 |
77 | Let's train the index using the `documents`, `train_queries` and `train_documents` we have gathered.
78 |
79 | ```python
80 | import torch
81 | from neural_cherche import models
82 | from neural_tree import clustering, trees, utils
83 |
84 | model = models.ColBERT(
85 | model_name_or_path="raphaelsty/neural-cherche-colbert",
86 | device="cuda" if torch.cuda.is_available() else "cpu",
87 | )
88 |
89 | tree = trees.ColBERT(
90 | key="id",
91 | on=["content"],
92 | model=model,
93 | documents=documents,
94 | leaf_balance_factor=100, # Number of documents per leaf
95 | branch_balance_factor=5, # Number children per node
96 | n_jobs=-1, # set to 1 with Google Colab
97 | )
98 |
99 | optimizer = torch.optim.AdamW(lr=3e-3, params=list(tree.parameters()))
100 |
101 | for step, batch_queries, batch_documents in utils.iter(
102 | queries=train_queries,
103 | documents=train_documents,
104 | shuffle=True,
105 | epochs=50,
106 | batch_size=32,
107 | ):
108 | loss = tree.loss(
109 | queries=batch_queries,
110 | documents=batch_documents,
111 | )
112 |
113 | loss.backward()
114 | optimizer.step()
115 | optimizer.zero_grad(set_to_none=True)
116 | ```
117 |
118 |
119 | Let's now duplicate some documents of the tree in order to increase accuracy.
120 |
121 | ```python
122 | documents_to_leafs = clustering.optimize_leafs(
123 | tree=tree,
124 | queries=train_queries + test_queries,
125 | documents=documents,
126 | )
127 |
128 | tree = tree.add(
129 | documents=documents,
130 | documents_to_leafs=documents_to_leafs,
131 | )
132 | ```
133 |
134 | We are now ready to retrieve documents:
135 |
136 | ```python
137 | scores = tree(
138 | queries=["bordeaux", "milan"],
139 | k_leafs=2,
140 | k=2,
141 | )
142 |
143 | print(scores["documents"])
144 | ```
145 |
146 | ```python
147 | [
148 | [
149 | {"id": 4, "similarity": 5.28, "leaf": "12"},
150 | {"id": 0, "similarity": 3.17, "leaf": "12"},
151 | ],
152 | [
153 | {"id": 5, "similarity": 5.11, "leaf": "10"},
154 | {"id": 2, "similarity": 3.57, "leaf": "10"},
155 | ],
156 | ]
157 | ```
158 |
159 | ## Benchmarks
160 |
161 |
162 |
163 |
164 | |
165 | Scifact Dataset |
166 |
167 |
168 | Vanilla |
169 | Neural-Tree |
170 |
171 |
172 |
173 |
174 | model |
175 | HuggingFace Checkpoint |
176 | ndcg@10 |
177 | hits@10 |
178 | hits@1 |
179 | queries / second |
180 | ndcg@10 |
181 | hits@10 |
182 | hits@1 |
183 | queries / second |
184 | Acceleration |
185 |
186 |
187 | TfIdf Cherche |
188 | - |
189 | 0,61 |
190 | 0,85 |
191 | 0,47 |
192 | 760 |
193 | 0,56 |
194 | 0,82 |
195 | 0,42 |
196 | 1080 |
197 | +42.11% |
198 |
199 |
200 | SentenceTransformer GPU Faiss.IndexFlatL2 CPU |
201 | sentence-transformers/all-mpnet-base-v2 |
202 | 0,66 |
203 | 0,89 |
204 | 0,53 |
205 | 475 |
206 | 0,66 |
207 | 0,88 |
208 | 0,53 |
209 | 518 |
210 | +9.05% |
211 |
212 |
213 | ColBERT Neural-Cherche GPU |
214 | raphaelsty/neural-cherche-colbert |
215 | 0,70 |
216 | 0,92 |
217 | 0,58 |
218 | 3 |
219 | 0,70 |
220 | 0,91 |
221 | 0,59 |
222 | 256 |
223 | x85 |
224 |
225 |
226 |
227 |
228 | Note that this benchmark do not implement [ColBERTV2](https://arxiv.org/abs/2112.01488) efficient retrieval but rather compare ColBERT raw retrieval with Neural-Tree. We could accelerate SentenceTransformer vanilla by using optimized Faiss index.
229 |
230 | ## Contributing
231 |
232 | We welcome contributions to Neural-Tree. Our focus includes improving the clustering of ColBERT embeddings which is currently handled by TfIdf. Neural-Cherche will also be a tool designed to enhance tree visualization, extract nodes topics, and leverage the tree structure to accelerate Large Language Model (LLM) retrieval.
233 |
234 | ## License
235 |
236 | This project is licensed under the terms of the MIT license.
237 |
238 | ## References
239 |
240 | - [Constructing Tree-based Index for Efficient and Effective Dense Retrieval, Github](https://github.com/cshaitao/jtr)
241 |
242 | - [ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT](https://arxiv.org/abs/2004.12832)
243 |
244 | - [Myriade](https://github.com/MaxHalford/myriade)
245 |
246 |
247 |
--------------------------------------------------------------------------------
/docs/.pages:
--------------------------------------------------------------------------------
1 | nav:
2 | - Home: trees
3 | - existing_tree
4 | - evaluate
5 | - api
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | raphaelsty.github.io/neural-tree/
--------------------------------------------------------------------------------
/docs/api/.pages:
--------------------------------------------------------------------------------
1 | title: API reference
2 | arrange:
3 | - overview.md
4 | - ...
5 |
--------------------------------------------------------------------------------
/docs/api/clustering/.pages:
--------------------------------------------------------------------------------
1 | title: clustering
--------------------------------------------------------------------------------
/docs/api/clustering/KMeans.md:
--------------------------------------------------------------------------------
1 | # KMeans
2 |
3 | KMeans clustering.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **documents_embeddings** (*torch.Tensor*)
10 |
11 | - **n_clusters** (*int*)
12 |
13 | - **max_iter** (*int*)
14 |
15 | - **n_init** (*int*)
16 |
17 | - **seed** (*int*)
18 |
19 | - **device** (*str*)
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/docs/api/clustering/average.md:
--------------------------------------------------------------------------------
1 | # average
2 |
3 | Replace KMeans clustering with average clustering when an existing graph is provided.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **key** (*str*)
10 |
11 | - **documents** (*list*)
12 |
13 | - **documents_embeddings** (*numpy.ndarray | scipy.sparse._csr.csr_matrix*)
14 |
15 | - **graph**
16 |
17 | - **scoring**
18 |
19 | - **device** (*str*)
20 |
21 |
22 |
23 | ## Examples
24 |
25 | ```python
26 | >>> from neural_tree import clustering, scoring
27 | >>> import numpy as np
28 |
29 | >>> documents = [
30 | ... {"id": 0, "text": "Paris is the capital of France."},
31 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
32 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
33 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
34 | ... ]
35 |
36 | >>> documents_embeddings = np.array([
37 | ... [1, 1],
38 | ... [1, 2],
39 | ... [10, 10],
40 | ... [1, 3],
41 | ... ])
42 |
43 | >>> graph = {1: {11: {111: [{'id': 0}, {'id': 3}], 112: [{'id': 1}]}, 12: {121: [{'id': 2}], 122: [{'id': 3}]}}}
44 |
45 | >>> clustering.average(
46 | ... key="id",
47 | ... documents_embeddings=documents_embeddings,
48 | ... documents=documents,
49 | ... graph=graph[1],
50 | ... scoring=scoring.SentenceTransformer(key="id", on=["text"], model=None),
51 | ... )
52 | ```
53 |
54 |
--------------------------------------------------------------------------------
/docs/api/clustering/get-mapping-nodes-documents.md:
--------------------------------------------------------------------------------
1 | # get_mapping_nodes_documents
2 |
3 | Get documents from specific node.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **graph** (*dict | list*)
10 |
11 | - **documents** (*list | None*) – defaults to `None`
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/docs/api/clustering/optimize-leafs.md:
--------------------------------------------------------------------------------
1 | # optimize_leafs
2 |
3 | Optimize the clusters.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **tree**
10 |
11 | - **documents** (*list[dict]*)
12 |
13 | - **queries** (*list[str]*)
14 |
15 | - **k_tree** (*int*) – defaults to `2`
16 |
17 | - **k_retriever** (*int*) – defaults to `10`
18 |
19 | - **k_leafs** (*int*) – defaults to `2`
20 |
21 | - **kwargs**
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/docs/api/datasets/.pages:
--------------------------------------------------------------------------------
1 | title: datasets
--------------------------------------------------------------------------------
/docs/api/datasets/load-beir-test.md:
--------------------------------------------------------------------------------
1 | # load_beir_test
2 |
3 | Load BEIR testing dataset.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **dataset_name** (*str*)
10 |
11 | Dataset name.
12 |
13 |
14 |
15 | ## Examples
16 |
17 | ```python
18 | >>> from neural_tree import datasets
19 |
20 | >>> documents, queries_ids, queries, qrels = datasets.load_beir_test(
21 | ... dataset_name="scifact",
22 | ... )
23 |
24 | >>> len(documents)
25 | 5183
26 |
27 | >>> assert len(queries_ids) == len(queries) == len(qrels)
28 | ```
29 |
30 |
--------------------------------------------------------------------------------
/docs/api/datasets/load-beir-train.md:
--------------------------------------------------------------------------------
1 | # load_beir_train
2 |
3 | Load training dataset.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **dataset_name** (*str*)
10 |
11 | Dataset name
12 |
13 |
14 |
15 | ## Examples
16 |
17 | ```python
18 | >>> from neural_tree import datasets
19 |
20 | >>> documents, train_queries, train_documents = datasets.load_beir_train(
21 | ... dataset_name="scifact",
22 | ... )
23 |
24 | >>> len(documents)
25 | 5183
26 |
27 | >>> assert len(train_queries) == len(train_documents)
28 | ```
29 |
30 |
--------------------------------------------------------------------------------
/docs/api/datasets/load-beir.md:
--------------------------------------------------------------------------------
1 | # load_beir
2 |
3 | Load BEIR dataset.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **dataset_name** (*str*)
10 |
11 | - **split** (*str*)
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/docs/api/leafs/.pages:
--------------------------------------------------------------------------------
1 | title: leafs
--------------------------------------------------------------------------------
/docs/api/nodes/.pages:
--------------------------------------------------------------------------------
1 | title: nodes
--------------------------------------------------------------------------------
/docs/api/overview.md:
--------------------------------------------------------------------------------
1 | # Overview
2 |
3 | ## clustering
4 |
5 | - [KMeans](../clustering/KMeans)
6 | - [average](../clustering/average)
7 | - [get_mapping_nodes_documents](../clustering/get-mapping-nodes-documents)
8 | - [optimize_leafs](../clustering/optimize-leafs)
9 |
10 | ## datasets
11 |
12 | - [load_beir](../datasets/load-beir)
13 | - [load_beir_test](../datasets/load-beir-test)
14 | - [load_beir_train](../datasets/load-beir-train)
15 |
16 | ## leafs
17 |
18 | - [Leaf](../leafs/Leaf)
19 |
20 | ## nodes
21 |
22 | - [Node](../nodes/Node)
23 |
24 | ## retrievers
25 |
26 | - [ColBERT](../retrievers/ColBERT)
27 | - [SentenceTransformer](../retrievers/SentenceTransformer)
28 | - [TfIdf](../retrievers/TfIdf)
29 |
30 | ## scoring
31 |
32 | - [BaseScore](../scoring/BaseScore)
33 | - [ColBERT](../scoring/ColBERT)
34 | - [SentenceTransformer](../scoring/SentenceTransformer)
35 | - [TfIdf](../scoring/TfIdf)
36 |
37 | ## trees
38 |
39 | - [ColBERT](../trees/ColBERT)
40 | - [SentenceTransformer](../trees/SentenceTransformer)
41 | - [TfIdf](../trees/TfIdf)
42 | - [Tree](../trees/Tree)
43 |
44 | ## utils
45 |
46 | - [batchify](../utils/batchify)
47 | - [evaluate](../utils/evaluate)
48 | - [iter](../utils/iter)
49 | - [leafs_precision](../utils/leafs-precision)
50 | - [sanity_check](../utils/sanity-check)
51 | - [set_env](../utils/set-env)
52 |
53 |
--------------------------------------------------------------------------------
/docs/api/retrievers/.pages:
--------------------------------------------------------------------------------
1 | title: retrievers
--------------------------------------------------------------------------------
/docs/api/retrievers/ColBERT.md:
--------------------------------------------------------------------------------
1 | # ColBERT
2 |
3 | ColBERT retriever.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **key** (*str*)
10 |
11 | - **on** (*str | list[str]*)
12 |
13 | - **device** (*str*)
14 |
15 |
16 |
17 |
18 | ## Methods
19 |
20 | ???- note "__call__"
21 |
22 | Rank documents givent queries.
23 |
24 | **Parameters**
25 |
26 | - **queries_embeddings** (*dict[str, torch.Tensor]*)
27 | - **batch_size** (*int*) – defaults to `32`
28 | - **k** (*int*) – defaults to `None`
29 | - **tqdm_bar** (*bool*) – defaults to `False`
30 |
31 | ???- note "add"
32 |
33 | Add documents embeddings.
34 |
35 | **Parameters**
36 |
37 | - **documents_embeddings** (*dict*)
38 |
39 | ???- note "encode_documents"
40 |
41 | Encode documents.
42 |
43 | **Parameters**
44 |
45 | - **documents** (*list[str]*)
46 | - **batch_size** (*int*) – defaults to `32`
47 | - **tqdm_bar** (*bool*) – defaults to `True`
48 | - **query_mode** (*bool*) – defaults to `False`
49 | - **kwargs**
50 |
51 | ???- note "encode_queries"
52 |
53 | Encode queries.
54 |
55 | **Parameters**
56 |
57 | - **queries** (*list[str]*)
58 | - **batch_size** (*int*) – defaults to `32`
59 | - **tqdm_bar** (*bool*) – defaults to `True`
60 | - **query_mode** (*bool*) – defaults to `True`
61 | - **kwargs**
62 |
63 |
--------------------------------------------------------------------------------
/docs/api/retrievers/SentenceTransformer.md:
--------------------------------------------------------------------------------
1 | # SentenceTransformer
2 |
3 | Sentence Transformer retriever.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **key** (*str*)
10 |
11 | - **device** (*str*) – defaults to `cpu`
12 |
13 |
14 |
15 | ## Examples
16 |
17 | ```python
18 | >>> from neural_tree import retrievers
19 | >>> from sentence_transformers import SentenceTransformer
20 | >>> from pprint import pprint
21 |
22 | >>> model = SentenceTransformer("all-mpnet-base-v2")
23 |
24 | >>> retriever = retrievers.SentenceTransformer(key="id")
25 |
26 | >>> retriever = retriever.add(
27 | ... documents_embeddings={
28 | ... 0: model.encode("Paris is the capital of France."),
29 | ... 1: model.encode("Berlin is the capital of Germany."),
30 | ... 2: model.encode("Paris and Berlin are European cities."),
31 | ... 3: model.encode("Paris and Berlin are beautiful cities."),
32 | ... }
33 | ... )
34 |
35 | >>> queries_embeddings = {
36 | ... 0: model.encode("Paris"),
37 | ... 1: model.encode("Berlin"),
38 | ... }
39 |
40 | >>> candidates = retriever(queries_embeddings=queries_embeddings, k=2)
41 | >>> pprint(candidates)
42 | [[{'id': 0, 'similarity': 0.644777984318611},
43 | {'id': 3, 'similarity': 0.52865785276988}],
44 | [{'id': 1, 'similarity': 0.6901492368348436},
45 | {'id': 3, 'similarity': 0.5457692206973245}]]
46 | ```
47 |
48 | ## Methods
49 |
50 | ???- note "__call__"
51 |
52 | Retrieve documents.
53 |
54 | **Parameters**
55 |
56 | - **queries_embeddings** (*dict[int, numpy.ndarray]*)
57 | - **k** (*int | None*) – defaults to `100`
58 | - **kwargs**
59 |
60 | ???- note "add"
61 |
62 | Add documents to the faiss index.
63 |
64 | **Parameters**
65 |
66 | - **documents_embeddings** (*dict[int, numpy.ndarray]*)
67 |
68 |
--------------------------------------------------------------------------------
/docs/api/retrievers/TfIdf.md:
--------------------------------------------------------------------------------
1 | # TfIdf
2 |
3 | TfIdf retriever
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **key** (*str*)
10 |
11 | - **on** (*list[str]*)
12 |
13 |
14 |
15 |
16 | ## Methods
17 |
18 | ???- note "__call__"
19 |
20 | Retrieve documents from batch of queries.
21 |
22 | **Parameters**
23 |
24 | - **queries_embeddings** (*dict[str, scipy.sparse._csr.csr_matrix]*)
25 | - **k** (*int*) – defaults to `None`
26 | - **batch_size** (*int*) – defaults to `2000`
27 | - **tqdm_bar** (*bool*) – defaults to `True`
28 |
29 | ???- note "add"
30 |
31 | Add new documents to the TFIDF retriever. The tfidf won't be refitted.
32 |
33 | **Parameters**
34 |
35 | - **documents_embeddings** (*dict[str, scipy.sparse._csr.csr_matrix]*)
36 |
37 | ???- note "encode_documents"
38 |
39 | Encode queries into sparse matrix.
40 |
41 | **Parameters**
42 |
43 | - **documents** (*list[dict]*)
44 | - **model** (*sklearn.feature_extraction.text.TfidfVectorizer*)
45 |
46 | ???- note "encode_queries"
47 |
48 | Encode queries into sparse matrix.
49 |
50 | **Parameters**
51 |
52 | - **queries** (*list[str]*)
53 | - **model** (*sklearn.feature_extraction.text.TfidfVectorizer*)
54 |
55 | ???- note "top_k"
56 |
57 | Return the top k documents for each query.
58 |
59 | **Parameters**
60 |
61 | - **similarities** (*scipy.sparse._csc.csc_matrix*)
62 | - **k** (*int*)
63 |
64 |
--------------------------------------------------------------------------------
/docs/api/scoring/.pages:
--------------------------------------------------------------------------------
1 | title: scoring
--------------------------------------------------------------------------------
/docs/api/scoring/BaseScore.md:
--------------------------------------------------------------------------------
1 | # BaseScore
2 |
3 | Base class for scoring functions.
4 |
5 |
6 |
7 |
8 |
9 |
10 | ## Methods
11 |
12 | ???- note "convert_to_tensor"
13 |
14 | Transform sparse matrix to tensor.
15 |
16 | **Parameters**
17 |
18 | - **embeddings** (*scipy.sparse._csr.csr_matrix | numpy.ndarray*)
19 | - **device** (*str*)
20 |
21 | ???- note "distinct_documents_encoder"
22 |
23 | Return True if the encoder is distinct for documents and nodes.
24 |
25 |
26 | ???- note "encode_queries_for_retrieval"
27 |
28 | Encode queries for retrieval.
29 |
30 | **Parameters**
31 |
32 | - **queries** (*list[str]*)
33 |
34 | ???- note "get_retriever"
35 |
36 | Create a retriever
37 |
38 |
39 | ???- note "leaf_scores"
40 |
41 | Return the scores of the embeddings.
42 |
43 | **Parameters**
44 |
45 | - **queries_embeddings** (*torch.Tensor*)
46 | - **leaf_embedding** (*torch.Tensor*)
47 |
48 | ???- note "nodes_scores"
49 |
50 | Score between queries and nodes embeddings.
51 |
52 | **Parameters**
53 |
54 | - **queries_embeddings** (*torch.Tensor*)
55 | - **nodes_embeddings** (*torch.Tensor*)
56 |
57 | ???- note "stack"
58 |
59 | Stack list of embeddings.
60 |
61 | - **embeddings** (*list[scipy.sparse._csr.csr_matrix | numpy.ndarray | dict]*)
62 |
63 | ???- note "transform_documents"
64 |
65 | Transform documents to embeddings.
66 |
67 | **Parameters**
68 |
69 | - **documents** (*list[dict]*)
70 |
71 | ???- note "transform_queries"
72 |
73 | Transform queries to embeddings.
74 |
75 | **Parameters**
76 |
77 | - **queries** (*list[str]*)
78 |
79 |
--------------------------------------------------------------------------------
/docs/api/scoring/ColBERT.md:
--------------------------------------------------------------------------------
1 | # ColBERT
2 |
3 | TfIdf scoring function.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **key** (*str*)
10 |
11 | - **on** (*list | str*)
12 |
13 | - **documents** (*list*)
14 |
15 | - **model** (*neural_cherche.models.colbert.ColBERT*) – defaults to `None`
16 |
17 | - **device** (*str*) – defaults to `cpu`
18 |
19 | - **kwargs**
20 |
21 |
22 | ## Attributes
23 |
24 | - **distinct_documents_encoder**
25 |
26 | Return True if the encoder is distinct for documents and nodes.
27 |
28 |
29 | ## Examples
30 |
31 | ```python
32 | >>> from neural_tree import trees, scoring
33 | >>> from neural_cherche import models
34 | >>> from sklearn.feature_extraction.text import TfidfVectorizer
35 | >>> from pprint import pprint
36 | >>> import torch
37 |
38 | >>> _ = torch.manual_seed(42)
39 |
40 | >>> documents = [
41 | ... {"id": 0, "text": "Paris is the capital of France."},
42 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
43 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
44 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
45 | ... ]
46 |
47 | >>> model = models.ColBERT(
48 | ... model_name_or_path="sentence-transformers/all-mpnet-base-v2",
49 | ... embedding_size=128,
50 | ... max_length_document=96,
51 | ... max_length_query=32,
52 | ... )
53 |
54 | >>> tree = trees.ColBERTTree(
55 | ... key="id",
56 | ... on="text",
57 | ... model=model,
58 | ... documents=documents,
59 | ... leaf_balance_factor=1,
60 | ... branch_balance_factor=2,
61 | ... n_jobs=1,
62 | ... )
63 |
64 | >>> print(tree)
65 | node 1
66 | node 10
67 | leaf 100
68 | leaf 101
69 | node 11
70 | leaf 110
71 | leaf 111
72 |
73 | >>> tree.leafs_to_documents
74 | {'100': [0], '101': [1], '110': [2], '111': [3]}
75 |
76 | >>> candidates = tree(
77 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."],
78 | ... k_leafs=2,
79 | ... k=2,
80 | ... )
81 |
82 | >>> candidates["scores"]
83 | array([[28.12037659, 18.32332611],
84 | [29.28324509, 21.38923264]])
85 |
86 | >>> candidates["leafs"]
87 | array([['100', '101'],
88 | ['110', '111']], dtype='>> pprint(candidates["tree_scores"])
91 | [{'10': tensor(28.1204),
92 | '100': tensor(28.1204),
93 | '101': tensor(18.3233),
94 | '11': tensor(20.9327)},
95 | {'10': tensor(21.6886),
96 | '11': tensor(29.2832),
97 | '110': tensor(29.2832),
98 | '111': tensor(21.3892)}]
99 |
100 | >>> pprint(candidates["documents"])
101 | [[{'id': 0, 'leaf': '100', 'similarity': 28.120376586914062},
102 | {'id': 1, 'leaf': '101', 'similarity': 18.323326110839844}],
103 | [{'id': 2, 'leaf': '110', 'similarity': 29.283245086669922},
104 | {'id': 3, 'leaf': '111', 'similarity': 21.389232635498047}]]
105 | ```
106 |
107 | ## Methods
108 |
109 | ???- note "average"
110 |
111 | Average embeddings.
112 |
113 | - **embeddings** (*torch.Tensor*)
114 |
115 | ???- note "convert_to_tensor"
116 |
117 | Transform sparse matrix to tensor.
118 |
119 | **Parameters**
120 |
121 | - **embeddings** (*numpy.ndarray | torch.Tensor*)
122 | - **device** (*str*)
123 |
124 | ???- note "encode_queries_for_retrieval"
125 |
126 | Encode queries for retrieval.
127 |
128 | **Parameters**
129 |
130 | - **queries** (*list[str]*)
131 |
132 | ???- note "get_retriever"
133 |
134 | Create a retriever
135 |
136 |
137 | ???- note "leaf_scores"
138 |
139 | Return the scores of the embeddings.
140 |
141 | **Parameters**
142 |
143 | - **queries_embeddings** (*torch.Tensor*)
144 | - **leaf_embedding** (*torch.Tensor*)
145 |
146 | ???- note "nodes_scores"
147 |
148 | Score between queries and nodes embeddings.
149 |
150 | **Parameters**
151 |
152 | - **queries_embeddings** (*torch.Tensor*)
153 | - **nodes_embeddings** (*torch.Tensor*)
154 |
155 | ???- note "stack"
156 |
157 | Stack list of embeddings.
158 |
159 | **Parameters**
160 |
161 | - **embeddings** (*list[torch.Tensor | numpy.ndarray]*)
162 |
163 | ???- note "transform_documents"
164 |
165 | Transform documents to embeddings.
166 |
167 | **Parameters**
168 |
169 | - **documents** (*list[dict]*)
170 | - **batch_size** (*int*)
171 | - **tqdm_bar** (*bool*)
172 | - **kwargs**
173 |
174 | ???- note "transform_queries"
175 |
176 | Transform queries to embeddings.
177 |
178 | **Parameters**
179 |
180 | - **queries** (*list[str]*)
181 | - **batch_size** (*int*)
182 | - **tqdm_bar** (*bool*)
183 | - **kwargs**
184 |
185 |
--------------------------------------------------------------------------------
/docs/api/scoring/SentenceTransformer.md:
--------------------------------------------------------------------------------
1 | # SentenceTransformer
2 |
3 | Sentence Transformer scoring function.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **key** (*str*)
10 |
11 | - **on** (*str | list*)
12 |
13 | - **model** (*sentence_transformers.SentenceTransformer.SentenceTransformer*)
14 |
15 | - **device** (*str*) – defaults to `cpu`
16 |
17 | - **faiss_device** (*str*) – defaults to `cpu`
18 |
19 |
20 | ## Attributes
21 |
22 | - **distinct_documents_encoder**
23 |
24 | Return True if the encoder is distinct for documents and nodes.
25 |
26 |
27 | ## Examples
28 |
29 | ```python
30 | >>> from neural_tree import trees, scoring
31 | >>> from sentence_transformers import SentenceTransformer
32 | >>> from pprint import pprint
33 |
34 | >>> documents = [
35 | ... {"id": 0, "text": "Paris is the capital of France."},
36 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
37 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
38 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
39 | ... ]
40 |
41 | >>> tree = trees.Tree(
42 | ... key="id",
43 | ... documents=documents,
44 | ... scoring=scoring.SentenceTransformer(key="id", on=["text"], model=SentenceTransformer("all-mpnet-base-v2")),
45 | ... leaf_balance_factor=1,
46 | ... branch_balance_factor=2,
47 | ... n_jobs=1,
48 | ... )
49 |
50 | >>> print(tree)
51 | node 1
52 | node 11
53 | node 110
54 | leaf 1100
55 | leaf 1101
56 | leaf 111
57 | leaf 10
58 |
59 | >>> candidates = tree(
60 | ... queries=["paris", "berlin"],
61 | ... k_leafs=2,
62 | ... )
63 |
64 | >>> candidates["scores"]
65 | array([[0.72453916, 0.60635257],
66 | [0.58386189, 0.57546711]])
67 |
68 | >>> candidates["leafs"]
69 | array([['111', '10'],
70 | ['1101', '1100']], dtype='>> pprint(candidates["tree_scores"])
73 | [{'10': tensor(0.6064),
74 | '11': tensor(0.7245),
75 | '110': tensor(0.5542),
76 | '1100': tensor(0.5403),
77 | '1101': tensor(0.5542),
78 | '111': tensor(0.7245)},
79 | {'10': tensor(0.5206),
80 | '11': tensor(0.5797),
81 | '110': tensor(0.5839),
82 | '1100': tensor(0.5755),
83 | '1101': tensor(0.5839),
84 | '111': tensor(0.4026)}]
85 |
86 | >>> pprint(candidates["documents"])
87 | [[{'id': 0, 'leaf': '111', 'similarity': 0.6447779347587058},
88 | {'id': 1, 'leaf': '10', 'similarity': 0.43175890864117644}],
89 | [{'id': 3, 'leaf': '1101', 'similarity': 0.545769273959571},
90 | {'id': 2, 'leaf': '1100', 'similarity': 0.54081365990618}]]
91 | ```
92 |
93 | ## Methods
94 |
95 | ???- note "average"
96 |
97 | Average embeddings.
98 |
99 | - **embeddings** (*numpy.ndarray*)
100 |
101 | ???- note "convert_to_tensor"
102 |
103 | Convert numpy array to torch tensor.
104 |
105 | **Parameters**
106 |
107 | - **embeddings** (*numpy.ndarray*)
108 | - **device** (*str*)
109 |
110 | ???- note "encode_queries_for_retrieval"
111 |
112 | Encode queries for retrieval.
113 |
114 | - **queries** (*list[str]*)
115 |
116 | ???- note "get_retriever"
117 |
118 | Create a retriever
119 |
120 |
121 | ???- note "leaf_scores"
122 |
123 | Computes scores between query and leaf embedding.
124 |
125 | **Parameters**
126 |
127 | - **queries_embeddings** (*torch.Tensor*)
128 | - **leaf_embedding** (*torch.Tensor*)
129 |
130 | ???- note "nodes_scores"
131 |
132 | Score between queries and nodes embeddings.
133 |
134 | **Parameters**
135 |
136 | - **queries_embeddings** (*torch.Tensor*)
137 | - **nodes_embeddings** (*torch.Tensor*)
138 |
139 | ???- note "stack"
140 |
141 | Stack embeddings.
142 |
143 | - **embeddings** (*list[numpy.ndarray]*)
144 |
145 | ???- note "transform_documents"
146 |
147 | Transform documents to embeddings.
148 |
149 | **Parameters**
150 |
151 | - **documents** (*list[dict]*)
152 | - **batch_size** (*int*)
153 | - **kwargs**
154 |
155 | ???- note "transform_queries"
156 |
157 | Transform queries to embeddings.
158 |
159 | **Parameters**
160 |
161 | - **queries** (*list[str]*)
162 | - **batch_size** (*int*)
163 | - **kwargs**
164 |
165 |
--------------------------------------------------------------------------------
/docs/api/scoring/TfIdf.md:
--------------------------------------------------------------------------------
1 | # TfIdf
2 |
3 | TfIdf scoring function.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **key** (*str*)
10 |
11 | - **on** (*list | str*)
12 |
13 | - **documents** (*list*)
14 |
15 | - **tfidf_nodes** (*sklearn.feature_extraction.text.TfidfVectorizer | None*) – defaults to `None`
16 |
17 | - **tfidf_documents** (*sklearn.feature_extraction.text.TfidfVectorizer | None*) – defaults to `None`
18 |
19 | - **kwargs**
20 |
21 |
22 | ## Attributes
23 |
24 | - **distinct_documents_encoder**
25 |
26 | Return True if the encoder is distinct for documents and nodes.
27 |
28 |
29 | ## Examples
30 |
31 | ```python
32 | >>> from neural_tree import trees, scoring
33 | >>> from sklearn.feature_extraction.text import TfidfVectorizer
34 | >>> from pprint import pprint
35 |
36 | >>> documents = [
37 | ... {"id": 0, "text": "Paris is the capital of France."},
38 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
39 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
40 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
41 | ... ]
42 |
43 | >>> tree = trees.Tree(
44 | ... key="id",
45 | ... documents=documents,
46 | ... scoring=scoring.TfIdf(key="id", on=["text"], documents=documents),
47 | ... leaf_balance_factor=1,
48 | ... branch_balance_factor=2,
49 | ... )
50 |
51 | >>> print(tree)
52 | node 1
53 | node 10
54 | leaf 100
55 | leaf 101
56 | node 11
57 | leaf 110
58 | leaf 111
59 |
60 | >>> tree.leafs_to_documents
61 | {'100': [0], '101': [1], '110': [2], '111': [3]}
62 |
63 | >>> candidates = tree(
64 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."],
65 | ... k_leafs=2,
66 | ... k=2,
67 | ... )
68 |
69 | >>> candidates["scores"]
70 | array([[0.99999994, 0.63854915],
71 | [0.99999994, 0.72823119]])
72 |
73 | >>> candidates["leafs"]
74 | array([['100', '101'],
75 | ['110', '111']], dtype='>> pprint(candidates["tree_scores"])
78 | [{'10': tensor(1.0000),
79 | '100': tensor(1.0000),
80 | '101': tensor(0.6385),
81 | '11': tensor(0.1076)},
82 | {'10': tensor(0.1076),
83 | '11': tensor(1.0000),
84 | '110': tensor(1.0000),
85 | '111': tensor(0.7282)}]
86 |
87 | >>> pprint(candidates["documents"])
88 | [[{'id': 0, 'leaf': '100', 'similarity': 0.9999999999999978},
89 | {'id': 1, 'leaf': '101', 'similarity': 0.39941742405759667}],
90 | [{'id': 2, 'leaf': '110', 'similarity': 0.9999999999999978},
91 | {'id': 3, 'leaf': '111', 'similarity': 0.5385719658738707}]]
92 | ```
93 |
94 | ## Methods
95 |
96 | ???- note "average"
97 |
98 | Average embeddings.
99 |
100 | - **embeddings** (*scipy.sparse._csr.csr_matrix*)
101 |
102 | ???- note "convert_to_tensor"
103 |
104 | Transform sparse matrix to tensor.
105 |
106 | **Parameters**
107 |
108 | - **embeddings** (*scipy.sparse._csr.csr_matrix*)
109 | - **device** (*str*)
110 |
111 | ???- note "encode_queries_for_retrieval"
112 |
113 | Encode queries for retrieval.
114 |
115 | **Parameters**
116 |
117 | - **queries** (*list[str]*)
118 |
119 | ???- note "get_retriever"
120 |
121 | Create a retriever
122 |
123 |
124 | ???- note "leaf_scores"
125 |
126 | Return the scores of the embeddings.
127 |
128 | **Parameters**
129 |
130 | - **queries_embeddings** (*torch.Tensor*)
131 | - **leaf_embedding** (*torch.Tensor*)
132 |
133 | ???- note "nodes_scores"
134 |
135 | Score between queries and nodes embeddings.
136 |
137 | **Parameters**
138 |
139 | - **queries_embeddings** (*torch.Tensor*)
140 | - **nodes_embeddings** (*torch.Tensor*)
141 |
142 | ???- note "stack"
143 |
144 | Stack list of embeddings.
145 |
146 | - **embeddings** (*list[scipy.sparse._csr.csr_matrix]*)
147 |
148 | ???- note "transform_documents"
149 |
150 | Transform documents to embeddings.
151 |
152 | **Parameters**
153 |
154 | - **documents** (*list[dict]*)
155 | - **kwargs**
156 |
157 | ???- note "transform_queries"
158 |
159 | Transform queries to embeddings.
160 |
161 | **Parameters**
162 |
163 | - **queries** (*list[str]*)
164 | - **kwargs**
165 |
166 |
--------------------------------------------------------------------------------
/docs/api/trees/.pages:
--------------------------------------------------------------------------------
1 | title: trees
--------------------------------------------------------------------------------
/docs/api/utils/.pages:
--------------------------------------------------------------------------------
1 | title: utils
--------------------------------------------------------------------------------
/docs/api/utils/batchify.md:
--------------------------------------------------------------------------------
1 | # batchify
2 |
3 |
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **X** (*list[str]*)
10 |
11 | - **batch_size** (*int*)
12 |
13 | - **desc** (*str*) – defaults to ``
14 |
15 | - **tqdm_bar** (*bool*) – defaults to `True`
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/docs/api/utils/evaluate.md:
--------------------------------------------------------------------------------
1 | # evaluate
2 |
3 | Evaluate candidates matchs.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **scores** (*list[list[dict]]*)
10 |
11 | - **qrels** (*dict*)
12 |
13 | Qrels.
14 |
15 | - **queries_ids** (*list[str]*)
16 |
17 | - **metrics** (*list*) – defaults to `[]`
18 |
19 | Metrics to compute.
20 |
21 | - **key** (*str*) – defaults to `id`
22 |
23 |
24 |
25 | ## Examples
26 |
27 | ```python
28 | >>> from neural_cherche import models, retrieve, utils
29 | >>> import torch
30 |
31 | >>> _ = torch.manual_seed(42)
32 |
33 | >>> model = models.Splade(
34 | ... model_name_or_path="distilbert-base-uncased",
35 | ... device="cpu",
36 | ... )
37 |
38 | >>> documents, queries_ids, queries, qrels = utils.load_beir(
39 | ... "scifact",
40 | ... split="test",
41 | ... )
42 |
43 | >>> documents = documents[:10]
44 |
45 | >>> retriever = retrieve.Splade(
46 | ... key="id",
47 | ... on=["title", "text"],
48 | ... model=model
49 | ... )
50 |
51 | >>> documents_embeddings = retriever.encode_documents(
52 | ... documents=documents,
53 | ... batch_size=1,
54 | ... )
55 |
56 | >>> documents_embeddings = retriever.add(
57 | ... documents_embeddings=documents_embeddings,
58 | ... )
59 |
60 | >>> queries_embeddings = retriever.encode_queries(
61 | ... queries=queries,
62 | ... batch_size=1,
63 | ... )
64 |
65 | >>> scores = retriever(
66 | ... queries_embeddings=queries_embeddings,
67 | ... k=30,
68 | ... batch_size=1,
69 | ... )
70 |
71 | >>> utils.evaluate(
72 | ... scores=scores,
73 | ... qrels=qrels,
74 | ... queries_ids=queries_ids,
75 | ... metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"]
76 | ... )
77 | {'map': 0.0033333333333333335, 'ndcg@10': 0.0033333333333333335, 'ndcg@100': 0.0033333333333333335, 'recall@10': 0.0033333333333333335, 'recall@100': 0.0033333333333333335}
78 | ```
79 |
80 |
--------------------------------------------------------------------------------
/docs/api/utils/iter.md:
--------------------------------------------------------------------------------
1 | # iter
2 |
3 | Iterate over the dataset.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **queries**
10 |
11 | List of queries paired with documents.
12 |
13 | - **documents**
14 |
15 | List of documents paired with queries.
16 |
17 | - **batch_size** – defaults to `512`
18 |
19 | Size of the batch.
20 |
21 | - **epochs** (*int*) – defaults to `1`
22 |
23 | Number of epochs.
24 |
25 | - **shuffle** – defaults to `True`
26 |
27 | - **tqdm_bar** – defaults to `True`
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/docs/api/utils/leafs-precision.md:
--------------------------------------------------------------------------------
1 | # leafs_precision
2 |
3 | Calculate the precision of the leafs.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **key** (*str*)
10 |
11 | - **documents** (*list*)
12 |
13 | - **leafs** (*numpy.ndarray*)
14 |
15 | - **documents_to_leaf** (*dict*)
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/docs/api/utils/sanity-check.md:
--------------------------------------------------------------------------------
1 | # sanity_check
2 |
3 | Check if the input is valid.
4 |
5 |
6 |
7 | ## Parameters
8 |
9 | - **branch_balance_factor** (*int*)
10 |
11 | - **leaf_balance_factor** (*int*)
12 |
13 | - **graph** (*dict*)
14 |
15 | - **documents** (*list*)
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/docs/api/utils/set-env.md:
--------------------------------------------------------------------------------
1 | # set_env
2 |
3 | Set environment variables.
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/docs/css/version-select.css:
--------------------------------------------------------------------------------
1 | @media only screen and (max-width:76.1875em) {
2 | #version-selector {
3 | padding: .6rem .8rem;
4 | }
5 | }
--------------------------------------------------------------------------------
/docs/evaluate/.pages:
--------------------------------------------------------------------------------
1 | title: Evaluate
2 | nav:
3 | - evaluate.md
4 |
5 |
--------------------------------------------------------------------------------
/docs/evaluate/evaluate.md:
--------------------------------------------------------------------------------
1 | # Evaluate
2 |
3 | Neural-tree evaluation is based on [RANX](https://github.com/AmenRa/ranx). We can also download datasets of [BEIR Benchmark](https://github.com/beir-cellar/beir) with the `utils.load_beir` function.
4 |
5 |
6 | ## Installation
7 |
8 | ```bash
9 | pip install "neural-tree[eval]"
10 | ```
11 |
12 | ## Usage
13 |
14 | Here is an example of how to train a tree-based index using the `scifact` dataset and how to evaluate it.
15 |
16 | ```python
17 | import torch
18 | from neural_cherche import models
19 | from sentence_transformers import SentenceTransformer
20 |
21 | from neural_tree import clustering, datasets, trees, utils
22 |
23 | documents, train_queries, train_documents = datasets.load_beir_train(
24 | dataset_name="scifact",
25 | )
26 |
27 |
28 | model = models.ColBERT(
29 | model_name_or_path="raphaelsty/neural-cherche-colbert",
30 | device="cuda",
31 | )
32 |
33 | # We intialize a ColBERT index from a
34 | # SentenceTransformer-based hierarchical clustering.
35 | tree = trees.ColBERT(
36 | key="id",
37 | on=["title", "text"],
38 | model=model,
39 | sentence_transformer=SentenceTransformer(model_name_or_path="all-mpnet-base-v2"),
40 | documents=documents,
41 | leaf_balance_factor=100,
42 | branch_balance_factor=5,
43 | n_jobs=-1,
44 | device="cuda",
45 | faiss_device="cuda",
46 | )
47 |
48 | optimizer = torch.optim.AdamW(lr=3e-3, params=list(tree.parameters()))
49 |
50 |
51 | for step, batch_queries, batch_documents in utils.iter(
52 | queries=train_queries,
53 | documents=train_documents,
54 | shuffle=True,
55 | epochs=50,
56 | batch_size=128,
57 | ):
58 | loss = tree.loss(
59 | queries=batch_queries,
60 | documents=batch_documents,
61 | )
62 |
63 | loss.backward()
64 | optimizer.step()
65 | optimizer.zero_grad(set_to_none=True)
66 |
67 |
68 | documents, queries_ids, test_queries, qrels = datasets.load_beir_test(
69 | dataset_name="scifact",
70 | )
71 |
72 | documents_to_leafs = clustering.optimize_leafs(
73 | tree=tree,
74 | queries=train_queries + test_queries,
75 | documents=documents,
76 | )
77 |
78 | tree = tree.add(
79 | documents=documents,
80 | documents_to_leafs=documents_to_leafs,
81 | )
82 |
83 | candidates = tree(
84 | queries=test_queries,
85 | k_leafs=2, # number of leafs to search
86 | k=10, # number of documents to retrieve
87 | )
88 |
89 | documents, queries_ids, test_queries, qrels = datasets.load_beir_test(
90 | dataset_name="scifact",
91 | )
92 |
93 | candidates = tree(
94 | queries=test_queries,
95 | k_leafs=2,
96 | k=10,
97 | )
98 |
99 |
100 | scores = utils.evaluate(
101 | scores=candidates["documents"],
102 | qrels=qrels,
103 | queries_ids=queries_ids,
104 | )
105 |
106 | print(scores)
107 | ```
108 |
109 | ```python
110 | {"ndcg@10": 0.6957728027724698, "hits@1": 0.59, "hits@2": 0.69, "hits@3": 0.76, "hits@4": 0.8133333333333334, "hits@5": 0.8533333333333334, "hits@10": 0.91}
111 | ```
112 |
113 | ## Evaluation dataset
114 |
115 | Here are what documents should looks like (an id with multiples fields):
116 |
117 | ```python
118 | [
119 | {
120 | "id": "document_0",
121 | "title": "Bayesian measures of model complexity and fit",
122 | "text": "Summary. We consider the problem of comparing complex hierarchical models in which the number of parameters is not clearly defined. Using an information theoretic argument we derive a measure pD for the effective number of parameters in a model as the difference between the posterior mean of the deviance and the deviance at the posterior means of the parameters of interest. In general pD approximately corresponds to the trace of the product of Fisher's information and the posterior covariance, which in normal models is the trace of the ‘hat’ matrix projecting observations onto fitted values. Its properties in exponential families are explored. The posterior mean deviance is suggested as a Bayesian measure of fit or adequacy, and the contributions of individual observations to the fit and complexity can give rise to a diagnostic plot of deviance residuals against leverages. Adding pD to the posterior mean deviance gives a deviance information criterion for comparing models, which is related to other information criteria and has an approximate decision theoretic justification. The procedure is illustrated in some examples, and comparisons are drawn with alternative Bayesian and classical proposals. Throughout it is emphasized that the quantities required are trivial to compute in a Markov chain Monte Carlo analysis.",
123 | },
124 | {
125 | "id": "document_1",
126 | "title": "Simplifying likelihood ratios",
127 | "text": "Likelihood ratios are one of the best measures of diagnostic accuracy, although they are seldom used, because interpreting them requires a calculator to convert back and forth between “probability” and “odds” of disease. This article describes a simpler method of interpreting likelihood ratios, one that avoids calculators, nomograms, and conversions to “odds” of disease. Several examples illustrate how the clinician can use this method to refine diagnostic decisions at the bedside.",
128 | },
129 | ]
130 | ```
131 |
132 | Queries is a list of strings:
133 |
134 | ```python
135 | [
136 | "Varenicline monotherapy is more effective after 12 weeks of treatment compared to combination nicotine replacement therapies with varenicline or bupropion.",
137 | "Venules have a larger lumen diameter than arterioles.",
138 | "Venules have a thinner or absent smooth layer compared to arterioles.",
139 | "Vitamin D deficiency effects the term of delivery.",
140 | "Vitamin D deficiency is unrelated to birth weight.",
141 | "Women with a higher birth weight are more likely to develop breast cancer later in life.",
142 | ]
143 | ```
144 |
145 | QueriesIds is a list of ids with respect to the order of queries:
146 |
147 | ```python
148 | [
149 | "0",
150 | "1",
151 | "2",
152 | "3",
153 | "4",
154 | "5",
155 | ]
156 | ```
157 |
158 | Qrels is the mapping between queries ids as key and dict of relevant documents with 1 as value:
159 |
160 | ```python
161 | {
162 | "1": {"document_0": 1},
163 | "3": {"document_10": 1},
164 | "5": {"document_5": 1},
165 | "13": {"document_22": 1},
166 | "36": {"document_23": 1, "document_0": 1},
167 | "42": {"document_2": 1},
168 | }
169 | ```
170 |
171 | ## Metrics
172 |
173 | We can evaluate our model with various metrics detailed [here](https://amenra.github.io/ranx/metrics/).
--------------------------------------------------------------------------------
/docs/existing_tree/.pages:
--------------------------------------------------------------------------------
1 | title: Existing tree
2 | nav:
3 | - existing_tree.md
4 |
--------------------------------------------------------------------------------
/docs/existing_tree/existing_tree.md:
--------------------------------------------------------------------------------
1 | # Build an index from an existing tree
2 |
3 | Neural-Tree can build a tree from an existing graph. This is useful when we have a specific use case where we want to retrieve the right leaf for a query.
4 |
5 | The tree we want to pass should follow some rules:
6 |
7 | - We should avoid nodes with a lot of children. The more children a node has, the more time it will take to explore this node.
8 |
9 | - A node must have only one parent. This is a rule for the tree to be a tree. You can somehow duplicate a node to have it in multiple places in the tree.
10 |
11 | Let's create a tree which has one root node, two children nodes and two leafs nodes which contains up to 3 documents.
12 |
13 | ```python
14 | graph = {
15 | "root": {
16 | "science": {
17 | "machine learning": [
18 | {"id": 0, "content": "bayern football team"},
19 | {"id": 1, "content": "toulouse rugby team"},
20 | ],
21 | "computer": [
22 | {"id": 2, "content": "Apple Macintosh"},
23 | {"id": 3, "content": "Microsoft Windows"},
24 | {"id": 4, "content": "Linux Ubuntu"},
25 | ],
26 | },
27 | "history": {
28 | "france": [
29 | {"id": 5, "content": "history of france"},
30 | {"id": 6, "content": "french revolution"},
31 | ],
32 | "italia": [
33 | {"id": 7, "content": "history of rome"},
34 | {"id": 8, "content": "history of venice"},
35 | ],
36 | },
37 | }
38 | }
39 | ```
40 |
41 | We can now initialize either a TfIdf, a SentenceTransformer or a ColBERT tree using the graph we have created.
42 |
43 | ```python
44 | from neural_tree import trees
45 | from neural_cherche import models
46 |
47 | model = models.ColBERT(
48 | model_name_or_path="raphaelsty/neural-cherche-colbert",
49 | device="cuda" if torch.cuda.is_available() else "cpu",
50 | )
51 |
52 | tree = trees.ColBERT(
53 | key="id",
54 | on=["content"],
55 | model=model,
56 | graph=graph,
57 | n_jobs=-1,
58 | )
59 |
60 | print(tree)
61 | ```
62 |
63 | This will output:
64 |
65 | ```python
66 | node root
67 | node science
68 | leaf computer
69 | leaf machine learning
70 | node history
71 | leaf france
72 | leaf italia
73 | ```
74 |
75 | Once we have created our tree we can export it back to json using the `tree.to_json()`:
76 |
77 | ```python
78 | {
79 | "root": {
80 | "science": {
81 | "computer": [{"id": 2}, {"id": 3}, {"id": 4}],
82 | "machine learning": [{"id": 0}, {"id": 1}],
83 | },
84 | "history": {"france": [{"id": 5}, {"id": 6}], "italia": [{"id": 7}, {"id": 8}]},
85 | }
86 | }
87 | ```
88 |
--------------------------------------------------------------------------------
/docs/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raphaelsty/neural-tree/c7fd61f6ae37931482b03ab2d569b0309bcd726b/docs/img/logo.png
--------------------------------------------------------------------------------
/docs/img/neural_tree.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raphaelsty/neural-tree/c7fd61f6ae37931482b03ab2d569b0309bcd726b/docs/img/neural_tree.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Neural-Tree
4 |
Neural Search
5 |
6 |
7 | 
8 |
9 |
10 |
11 |

12 |
13 |

14 |
15 |
16 | ## Installation
17 |
18 | We can install neural-tree using:
19 |
20 | ```
21 | pip install neural-tree
22 | ```
23 |
24 | If we plan to evaluate our model while training install:
25 |
26 | ```
27 | pip install "neural-tree[eval]"
28 | ```
--------------------------------------------------------------------------------
/docs/javascripts/config.js:
--------------------------------------------------------------------------------
1 | window.MathJax = {
2 | tex: {
3 | inlineMath: [["\\(", "\\)"]],
4 | displayMath: [["\\[", "\\]"]],
5 | processEscapes: true,
6 | processEnvironments: true
7 | },
8 | options: {
9 | ignoreHtmlClass: ".*|",
10 | processHtmlClass: "arithmatex"
11 | }
12 | };
--------------------------------------------------------------------------------
/docs/js/version-select.js:
--------------------------------------------------------------------------------
1 | window.addEventListener("DOMContentLoaded", function () {
2 | // This is a bit hacky. Figure out the base URL from a known CSS file the
3 | // template refers to...
4 | var ex = new RegExp("/?css/version-select.css$");
5 | var sheet = document.querySelector('link[href$="version-select.css"]');
6 |
7 | var ABS_BASE_URL = sheet.href.replace(ex, "");
8 | var CURRENT_VERSION = ABS_BASE_URL.split("/").pop();
9 |
10 | function makeSelect(options, selected) {
11 | var select = document.createElement("select");
12 | select.classList.add("form-control");
13 |
14 | options.forEach(function (i) {
15 | var option = new Option(i.text, i.value, undefined,
16 | i.value === selected);
17 | select.add(option);
18 | });
19 |
20 | return select;
21 | }
22 |
23 | var xhr = new XMLHttpRequest();
24 | xhr.open("GET", ABS_BASE_URL + "/../versions.json");
25 | xhr.onload = function () {
26 | var versions = JSON.parse(this.responseText);
27 |
28 | var realVersion = versions.find(function (i) {
29 | return i.version === CURRENT_VERSION ||
30 | i.aliases.includes(CURRENT_VERSION);
31 | }).version;
32 |
33 | var select = makeSelect(versions.map(function (i) {
34 | return { text: i.title, value: i.version };
35 | }), realVersion);
36 | select.addEventListener("change", function (event) {
37 | window.location.href = ABS_BASE_URL + "/../" + this.value;
38 | });
39 |
40 | var container = document.createElement("div");
41 | container.id = "version-selector";
42 | container.className = "md-nav__item";
43 | container.appendChild(select);
44 |
45 | var sidebar = document.querySelector(".md-nav--primary > .md-nav__list");
46 | sidebar.parentNode.insertBefore(container, sidebar);
47 | };
48 | xhr.send();
49 | });
--------------------------------------------------------------------------------
/docs/scripts/index.py:
--------------------------------------------------------------------------------
1 | """This script is responsible for building the API reference. The API reference is located in
2 | docs/api. The script scans through all the modules, classes, and functions. It processes
3 | the __doc__ of each object and formats it so that MkDocs can process it in turn.
4 | """
5 | import functools
6 | import importlib
7 | import inspect
8 | import os
9 | import pathlib
10 | import re
11 | import shutil
12 |
13 | from numpydoc.docscrape import ClassDoc, FunctionDoc
14 |
15 | package = "neural_tree"
16 |
17 | # shutil.copy("README.md", "docs/index.md")
18 |
19 |
20 | def paragraph(text):
21 | return f"{text}\n"
22 |
23 |
24 | def h1(text):
25 | return paragraph(f"# {text}")
26 |
27 |
28 | def h2(text):
29 | return paragraph(f"## {text}")
30 |
31 |
32 | def h3(text):
33 | return paragraph(f"### {text}")
34 |
35 |
36 | def h4(text):
37 | return paragraph(f"#### {text}")
38 |
39 |
40 | def link(caption, href):
41 | return f"[{caption}]({href})"
42 |
43 |
44 | def code(text):
45 | return f"`{text}`"
46 |
47 |
48 | def li(text):
49 | return f"- {text}\n"
50 |
51 |
52 | def snake_to_kebab(text):
53 | return text.replace("_", "-")
54 |
55 |
56 | def inherit_docstring(c, meth):
57 | """Since Python 3.5, inspect.getdoc is supposed to return the docstring from a parent class
58 | if a class has none. However this doesn't seem to work for Cython classes.
59 | """
60 |
61 | doc = None
62 |
63 | for ancestor in inspect.getmro(c):
64 | try:
65 | ancestor_meth = getattr(ancestor, meth)
66 | except AttributeError:
67 | break
68 | doc = inspect.getdoc(ancestor_meth)
69 | if doc:
70 | break
71 |
72 | return doc
73 |
74 |
75 | def inherit_signature(c, method_name):
76 | m = getattr(c, method_name)
77 | sig = inspect.signature(m)
78 |
79 | params = []
80 |
81 | for param in sig.parameters.values():
82 | if param.name == "self" or param.annotation is not param.empty:
83 | params.append(param)
84 | continue
85 |
86 | for ancestor in inspect.getmro(c):
87 | try:
88 | ancestor_meth = inspect.signature(getattr(ancestor, m.__name__))
89 | except AttributeError:
90 | break
91 | try:
92 | ancestor_param = ancestor_meth.parameters[param.name]
93 | except KeyError:
94 | break
95 | if ancestor_param.annotation is not param.empty:
96 | param = param.replace(annotation=ancestor_param.annotation)
97 | break
98 |
99 | params.append(param)
100 |
101 | return_annotation = sig.return_annotation
102 | if return_annotation is inspect._empty:
103 | for ancestor in inspect.getmro(c):
104 | try:
105 | ancestor_meth = inspect.signature(getattr(ancestor, m.__name__))
106 | except AttributeError:
107 | break
108 | if ancestor_meth.return_annotation is not inspect._empty:
109 | return_annotation = ancestor_meth.return_annotation
110 | break
111 |
112 | return sig.replace(parameters=params, return_annotation=return_annotation)
113 |
114 |
115 | def snake_to_kebab(snake: str) -> str:
116 | return snake.replace("_", "-")
117 |
118 |
119 | def pascal_to_kebab(string):
120 | string = re.sub("(.)([A-Z][a-z]+)", r"\1-\2", string)
121 | string = re.sub("(.)([0-9]+)", r"\1-\2", string)
122 | return re.sub("([a-z0-9])([A-Z])", r"\1-\2", string).lower()
123 |
124 |
125 | class Linkifier:
126 | def __init__(self):
127 | path_index = {}
128 | name_index = {}
129 |
130 | modules = {
131 | module: importlib.import_module(f"{package}.{module}")
132 | for module in importlib.import_module(f"{package}").__all__
133 | }
134 |
135 | def index_module(mod_name, mod, path):
136 | path = os.path.join(path, mod_name)
137 | dotted_path = path.replace("/", ".")
138 |
139 | for func_name, func in inspect.getmembers(mod, inspect.isfunction):
140 | for e in (
141 | f"{mod_name}.{func_name}",
142 | f"{dotted_path}.{func_name}",
143 | f"{func.__module__}.{func_name}",
144 | ):
145 | path_index[e] = os.path.join(path, snake_to_kebab(func_name))
146 | name_index[e] = f"{dotted_path}.{func_name}"
147 |
148 | for klass_name, klass in inspect.getmembers(mod, inspect.isclass):
149 | for e in (
150 | f"{mod_name}.{klass_name}",
151 | f"{dotted_path}.{klass_name}",
152 | f"{klass.__module__}.{klass_name}",
153 | ):
154 | path_index[e] = os.path.join(path, klass_name)
155 | name_index[e] = f"{dotted_path}.{klass_name}"
156 |
157 | for submod_name, submod in inspect.getmembers(mod, inspect.ismodule):
158 | if submod_name not in mod.__all__ or submod_name == "typing":
159 | continue
160 | for e in (f"{mod_name}.{submod_name}", f"{dotted_path}.{submod_name}"):
161 | path_index[e] = os.path.join(path, snake_to_kebab(submod_name))
162 |
163 | # Recurse
164 | index_module(submod_name, submod, path=path)
165 |
166 | for mod_name, mod in modules.items():
167 | index_module(mod_name, mod, path="")
168 |
169 | # Prepend {package} to each index entry
170 | for k in list(path_index.keys()):
171 | path_index[f"{package}.{k}"] = path_index[k]
172 | for k in list(name_index.keys()):
173 | name_index[f"{package}.{k}"] = name_index[k]
174 |
175 | self.path_index = path_index
176 | self.name_index = name_index
177 |
178 | def linkify(self, text, use_fences, depth):
179 | path = self.path_index.get(text)
180 | name = self.name_index.get(text)
181 | if path and name:
182 | backwards = "../" * (depth + 1)
183 | if use_fences:
184 | return f"[`{name}`]({backwards}{path})"
185 | return f"[{name}]({backwards}{path})"
186 | return None
187 |
188 | def linkify_fences(self, text, depth):
189 | between_fences = re.compile("`[\w\.]+\.\w+`")
190 | return between_fences.sub(
191 | lambda x: self.linkify(x.group().strip("`"), True, depth) or x.group(), text
192 | )
193 |
194 | def linkify_dotted(self, text, depth):
195 | dotted = re.compile("\w+\.[\.\w]+")
196 | return dotted.sub(
197 | lambda x: self.linkify(x.group(), False, depth) or x.group(), text
198 | )
199 |
200 |
201 | def concat_lines(lines):
202 | return inspect.cleandoc(" ".join("\n\n" if line == "" else line for line in lines))
203 |
204 |
205 | def print_docstring(obj, file, depth):
206 | """Prints a classes's docstring to a file."""
207 |
208 | doc = ClassDoc(obj) if inspect.isclass(obj) else FunctionDoc(obj)
209 |
210 | printf = functools.partial(print, file=file)
211 |
212 | printf(h1(obj.__name__))
213 | printf(linkifier.linkify_fences(paragraph(concat_lines(doc["Summary"])), depth))
214 | printf(
215 | linkifier.linkify_fences(
216 | paragraph(concat_lines(doc["Extended Summary"])), depth
217 | )
218 | )
219 |
220 | # We infer the type annotations from the signatures, and therefore rely on the signature
221 | # instead of the docstring for documenting parameters
222 | try:
223 | signature = inspect.signature(obj)
224 | except ValueError:
225 | signature = (
226 | inspect.Signature()
227 | ) # TODO: this is necessary for Cython classes, but it's not correct
228 | params_desc = {param.name: " ".join(param.desc) for param in doc["Parameters"]}
229 |
230 | # Parameters
231 | if signature.parameters:
232 | printf(h2("Parameters"))
233 | for param in signature.parameters.values():
234 | # Name
235 | printf(f"- **{param.name}**", end="")
236 | # Type annotation
237 | if param.annotation is not param.empty:
238 | anno = inspect.formatannotation(param.annotation)
239 | anno = linkifier.linkify_dotted(anno, depth)
240 | printf(f" (*{anno}*)", end="")
241 | # Default value
242 | if param.default is not param.empty:
243 | printf(f" – defaults to `{param.default}`", end="")
244 | printf("\n", file=file)
245 | # Description
246 | if param.name in params_desc:
247 | desc = params_desc[param.name]
248 | if desc:
249 | printf(f" {desc}\n")
250 | printf("")
251 |
252 | # Attributes
253 | if doc["Attributes"]:
254 | printf(h2("Attributes"))
255 | for attr in doc["Attributes"]:
256 | # Name
257 | printf(f"- **{attr.name}**", end="")
258 | # Type annotation
259 | if attr.type:
260 | printf(f" (*{attr.type}*)", end="")
261 | printf("\n", file=file)
262 | # Description
263 | desc = " ".join(attr.desc)
264 | if desc:
265 | printf(f" {desc}\n")
266 | printf("")
267 |
268 | # Examples
269 | if doc["Examples"]:
270 | printf(h2("Examples"))
271 |
272 | in_code = False
273 | after_space = False
274 |
275 | for line in inspect.cleandoc("\n".join(doc["Examples"])).splitlines():
276 | if (
277 | in_code
278 | and after_space
279 | and line
280 | and not line.startswith(">>>")
281 | and not line.startswith("...")
282 | ):
283 | printf("```\n")
284 | in_code = False
285 | after_space = False
286 |
287 | if not in_code and line.startswith(">>>"):
288 | printf("```python")
289 | in_code = True
290 |
291 | after_space = False
292 | if not line:
293 | after_space = True
294 |
295 | printf(line)
296 |
297 | if in_code:
298 | printf("```")
299 | printf("")
300 |
301 | # Methods
302 | if inspect.isclass(obj) and doc["Methods"]:
303 | printf(h2("Methods"))
304 | printf_indent = lambda x, **kwargs: printf(f" {x}", **kwargs)
305 |
306 | for meth in doc["Methods"]:
307 | printf(paragraph(f'???- note "{meth.name}"'))
308 |
309 | # Parse method docstring
310 | docstring = inherit_docstring(c=obj, meth=meth.name)
311 | if not docstring:
312 | continue
313 | meth_doc = FunctionDoc(func=None, doc=docstring)
314 |
315 | printf_indent(paragraph(" ".join(meth_doc["Summary"])))
316 | if meth_doc["Extended Summary"]:
317 | printf_indent(paragraph(" ".join(meth_doc["Extended Summary"])))
318 |
319 | # We infer the type annotations from the signatures, and therefore rely on the signature
320 | # instead of the docstring for documenting parameters
321 | signature = inherit_signature(obj, meth.name)
322 | params_desc = {
323 | param.name: " ".join(param.desc) for param in doc["Parameters"]
324 | }
325 |
326 | # Parameters
327 | if (
328 | len(signature.parameters) > 1
329 | ): # signature is never empty, but self doesn't count
330 | printf_indent("**Parameters**\n")
331 | for param in signature.parameters.values():
332 | if param.name == "self":
333 | continue
334 | # Name
335 | printf_indent(f"- **{param.name}**", end="")
336 | # Type annotation
337 | if param.annotation is not param.empty:
338 | printf_indent(
339 | f" (*{inspect.formatannotation(param.annotation)}*)", end=""
340 | )
341 | # Default value
342 | if param.default is not param.empty:
343 | printf_indent(f" – defaults to `{param.default}`", end="")
344 | printf_indent("", file=file)
345 | # Description
346 | desc = params_desc.get(param.name)
347 | if desc:
348 | printf_indent(f" {desc}")
349 | printf_indent("")
350 |
351 | # Returns
352 | if meth_doc["Returns"]:
353 | printf_indent("**Returns**\n")
354 | return_val = meth_doc["Returns"][0]
355 | if signature.return_annotation is not inspect._empty:
356 | if inspect.isclass(signature.return_annotation):
357 | printf_indent(
358 | f"*{signature.return_annotation.__name__}*: ", end=""
359 | )
360 | else:
361 | printf_indent(f"*{signature.return_annotation}*: ", end="")
362 | printf_indent(return_val.type)
363 | printf_indent("")
364 |
365 | # Notes
366 | if doc["Notes"]:
367 | printf(h2("Notes"))
368 | printf(paragraph("\n".join(doc["Notes"])))
369 |
370 | # References
371 | if doc["References"]:
372 | printf(h2("References"))
373 | printf(paragraph("\n".join(doc["References"])))
374 |
375 |
376 | def print_module(mod, path, overview, is_submodule=False):
377 | mod_name = mod.__name__.split(".")[-1]
378 |
379 | # Create a directory for the module
380 | mod_slug = snake_to_kebab(mod_name)
381 | mod_path = path.joinpath(mod_slug)
382 | mod_short_path = str(mod_path).replace("docs/api/", "")
383 | os.makedirs(mod_path, exist_ok=True)
384 | with open(mod_path.joinpath(".pages"), "w") as f:
385 | f.write(f"title: {mod_name}")
386 |
387 | # Add the module to the overview
388 | if is_submodule:
389 | print(h3(mod_name), file=overview)
390 | else:
391 | print(h2(mod_name), file=overview)
392 | if mod.__doc__:
393 | print(paragraph(mod.__doc__), file=overview)
394 |
395 | # Extract all public classes and functions
396 | ispublic = lambda x: x.__name__ in mod.__all__ and not x.__name__.startswith("_")
397 | classes = inspect.getmembers(mod, lambda x: inspect.isclass(x) and ispublic(x))
398 | funcs = inspect.getmembers(mod, lambda x: inspect.isfunction(x) and ispublic(x))
399 |
400 | # Classes
401 |
402 | if classes and funcs:
403 | print("\n**Classes**\n", file=overview)
404 |
405 | for _, c in classes:
406 | print(f"{mod_name}.{c.__name__}")
407 |
408 | # Add the class to the overview
409 | slug = snake_to_kebab(c.__name__)
410 | print(
411 | li(link(c.__name__, f"../{mod_short_path}/{slug}")), end="", file=overview
412 | )
413 |
414 | # Write down the class' docstring
415 | with open(mod_path.joinpath(slug).with_suffix(".md"), "w") as file:
416 | print_docstring(obj=c, file=file, depth=mod_short_path.count("/") + 1)
417 |
418 | # Functions
419 |
420 | if classes and funcs:
421 | print("\n**Functions**\n", file=overview)
422 |
423 | for _, f in funcs:
424 | print(f"{mod_name}.{f.__name__}")
425 |
426 | # Add the function to the overview
427 | slug = snake_to_kebab(f.__name__)
428 | print(
429 | li(link(f.__name__, f"../{mod_short_path}/{slug}")), end="", file=overview
430 | )
431 |
432 | # Write down the function' docstring
433 | with open(mod_path.joinpath(slug).with_suffix(".md"), "w") as file:
434 | print_docstring(obj=f, file=file, depth=mod_short_path.count(".") + 1)
435 |
436 | # Sub-modules
437 | for name, submod in inspect.getmembers(mod, inspect.ismodule):
438 | # We only want to go through the public submodules, such as optim.schedulers
439 | if (
440 | name in ("tags", "typing", "inspect", "skmultiflow_utils")
441 | or name not in mod.__all__
442 | or name.startswith("_")
443 | ):
444 | continue
445 | print_module(mod=submod, path=mod_path, overview=overview, is_submodule=True)
446 |
447 | print("", file=overview)
448 |
449 |
450 | if __name__ == "__main__":
451 | api_path = pathlib.Path("docs/api")
452 |
453 | # Create a directory for the API reference
454 | shutil.rmtree(api_path, ignore_errors=True)
455 | os.makedirs(api_path, exist_ok=True)
456 | with open(api_path.joinpath(".pages"), "w") as f:
457 | f.write("title: API reference\narrange:\n - overview.md\n - ...\n")
458 |
459 | overview = open(api_path.joinpath("overview.md"), "w")
460 | print(h1("Overview"), file=overview)
461 |
462 | linkifier = Linkifier()
463 |
464 | for mod_name, mod in inspect.getmembers(
465 | importlib.import_module(f"{package}"), inspect.ismodule
466 | ):
467 | if mod_name.startswith("_"):
468 | continue
469 | print(mod_name)
470 | print_module(mod, path=api_path, overview=overview)
471 |
--------------------------------------------------------------------------------
/docs/stylesheets/extra.css:
--------------------------------------------------------------------------------
1 | .md-typeset h2 {
2 | margin: 1.5em 0;
3 | padding-bottom: .4rem;
4 | border-bottom: .04rem solid var(--md-default-fg-color--lighter);
5 | }
6 |
7 | .md-footer {
8 | margin-top: 2em;
9 | }
10 |
11 | .md-typeset pre > code {
12 | border-radius: 0.5em;
13 | }
--------------------------------------------------------------------------------
/docs/trees/.pages:
--------------------------------------------------------------------------------
1 | title: Fine-tune
2 | nav:
3 | - colbert.md
4 | - sentence_transformer.md
5 | - tfidf.md
6 |
--------------------------------------------------------------------------------
/docs/trees/colbert.md:
--------------------------------------------------------------------------------
1 | # Colbert
2 |
3 | The ColBERT Tree utilizes hierarchical clustering with either TfIdf sparse vectors or SentenceTransformer embeddings to set up the tree's initial structure.
4 |
5 | Following the tree's initialization, documents are sorted into specific leaves according to the hierarchical clustering results. Subsequently, the average ColBERT embeddings of the documents are assigned to each node.
6 |
7 | During a query, the tree employs the ColBERT pre-trained model and its scoring function to identify the most relevant leaf or leaves for effective search results.
8 |
9 | ## Documents
10 |
11 | To create a tree-based index for ColBERT, we will need to:
12 |
13 | - Gather the whole set of documents we want to index.
14 | - Gather queries paired to documents.
15 | - Sample the training set in order to evaluate the index.
16 |
17 |
18 | ```python
19 | # Whole set of documents we want to index.
20 | documents = [
21 | {"id": 0, "content": "paris"},
22 | {"id": 1, "content": "london"},
23 | {"id": 2, "content": "berlin"},
24 | {"id": 3, "content": "rome"},
25 | {"id": 4, "content": "bordeaux"},
26 | {"id": 5, "content": "milan"},
27 | ]
28 |
29 | # Paired training documents
30 | train_documents = [
31 | {"id": 0, "content": "paris"},
32 | {"id": 1, "content": "london"},
33 | {"id": 2, "content": "berlin"},
34 | {"id": 3, "content": "rome"},
35 | ]
36 |
37 | # Paired training queries
38 | train_queries = [
39 | "paris is the capital of france",
40 | "london is the capital of england",
41 | "berlin is the capital of germany",
42 | "rome is the capital of italy",
43 | ]
44 | ```
45 |
46 | Let's train the index using the `documents`, `train_queries` and `train_documents` we have gathered.
47 |
48 | ```python
49 | import torch
50 | from neural_cherche import models
51 |
52 | from neural_tree import trees, utils
53 |
54 | model = models.ColBERT(
55 | model_name_or_path="raphaelsty/neural-cherche-colbert",
56 | device="cuda" if torch.cuda.is_available() else "cpu",
57 | )
58 |
59 | tree = trees.ColBERT(
60 | key="id", # The field to use as a key for the documents.
61 | on=["content"], # The fields to use for the model.
62 | model=model,
63 | documents=documents,
64 | leaf_balance_factor=100, # Minimum number of documents per leaf.
65 | branch_balance_factor=5, # Number of childs per node.
66 | n_jobs=-1, # We want to set it to 1 when using Google Colab.
67 | device="cuda" if torch.cuda.is_available() else "cpu",
68 | )
69 |
70 | optimizer = torch.optim.AdamW(lr=3e-3, params=list(tree.parameters()))
71 |
72 | for step, batch_queries, batch_documents in utils.iter(
73 | queries=train_queries,
74 | documents=train_documents,
75 | shuffle=True,
76 | epochs=50,
77 | batch_size=32,
78 | ):
79 | loss = tree.loss(
80 | queries=batch_queries,
81 | documents=batch_documents,
82 | )
83 |
84 | loss.backward()
85 | optimizer.step()
86 | optimizer.zero_grad(set_to_none=True)
87 | ```
88 |
89 | We can already use the `tree` to search for documents using the `tree` method.
90 |
91 | The `call` method of the tree outputs a dictionary containing several key pieces of information: the retrieved leaves under leafs, the score assigned to each leaf under scores, a record of the explored nodes and leaves along with their scores under tree_scores, and the documents retrieved for each query listed under documents.
92 |
93 | ```python
94 | tree(
95 | queries=["history"],
96 | k=10, # Number of documents to return for each query.
97 | k_leafs=1, # The number of leafs to return for each query.
98 | )
99 | ```
100 |
101 | ```python
102 | {
103 | "leafs": array([["10"]], dtype=" tuple[torch.Tensor, list, list]:
16 | """Replace KMeans clustering with average clustering when an existing graph is provided.
17 |
18 | Examples
19 | --------
20 | >>> from neural_tree import clustering, scoring
21 | >>> import numpy as np
22 |
23 | >>> documents = [
24 | ... {"id": 0, "text": "Paris is the capital of France."},
25 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
26 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
27 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
28 | ... ]
29 |
30 | >>> documents_embeddings = np.array([
31 | ... [1, 1],
32 | ... [1, 2],
33 | ... [10, 10],
34 | ... [1, 3],
35 | ... ])
36 |
37 | >>> graph = {1: {11: {111: [{'id': 0}, {'id': 3}], 112: [{'id': 1}]}, 12: {121: [{'id': 2}], 122: [{'id': 3}]}}}
38 |
39 | >>> clustering.average(
40 | ... key="id",
41 | ... documents_embeddings=documents_embeddings,
42 | ... documents=documents,
43 | ... graph=graph[1],
44 | ... scoring=scoring.SentenceTransformer(key="id", on=["text"], model=None),
45 | ... )
46 |
47 | """
48 | mapping_documents_embeddings = {
49 | document[key]: embedding
50 | for document, embedding in zip(documents, documents_embeddings)
51 | }
52 |
53 | mapping_nodes_documents = {
54 | node: get_mapping_nodes_documents(graph=graph[node]) for node in graph.keys()
55 | }
56 |
57 | mappings_nodes_embeddings = {
58 | node: scoring.average(
59 | scoring.stack(
60 | [
61 | mapping_documents_embeddings[document[key]]
62 | for document in node_documents
63 | ]
64 | )
65 | )
66 | for node, node_documents in mapping_nodes_documents.items()
67 | }
68 |
69 | mapping_documents_ids = {document[key]: document for document in documents}
70 |
71 | mappings_nodes_embeddings = list(mappings_nodes_embeddings.values())
72 |
73 | if isinstance(mappings_nodes_embeddings[0], np.ndarray):
74 | node_embeddings = torch.tensor(
75 | data=np.stack(arrays=mappings_nodes_embeddings),
76 | device=device,
77 | dtype=torch.float32,
78 | requires_grad=True,
79 | )
80 | else:
81 | node_embeddings = torch.stack(tensors=mappings_nodes_embeddings, dim=0)
82 | node_embeddings = node_embeddings.to(device=device)
83 | node_embeddings.requires_grad = True
84 |
85 | extended_documents, extended_documents_embeddings, labels = [], [], []
86 | for node, node_documents in mapping_nodes_documents.items():
87 | extended_documents.extend(
88 | [mapping_documents_ids[document[key]] for document in node_documents]
89 | )
90 | extended_documents_embeddings.extend(
91 | [mapping_documents_embeddings[document[key]] for document in node_documents]
92 | )
93 | labels.extend([node] * len(node_documents))
94 |
95 | return (
96 | node_embeddings,
97 | labels,
98 | extended_documents,
99 | scoring.stack(extended_documents_embeddings),
100 | )
101 |
102 |
103 | def get_mapping_nodes_documents(graph: dict | list, documents: list | None = None):
104 | """Get documents from specific node."""
105 | if documents is None:
106 | documents = []
107 |
108 | if isinstance(graph, list):
109 | documents.extend(graph)
110 | return documents
111 |
112 | for node, child in graph.items():
113 | if isinstance(child, dict):
114 | documents = get_mapping_nodes_documents(graph=child, documents=documents)
115 | else:
116 | documents.extend(child)
117 | return documents
118 |
--------------------------------------------------------------------------------
/neural_tree/clustering/kmeans.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sklearn import cluster
3 |
4 | __all__ = ["KMeans"]
5 |
6 |
7 | def KMeans(
8 | documents_embeddings: torch.Tensor,
9 | n_clusters: int,
10 | max_iter: int,
11 | n_init: int,
12 | seed: int,
13 | device: str,
14 | ) -> tuple[torch.Tensor, list]:
15 | """KMeans clustering."""
16 | kmeans: cluster.KMeans = cluster.KMeans(
17 | n_clusters=n_clusters,
18 | max_iter=max_iter,
19 | n_init=n_init,
20 | random_state=seed,
21 | ).fit(X=documents_embeddings)
22 |
23 | node_embeddings = torch.tensor(
24 | data=kmeans.cluster_centers_,
25 | device=device,
26 | dtype=torch.float32,
27 | requires_grad=True,
28 | )
29 |
30 | return (
31 | node_embeddings,
32 | kmeans.labels_,
33 | )
34 |
--------------------------------------------------------------------------------
/neural_tree/clustering/optimize.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import numpy as np
4 | from cherche import retrieve
5 | from scipy import sparse
6 | from scipy.sparse import csr_matrix, dok_matrix
7 |
8 | __all__ = ["optimize_leafs"]
9 |
10 |
11 | def create_sparse_matrix_retriever(
12 | candidates: list[list], mapping_documents: dict, key: str
13 | ) -> csr_matrix:
14 | """Build a sparse matrix (queries, documents) with 1 when document is relevant to a
15 | query."""
16 | query_documents_mapping = collections.defaultdict(list)
17 | for query, query_candidates in enumerate(iterable=candidates):
18 | for document in query_candidates:
19 | query_documents_mapping[query].append(mapping_documents[document[key]])
20 |
21 | query_documents_matrix = dok_matrix(
22 | arg1=(len(candidates), len(mapping_documents)), dtype=np.int8
23 | )
24 |
25 | for query, query_documents in query_documents_mapping.items():
26 | query_documents_matrix[query, query_documents] = 1
27 |
28 | return query_documents_matrix.tocsr()
29 |
30 |
31 | def create_sparse_matrix_tree(
32 | candidates_tree: np.ndarray,
33 | ) -> tuple[sparse.csr_matrix, dict[int, list[int]]]:
34 | """Build a sparse matrix (queries, leafs) with 1 when leaf is relevant to a query."""
35 | leafs_to_queries = collections.defaultdict(list)
36 |
37 | for query, query_leafs in enumerate(iterable=candidates_tree):
38 | for leaf in query_leafs.tolist():
39 | leafs_to_queries[leaf].append(query)
40 |
41 | query_leafs_matrix = dok_matrix(
42 | arg1=(len(candidates_tree), len(leafs_to_queries)), dtype=np.int8
43 | )
44 |
45 | mapping_leafs = {
46 | leaf: index for index, leaf in enumerate(iterable=leafs_to_queries)
47 | }
48 |
49 | for leaf, leaf_queries in leafs_to_queries.items():
50 | for query in leaf_queries:
51 | query_leafs_matrix[query, mapping_leafs[leaf]] = 1
52 |
53 | return query_leafs_matrix.tocsr(), {
54 | index: leaf for leaf, index in mapping_leafs.items()
55 | }
56 |
57 |
58 | def top_k(similarities, k: int):
59 | """Return the top k documents for each query."""
60 | similarities *= -1
61 | matchs = []
62 | for row in similarities:
63 | _k = min(row.data.shape[0] - 1, k)
64 | ind = np.argpartition(a=row.data, kth=_k, axis=0)[:k]
65 | similarity = np.take_along_axis(arr=row.data, indices=ind, axis=0)
66 | indices = np.take_along_axis(arr=row.indices, indices=ind, axis=0)
67 | ind = np.argsort(a=similarity, axis=0)
68 | matchs.append(np.take_along_axis(arr=indices, indices=ind, axis=0))
69 | return matchs
70 |
71 |
72 | def optimize_leafs(
73 | tree,
74 | documents: list[dict],
75 | queries: list[str],
76 | k_tree: int = 2,
77 | k_retriever: int = 10,
78 | k_leafs: int = 2,
79 | **kwargs,
80 | ) -> dict:
81 | """Optimize the clusters."""
82 | mapping_documents = {
83 | document[tree.key]: index for index, document in enumerate(iterable=documents)
84 | }
85 |
86 | retriever = retrieve.TfIdf(key=tree.key, on=tree.scoring.on, documents=documents)
87 | query_documents_matrix = create_sparse_matrix_retriever(
88 | candidates=retriever(q=queries, k=k_retriever, batch_size=512, tqdm_bar=False),
89 | mapping_documents=mapping_documents,
90 | key=tree.key,
91 | )
92 |
93 | inverse_mapping_document = {
94 | index: document for document, index in mapping_documents.items()
95 | }
96 |
97 | query_leafs_matrix, inverse_mapping_leafs = create_sparse_matrix_tree(
98 | candidates_tree=tree(
99 | queries=queries,
100 | k=k_tree,
101 | score_documents=False,
102 | **kwargs,
103 | )["leafs"]
104 | )
105 |
106 | documents_to_leafs = collections.defaultdict(list)
107 | for document, leafs in enumerate(
108 | iterable=top_k(
109 | similarities=query_documents_matrix.T @ query_leafs_matrix,
110 | k=k_leafs,
111 | )
112 | ):
113 | for leaf in leafs.tolist():
114 | documents_to_leafs[inverse_mapping_document[document]].append(
115 | inverse_mapping_leafs[leaf]
116 | )
117 |
118 | return documents_to_leafs
119 |
--------------------------------------------------------------------------------
/neural_tree/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .beir import load_beir, load_beir_test, load_beir_train
2 |
3 | __all__ = ["load_beir", "load_beir_train", "load_beir_test"]
4 |
--------------------------------------------------------------------------------
/neural_tree/datasets/beir.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | __all__ = ["load_beir", "load_beir_train", "load_beir_test"]
4 |
5 |
6 | def _make_pairs(queries: dict, qrels: dict) -> tuple[list, list]:
7 | """Make pairs of queries and documents for training."""
8 | test_queries, test_documents = [], []
9 | for query, (_, documents_queries) in zip(queries, qrels.items()):
10 | for document_id in documents_queries:
11 | test_queries.append(query)
12 | test_documents.append({"id": document_id})
13 | return test_queries, test_documents
14 |
15 |
16 | def load_beir(dataset_name: str, split: str) -> tuple:
17 | """Load BEIR dataset."""
18 | from beir import util
19 | from beir.datasets.data_loader import GenericDataLoader
20 |
21 | path = f"./beir_datasets/{dataset_name}"
22 | if not os.path.isdir(s=path):
23 | path = util.download_and_unzip(
24 | url=f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip",
25 | out_dir="./beir_datasets/",
26 | )
27 |
28 | documents, queries, qrels = GenericDataLoader(data_folder=path).load(split=split)
29 |
30 | documents = [
31 | {
32 | "id": document_id,
33 | "title": document["title"],
34 | "text": document["text"],
35 | }
36 | for document_id, document in documents.items()
37 | ]
38 |
39 | return documents, queries, qrels
40 |
41 |
42 | def load_beir_train(dataset_name: str) -> tuple[list, list, list]:
43 | """Load training dataset.
44 |
45 | Parameters
46 | ----------
47 | dataset_name
48 | Dataset name
49 |
50 | Examples
51 | --------
52 | >>> from neural_tree import datasets
53 |
54 | >>> documents, train_queries, train_documents = datasets.load_beir_train(
55 | ... dataset_name="scifact",
56 | ... )
57 |
58 | >>> len(documents)
59 | 5183
60 |
61 | >>> assert len(train_queries) == len(train_documents)
62 |
63 | """
64 | documents, queries, qrels = load_beir(dataset_name=dataset_name, split="train")
65 |
66 | train_queries, train_documents = _make_pairs(
67 | queries=list(queries.values()), qrels=qrels
68 | )
69 |
70 | return documents, train_queries, train_documents
71 |
72 |
73 | def load_beir_test(dataset_name: str) -> tuple[list, list, dict]:
74 | """Load BEIR testing dataset.
75 |
76 | Parameters
77 | ----------
78 | dataset_name
79 | Dataset name.
80 |
81 | Examples
82 | --------
83 | >>> from neural_tree import datasets
84 |
85 | >>> documents, queries_ids, queries, qrels = datasets.load_beir_test(
86 | ... dataset_name="scifact",
87 | ... )
88 |
89 | >>> len(documents)
90 | 5183
91 |
92 | >>> assert len(queries_ids) == len(queries) == len(qrels)
93 | """
94 | documents, queries, qrels = load_beir(dataset_name=dataset_name, split="test")
95 | return documents, list(queries.keys()), list(queries.values()), qrels
96 |
--------------------------------------------------------------------------------
/neural_tree/leafs/__init__.py:
--------------------------------------------------------------------------------
1 | from .leaf import Leaf
2 |
3 | __all__ = ["Leaf"]
4 |
--------------------------------------------------------------------------------
/neural_tree/leafs/leaf.py:
--------------------------------------------------------------------------------
1 | __all__ = ["Leaf"]
2 |
3 | import collections
4 |
5 | import torch
6 |
7 | from ..scoring import SentenceTransformer, TfIdf
8 |
9 |
10 | class Leaf(torch.nn.Module):
11 | """Leaf class."""
12 |
13 | def __init__(
14 | self,
15 | key: str,
16 | level: int,
17 | documents: list,
18 | documents_embeddings: list,
19 | node_name: int,
20 | scoring: TfIdf | SentenceTransformer,
21 | parent: int = 0,
22 | create_retrievers: bool = True,
23 | **kwargs,
24 | ) -> None:
25 | super(Leaf, self).__init__()
26 | self.key = key
27 | self.level = level
28 | self.node_name = node_name
29 | self.parent = parent
30 | self.create_retrievers = create_retrievers
31 |
32 | self.documents = {}
33 |
34 | if self.create_retrievers:
35 | self.retriever = scoring.get_retriever()
36 |
37 | if scoring.distinct_documents_encoder:
38 | documents_embeddings = None
39 | elif self.create_retrievers:
40 | documents_embeddings = {
41 | document[self.key]: embedding
42 | for document, embedding in zip(documents, documents_embeddings)
43 | }
44 |
45 | self.add(
46 | scoring=scoring,
47 | documents=documents,
48 | documents_embeddings=documents_embeddings,
49 | )
50 |
51 | def __str__(self) -> str:
52 | """String representation of a leaf."""
53 | sep = "\t"
54 | return f"{self.level * sep} leaf {self.node_name}"
55 |
56 | def add(
57 | self,
58 | scoring: SentenceTransformer | TfIdf,
59 | documents: list,
60 | documents_embeddings: dict | None = None,
61 | ) -> "Leaf":
62 | """Add document to the leaf."""
63 | if not self.create_retrievers:
64 | # If we don't want to create retrievers for the leaves
65 | for document in documents:
66 | self.documents[document[self.key]] = True
67 | return self
68 |
69 | if documents_embeddings is None:
70 | documents_embeddings = self.retriever.encode_documents(
71 | documents=documents,
72 | model=scoring.model,
73 | )
74 |
75 | documents_embeddings = {
76 | document: embedding
77 | for document, embedding in documents_embeddings.items()
78 | if document not in self.documents
79 | }
80 |
81 | if not documents_embeddings:
82 | return self
83 |
84 | self.retriever.add(
85 | documents_embeddings=documents_embeddings,
86 | )
87 |
88 | for document in documents:
89 | self.documents[document[self.key]] = True
90 |
91 | return self
92 |
93 | def nodes_scores(
94 | self,
95 | scoring: SentenceTransformer | TfIdf,
96 | queries_embeddings: torch.Tensor,
97 | node_embedding: torch.Tensor,
98 | ) -> torch.Tensor:
99 | """Compute the scores between the queries and the leaf."""
100 | return scoring.leaf_scores(
101 | queries_embeddings=queries_embeddings, leaf_embedding=node_embedding
102 | )
103 |
104 | def __call__(
105 | self,
106 | queries_embeddings,
107 | k: int,
108 | ) -> torch.Tensor:
109 | """Return scores between query and leaf documents."""
110 | if not self.documents:
111 | return [[] for _ in range(len(queries_embeddings))]
112 |
113 | candidates = self.retriever(
114 | queries_embeddings=queries_embeddings,
115 | tqdm_bar=False,
116 | k=k,
117 | )
118 |
119 | return [
120 | [{**document, "leaf": self.node_name} for document in query_documents]
121 | for query_documents in candidates
122 | ]
123 |
124 | def search(
125 | self,
126 | tree_scores: collections.defaultdict,
127 | **kwargs,
128 | ) -> tuple[torch.Tensor, list]:
129 | """Return the documents in the leaf."""
130 | return tree_scores
131 |
132 | def to_json(self) -> dict:
133 | """Return the leaf as a json."""
134 | return [{self.key: document} for document in self.documents]
135 |
--------------------------------------------------------------------------------
/neural_tree/nodes/__init__.py:
--------------------------------------------------------------------------------
1 | from .node import Node
2 |
3 | __all__ = ["Node"]
4 |
--------------------------------------------------------------------------------
/neural_tree/nodes/node.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import torch
4 | from joblib import Parallel, delayed
5 |
6 | from ..clustering import KMeans, average
7 | from ..leafs import Leaf
8 | from ..scoring import SentenceTransformer, TfIdf
9 |
10 | __all__ = ["Node"]
11 |
12 |
13 | class Node(torch.nn.Module):
14 | """Node of the tree."""
15 |
16 | def __init__(
17 | self,
18 | level: int,
19 | key: str,
20 | documents_embeddings: torch.Tensor,
21 | documents: list,
22 | leaf_balance_factor: int,
23 | branch_balance_factor: int,
24 | device: str,
25 | node_name: int | str,
26 | scoring: SentenceTransformer | TfIdf,
27 | seed: int,
28 | max_iter: int,
29 | n_init: int,
30 | parent: int,
31 | n_jobs: int,
32 | create_retrievers: bool,
33 | graph: dict | None,
34 | ) -> None:
35 | super(Node, self).__init__()
36 | self.level = level
37 | self.leaf_balance_factor = leaf_balance_factor
38 | self.branch_balance_factor = branch_balance_factor
39 | self.device = device
40 | self.seed = seed
41 | self.parent = parent
42 | self.node_name = node_name
43 |
44 | if graph is not None:
45 | self.nodes_embeddings, labels, documents, documents_embeddings = average(
46 | key=key,
47 | documents=documents,
48 | documents_embeddings=documents_embeddings,
49 | graph=graph,
50 | scoring=scoring,
51 | device=self.device,
52 | )
53 | else:
54 | self.nodes_embeddings, labels = KMeans(
55 | documents_embeddings=documents_embeddings,
56 | n_clusters=self.branch_balance_factor,
57 | max_iter=max_iter,
58 | n_init=n_init,
59 | seed=self.seed,
60 | device=self.device,
61 | )
62 |
63 | clusters = collections.defaultdict(list)
64 | for document, embedding, group in zip(documents, documents_embeddings, labels):
65 | clusters[group].append((document, embedding))
66 |
67 | if n_jobs == 1:
68 | self.childs = [
69 | self.create_child(
70 | level=self.level + 1,
71 | node_name=f"{self.node_name}{group}" if graph is None else group,
72 | key=key,
73 | documents=[document for document, _ in clusters[group]],
74 | documents_embeddings=[
75 | embedding for _, embedding in clusters[group]
76 | ],
77 | scoring=scoring,
78 | max_iter=max_iter,
79 | n_init=n_init,
80 | create_retrievers=create_retrievers,
81 | graph=graph[group] if graph is not None else None,
82 | n_jobs=n_jobs,
83 | seed=self.seed,
84 | )
85 | for group in sorted(
86 | clusters, key=lambda key: len(clusters[key]), reverse=True
87 | )
88 | ]
89 | else:
90 | self.childs = Parallel(n_jobs=n_jobs)(
91 | delayed(function=self.create_child)(
92 | level=self.level + 1,
93 | node_name=f"{self.node_name}{group}" if graph is None else group,
94 | key=key,
95 | documents=[document for document, _ in clusters[group]],
96 | documents_embeddings=[
97 | embedding for _, embedding in clusters[group]
98 | ],
99 | scoring=scoring,
100 | max_iter=max_iter,
101 | n_init=n_init,
102 | create_retrievers=create_retrievers,
103 | graph=graph[group] if graph is not None else None,
104 | n_jobs=n_jobs,
105 | seed=self.seed,
106 | )
107 | for group in sorted(
108 | clusters, key=lambda key: len(clusters[key]), reverse=True
109 | )
110 | )
111 |
112 | def create_child(
113 | self,
114 | level: int,
115 | node_name: str,
116 | key: str,
117 | documents: list,
118 | documents_embeddings: list,
119 | scoring: SentenceTransformer | TfIdf,
120 | max_iter: int,
121 | n_init: int,
122 | create_retrievers: bool,
123 | graph: dict | list | None,
124 | n_jobs: int,
125 | seed: int,
126 | ) -> None:
127 | """Create a child."""
128 | child = Leaf if len(documents) <= self.leaf_balance_factor else Node
129 | if graph is not None and isinstance(graph, list):
130 | child = Leaf
131 |
132 | child = child(
133 | level=level,
134 | node_name=node_name,
135 | key=key,
136 | scoring=scoring,
137 | documents=documents,
138 | documents_embeddings=scoring.stack(embeddings=documents_embeddings),
139 | leaf_balance_factor=self.leaf_balance_factor,
140 | branch_balance_factor=self.branch_balance_factor,
141 | device=self.device,
142 | seed=seed,
143 | max_iter=max_iter,
144 | n_init=n_init,
145 | parent=self.node_name,
146 | create_retrievers=create_retrievers,
147 | graph=graph,
148 | n_jobs=n_jobs,
149 | )
150 | return child
151 |
152 | def __str__(self) -> str:
153 | """String representation of a"""
154 | sep = "\t"
155 | return f"{self.level * sep} node {self.node_name}"
156 |
157 | def nodes_scores(
158 | self,
159 | scoring: SentenceTransformer | TfIdf,
160 | queries_embeddings: torch.Tensor,
161 | **kwargs,
162 | ) -> torch.Tensor:
163 | """Return the scores of the embeddings."""
164 | return scoring.nodes_scores(
165 | queries_embeddings=queries_embeddings,
166 | nodes_embeddings=self.nodes_embeddings,
167 | )
168 |
169 | def get_childs_and_scores(
170 | self,
171 | queries: list,
172 | scores: torch.Tensor,
173 | tree_scores: collections.defaultdict,
174 | paths: list | None,
175 | k: int,
176 | ) -> tuple[torch.Tensor, torch.Tensor]:
177 | """Return the childs and scores given matrix of scores."""
178 | if paths is None:
179 | scores = torch.stack(tensors=scores, dim=1)
180 | scores, childs = torch.topk(input=scores, k=min(k, scores.shape[1]), dim=1)
181 | return childs, scores
182 |
183 | # If paths is not None, we go through the choosen path.
184 | path = [query_path.pop(0) if query_path else leaf for leaf, query_path in paths]
185 |
186 | child_node_names = []
187 | for node_name in path:
188 | for index, child in enumerate(iterable=self.childs):
189 | if node_name == child.node_name:
190 | child_node_names.append(index)
191 | break
192 |
193 | childs = torch.tensor(
194 | data=child_node_names,
195 | dtype=torch.long,
196 | device=self.device,
197 | ).unsqueeze(dim=1)
198 |
199 | scores = torch.stack(
200 | tensors=[tree_scores[query][node] for query, node in zip(queries, path)],
201 | dim=0,
202 | ).unsqueeze(dim=1)
203 |
204 | return childs, scores
205 |
206 | def search(
207 | self,
208 | queries: list[str],
209 | queries_embeddings: torch.Tensor,
210 | scoring: SentenceTransformer | TfIdf,
211 | k: int,
212 | beam_search_depth: int,
213 | paths: dict[list] | None = None,
214 | tree_scores: collections.defaultdict | None = None,
215 | **kwargs,
216 | ) -> tuple[torch.Tensor, list]:
217 | """Search for the closest embedding.
218 |
219 | Parameters
220 | ----------
221 | queries:
222 | Queries to search for.
223 | embeddings:
224 | Embeddings to search for.
225 | tree_scores:
226 | Dictionnary of already computed scores in the tree.
227 | documents:
228 | Documents add to the leafs.
229 | k:
230 | node_name of closest embeddings to return.
231 | paths:
232 | Paths to explore.
233 | """
234 | # We go through the choosen path and we do not explore the tree if paths.
235 | if paths is not None:
236 | k = 1
237 |
238 | # Store childs scores:
239 | if tree_scores is None:
240 | tree_scores = collections.defaultdict(dict)
241 |
242 | scores = []
243 | for index, node in enumerate(iterable=self.childs):
244 | score = node.nodes_scores(
245 | scoring=scoring,
246 | queries_embeddings=queries_embeddings,
247 | node_embedding=self.nodes_embeddings[index],
248 | )
249 |
250 | scores.append(score)
251 | for query, query_score in zip(queries, score):
252 | tree_scores[query][node.node_name] = query_score
253 |
254 | childs, scores = self.get_childs_and_scores(
255 | queries=queries,
256 | scores=scores,
257 | tree_scores=tree_scores,
258 | paths=paths,
259 | k=k if self.level == beam_search_depth else 1,
260 | )
261 |
262 | # Aggregate embeddings by child.
263 | index_embeddings, index_queries, index_paths = (
264 | collections.defaultdict(list),
265 | collections.defaultdict(list),
266 | collections.defaultdict(list),
267 | )
268 |
269 | for index, query_childs in enumerate(iterable=childs):
270 | for child in query_childs.tolist():
271 | index_embeddings[child].append(queries_embeddings[index])
272 | index_queries[child].append(queries[index])
273 |
274 | if paths is not None:
275 | index_paths[child].append(paths[index])
276 |
277 | index_paths = dict(index_paths)
278 |
279 | for (child, embeddings), (_, queries_child) in zip(
280 | index_embeddings.items(),
281 | index_queries.items(),
282 | ):
283 | tree_scores = self.childs[child].search(
284 | queries=queries_child,
285 | queries_embeddings=torch.stack(tensors=embeddings, axis=0),
286 | scoring=scoring,
287 | tree_scores=tree_scores,
288 | paths=index_paths[child] if child in index_paths else None,
289 | k=k,
290 | beam_search_depth=beam_search_depth,
291 | )
292 |
293 | return tree_scores
294 |
295 | def to_json(self) -> dict:
296 | return {child.node_name: child.to_json() for child in self.childs}
297 |
--------------------------------------------------------------------------------
/neural_tree/retrievers/__init__.py:
--------------------------------------------------------------------------------
1 | from .colbert import ColBERT
2 | from .sentence_transformer import SentenceTransformer
3 | from .tfidf import TfIdf
4 |
5 | __all__ = ["ColBERT", "SentenceTransformer", "TfIdf"]
6 |
--------------------------------------------------------------------------------
/neural_tree/retrievers/colbert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from neural_cherche import rank
3 |
4 | from .. import utils
5 |
6 | __all__ = ["ColBERT"]
7 |
8 |
9 | class ColBERT(rank.ColBERT):
10 | """ColBERT retriever."""
11 |
12 | def __init__(
13 | self,
14 | key: str,
15 | on: str | list[str],
16 | device: str,
17 | ) -> None:
18 | self.key = key
19 | self.on = on if isinstance(on, list) else [on]
20 | self.embeddings = {}
21 | self.documents = []
22 | self.device = device
23 |
24 | def add(self, documents_embeddings: dict) -> "ColBERT":
25 | """Add documents embeddings."""
26 | documents_embeddings = {
27 | document_id: embedding
28 | for document_id, embedding in documents_embeddings.items()
29 | if document_id not in self.embeddings
30 | }
31 |
32 | self.embeddings.update(documents_embeddings)
33 | self.documents.extend(
34 | [{self.key: document_id} for document_id in documents_embeddings.keys()]
35 | )
36 |
37 | return self
38 |
39 | def __call__(
40 | self,
41 | queries_embeddings: dict[str, torch.Tensor],
42 | batch_size: int = 32,
43 | k: int = None,
44 | tqdm_bar: bool = False,
45 | ) -> list[list[str]]:
46 | """Rank documents givent queries.
47 |
48 | Parameters
49 | ----------
50 | queries
51 | Queries.
52 | documents
53 | Documents.
54 | queries_embeddings
55 | Queries embeddings.
56 | batch_size
57 | Batch size.
58 | tqdm_bar
59 | Show tqdm bar.
60 | k
61 | Number of documents to retrieve.
62 | """
63 | scores = []
64 |
65 | for query, embedding_query in queries_embeddings.items():
66 | query_scores = []
67 |
68 | embedding_query = embedding_query.to(device=self.device)
69 |
70 | for batch_documents in utils.batchify(
71 | X=self.documents,
72 | batch_size=batch_size,
73 | tqdm_bar=tqdm_bar,
74 | ):
75 | embeddings_batch_documents = torch.stack(
76 | tensors=[
77 | self.embeddings[document[self.key]]
78 | for document in batch_documents
79 | ],
80 | dim=0,
81 | )
82 |
83 | query_documents_scores = torch.einsum(
84 | "sh,bth->bst",
85 | embedding_query,
86 | embeddings_batch_documents,
87 | )
88 |
89 | query_scores.append(
90 | query_documents_scores.max(dim=2).values.sum(axis=1)
91 | )
92 |
93 | scores.append(torch.cat(tensors=query_scores, dim=0))
94 |
95 | return self._rank(scores=scores, documents=self.documents, k=k)
96 |
97 | def _rank(
98 | self, scores: torch.Tensor, documents: list[dict], k: int
99 | ) -> list[list[dict]]:
100 | """Rank documents by scores.
101 |
102 | Parameters
103 | ----------
104 | scores
105 | Scores.
106 | documents
107 | Documents.
108 | k
109 | Number of documents to retrieve.
110 | """
111 | ranked = []
112 |
113 | for query_scores in scores:
114 | top_k = torch.topk(
115 | input=query_scores,
116 | k=min(k, len(self.documents)) if k is not None else len(self.documents),
117 | dim=-1,
118 | )
119 |
120 | ranked.append(
121 | [
122 | {**self.documents[indice], "similarity": similarity}
123 | for indice, similarity in zip(top_k.indices, top_k.values.tolist())
124 | ]
125 | )
126 |
127 | return ranked
128 |
--------------------------------------------------------------------------------
/neural_tree/retrievers/sentence_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import numpy as np
4 |
5 | __all__ = ["SentenceTransformer"]
6 |
7 |
8 | class SentenceTransformer:
9 | """Sentence Transformer retriever.
10 |
11 | Examples
12 | --------
13 | >>> from neural_tree import retrievers
14 | >>> from sentence_transformers import SentenceTransformer
15 | >>> from pprint import pprint
16 |
17 | >>> model = SentenceTransformer("all-mpnet-base-v2")
18 |
19 | >>> retriever = retrievers.SentenceTransformer(key="id")
20 |
21 | >>> retriever = retriever.add(
22 | ... documents_embeddings={
23 | ... 0: model.encode("Paris is the capital of France."),
24 | ... 1: model.encode("Berlin is the capital of Germany."),
25 | ... 2: model.encode("Paris and Berlin are European cities."),
26 | ... 3: model.encode("Paris and Berlin are beautiful cities."),
27 | ... }
28 | ... )
29 |
30 | >>> queries_embeddings = {
31 | ... 0: model.encode("Paris"),
32 | ... 1: model.encode("Berlin"),
33 | ... }
34 |
35 | >>> candidates = retriever(queries_embeddings=queries_embeddings, k=2)
36 | >>> pprint(candidates)
37 | [[{'id': 0, 'similarity': 0.644777984318611},
38 | {'id': 3, 'similarity': 0.52865785276988}],
39 | [{'id': 1, 'similarity': 0.6901492368348436},
40 | {'id': 3, 'similarity': 0.5457692206973245}]]
41 |
42 | """
43 |
44 | def __init__(self, key: str, device: str = "cpu") -> None:
45 | self.key = key
46 | self.device = device
47 | self.index = None
48 | self.documents = []
49 |
50 | def _build(self, embeddings: np.ndarray) -> Any:
51 | """Build faiss index.
52 |
53 | Parameters
54 | ----------
55 | index
56 | faiss index.
57 | embeddings
58 | Embeddings of the documents.
59 |
60 | """
61 | if self.index is None:
62 | try:
63 | import faiss
64 | except:
65 | raise ImportError(
66 | 'Run pip install "neural-tree[cpu]" or pip install "neural-tree[gpu]" to run faiss on cpu / gpu.'
67 | )
68 | self.index = faiss.IndexFlatL2(embeddings.shape[1])
69 | if self.device == "cuda":
70 | try:
71 | self.index = faiss.index_cpu_to_gpu(
72 | faiss.StandardGpuResources(), 0, self.index
73 | )
74 | except:
75 | raise ImportError(
76 | 'Run pip install "neural-tree[gpu]" to run faiss on gpu.'
77 | )
78 |
79 | if not self.index.is_trained and embeddings:
80 | self.index.train(embeddings)
81 |
82 | self.index.add(embeddings)
83 | return self.index
84 |
85 | def add(self, documents_embeddings: dict[int, np.ndarray]) -> "SentenceTransformer":
86 | """Add documents to the faiss index."""
87 | self.documents.extend(list(documents_embeddings.keys()))
88 | self.index = self._build(
89 | embeddings=np.array(object=list(documents_embeddings.values()))
90 | )
91 | return self
92 |
93 | def __call__(
94 | self,
95 | queries_embeddings: dict[int, np.ndarray],
96 | k: int | None = 100,
97 | **kwargs,
98 | ) -> list:
99 | """Retrieve documents."""
100 | if k is None:
101 | k = len(self.documents)
102 |
103 | k = min(k, len(self.documents))
104 | queries_embeddings = np.array(object=list(queries_embeddings.values()))
105 | distances, indexes = self.index.search(queries_embeddings, k)
106 | matchs = np.take(a=self.documents, indices=np.where(indexes < 0, 0, indexes))
107 | rank = []
108 | for distance, index, match in zip(distances, indexes, matchs):
109 | rank.append(
110 | [
111 | {
112 | self.key: m,
113 | "similarity": 1 / (1 + d),
114 | }
115 | for d, idx, m in zip(distance, index, match)
116 | if idx > -1
117 | ]
118 | )
119 |
120 | return rank
121 |
--------------------------------------------------------------------------------
/neural_tree/retrievers/tfidf.py:
--------------------------------------------------------------------------------
1 | from neural_cherche import retrieve
2 | from scipy.sparse import csr_matrix
3 | from sklearn.feature_extraction.text import TfidfVectorizer
4 |
5 | __all__ = ["TfIdf"]
6 |
7 |
8 | class TfIdf(retrieve.TfIdf):
9 | """TfIdf retriever"""
10 |
11 | def __init__(
12 | self,
13 | key: str,
14 | on: list[str],
15 | ) -> None:
16 | super().__init__(key=key, on=on, fit=False)
17 | self.tfidf = None
18 |
19 | def encode_documents(
20 | self, documents: list[dict], model: TfidfVectorizer
21 | ) -> dict[str, csr_matrix]:
22 | """Encode queries into sparse matrix."""
23 | content = [
24 | " ".join([doc.get(field, "") for field in self.on]) for doc in documents
25 | ]
26 |
27 | # matrix is a csr matrix of shape (n_documents, n_features)
28 | matrix = model.transform(raw_documents=content)
29 | return {document[self.key]: row for document, row in zip(documents, matrix)}
30 |
31 | def encode_queries(
32 | self, queries: list[str], model: TfidfVectorizer
33 | ) -> dict[str, csr_matrix]:
34 | """Encode queries into sparse matrix."""
35 | matrix = model.transform(raw_documents=queries)
36 | return {query: row for query, row in zip(queries, matrix)}
37 |
--------------------------------------------------------------------------------
/neural_tree/scoring/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseScore
2 | from .colbert import ColBERT
3 | from .sentence_transformer import SentenceTransformer
4 | from .tfidf import TfIdf
5 |
6 | __all__ = ["BaseScore", "ColBERT", "SentenceTransformer", "TfIdf"]
7 |
--------------------------------------------------------------------------------
/neural_tree/scoring/base.py:
--------------------------------------------------------------------------------
1 | """Base class for scoring functions."""
2 | from abc import ABC, abstractmethod
3 | from typing import Any
4 |
5 | import numpy as np
6 | import torch
7 | from scipy import sparse
8 |
9 | __all__ = ["BaseScore"]
10 |
11 |
12 | class BaseScore(ABC):
13 | """Base class for scoring functions."""
14 |
15 | def __init__(self) -> None:
16 | pass
17 |
18 | @abstractmethod
19 | def distinct_documents_encoder(self) -> bool:
20 | """Return True if the encoder is distinct for documents and nodes."""
21 |
22 | @abstractmethod
23 | def transform_queries(
24 | self, queries: list[str]
25 | ) -> sparse.csr_matrix | np.ndarray | dict:
26 | """Transform queries to embeddings."""
27 |
28 | @abstractmethod
29 | def transform_documents(
30 | self, documents: list[dict]
31 | ) -> sparse.csr_matrix | np.ndarray | dict:
32 | """Transform documents to embeddings."""
33 |
34 | @abstractmethod
35 | def get_retriever(self) -> Any:
36 | """Create a retriever"""
37 |
38 | @abstractmethod
39 | def encode_queries_for_retrieval(
40 | self, queries: list[str]
41 | ) -> sparse.csr_matrix | np.ndarray | dict:
42 | """Encode queries for retrieval."""
43 |
44 | @abstractmethod
45 | def convert_to_tensor(
46 | embeddings: sparse.csr_matrix | np.ndarray, device: str
47 | ) -> torch.Tensor:
48 | """Transform sparse matrix to tensor."""
49 |
50 | @abstractmethod
51 | def nodes_scores(
52 | queries_embeddings: torch.Tensor, nodes_embeddings: torch.Tensor
53 | ) -> torch.Tensor:
54 | """Score between queries and nodes embeddings."""
55 |
56 | @abstractmethod
57 | def leaf_scores(
58 | queries_embeddings: torch.Tensor, leaf_embedding: torch.Tensor
59 | ) -> torch.Tensor:
60 | """Return the scores of the embeddings."""
61 |
62 | @abstractmethod
63 | def stack(
64 | embeddings: list[sparse.csr_matrix | np.ndarray | dict],
65 | ) -> sparse.csr_matrix | np.ndarray | dict:
66 | """Stack list of embeddings."""
67 |
--------------------------------------------------------------------------------
/neural_tree/scoring/colbert.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from neural_cherche import models
4 |
5 | from ..retrievers import ColBERT as colbert_retriever
6 | from ..utils import batchify, set_env
7 | from .base import BaseScore
8 |
9 | __all__ = ["ColBERT"]
10 |
11 |
12 | class ColBERT(BaseScore):
13 | """TfIdf scoring function.
14 |
15 | Examples
16 | --------
17 | >>> from neural_tree import trees, scoring
18 | >>> from neural_cherche import models
19 | >>> from sklearn.feature_extraction.text import TfidfVectorizer
20 | >>> from pprint import pprint
21 | >>> import torch
22 |
23 | >>> _ = torch.manual_seed(42)
24 |
25 | >>> documents = [
26 | ... {"id": 0, "text": "Paris is the capital of France."},
27 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
28 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
29 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
30 | ... ]
31 |
32 | >>> model = models.ColBERT(
33 | ... model_name_or_path="sentence-transformers/all-mpnet-base-v2",
34 | ... embedding_size=128,
35 | ... max_length_document=96,
36 | ... max_length_query=32,
37 | ... )
38 |
39 | >>> tree = trees.ColBERTTree(
40 | ... key="id",
41 | ... on="text",
42 | ... model=model,
43 | ... documents=documents,
44 | ... leaf_balance_factor=1,
45 | ... branch_balance_factor=2,
46 | ... n_jobs=1,
47 | ... )
48 |
49 | >>> print(tree)
50 | node 1
51 | node 10
52 | leaf 100
53 | leaf 101
54 | node 11
55 | leaf 110
56 | leaf 111
57 |
58 | >>> tree.leafs_to_documents
59 | {'100': [0], '101': [1], '110': [2], '111': [3]}
60 |
61 | >>> candidates = tree(
62 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."],
63 | ... k_leafs=2,
64 | ... k=2,
65 | ... )
66 |
67 | >>> candidates["scores"]
68 | array([[28.12037659, 18.32332611],
69 | [29.28324509, 21.38923264]])
70 |
71 | >>> candidates["leafs"]
72 | array([['100', '101'],
73 | ['110', '111']], dtype='>> pprint(candidates["tree_scores"])
76 | [{'10': tensor(28.1204),
77 | '100': tensor(28.1204),
78 | '101': tensor(18.3233),
79 | '11': tensor(20.9327)},
80 | {'10': tensor(21.6886),
81 | '11': tensor(29.2832),
82 | '110': tensor(29.2832),
83 | '111': tensor(21.3892)}]
84 |
85 | >>> pprint(candidates["documents"])
86 | [[{'id': 0, 'leaf': '100', 'similarity': 28.120376586914062},
87 | {'id': 1, 'leaf': '101', 'similarity': 18.323326110839844}],
88 | [{'id': 2, 'leaf': '110', 'similarity': 29.283245086669922},
89 | {'id': 3, 'leaf': '111', 'similarity': 21.389232635498047}]]
90 |
91 | """
92 |
93 | def __init__(
94 | self,
95 | key: str,
96 | on: list | str,
97 | documents: list,
98 | model: models.ColBERT = None,
99 | device: str = "cpu",
100 | **kwargs,
101 | ) -> None:
102 | """Initialize the scoring function."""
103 | set_env()
104 | self.key = key
105 | self.on = [on] if isinstance(on, str) else on
106 | self.model = model
107 | self.device = device
108 |
109 | @property
110 | def distinct_documents_encoder(self) -> bool:
111 | """Return True if the encoder is distinct for documents and nodes."""
112 | return False
113 |
114 | def transform_queries(
115 | self, queries: list[str], batch_size: int, tqdm_bar: bool, *kwargs
116 | ) -> torch.Tensor:
117 | """Transform queries to embeddings."""
118 | queries_embeddings = []
119 |
120 | for batch in batchify(X=queries, batch_size=batch_size, tqdm_bar=tqdm_bar):
121 | queries_embeddings.append(
122 | self.model.encode(texts=batch, query_mode=True)["embeddings"]
123 | )
124 |
125 | return (
126 | queries_embeddings[0].to(self.device)
127 | if len(queries_embeddings) == 1
128 | else torch.cat(tensors=queries_embeddings, dim=0).to(device=self.device)
129 | )
130 |
131 | def transform_documents(
132 | self, documents: list[dict], batch_size: int, tqdm_bar: bool, **kwargs
133 | ) -> torch.Tensor:
134 | """Transform documents to embeddings."""
135 | documents_embeddings = []
136 |
137 | for batch in batchify(
138 | X=[
139 | " ".join([document[field] for field in self.on])
140 | for document in documents
141 | ],
142 | batch_size=batch_size,
143 | tqdm_bar=tqdm_bar,
144 | ):
145 | documents_embeddings.append(
146 | self.model.encode(texts=batch, query_mode=False)["embeddings"]
147 | )
148 |
149 | return (
150 | documents_embeddings[0].to(self.device)
151 | if len(documents_embeddings) == 1
152 | else torch.cat(tensors=documents_embeddings, dim=0).to(device=self.device)
153 | )
154 |
155 | def get_retriever(self) -> None:
156 | """Create a retriever"""
157 | return colbert_retriever(key=self.key, on=self.on, device=self.device)
158 |
159 | def encode_queries_for_retrieval(self, queries: list[str]) -> None:
160 | """Encode queries for retrieval."""
161 | pass
162 |
163 | @staticmethod
164 | def convert_to_tensor(
165 | embeddings: np.ndarray | torch.Tensor, device: str
166 | ) -> torch.Tensor:
167 | """Transform sparse matrix to tensor."""
168 | if isinstance(embeddings, np.ndarray):
169 | return torch.tensor(data=embeddings, device=device, dtype=torch.float32)
170 | return embeddings.to(device=device)
171 |
172 | @staticmethod
173 | def nodes_scores(
174 | queries_embeddings: torch.Tensor, nodes_embeddings: torch.Tensor
175 | ) -> torch.Tensor:
176 | """Score between queries and nodes embeddings."""
177 | return torch.stack(
178 | tensors=[
179 | torch.einsum(
180 | "sh,bth->bst",
181 | query_embedding,
182 | nodes_embeddings,
183 | )
184 | .max(dim=2)
185 | .values.sum(axis=1)
186 | .max(dim=0)
187 | .values
188 | for query_embedding in queries_embeddings
189 | ],
190 | dim=0,
191 | )
192 |
193 | @staticmethod
194 | def leaf_scores(
195 | queries_embeddings: torch.Tensor, leaf_embedding: torch.Tensor
196 | ) -> torch.Tensor:
197 | """Return the scores of the embeddings."""
198 | return torch.stack(
199 | tensors=[
200 | torch.einsum(
201 | "sh,th->st",
202 | query_embedding,
203 | leaf_embedding,
204 | )
205 | .max(dim=1)
206 | .values.sum()
207 | for query_embedding in queries_embeddings
208 | ],
209 | dim=0,
210 | )
211 |
212 | def stack(self, embeddings: list[torch.Tensor | np.ndarray]) -> torch.Tensor:
213 | """Stack list of embeddings."""
214 | if isinstance(embeddings, np.ndarray):
215 | return self.convert_to_tensor(
216 | embeddings=embeddings, device=self.model.device
217 | )
218 | return torch.stack(tensors=embeddings, dim=0)
219 |
220 | @staticmethod
221 | def average(embeddings: torch.Tensor) -> torch.Tensor:
222 | """Average embeddings."""
223 | return embeddings.mean(axis=0)
224 |
--------------------------------------------------------------------------------
/neural_tree/scoring/sentence_transformer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sentence_transformers import SentenceTransformer
4 |
5 | from ..retrievers import SentenceTransformer as SentenceTransformerRetriever
6 | from ..utils import set_env
7 |
8 | __all__ = ["SentenceTransformer"]
9 |
10 |
11 | class SentenceTransformer:
12 | """Sentence Transformer scoring function.
13 |
14 | Examples
15 | --------
16 | >>> from neural_tree import trees, scoring
17 | >>> from sentence_transformers import SentenceTransformer
18 | >>> from pprint import pprint
19 |
20 | >>> documents = [
21 | ... {"id": 0, "text": "Paris is the capital of France."},
22 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
23 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
24 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
25 | ... ]
26 |
27 | >>> tree = trees.Tree(
28 | ... key="id",
29 | ... documents=documents,
30 | ... scoring=scoring.SentenceTransformer(key="id", on=["text"], model=SentenceTransformer("all-mpnet-base-v2")),
31 | ... leaf_balance_factor=1,
32 | ... branch_balance_factor=2,
33 | ... n_jobs=1,
34 | ... )
35 |
36 | >>> print(tree)
37 | node 1
38 | node 11
39 | node 110
40 | leaf 1100
41 | leaf 1101
42 | leaf 111
43 | leaf 10
44 |
45 | >>> candidates = tree(
46 | ... queries=["paris", "berlin"],
47 | ... k_leafs=2,
48 | ... )
49 |
50 | >>> candidates["scores"]
51 | array([[0.72453916, 0.60635257],
52 | [0.58386189, 0.57546711]])
53 |
54 | >>> candidates["leafs"]
55 | array([['111', '10'],
56 | ['1101', '1100']], dtype='>> pprint(candidates["tree_scores"])
59 | [{'10': tensor(0.6064),
60 | '11': tensor(0.7245),
61 | '110': tensor(0.5542),
62 | '1100': tensor(0.5403),
63 | '1101': tensor(0.5542),
64 | '111': tensor(0.7245)},
65 | {'10': tensor(0.5206),
66 | '11': tensor(0.5797),
67 | '110': tensor(0.5839),
68 | '1100': tensor(0.5755),
69 | '1101': tensor(0.5839),
70 | '111': tensor(0.4026)}]
71 |
72 | >>> pprint(candidates["documents"])
73 | [[{'id': 0, 'leaf': '111', 'similarity': 0.6447779347587058},
74 | {'id': 1, 'leaf': '10', 'similarity': 0.43175890864117644}],
75 | [{'id': 3, 'leaf': '1101', 'similarity': 0.545769273959571},
76 | {'id': 2, 'leaf': '1100', 'similarity': 0.54081365990618}]]
77 |
78 | """
79 |
80 | def __init__(
81 | self,
82 | key: str,
83 | on: str | list,
84 | model: SentenceTransformer,
85 | device: str = "cpu",
86 | faiss_device: str = "cpu",
87 | ) -> None:
88 | set_env()
89 | self.key = key
90 | self.on = [on] if isinstance(on, str) else on
91 | self.model = model
92 | self.device = device
93 | self.faiss_device = faiss_device
94 |
95 | @property
96 | def distinct_documents_encoder(self) -> bool:
97 | """Return True if the encoder is distinct for documents and nodes."""
98 | return False
99 |
100 | def transform_queries(
101 | self, queries: list[str], batch_size: int, **kwargs
102 | ) -> np.ndarray:
103 | """Transform queries to embeddings."""
104 | return self.model.encode(queries, batch_size=batch_size)
105 |
106 | def transform_documents(
107 | self, documents: list[dict], batch_size: int, **kwargs
108 | ) -> np.ndarray:
109 | """Transform documents to embeddings."""
110 | return self.model.encode(
111 | [
112 | " ".join([document[field] for field in self.on])
113 | for document in documents
114 | ],
115 | batch_size=batch_size,
116 | )
117 |
118 | def get_retriever(self) -> None:
119 | """Create a retriever"""
120 | return SentenceTransformerRetriever(key=self.key, device=self.faiss_device)
121 |
122 | @staticmethod
123 | def encode_queries_for_retrieval(queries: list[str]) -> None:
124 | """Encode queries for retrieval."""
125 | pass
126 |
127 | @staticmethod
128 | def convert_to_tensor(embeddings: np.ndarray, device: str) -> torch.Tensor:
129 | """Convert numpy array to torch tensor."""
130 | return torch.tensor(data=embeddings, device=device, dtype=torch.float32)
131 |
132 | @staticmethod
133 | def nodes_scores(
134 | queries_embeddings: torch.Tensor, nodes_embeddings: torch.Tensor
135 | ) -> torch.Tensor:
136 | """Score between queries and nodes embeddings."""
137 | return torch.max(
138 | input=torch.mm(input=queries_embeddings, mat2=nodes_embeddings.T),
139 | dim=1,
140 | ).values
141 |
142 | @staticmethod
143 | def leaf_scores(
144 | queries_embeddings: torch.Tensor, leaf_embedding: torch.Tensor
145 | ) -> torch.Tensor:
146 | """Computes scores between query and leaf embedding."""
147 | return queries_embeddings @ leaf_embedding.T
148 |
149 | @staticmethod
150 | def stack(embeddings: list[np.ndarray]) -> np.ndarray:
151 | """Stack embeddings."""
152 | return (
153 | np.vstack(tup=embeddings)
154 | if len(embeddings) > 1
155 | else embeddings[0].reshape(1, -1)
156 | )
157 |
158 | @staticmethod
159 | def average(embeddings: np.ndarray) -> np.ndarray:
160 | """Average embeddings."""
161 | return np.mean(a=embeddings, axis=0)
162 |
--------------------------------------------------------------------------------
/neural_tree/scoring/tfidf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy import sparse
4 | from sklearn.feature_extraction.text import TfidfVectorizer
5 |
6 | from ..retrievers import TfIdf as tfidf_retriever
7 | from ..utils import set_env
8 | from .base import BaseScore
9 |
10 | __all__ = ["TfIdf"]
11 |
12 |
13 | class TfIdf(BaseScore):
14 | """TfIdf scoring function.
15 |
16 | Examples
17 | --------
18 | >>> from neural_tree import trees, scoring
19 | >>> from sklearn.feature_extraction.text import TfidfVectorizer
20 | >>> from pprint import pprint
21 |
22 | >>> documents = [
23 | ... {"id": 0, "text": "Paris is the capital of France."},
24 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
25 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
26 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
27 | ... ]
28 |
29 | >>> tree = trees.Tree(
30 | ... key="id",
31 | ... documents=documents,
32 | ... scoring=scoring.TfIdf(key="id", on=["text"], documents=documents),
33 | ... leaf_balance_factor=1,
34 | ... branch_balance_factor=2,
35 | ... )
36 |
37 | >>> print(tree)
38 | node 1
39 | node 10
40 | leaf 100
41 | leaf 101
42 | node 11
43 | leaf 110
44 | leaf 111
45 |
46 | >>> tree.leafs_to_documents
47 | {'100': [0], '101': [1], '110': [2], '111': [3]}
48 |
49 | >>> candidates = tree(
50 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."],
51 | ... k_leafs=2,
52 | ... k=2,
53 | ... )
54 |
55 | >>> candidates["scores"]
56 | array([[0.99999994, 0.63854915],
57 | [0.99999994, 0.72823119]])
58 |
59 | >>> candidates["leafs"]
60 | array([['100', '101'],
61 | ['110', '111']], dtype='>> pprint(candidates["tree_scores"])
64 | [{'10': tensor(1.0000),
65 | '100': tensor(1.0000),
66 | '101': tensor(0.6385),
67 | '11': tensor(0.1076)},
68 | {'10': tensor(0.1076),
69 | '11': tensor(1.0000),
70 | '110': tensor(1.0000),
71 | '111': tensor(0.7282)}]
72 |
73 | >>> pprint(candidates["documents"])
74 | [[{'id': 0, 'leaf': '100', 'similarity': 0.9999999999999978},
75 | {'id': 1, 'leaf': '101', 'similarity': 0.39941742405759667}],
76 | [{'id': 2, 'leaf': '110', 'similarity': 0.9999999999999978},
77 | {'id': 3, 'leaf': '111', 'similarity': 0.5385719658738707}]]
78 |
79 | """
80 |
81 | def __init__(
82 | self,
83 | key: str,
84 | on: list | str,
85 | documents: list,
86 | tfidf_nodes: TfidfVectorizer | None = None,
87 | tfidf_documents: TfidfVectorizer | None = None,
88 | **kwargs,
89 | ) -> None:
90 | """Initialize the scoring function."""
91 | set_env()
92 |
93 | self.key = key
94 | self.on = [on] if isinstance(on, str) else on
95 |
96 | if tfidf_nodes is None:
97 | tfidf_nodes = TfidfVectorizer()
98 |
99 | if tfidf_documents is None:
100 | tfidf_documents = TfidfVectorizer(
101 | lowercase=True, ngram_range=(3, 7), analyzer="char_wb"
102 | )
103 |
104 | self.tfidf_nodes = tfidf_nodes.fit(
105 | raw_documents=[
106 | " ".join([document[field] for field in self.on])
107 | for document in documents
108 | ],
109 | )
110 |
111 | self.model = tfidf_documents.fit(
112 | raw_documents=[
113 | " ".join([document[field] for field in self.on])
114 | for document in documents
115 | ],
116 | )
117 |
118 | @property
119 | def distinct_documents_encoder(self) -> bool:
120 | """Return True if the encoder is distinct for documents and nodes."""
121 | return True
122 |
123 | def transform_queries(self, queries: list[str], **kwargs) -> sparse.csr_matrix:
124 | """Transform queries to embeddings."""
125 | return self.tfidf_nodes.transform(raw_documents=queries)
126 |
127 | def transform_documents(self, documents: list[dict], **kwargs) -> sparse.csr_matrix:
128 | """Transform documents to embeddings."""
129 | return self.tfidf_nodes.transform(
130 | raw_documents=[
131 | " ".join([document[field] for field in self.on])
132 | for document in documents
133 | ],
134 | )
135 |
136 | def get_retriever(self) -> tfidf_retriever:
137 | """Create a retriever"""
138 | return tfidf_retriever(
139 | key=self.key,
140 | on=self.on,
141 | )
142 |
143 | def encode_queries_for_retrieval(self, queries: list[str]) -> sparse.csr_matrix:
144 | """Encode queries for retrieval."""
145 | return self.model.transform(raw_documents=queries)
146 |
147 | @staticmethod
148 | def convert_to_tensor(embeddings: sparse.csr_matrix, device: str) -> torch.Tensor:
149 | """Transform sparse matrix to tensor."""
150 | embeddings = embeddings.tocoo()
151 | return torch.sparse_coo_tensor(
152 | indices=torch.tensor(
153 | data=np.vstack(tup=(embeddings.row, embeddings.col)),
154 | dtype=torch.long,
155 | device=device,
156 | ),
157 | values=torch.tensor(
158 | data=embeddings.data, dtype=torch.float32, device=device
159 | ),
160 | size=torch.Size(embeddings.shape),
161 | ).to(device=device)
162 |
163 | @staticmethod
164 | def nodes_scores(
165 | queries_embeddings: torch.Tensor, nodes_embeddings: torch.Tensor
166 | ) -> torch.Tensor:
167 | """Score between queries and nodes embeddings."""
168 | return torch.max(
169 | input=torch.mm(
170 | input=queries_embeddings,
171 | mat2=nodes_embeddings.T,
172 | ).to_dense(),
173 | dim=1,
174 | ).values
175 |
176 | @staticmethod
177 | def leaf_scores(
178 | queries_embeddings: torch.Tensor, leaf_embedding: torch.Tensor
179 | ) -> torch.Tensor:
180 | """Return the scores of the embeddings."""
181 | return (
182 | torch.mm(input=queries_embeddings, mat2=leaf_embedding.unsqueeze(dim=0).T)
183 | .to_dense()
184 | .flatten()
185 | )
186 |
187 | @staticmethod
188 | def stack(embeddings: list[sparse.csr_matrix]) -> sparse.csr_matrix:
189 | """Stack list of embeddings."""
190 | return (
191 | sparse.vstack(blocks=embeddings) if len(embeddings) > 1 else embeddings[0]
192 | )
193 |
194 | @staticmethod
195 | def average(embeddings: sparse.csr_matrix) -> sparse.csr_matrix:
196 | """Average embeddings."""
197 | return embeddings.mean(axis=0)
198 |
--------------------------------------------------------------------------------
/neural_tree/trees/__init__.py:
--------------------------------------------------------------------------------
1 | from .colbert import ColBERT
2 | from .sentence_transformer import SentenceTransformer
3 | from .tfidf import TfIdf
4 | from .tree import Tree
5 |
6 | __all__ = ["ColBERT", "SentenceTransformer", "TfIdf", "Tree"]
7 |
--------------------------------------------------------------------------------
/neural_tree/trees/colbert.py:
--------------------------------------------------------------------------------
1 | from neural_cherche import models
2 | from sentence_transformers import SentenceTransformer as SentenceTransformerModel
3 |
4 | from ..clustering import get_mapping_nodes_documents
5 | from ..scoring import ColBERT as scoring_ColBERT
6 | from .sentence_transformer import SentenceTransformer
7 | from .tfidf import TfIdf
8 | from .tree import Tree
9 |
10 | __all__ = ["ColBERT"]
11 |
12 |
13 | class ColBERT(Tree):
14 | """ColBERT retriever.
15 |
16 | Parameters
17 | ----------
18 | key
19 | Key to identify the documents.
20 | on
21 | List of columns to use for the retrieval.
22 | model
23 | ColBERT model.
24 | sentence_transformer
25 | SentenceTransformer model in order to perform the hierarchical clustering. If
26 | None, the hierarchical clustering is performed with a TfIdf model.
27 | documents
28 | List of documents to index.
29 | graph
30 | Existing graph to initialize the tree.
31 | leaf_balance_factor
32 | Balance factor for the leafs. Once there is less than `leaf_balance_factor`
33 | documents in a node, the node becomes a leaf.
34 | branch_balance_factor
35 | Balance factor for the branches. The number of children of a node is limited to
36 | `branch_balance_factor`.
37 | device
38 | Device to use for the retrieval.
39 | n_jobs
40 | Number of jobs to use when creating the tree. If -1, all CPUs are used.
41 | batch_size
42 | Batch size to use when creating the tree.
43 | max_iter
44 | Maximum number of iterations to perform with Kmeans algorithm when creating the
45 | tree.
46 | n_init
47 | Number of time the KMeans algorithm will be run with different centroid seeds.
48 | create_retrievers
49 | Whether to create the retrievers or not. If False, the tree is only created and
50 | the __call__ method will only output relevant leafs and scores rather than
51 | ranked documents.
52 | tqdm_bar
53 | Whether to show the tqdm bar when creating the tree.
54 | seed
55 | Random seed.
56 |
57 | """
58 |
59 | def __init__(
60 | self,
61 | key: str,
62 | on: str | list[str],
63 | model: models.ColBERT,
64 | sentence_transformer: SentenceTransformerModel | None = None,
65 | documents: list[dict] | None = None,
66 | graph: dict | None = None,
67 | leaf_balance_factor: int = 100,
68 | branch_balance_factor: int = 5,
69 | device: str = "cpu",
70 | n_jobs: int = -1,
71 | batch_size: int = 32,
72 | max_iter: int = 3000,
73 | n_init: int = 100,
74 | create_retrievers: bool = True,
75 | tqdm_bar: bool = True,
76 | seed: int = 42,
77 | ) -> None:
78 | """Create a tree with the TfIdf scoring."""
79 | if graph is not None:
80 | documents = get_mapping_nodes_documents(graph=graph)
81 | elif graph is None and sentence_transformer is None:
82 | index = TfIdf(
83 | key=key,
84 | on=on,
85 | documents=documents,
86 | leaf_balance_factor=leaf_balance_factor,
87 | branch_balance_factor=branch_balance_factor,
88 | create_retrievers=False,
89 | n_jobs=n_jobs,
90 | max_iter=max_iter,
91 | n_init=n_init,
92 | seed=seed,
93 | )
94 | graph = index.to_json()
95 | else:
96 | index = SentenceTransformer(
97 | key=key,
98 | on=on,
99 | documents=documents,
100 | model=sentence_transformer,
101 | leaf_balance_factor=leaf_balance_factor,
102 | branch_balance_factor=branch_balance_factor,
103 | n_jobs=n_jobs,
104 | create_retrievers=False,
105 | max_iter=max_iter,
106 | n_init=n_init,
107 | batch_size=batch_size,
108 | seed=seed,
109 | )
110 | graph = index.to_json()
111 |
112 | scoring = scoring_ColBERT(
113 | key=key,
114 | on=on,
115 | documents=documents,
116 | model=model,
117 | device=device,
118 | )
119 |
120 | # We computes embeddings here because we need documents contents.
121 | documents_embeddings = scoring.transform_documents(
122 | documents=documents,
123 | model=model,
124 | device=device,
125 | batch_size=batch_size,
126 | tqdm_bar=tqdm_bar,
127 | )
128 |
129 | documents_embeddings = {
130 | document[key]: embedding
131 | for document, embedding in zip(documents, documents_embeddings)
132 | }
133 |
134 | super().__init__(
135 | key=key,
136 | graph=graph,
137 | documents=documents,
138 | scoring=scoring,
139 | documents_embeddings=documents_embeddings,
140 | leaf_balance_factor=leaf_balance_factor,
141 | branch_balance_factor=branch_balance_factor,
142 | device=device,
143 | n_jobs=1,
144 | create_retrievers=create_retrievers,
145 | max_iter=max_iter,
146 | n_init=n_init,
147 | seed=seed,
148 | )
149 |
--------------------------------------------------------------------------------
/neural_tree/trees/sentence_transformer.py:
--------------------------------------------------------------------------------
1 | from sentence_transformers import SentenceTransformer
2 |
3 | from ..clustering import get_mapping_nodes_documents
4 | from ..scoring import SentenceTransformer as scoring_SenrenceTransformer
5 | from .tree import Tree
6 |
7 | __all__ = ["SentenceTransformer"]
8 |
9 |
10 | class SentenceTransformer(Tree):
11 | """Tree with Sentence Transformer scoring function.
12 |
13 | Parameters
14 | ----------
15 | key
16 | Key to identify the documents.
17 | on
18 | List of columns to use for the retrieval.
19 | model
20 | Sentence Transformer model.
21 | documents
22 | List of documents to index.
23 | graph
24 | Existing graph to initialize the tree.
25 | leaf_balance_factor
26 | Balance factor for the leafs. Once there is less than `leaf_balance_factor`
27 | documents in a node, the node becomes a leaf.
28 | branch_balance_factor
29 | Balance factor for the branches. The number of children of a node is limited to
30 | `branch_balance_factor`.
31 | device
32 | Device to use for the retrieval.
33 | n_jobs
34 | Number of jobs to use when creating the tree. If -1, all CPUs are used.
35 | batch_size
36 | Batch size to use when creating the tree.
37 | max_iter
38 | Maximum number of iterations to perform with Kmeans algorithm when creating the
39 | tree.
40 | n_init
41 | Number of time the KMeans algorithm will be run with different centroid seeds.
42 | create_retrievers
43 | Whether to create the retrievers or not. If False, the tree is only created and
44 | the __call__ method will only output relevant leafs and scores rather than
45 | ranked documents.
46 | tqdm_bar
47 | Whether to show the tqdm bar when creating the tree.
48 | seed
49 | Random seed.
50 |
51 | Examples
52 | --------
53 | >>> from neural_tree import trees
54 | >>> from sentence_transformers import SentenceTransformer
55 | >>> from pprint import pprint
56 |
57 | >>> documents = [
58 | ... {"id": 0, "text": "Paris is the capital of France."},
59 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
60 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
61 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
62 | ... ]
63 |
64 | >>> tree = trees.SentenceTransformer(
65 | ... key="id",
66 | ... on=["text"],
67 | ... documents=documents,
68 | ... model=SentenceTransformer("all-mpnet-base-v2"),
69 | ... leaf_balance_factor=2,
70 | ... branch_balance_factor=2,
71 | ... device="cpu",
72 | ... )
73 |
74 | >>> tree = tree.add(documents=documents)
75 |
76 | >>> print(tree)
77 | node 1
78 | node 11
79 | leaf 110
80 | leaf 111
81 | leaf 10
82 |
83 | >>> tree.leafs_to_documents
84 | {'110': [2, 3, 1], '111': [0], '10': [1]}
85 |
86 | >>> candidates = tree(
87 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."],
88 | ... k_leafs=2,
89 | ... k=1,
90 | ... )
91 |
92 | >>> candidates["scores"]
93 | array([[1. , 0.76908004],
94 | [0.88792843, 0.82272887]])
95 |
96 | >>> candidates["leafs"]
97 | array([['111', '10'],
98 | ['110', '10']], dtype='>> pprint(candidates["tree_scores"])
101 | [{'10': tensor(0.7691, device='mps:0'),
102 | '11': tensor(1., device='mps:0'),
103 | '110': tensor(0.6536, device='mps:0'),
104 | '111': tensor(1., device='mps:0')},
105 | {'10': tensor(0.8227, device='mps:0'),
106 | '11': tensor(0.8879, device='mps:0'),
107 | '110': tensor(0.8879, device='mps:0'),
108 | '111': tensor(0.6923, device='mps:0')}]
109 |
110 | >>> pprint(candidates["documents"])
111 | [[{'id': 0, 'leaf': '111', 'similarity': 1.0}],
112 | [{'id': 2, 'leaf': '110', 'similarity': 1.0}]]
113 |
114 | """
115 |
116 | def __init__(
117 | self,
118 | key: str,
119 | on: str | list[str],
120 | model: SentenceTransformer,
121 | documents: list[dict] | None = None,
122 | documents_embeddings: dict | None = None,
123 | graph: dict | None = None,
124 | leaf_balance_factor: int = 100,
125 | branch_balance_factor: int = 5,
126 | device: str = "cpu",
127 | faiss_device: str = "cpu",
128 | batch_size: int = 32,
129 | n_jobs: int = -1,
130 | max_iter: int = 3000,
131 | n_init: int = 100,
132 | create_retrievers: bool = True,
133 | seed: int = 42,
134 | ) -> None:
135 | """Create a tree with the TfIdf scoring."""
136 | if graph is not None:
137 | documents = get_mapping_nodes_documents(graph=graph)
138 |
139 | super().__init__(
140 | key=key,
141 | documents=documents,
142 | graph=graph,
143 | documents_embeddings=documents_embeddings,
144 | scoring=scoring_SenrenceTransformer(
145 | key=key,
146 | on=on,
147 | model=model,
148 | device=device,
149 | faiss_device=faiss_device,
150 | ),
151 | leaf_balance_factor=leaf_balance_factor,
152 | branch_balance_factor=branch_balance_factor,
153 | device=device,
154 | batch_size=batch_size,
155 | n_jobs=n_jobs,
156 | max_iter=max_iter,
157 | n_init=n_init,
158 | create_retrievers=create_retrievers,
159 | seed=seed,
160 | )
161 |
--------------------------------------------------------------------------------
/neural_tree/trees/tfidf.py:
--------------------------------------------------------------------------------
1 | from sklearn.feature_extraction.text import TfidfVectorizer
2 |
3 | from ..clustering import get_mapping_nodes_documents
4 | from ..scoring import TfIdf as scoring_TfIdf
5 | from .tree import Tree
6 |
7 | __all__ = ["TfIdf"]
8 |
9 |
10 | class TfIdf(Tree):
11 | """Tree with tfidf scoring function.
12 |
13 | Parameters
14 | ----------
15 | key
16 | Key to identify the documents.
17 | on
18 | List of columns to use for the retrieval.
19 | tfidf_nodes
20 | TfidfVectorizer for the nodes.
21 | tfidf_documents
22 | TfidfVectorizer for the documents.
23 | documents
24 | List of documents to index.
25 | graph
26 | Existing graph to initialize the tree.
27 | leaf_balance_factor
28 | Balance factor for the leafs. Once there is less than `leaf_balance_factor`
29 | documents in a node, the node becomes a leaf.
30 | branch_balance_factor
31 | Balance factor for the branches. The number of children of a node is limited to
32 | `branch_balance_factor`.
33 | device
34 | Device to use for the retrieval.
35 | n_jobs
36 | Number of jobs to use when creating the tree. If -1, all CPUs are used.
37 | max_iter
38 | Maximum number of iterations to perform with Kmeans algorithm when creating the
39 | tree.
40 | n_init
41 | Number of time the KMeans algorithm will be run with different centroid seeds.
42 | create_retrievers
43 | Whether to create the retrievers or not. If False, the tree is only created and
44 | the __call__ method will only output relevant leafs and scores rather than
45 | ranked documents.
46 | tqdm_bar
47 | Whether to show the tqdm bar when creating the tree.
48 | seed
49 | Random seed.
50 |
51 | Examples
52 | --------
53 | >>> from neural_tree import trees
54 | >>> from pprint import pprint
55 |
56 | >>> documents = [
57 | ... {"id": 0, "text": "Paris is the capital of France."},
58 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
59 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
60 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
61 | ... ]
62 |
63 | >>> tree = trees.TfIdf(
64 | ... key="id",
65 | ... on="text",
66 | ... documents=documents,
67 | ... leaf_balance_factor=2,
68 | ... branch_balance_factor=2,
69 | ... )
70 |
71 | >>> tree = tree.add(documents=documents)
72 |
73 | >>> print(tree)
74 | node 1
75 | leaf 10
76 | leaf 11
77 |
78 | >>> tree.leafs_to_documents
79 | {'10': [0, 1], '11': [2, 3]}
80 |
81 | >>> candidates = tree(
82 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."],
83 | ... k_leafs=2,
84 | ... k=2,
85 | ... )
86 |
87 | >>> candidates["scores"]
88 | array([[0.81927449, 0.10763316],
89 | [0.8641156 , 0.10763316]])
90 |
91 | >>> candidates["leafs"]
92 | array([['10', '11'],
93 | ['11', '10']], dtype='>> pprint(candidates["tree_scores"])
96 | [{'10': tensor(0.8193), '11': tensor(0.1076)},
97 | {'10': tensor(0.1076), '11': tensor(0.8641)}]
98 |
99 | >>> pprint(candidates["documents"])
100 | [[{'id': 0, 'leaf': '10', 'similarity': 0.9999999999999978},
101 | {'id': 1, 'leaf': '10', 'similarity': 0.39941742405759667}],
102 | [{'id': 2, 'leaf': '11', 'similarity': 0.9999999999999978},
103 | {'id': 3, 'leaf': '11', 'similarity': 0.5385719658738707}]]
104 |
105 | """
106 |
107 | def __init__(
108 | self,
109 | key: str,
110 | on: str | list[str],
111 | documents: list[dict] | None = None,
112 | graph: dict | None = None,
113 | leaf_balance_factor: int = 100,
114 | branch_balance_factor: int = 5,
115 | tfidf_nodes: TfidfVectorizer | None = None,
116 | tfidf_documents: TfidfVectorizer | None = None,
117 | device: str = "cpu",
118 | n_jobs: int = -1,
119 | max_iter: int = 3000,
120 | n_init: int = 100,
121 | create_retrievers: bool = True,
122 | seed: int = 42,
123 | ) -> None:
124 | """Create a tree with the TfIdf scoring."""
125 | if graph is not None:
126 | documents = get_mapping_nodes_documents(graph=graph)
127 |
128 | super().__init__(
129 | key=key,
130 | documents=documents,
131 | graph=graph,
132 | scoring=scoring_TfIdf(
133 | key=key,
134 | on=on,
135 | documents=documents,
136 | tfidf_nodes=tfidf_nodes,
137 | tfidf_documents=tfidf_documents,
138 | device=device,
139 | ),
140 | leaf_balance_factor=leaf_balance_factor,
141 | branch_balance_factor=branch_balance_factor,
142 | device=device,
143 | n_jobs=n_jobs,
144 | create_retrievers=create_retrievers,
145 | max_iter=max_iter,
146 | n_init=n_init,
147 | seed=seed,
148 | )
149 |
--------------------------------------------------------------------------------
/neural_tree/trees/tree.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import copy
3 | import random
4 | from functools import lru_cache
5 | from typing import Generator
6 |
7 | import numpy as np
8 | import torch
9 | from scipy import sparse
10 |
11 | from ..leafs import Leaf
12 | from ..nodes import Node
13 | from ..scoring import SentenceTransformer, TfIdf
14 | from ..utils import sanity_check
15 |
16 | __all__ = ["Tree"]
17 |
18 |
19 | class Tree(torch.nn.Module):
20 | """Tree based index for information retrieval.
21 |
22 | Examples
23 | --------
24 | >>> from neural_tree import trees, scoring, clustering
25 | >>> from pprint import pprint
26 |
27 | >>> device = "cpu"
28 |
29 | >>> queries = [
30 | ... "Paris is the capital of France.",
31 | ... "Berlin",
32 | ... "Berlin",
33 | ... "Paris is the capital of France."
34 | ... ]
35 |
36 | >>> documents = [
37 | ... {"id": 0, "text": "Paris is the capital of France."},
38 | ... {"id": 1, "text": "Berlin is the capital of Germany."},
39 | ... {"id": 2, "text": "Paris and Berlin are European cities."},
40 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."},
41 | ... ]
42 |
43 | >>> tree = trees.Tree(
44 | ... key="id",
45 | ... documents=documents,
46 | ... scoring=scoring.TfIdf(key="id", on=["text"], documents=documents),
47 | ... leaf_balance_factor=1,
48 | ... branch_balance_factor=2,
49 | ... device=device,
50 | ... n_jobs=1,
51 | ... )
52 |
53 | >>> print(tree)
54 | node 1
55 | node 10
56 | leaf 100
57 | leaf 101
58 | node 11
59 | leaf 110
60 | leaf 111
61 |
62 | >>> tree.documents_to_leafs
63 | {0: ['100'], 1: ['101'], 2: ['110'], 3: ['111']}
64 |
65 | >>> tree.leafs_to_documents
66 | {'100': [0], '101': [1], '110': [2], '111': [3]}
67 |
68 | >>> candidates = tree(
69 | ... queries=queries,
70 | ... k=2,
71 | ... k_leafs=2,
72 | ... )
73 |
74 | >>> pprint(candidates["documents"])
75 | [[{'id': 0, 'leaf': '100', 'similarity': 0.9999999999999978},
76 | {'id': 1, 'leaf': '101', 'similarity': 0.39941742405759667}],
77 | [{'id': 3, 'leaf': '111', 'similarity': 0.3523828592933607},
78 | {'id': 2, 'leaf': '110', 'similarity': 0.348413283355546}],
79 | [{'id': 3, 'leaf': '111', 'similarity': 0.3523828592933607},
80 | {'id': 2, 'leaf': '110', 'similarity': 0.348413283355546}],
81 | [{'id': 0, 'leaf': '100', 'similarity': 0.9999999999999978},
82 | {'id': 1, 'leaf': '101', 'similarity': 0.39941742405759667}]]
83 |
84 | >>> pprint(candidates["tree_scores"])
85 | [{'10': tensor(1.0000),
86 | '100': tensor(1.0000),
87 | '101': tensor(0.6385),
88 | '11': tensor(0.1076)},
89 | {'10': tensor(0.3235),
90 | '11': tensor(0.3327),
91 | '110': tensor(0.3327),
92 | '111': tensor(0.3327)},
93 | {'10': tensor(0.3235),
94 | '11': tensor(0.3327),
95 | '110': tensor(0.3327),
96 | '111': tensor(0.3327)},
97 | {'10': tensor(1.0000),
98 | '100': tensor(1.0000),
99 | '101': tensor(0.6385),
100 | '11': tensor(0.1076)}]
101 |
102 |
103 | >>> candidates = tree(
104 | ... queries=queries,
105 | ... leafs=["110", "111", "111", "111"],
106 | ... )
107 |
108 | >>> pprint(candidates["documents"])
109 | [[{'id': 2, 'leaf': '110', 'similarity': 0.1036216271728989}],
110 | [{'id': 3, 'leaf': '111', 'similarity': 0.3523828592933607}],
111 | [{'id': 3, 'leaf': '111', 'similarity': 0.3523828592933607}],
112 | [{'id': 3, 'leaf': '111', 'similarity': 0.09981163726061484}]]
113 |
114 | >>> optimizer = torch.optim.AdamW(lr=3e-5, params=list(tree.parameters()))
115 |
116 | >>> loss = tree.loss(
117 | ... queries=queries,
118 | ... documents=documents,
119 | ... )
120 |
121 | >>> loss.backward()
122 | >>> optimizer.step()
123 | >>> assert loss.item() > 0
124 |
125 | >>> graph = tree.to_json()
126 | >>> pprint(graph)
127 | {1: {'10': {'100': [{'id': 0}], '101': [{'id': 1}]},
128 | '11': {'110': [{'id': 2}], '111': [{'id': 3}]}}}
129 |
130 | >>> graph = {'sport': {'football': {'bayern': [{'id': 2, 'text': 'bayern football team'}],
131 | ... 'psg': [{'id': 1, 'text': 'psg football team'}]},
132 | ... 'rugby': {'toulouse': [{'id': 3, 'text': 'toulouse rugby team'}],
133 | ... 'ville rose': [{'id': 3, 'text': 'toulouse rugby team'},
134 | ... {'id': 4, 'text': 'tfc football team'}]}}}
135 |
136 | >>> documents = clustering.get_mapping_nodes_documents(graph=graph)
137 |
138 | >>> tree = trees.Tree(
139 | ... key="id",
140 | ... documents=documents,
141 | ... scoring=scoring.TfIdf(key="id", on=["text"], documents=documents),
142 | ... leaf_balance_factor=1,
143 | ... branch_balance_factor=2,
144 | ... device=device,
145 | ... graph=graph,
146 | ... n_jobs=1,
147 | ... )
148 |
149 | >>> tree.documents_to_leafs
150 | {3: ['ville rose', 'toulouse'], 4: ['ville rose'], 2: ['bayern'], 1: ['psg']}
151 |
152 | >>> tree.leafs_to_documents
153 | {'ville rose': [3, 4], 'toulouse': [3], 'bayern': [2], 'psg': [1]}
154 |
155 | >>> print(tree)
156 | node sport
157 | node rugby
158 | leaf ville rose
159 | leaf toulouse
160 | node football
161 | leaf bayern
162 | leaf psg
163 |
164 | >>> candidates = tree(
165 | ... queries=["psg", "toulouse"],
166 | ... k=2,
167 | ... k_leafs=2,
168 | ... )
169 |
170 | >>> pprint(candidates["documents"])
171 | [[{'id': 1, 'leaf': 'psg', 'similarity': 0.5255159378077358}],
172 | [{'id': 3, 'leaf': 'ville rose', 'similarity': 0.7865788511708137},
173 | {'id': 3, 'leaf': 'toulouse', 'similarity': 0.7865788511708137}]]
174 |
175 | References
176 | ----------
177 | [Li et al., 2023](https://arxiv.org/pdf/2206.02743.pdf)
178 |
179 | """
180 |
181 | def __init__(
182 | self,
183 | key: str,
184 | scoring: SentenceTransformer | TfIdf,
185 | documents: list,
186 | leaf_balance_factor: int,
187 | branch_balance_factor: int,
188 | device,
189 | seed: int,
190 | max_iter: int,
191 | n_init: int,
192 | n_jobs: int,
193 | batch_size: int = None,
194 | create_retrievers: bool = True,
195 | graph: dict | None = None,
196 | documents_embeddings: dict | None = None,
197 | ) -> None:
198 | super(Tree, self).__init__()
199 | self.key = key
200 | self.device = device
201 | self.seed = seed
202 | self.scoring = scoring
203 | self.create_retrievers = create_retrievers
204 | self.node_name = 1
205 |
206 | # Sanity check over input parameters
207 | sanity_check(
208 | branch_balance_factor=branch_balance_factor,
209 | leaf_balance_factor=leaf_balance_factor,
210 | graph=graph,
211 | documents=documents,
212 | )
213 |
214 | if graph is not None:
215 | for node_name in graph.keys():
216 | self.node_name = node_name
217 | break
218 |
219 | if documents_embeddings is None:
220 | documents_embeddings = self.scoring.transform_documents(
221 | documents=documents, batch_size=batch_size
222 | )
223 | else:
224 | documents_embeddings = self.scoring.stack(
225 | embeddings=[
226 | documents_embeddings[document[self.key]] for document in documents
227 | ]
228 | )
229 |
230 | self.tree = Node(
231 | level=0,
232 | node_name=self.node_name,
233 | key=self.key,
234 | documents=documents,
235 | documents_embeddings=documents_embeddings,
236 | scoring=scoring,
237 | leaf_balance_factor=leaf_balance_factor,
238 | branch_balance_factor=branch_balance_factor,
239 | device=self.device,
240 | seed=self.seed,
241 | n_jobs=n_jobs,
242 | create_retrievers=create_retrievers,
243 | graph=graph[self.node_name] if graph is not None else None,
244 | parent=0,
245 | max_iter=max_iter,
246 | n_init=n_init,
247 | )
248 |
249 | self.documents_to_leafs, self.leafs_to_documents = self.get_documents_leafs()
250 | self.negative_samples = self.get_negative_samples()
251 | self._paths = self.get_paths()
252 | self.mapping_leafs = self.get_mapping_leafs()
253 |
254 | def __str__(self) -> str:
255 | """Return the tree as string."""
256 | repr = ""
257 | for node in self.nodes():
258 | repr += f"{node}\n"
259 | return repr[:-1]
260 |
261 | def get_mapping_leafs(self) -> dict:
262 | """Returns mapping between leafs and their number."""
263 | mapping_leafs = {}
264 | for leaf in self.nodes():
265 | if isinstance(leaf, Leaf):
266 | mapping_leafs[leaf.node_name] = leaf
267 | return mapping_leafs
268 |
269 | def get_documents_leafs(self) -> dict:
270 | """Returns mapping between documents ids and leafs and vice versa."""
271 | documents_to_leafs, leafs_to_documents = (
272 | collections.defaultdict(list),
273 | collections.defaultdict(list),
274 | )
275 | for node in self.nodes():
276 | if isinstance(node, Leaf):
277 | for document in node.documents:
278 | documents_to_leafs[document].append(node.node_name)
279 | leafs_to_documents[node.node_name].append(document)
280 |
281 | return dict(documents_to_leafs), dict(leafs_to_documents)
282 |
283 | def get_paths(self) -> list[torch.Tensor]:
284 | """Map leafs to their nodes."""
285 | self.paths.cache_clear()
286 | paths = collections.defaultdict(list)
287 | for node in self.nodes():
288 | if isinstance(node, Leaf):
289 | leaf = node
290 | for _ in range(leaf.level):
291 | node = self.get_parent(node_name=node.node_name)
292 | if node.level != 0:
293 | paths[leaf.node_name].append(node.node_name)
294 | paths[leaf.node_name].reverse()
295 | return dict(paths)
296 |
297 | @lru_cache(maxsize=1000)
298 | def paths(self, leaf: int) -> dict:
299 | return copy.deepcopy(x=self._paths[leaf])
300 |
301 | def get_negative_samples(self) -> dict:
302 | """Return negative samples build from the tree."""
303 | levels = collections.defaultdict(list)
304 | for node in self.nodes():
305 | if node.node_name != self.node_name:
306 | levels[node.level].append((node.node_name, node.parent))
307 | negatives = {}
308 | for _, nodes in levels.items():
309 | for node, node_parent in nodes:
310 | negatives[node] = [
311 | negative_node
312 | for (negative_node, negative_node_parent) in nodes
313 | if (negative_node != node) and (negative_node_parent == node_parent)
314 | ]
315 | return negatives
316 |
317 | def parameters(self) -> Generator:
318 | """Return the parameters of the tree."""
319 | for node in self.nodes():
320 | if isinstance(node, Leaf):
321 | continue
322 | yield node.nodes_embeddings
323 |
324 | def nodes(
325 | self,
326 | node: Node | Leaf = None,
327 | ) -> Generator:
328 | """Iterate over the nodes of the tree."""
329 | if node is None:
330 | node = self.tree
331 | yield node
332 |
333 | for node in node.childs:
334 | yield node
335 |
336 | if not isinstance(node, Leaf):
337 | yield from self.nodes(node=node)
338 |
339 | def get_parent(self, node_name: int | str) -> int | str:
340 | """Get parent nodes of a specifc node.
341 |
342 | Parameters
343 | ----------
344 | number:
345 | Number of the node.
346 | """
347 | for node in self.nodes():
348 | if isinstance(node, Leaf):
349 | continue
350 |
351 | for child in node.childs:
352 | if child.node_name == node_name:
353 | return node
354 |
355 | return None
356 |
357 | @torch.no_grad()
358 | def __call__(
359 | self,
360 | queries: list[str],
361 | k: int = 100,
362 | k_leafs: int = 1,
363 | leafs: list[int] | None = None,
364 | score_documents: bool = True,
365 | beam_search_depth: int = 1,
366 | queries_embeddings: torch.Tensor | np.ndarray | dict = None,
367 | batch_size: int = 32,
368 | tqdm_bar: bool = True,
369 | ) -> tuple[torch.Tensor, list, list]:
370 | """Search for the closest embedding.
371 |
372 | Parameters
373 | ----------
374 | queries:
375 | Queries to search for.
376 | embeddings:
377 | Embeddings to search for.
378 | k:
379 | Number of leafs to search for.
380 | leafs:
381 | Leaf to search for.
382 | score_documents:
383 | Weather to score documents or not.
384 | """
385 | if queries_embeddings is None:
386 | queries_embeddings = self.scoring.transform_queries(
387 | queries=queries,
388 | batch_size=batch_size,
389 | tqdm_bar=tqdm_bar,
390 | )
391 |
392 | tree_scores = self._search(
393 | queries=queries,
394 | queries_embeddings=queries_embeddings,
395 | k_leafs=k_leafs,
396 | leafs=leafs,
397 | beam_search_depth=beam_search_depth,
398 | batch_size=batch_size,
399 | tqdm_bar=tqdm_bar,
400 | )
401 |
402 | if not leafs:
403 | leafs_scores = [
404 | {
405 | leaf: score.item()
406 | for leaf, score in sorted(
407 | query_scores.items(),
408 | key=lambda item: item[1].item(),
409 | reverse=True,
410 | )
411 | if leaf in self.mapping_leafs
412 | }
413 | for query_scores in tree_scores
414 | ]
415 | else:
416 | leafs_scores = [
417 | {leaf: query_scores[leaf].item()}
418 | for leaf, query_scores in zip(leafs, tree_scores)
419 | ]
420 |
421 | # We may not have k leafs for each query, so we take the minimum.
422 | if k_leafs > 1:
423 | k_leafs = min(
424 | min([len(query_leafs_scores) for query_leafs_scores in leafs_scores]),
425 | k_leafs,
426 | )
427 |
428 | candidates = {
429 | "leafs": np.array(
430 | object=[
431 | list(query_leafs_scores.keys())[:k_leafs]
432 | for query_leafs_scores in leafs_scores
433 | ]
434 | ),
435 | "scores": np.array(
436 | object=[
437 | list(query_leafs_scores.values())[:k_leafs]
438 | for query_leafs_scores in leafs_scores
439 | ]
440 | ),
441 | "tree_scores": tree_scores,
442 | }
443 |
444 | if not score_documents or not self.create_retrievers:
445 | return candidates
446 |
447 | if self.scoring.distinct_documents_encoder:
448 | queries_embeddings = self.scoring.encode_queries_for_retrieval(
449 | queries=queries,
450 | )
451 |
452 | leafs_queries, leafs_embeddings = (
453 | collections.defaultdict(list),
454 | collections.defaultdict(list),
455 | )
456 | for query, (query_leafs, embedding) in enumerate(
457 | iterable=zip(candidates["leafs"], queries_embeddings)
458 | ):
459 | for leaf in query_leafs:
460 | leafs_queries[leaf].append(query)
461 | leafs_embeddings[leaf].append(embedding)
462 |
463 | documents = collections.defaultdict(list)
464 | for leaf in leafs_queries:
465 | leaf_documents = self.mapping_leafs[leaf](
466 | queries_embeddings={
467 | query: embedding
468 | for query, embedding in enumerate(iterable=leafs_embeddings[leaf])
469 | },
470 | k=k,
471 | )
472 | for query, query_documents in zip(leafs_queries[leaf], leaf_documents):
473 | documents[query].extend(query_documents)
474 |
475 | # Sort documents if k > 1
476 | candidates["documents"] = [
477 | sorted(
478 | documents[query],
479 | key=lambda document: document["similarity"],
480 | reverse=True,
481 | )[: min(k, len(documents[query]))]
482 | if k_leafs > 1
483 | else documents[query]
484 | for query in range(len(queries))
485 | ]
486 |
487 | return candidates
488 |
489 | def empty(self) -> "Tree":
490 | """Empty the tree."""
491 | for node in self.nodes():
492 | if isinstance(node, Leaf):
493 | node.empty()
494 | return self
495 |
496 | def _search(
497 | self,
498 | queries: list[str],
499 | k_leafs: int = 1,
500 | leafs: list[int] | None = None,
501 | beam_search_depth: int = 1,
502 | queries_embeddings: torch.Tensor | np.ndarray | dict = None,
503 | batch_size: int = 32,
504 | tqdm_bar: bool = True,
505 | ) -> tuple[torch.Tensor, list, list]:
506 | """Search for the closest embedding with gradient.
507 |
508 | Parameters
509 | ----------
510 | queries:
511 | Queries to search for.
512 | embeddings:
513 | Embeddings to search for.
514 | """
515 | if queries_embeddings is None:
516 | queries_embeddings = self.scoring.transform_queries(
517 | queries=queries,
518 | batch_size=batch_size,
519 | tqdm_bar=tqdm_bar,
520 | )
521 |
522 | queries_embeddings = self.scoring.convert_to_tensor(
523 | embeddings=queries_embeddings, device=self.device
524 | )
525 |
526 | paths = (
527 | [(leaf, copy.copy(self.paths(leaf=leaf))) for leaf in leafs]
528 | if leafs is not None
529 | else None
530 | )
531 |
532 | tree_scores = self.tree.search(
533 | queries=[index for index, _ in enumerate(iterable=queries)],
534 | queries_embeddings=queries_embeddings,
535 | scoring=self.scoring,
536 | k=k_leafs,
537 | paths=paths,
538 | beam_search_depth=beam_search_depth,
539 | )
540 |
541 | return list(tree_scores.values())
542 |
543 | @torch.no_grad()
544 | def add(
545 | self,
546 | documents: list,
547 | documents_embeddings: np.ndarray | sparse.csr_matrix | dict = None,
548 | k: int = 1,
549 | documents_to_leafs: dict = None,
550 | batch_size: int = 32,
551 | tqdm_bar: bool = True,
552 | ) -> "Tree":
553 | """Add documents to the tree.
554 |
555 | Parameters
556 | ----------
557 | documents:
558 | Documents to add to the tree.
559 | embeddings:
560 | Embeddings of the documents.
561 | k:
562 | Number of leafs to add the documents to.
563 | """
564 | if documents_embeddings is None:
565 | documents_embeddings = self.scoring.transform_documents(
566 | documents=documents,
567 | batch_size=batch_size,
568 | tqdm_bar=tqdm_bar,
569 | )
570 |
571 | if documents_to_leafs is None:
572 | leafs = self(
573 | queries=[document[self.key] for document in documents],
574 | queries_embeddings=documents_embeddings,
575 | k=k,
576 | score_documents=False,
577 | tqdm_bar=False,
578 | )["leafs"].tolist()
579 | else:
580 | leafs = [documents_to_leafs[document[self.key]] for document in documents]
581 |
582 | documents_leafs, embeddings_leafs = (
583 | collections.defaultdict(list),
584 | collections.defaultdict(dict),
585 | )
586 |
587 | for document, embedding, document_leafs in zip(
588 | documents, documents_embeddings, leafs
589 | ):
590 | for leaf in document_leafs:
591 | documents_leafs[leaf].append(document)
592 | embeddings_leafs[leaf][document[self.key]] = embedding
593 |
594 | for leaf, embeddings in embeddings_leafs.items():
595 | self.mapping_leafs[leaf].add(
596 | documents=documents_leafs[leaf],
597 | documents_embeddings=None
598 | if self.scoring.distinct_documents_encoder
599 | else embeddings_leafs[leaf],
600 | scoring=self.scoring,
601 | )
602 |
603 | self.documents_to_leafs, self.leafs_to_documents = self.get_documents_leafs()
604 | self.negative_samples = self.get_negative_samples()
605 | return self
606 |
607 | def loss(
608 | self,
609 | queries: list[str],
610 | documents: list[dict],
611 | batch_size: int = 32,
612 | ) -> None:
613 | """Computes the loss of the tree given the input batch.
614 |
615 | Parameters
616 | ----------
617 | queries_embeddings:
618 | Embeddings of the queries.
619 | documents:
620 | Documents ids that where added to the tree.
621 | """
622 | leafs = [
623 | random.choice(seq=list(self.documents_to_leafs[document[self.key]]))
624 | for document in documents
625 | ]
626 |
627 | tree_scores = self._search(
628 | queries=queries,
629 | k_leafs=1,
630 | leafs=leafs,
631 | batch_size=batch_size,
632 | tqdm_bar=False,
633 | )
634 |
635 | loss, size = 0, 0
636 | cross_entropy = torch.nn.CrossEntropyLoss()
637 | for leaf, query_scores in zip(leafs, tree_scores):
638 | query_level_scores = [query_scores[leaf]]
639 |
640 | for negative_node in self.negative_samples[leaf]:
641 | query_level_scores.append(query_scores[negative_node])
642 |
643 | query_level_scores = torch.stack(
644 | tensors=query_level_scores, dim=0
645 | ).unsqueeze(dim=0)
646 |
647 | size += 1
648 | loss += cross_entropy(
649 | query_level_scores,
650 | torch.zeros(
651 | query_level_scores.shape[0],
652 | device=self.device,
653 | dtype=torch.long,
654 | ),
655 | )
656 |
657 | for node in copy.copy(self.paths(leaf=leaf)):
658 | query_level_scores = [query_scores[node]]
659 | for negative_node in self.negative_samples[node]:
660 | query_level_scores.append(query_scores[negative_node])
661 |
662 | query_level_scores = torch.stack(
663 | tensors=query_level_scores, dim=0
664 | ).unsqueeze(dim=0)
665 |
666 | size += 1
667 | loss += cross_entropy(
668 | query_level_scores,
669 | torch.zeros(
670 | query_level_scores.shape[0],
671 | device=self.device,
672 | dtype=torch.long,
673 | ),
674 | )
675 |
676 | return loss / size
677 |
678 | def to_json(self) -> dict:
679 | """Return the tree as a graph."""
680 | return {self.node_name: self.tree.to_json()}
681 |
--------------------------------------------------------------------------------
/neural_tree/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .batchify import batchify
2 | from .evaluate import evaluate, leafs_precision
3 | from .iter import iter
4 | from .sanity_check import sanity_check
5 | from .set_env import set_env
6 |
7 | __all__ = [
8 | "batchify",
9 | "evaluate",
10 | "leafs_precision",
11 | "iter",
12 | "sanity_check",
13 | "set_env",
14 | ]
15 |
--------------------------------------------------------------------------------
/neural_tree/utils/batchify.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 |
3 | __all__ = ["batchify"]
4 |
5 |
6 | def batchify(
7 | X: list[str], batch_size: int, desc: str = "", tqdm_bar: bool = True
8 | ) -> list:
9 | batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)]
10 |
11 | if tqdm_bar:
12 | for batch in tqdm.tqdm(
13 | batchs,
14 | position=0,
15 | total=1 + len(X) // batch_size,
16 | desc=desc,
17 | ):
18 | yield batch
19 | else:
20 | yield from batchs
21 |
--------------------------------------------------------------------------------
/neural_tree/utils/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | __all__ = ["evaluate", "leafs_precision"]
4 |
5 |
6 | def leafs_precision(
7 | key: str,
8 | documents: list,
9 | leafs: np.ndarray,
10 | documents_to_leaf: dict,
11 | ) -> float:
12 | """Calculate the precision of the leafs."""
13 | recall = 0
14 | for leafs_query, document in zip(leafs.tolist(), documents):
15 | for leaf_document in documents_to_leaf[document[key]]:
16 | if leafs_query[0] == leaf_document:
17 | recall += 1
18 | break
19 | return recall / len(leafs)
20 |
21 |
22 | def evaluate(
23 | scores: list[list[dict]],
24 | qrels: dict,
25 | queries_ids: list[str],
26 | metrics: list = [],
27 | key: str = "id",
28 | ) -> dict[str, float]:
29 | """Evaluate candidates matchs.
30 |
31 | Parameters
32 | ----------
33 | matchs
34 | Matchs.
35 | qrels
36 | Qrels.
37 | queries
38 | index of queries of qrels.
39 | k
40 | Number of documents to retrieve.
41 | metrics
42 | Metrics to compute.
43 |
44 | Examples
45 | --------
46 | >>> from neural_cherche import models, retrieve, utils
47 | >>> import torch
48 |
49 | >>> _ = torch.manual_seed(42)
50 |
51 | >>> model = models.Splade(
52 | ... model_name_or_path="distilbert-base-uncased",
53 | ... device="cpu",
54 | ... )
55 |
56 | >>> documents, queries_ids, queries, qrels = utils.load_beir(
57 | ... "scifact",
58 | ... split="test",
59 | ... )
60 |
61 | >>> documents = documents[:10]
62 |
63 | >>> retriever = retrieve.Splade(
64 | ... key="id",
65 | ... on=["title", "text"],
66 | ... model=model
67 | ... )
68 |
69 | >>> documents_embeddings = retriever.encode_documents(
70 | ... documents=documents,
71 | ... batch_size=1,
72 | ... )
73 |
74 | >>> documents_embeddings = retriever.add(
75 | ... documents_embeddings=documents_embeddings,
76 | ... )
77 |
78 | >>> queries_embeddings = retriever.encode_queries(
79 | ... queries=queries,
80 | ... batch_size=1,
81 | ... )
82 |
83 | >>> scores = retriever(
84 | ... queries_embeddings=queries_embeddings,
85 | ... k=30,
86 | ... batch_size=1,
87 | ... )
88 |
89 | >>> utils.evaluate(
90 | ... scores=scores,
91 | ... qrels=qrels,
92 | ... queries_ids=queries_ids,
93 | ... metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"]
94 | ... )
95 | {'map': 0.0033333333333333335, 'ndcg@10': 0.0033333333333333335, 'ndcg@100': 0.0033333333333333335, 'recall@10': 0.0033333333333333335, 'recall@100': 0.0033333333333333335}
96 |
97 | """
98 | from ranx import Qrels, Run
99 | from ranx import evaluate as ranx_evaluate
100 |
101 | qrels = Qrels(qrels=qrels)
102 |
103 | run_dict = {
104 | id_query: {
105 | match[key]: 1 - (rank / len(query_matchs))
106 | for rank, match in enumerate(iterable=query_matchs)
107 | }
108 | for id_query, query_matchs in zip(queries_ids, scores)
109 | }
110 |
111 | run = Run(run=run_dict)
112 |
113 | if not metrics:
114 | metrics = ["ndcg@10"] + [f"hits@{k}" for k in [1, 2, 3, 4, 5, 10]]
115 |
116 | return ranx_evaluate(
117 | qrels=qrels,
118 | run=run,
119 | metrics=metrics,
120 | make_comparable=True,
121 | )
122 |
--------------------------------------------------------------------------------
/neural_tree/utils/iter.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | import typing
4 |
5 | import tqdm
6 |
7 | from .batchify import batchify
8 |
9 | __all__ = ["iter"]
10 |
11 |
12 | def iter(
13 | queries, documents, batch_size=512, epochs: int = 1, shuffle=True, tqdm_bar=True
14 | ) -> typing.Generator:
15 | """Iterate over the dataset.
16 |
17 | Parameters
18 | ----------
19 | queries
20 | List of queries paired with documents.
21 | documents
22 | List of documents paired with queries.
23 | batch_size
24 | Size of the batch.
25 | epochs
26 | Number of epochs.
27 | """
28 | step = 0
29 | queries = copy.deepcopy(x=queries)
30 | documents = copy.deepcopy(x=documents)
31 |
32 | bar = tqdm.tqdm(iterable=range(epochs), position=0) if tqdm_bar else range(epochs)
33 |
34 | for _ in bar:
35 | if shuffle:
36 | queries_documents = list(zip(queries, documents))
37 | random.shuffle(x=queries_documents)
38 | queries, documents = zip(*queries_documents)
39 |
40 | for batch_queries, batch_documents in zip(
41 | batchify(X=queries, batch_size=batch_size, tqdm_bar=False),
42 | batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
43 | ):
44 | yield step, batch_queries, batch_documents
45 | step += 1
46 |
--------------------------------------------------------------------------------
/neural_tree/utils/sanity_check.py:
--------------------------------------------------------------------------------
1 | __all__ = ["sanity_check"]
2 |
3 |
4 | def sanity_check(
5 | branch_balance_factor: int, leaf_balance_factor: int, graph: dict, documents: list
6 | ) -> None:
7 | """Check if the input is valid."""
8 | if branch_balance_factor < 2:
9 | raise ValueError("Branch balance factor must be greater than 1.")
10 |
11 | if leaf_balance_factor < 1:
12 | raise ValueError("Leaf balance factor must be greater than 0.")
13 |
14 | if graph is not None:
15 | if len(graph.keys()) > 1:
16 | raise ValueError("Graph must have only one root node.")
17 |
18 | if documents is None and graph is None:
19 | raise ValueError("You must provide either documents or an existing graph.")
20 |
--------------------------------------------------------------------------------
/neural_tree/utils/set_env.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | __all__ = ["set_env"]
4 |
5 |
6 | def set_env() -> None:
7 | """Set environment variables."""
8 | os.environ["HF_HOME"] = os.environ["HOME"] + "/.cache/huggingface"
9 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
10 |
--------------------------------------------------------------------------------
/neural_tree/version.txt:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | filterwarnings =
3 | ignore::DeprecationWarning
4 | ignore::RuntimeWarning
5 | ignore::UserWarning
6 | addopts =
7 | --doctest-modules
8 | --verbose
9 | -ra
10 | --cov-config=.coveragerc
11 | -m "not web and not slow"
12 | doctest_optionflags = NORMALIZE_WHITESPACE NUMBER
13 | norecursedirs =
14 | build
15 | docs
16 | node_modules
17 | markers =
18 | web: tests that require using the Internet
19 | slow: tests that take a long time to run
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | # Inside of setup.cfg
2 | [metadata]
3 | description-file = README.md
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | from neural_tree.__version__ import __version__
4 |
5 | with open(file="README.md", mode="r", encoding="utf-8") as fh:
6 | long_description = fh.read()
7 |
8 | base_packages = [
9 | "torch >= 1.13",
10 | "tqdm >= 4.66",
11 | "transformers >= 4.34.0",
12 | "sentence-transformers >= 2.2.2",
13 | "neural-cherche >= 1.1.0",
14 | "scikit-learn >= 1.4.0",
15 | ]
16 |
17 | eval = ["ranx >= 0.3.16", "beir >= 2.0.0"]
18 |
19 | dev = ["mkdocs-material == 9.2.8"]
20 |
21 | setuptools.setup(
22 | name="neural-tree",
23 | version=f"{__version__}",
24 | license="MIT",
25 | author="Raphael Sourty",
26 | author_email="raphael.sourty@gmail.com",
27 | description="Neural-Tree",
28 | long_description=long_description,
29 | long_description_content_type="text/markdown",
30 | url="https://github.com/raphaelsty/neural-tree",
31 | download_url="https://github.com/user/neural-tree/archive/v_01.tar.gz",
32 | keywords=[
33 | "tree search",
34 | "neural search",
35 | "information retrieval",
36 | "semantic search",
37 | "colbert",
38 | "tree",
39 | ],
40 | packages=setuptools.find_packages(),
41 | install_requires=base_packages,
42 | extras_require={"eval": base_packages + eval, "dev": base_packages + eval + dev},
43 | classifiers=[
44 | "Programming Language :: Python :: 3",
45 | "License :: OSI Approved :: MIT License",
46 | "Operating System :: OS Independent",
47 | ],
48 | python_requires=">=3.6",
49 | )
50 |
--------------------------------------------------------------------------------