├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── docs ├── .pages ├── CNAME ├── api │ ├── .pages │ ├── clustering │ │ ├── .pages │ │ ├── KMeans.md │ │ ├── average.md │ │ ├── get-mapping-nodes-documents.md │ │ └── optimize-leafs.md │ ├── datasets │ │ ├── .pages │ │ ├── load-beir-test.md │ │ ├── load-beir-train.md │ │ └── load-beir.md │ ├── leafs │ │ ├── .pages │ │ └── Leaf.md │ ├── nodes │ │ ├── .pages │ │ └── Node.md │ ├── overview.md │ ├── retrievers │ │ ├── .pages │ │ ├── ColBERT.md │ │ ├── SentenceTransformer.md │ │ └── TfIdf.md │ ├── scoring │ │ ├── .pages │ │ ├── BaseScore.md │ │ ├── ColBERT.md │ │ ├── SentenceTransformer.md │ │ └── TfIdf.md │ ├── trees │ │ ├── .pages │ │ ├── ColBERT.md │ │ ├── SentenceTransformer.md │ │ ├── TfIdf.md │ │ └── Tree.md │ └── utils │ │ ├── .pages │ │ ├── batchify.md │ │ ├── evaluate.md │ │ ├── iter.md │ │ ├── leafs-precision.md │ │ ├── sanity-check.md │ │ └── set-env.md ├── css │ └── version-select.css ├── evaluate │ ├── .pages │ └── evaluate.md ├── existing_tree │ ├── .pages │ └── existing_tree.md ├── img │ ├── logo.png │ └── neural_tree.png ├── index.md ├── javascripts │ └── config.js ├── js │ └── version-select.js ├── scripts │ └── index.py ├── stylesheets │ └── extra.css └── trees │ ├── .pages │ ├── colbert.md │ ├── sentence_transformer.md │ └── tfidf.md ├── mkdocs.yml ├── neural_tree ├── __init__.py ├── __version__.py ├── clustering │ ├── __init__.py │ ├── average.py │ ├── kmeans.py │ └── optimize.py ├── datasets │ ├── __init__.py │ └── beir.py ├── leafs │ ├── __init__.py │ └── leaf.py ├── nodes │ ├── __init__.py │ └── node.py ├── retrievers │ ├── __init__.py │ ├── colbert.py │ ├── sentence_transformer.py │ └── tfidf.py ├── scoring │ ├── __init__.py │ ├── base.py │ ├── colbert.py │ ├── sentence_transformer.py │ └── tfidf.py ├── trees │ ├── __init__.py │ ├── colbert.py │ ├── sentence_transformer.py │ ├── tfidf.py │ └── tree.py ├── utils │ ├── __init__.py │ ├── batchify.py │ ├── evaluate.py │ ├── iter.py │ ├── sanity_check.py │ └── set_env.py └── version.txt ├── pytest.ini ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Raphael Sourty 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | livedoc: 2 | mkdocs build --clean 3 | mkdocs serve --dirtyreload 4 | 5 | deploydoc: 6 | mkdocs gh-deploy --force 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

Neural-Tree

4 |

Neural Search

5 |
6 | 7 |

8 | 9 |
10 | 11 | documentation 12 | 13 | license 14 |
15 | 16 |

17 | 18 | Are tree-based indexes the counterpart of standard ANN algorithms for token-level embeddings IR models? Neural-Tree replicate the SIGIR 2023 publication [Constructing Tree-based Index for Efficient and Effective Dense Retrieval](https://dl.acm.org/doi/10.1145/3539618.3591651) in order to accelerate ColBERT. Neural-Tree is compatible with Sentence Transformers and TfIdf models as in the original paper. 19 | 20 | Neural-Tree creates a tree using hierarchical clustering of documents and then learn embeddings in each node of the tree using paired queries and documents. Additionally, there is the flexibility to input an existing tree structure in JSON format to build the index. 21 | 22 | The optimization of the index by Neural-Tree is geared towards maintaining the performance level of the original model while significantly speeding up the search process. It is important to note that Neural-Tree does not modify the underlying model; therefore, it is advisable to initiate tree creation with a model that has already been fine-tuned. Given that Neural-Tree does not alter the model, the index training process is relatively quick. 23 | 24 | ## Installation 25 | 26 | We can install neural-tree using: 27 | 28 | ``` 29 | pip install neural-tree 30 | ``` 31 | 32 | If we plan to evaluate our model while training install: 33 | 34 | ``` 35 | pip install "neural-tree[eval]" 36 | ``` 37 | 38 | ## Documentation 39 | 40 | The complete documentation is available [here](https://raphaelsty.github.io/neural-tree/). 41 | 42 | 43 | ## Quick Start 44 | 45 | The following code shows how to train a tree index. Let's start by creating a fictional dataset: 46 | 47 | ```python 48 | documents = [ 49 | {"id": 0, "content": "paris"}, 50 | {"id": 1, "content": "london"}, 51 | {"id": 2, "content": "berlin"}, 52 | {"id": 3, "content": "rome"}, 53 | {"id": 4, "content": "bordeaux"}, 54 | {"id": 5, "content": "milan"}, 55 | ] 56 | 57 | train_queries = [ 58 | "paris is the capital of france", 59 | "london is the capital of england", 60 | "berlin is the capital of germany", 61 | "rome is the capital of italy", 62 | ] 63 | 64 | train_documents = [ 65 | {"id": 0, "content": "paris"}, 66 | {"id": 1, "content": "london"}, 67 | {"id": 2, "content": "berlin"}, 68 | {"id": 3, "content": "rome"}, 69 | ] 70 | 71 | test_queries = [ 72 | "bordeaux is the capital of france", 73 | "milan is the capital of italy", 74 | ] 75 | ``` 76 | 77 | Let's train the index using the `documents`, `train_queries` and `train_documents` we have gathered. 78 | 79 | ```python 80 | import torch 81 | from neural_cherche import models 82 | from neural_tree import clustering, trees, utils 83 | 84 | model = models.ColBERT( 85 | model_name_or_path="raphaelsty/neural-cherche-colbert", 86 | device="cuda" if torch.cuda.is_available() else "cpu", 87 | ) 88 | 89 | tree = trees.ColBERT( 90 | key="id", 91 | on=["content"], 92 | model=model, 93 | documents=documents, 94 | leaf_balance_factor=100, # Number of documents per leaf 95 | branch_balance_factor=5, # Number children per node 96 | n_jobs=-1, # set to 1 with Google Colab 97 | ) 98 | 99 | optimizer = torch.optim.AdamW(lr=3e-3, params=list(tree.parameters())) 100 | 101 | for step, batch_queries, batch_documents in utils.iter( 102 | queries=train_queries, 103 | documents=train_documents, 104 | shuffle=True, 105 | epochs=50, 106 | batch_size=32, 107 | ): 108 | loss = tree.loss( 109 | queries=batch_queries, 110 | documents=batch_documents, 111 | ) 112 | 113 | loss.backward() 114 | optimizer.step() 115 | optimizer.zero_grad(set_to_none=True) 116 | ``` 117 | 118 | 119 | Let's now duplicate some documents of the tree in order to increase accuracy. 120 | 121 | ```python 122 | documents_to_leafs = clustering.optimize_leafs( 123 | tree=tree, 124 | queries=train_queries + test_queries, 125 | documents=documents, 126 | ) 127 | 128 | tree = tree.add( 129 | documents=documents, 130 | documents_to_leafs=documents_to_leafs, 131 | ) 132 | ``` 133 | 134 | We are now ready to retrieve documents: 135 | 136 | ```python 137 | scores = tree( 138 | queries=["bordeaux", "milan"], 139 | k_leafs=2, 140 | k=2, 141 | ) 142 | 143 | print(scores["documents"]) 144 | ``` 145 | 146 | ```python 147 | [ 148 | [ 149 | {"id": 4, "similarity": 5.28, "leaf": "12"}, 150 | {"id": 0, "similarity": 3.17, "leaf": "12"}, 151 | ], 152 | [ 153 | {"id": 5, "similarity": 5.11, "leaf": "10"}, 154 | {"id": 2, "similarity": 3.57, "leaf": "10"}, 155 | ], 156 | ] 157 | ``` 158 | 159 | ## Benchmarks 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 |
Scifact Dataset
VanillaNeural-Tree
modelHuggingFace Checkpointndcg@10hits@10hits@1queries / secondndcg@10hits@10hits@1queries / secondAcceleration
TfIdf
Cherche
-0,610,850,477600,560,820,421080+42.11%
SentenceTransformer GPU
Faiss.IndexFlatL2 CPU
sentence-transformers/all-mpnet-base-v20,660,890,534750,660,880,53518+9.05%
ColBERT
Neural-Cherche GPU
raphaelsty/neural-cherche-colbert0,700,920,5830,700,910,59256x85
227 | 228 | Note that this benchmark do not implement [ColBERTV2](https://arxiv.org/abs/2112.01488) efficient retrieval but rather compare ColBERT raw retrieval with Neural-Tree. We could accelerate SentenceTransformer vanilla by using optimized Faiss index. 229 | 230 | ## Contributing 231 | 232 | We welcome contributions to Neural-Tree. Our focus includes improving the clustering of ColBERT embeddings which is currently handled by TfIdf. Neural-Cherche will also be a tool designed to enhance tree visualization, extract nodes topics, and leverage the tree structure to accelerate Large Language Model (LLM) retrieval. 233 | 234 | ## License 235 | 236 | This project is licensed under the terms of the MIT license. 237 | 238 | ## References 239 | 240 | - [Constructing Tree-based Index for Efficient and Effective Dense Retrieval, Github](https://github.com/cshaitao/jtr) 241 | 242 | - [ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT](https://arxiv.org/abs/2004.12832) 243 | 244 | - [Myriade](https://github.com/MaxHalford/myriade) 245 | 246 | 247 | -------------------------------------------------------------------------------- /docs/.pages: -------------------------------------------------------------------------------- 1 | nav: 2 | - Home: trees 3 | - existing_tree 4 | - evaluate 5 | - api -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | raphaelsty.github.io/neural-tree/ -------------------------------------------------------------------------------- /docs/api/.pages: -------------------------------------------------------------------------------- 1 | title: API reference 2 | arrange: 3 | - overview.md 4 | - ... 5 | -------------------------------------------------------------------------------- /docs/api/clustering/.pages: -------------------------------------------------------------------------------- 1 | title: clustering -------------------------------------------------------------------------------- /docs/api/clustering/KMeans.md: -------------------------------------------------------------------------------- 1 | # KMeans 2 | 3 | KMeans clustering. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **documents_embeddings** (*torch.Tensor*) 10 | 11 | - **n_clusters** (*int*) 12 | 13 | - **max_iter** (*int*) 14 | 15 | - **n_init** (*int*) 16 | 17 | - **seed** (*int*) 18 | 19 | - **device** (*str*) 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /docs/api/clustering/average.md: -------------------------------------------------------------------------------- 1 | # average 2 | 3 | Replace KMeans clustering with average clustering when an existing graph is provided. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | - **documents** (*list*) 12 | 13 | - **documents_embeddings** (*numpy.ndarray | scipy.sparse._csr.csr_matrix*) 14 | 15 | - **graph** 16 | 17 | - **scoring** 18 | 19 | - **device** (*str*) 20 | 21 | 22 | 23 | ## Examples 24 | 25 | ```python 26 | >>> from neural_tree import clustering, scoring 27 | >>> import numpy as np 28 | 29 | >>> documents = [ 30 | ... {"id": 0, "text": "Paris is the capital of France."}, 31 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 32 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 33 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 34 | ... ] 35 | 36 | >>> documents_embeddings = np.array([ 37 | ... [1, 1], 38 | ... [1, 2], 39 | ... [10, 10], 40 | ... [1, 3], 41 | ... ]) 42 | 43 | >>> graph = {1: {11: {111: [{'id': 0}, {'id': 3}], 112: [{'id': 1}]}, 12: {121: [{'id': 2}], 122: [{'id': 3}]}}} 44 | 45 | >>> clustering.average( 46 | ... key="id", 47 | ... documents_embeddings=documents_embeddings, 48 | ... documents=documents, 49 | ... graph=graph[1], 50 | ... scoring=scoring.SentenceTransformer(key="id", on=["text"], model=None), 51 | ... ) 52 | ``` 53 | 54 | -------------------------------------------------------------------------------- /docs/api/clustering/get-mapping-nodes-documents.md: -------------------------------------------------------------------------------- 1 | # get_mapping_nodes_documents 2 | 3 | Get documents from specific node. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **graph** (*dict | list*) 10 | 11 | - **documents** (*list | None*) – defaults to `None` 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /docs/api/clustering/optimize-leafs.md: -------------------------------------------------------------------------------- 1 | # optimize_leafs 2 | 3 | Optimize the clusters. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **tree** 10 | 11 | - **documents** (*list[dict]*) 12 | 13 | - **queries** (*list[str]*) 14 | 15 | - **k_tree** (*int*) – defaults to `2` 16 | 17 | - **k_retriever** (*int*) – defaults to `10` 18 | 19 | - **k_leafs** (*int*) – defaults to `2` 20 | 21 | - **kwargs** 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/api/datasets/.pages: -------------------------------------------------------------------------------- 1 | title: datasets -------------------------------------------------------------------------------- /docs/api/datasets/load-beir-test.md: -------------------------------------------------------------------------------- 1 | # load_beir_test 2 | 3 | Load BEIR testing dataset. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **dataset_name** (*str*) 10 | 11 | Dataset name. 12 | 13 | 14 | 15 | ## Examples 16 | 17 | ```python 18 | >>> from neural_tree import datasets 19 | 20 | >>> documents, queries_ids, queries, qrels = datasets.load_beir_test( 21 | ... dataset_name="scifact", 22 | ... ) 23 | 24 | >>> len(documents) 25 | 5183 26 | 27 | >>> assert len(queries_ids) == len(queries) == len(qrels) 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /docs/api/datasets/load-beir-train.md: -------------------------------------------------------------------------------- 1 | # load_beir_train 2 | 3 | Load training dataset. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **dataset_name** (*str*) 10 | 11 | Dataset name 12 | 13 | 14 | 15 | ## Examples 16 | 17 | ```python 18 | >>> from neural_tree import datasets 19 | 20 | >>> documents, train_queries, train_documents = datasets.load_beir_train( 21 | ... dataset_name="scifact", 22 | ... ) 23 | 24 | >>> len(documents) 25 | 5183 26 | 27 | >>> assert len(train_queries) == len(train_documents) 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /docs/api/datasets/load-beir.md: -------------------------------------------------------------------------------- 1 | # load_beir 2 | 3 | Load BEIR dataset. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **dataset_name** (*str*) 10 | 11 | - **split** (*str*) 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /docs/api/leafs/.pages: -------------------------------------------------------------------------------- 1 | title: leafs -------------------------------------------------------------------------------- /docs/api/nodes/.pages: -------------------------------------------------------------------------------- 1 | title: nodes -------------------------------------------------------------------------------- /docs/api/overview.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | ## clustering 4 | 5 | - [KMeans](../clustering/KMeans) 6 | - [average](../clustering/average) 7 | - [get_mapping_nodes_documents](../clustering/get-mapping-nodes-documents) 8 | - [optimize_leafs](../clustering/optimize-leafs) 9 | 10 | ## datasets 11 | 12 | - [load_beir](../datasets/load-beir) 13 | - [load_beir_test](../datasets/load-beir-test) 14 | - [load_beir_train](../datasets/load-beir-train) 15 | 16 | ## leafs 17 | 18 | - [Leaf](../leafs/Leaf) 19 | 20 | ## nodes 21 | 22 | - [Node](../nodes/Node) 23 | 24 | ## retrievers 25 | 26 | - [ColBERT](../retrievers/ColBERT) 27 | - [SentenceTransformer](../retrievers/SentenceTransformer) 28 | - [TfIdf](../retrievers/TfIdf) 29 | 30 | ## scoring 31 | 32 | - [BaseScore](../scoring/BaseScore) 33 | - [ColBERT](../scoring/ColBERT) 34 | - [SentenceTransformer](../scoring/SentenceTransformer) 35 | - [TfIdf](../scoring/TfIdf) 36 | 37 | ## trees 38 | 39 | - [ColBERT](../trees/ColBERT) 40 | - [SentenceTransformer](../trees/SentenceTransformer) 41 | - [TfIdf](../trees/TfIdf) 42 | - [Tree](../trees/Tree) 43 | 44 | ## utils 45 | 46 | - [batchify](../utils/batchify) 47 | - [evaluate](../utils/evaluate) 48 | - [iter](../utils/iter) 49 | - [leafs_precision](../utils/leafs-precision) 50 | - [sanity_check](../utils/sanity-check) 51 | - [set_env](../utils/set-env) 52 | 53 | -------------------------------------------------------------------------------- /docs/api/retrievers/.pages: -------------------------------------------------------------------------------- 1 | title: retrievers -------------------------------------------------------------------------------- /docs/api/retrievers/ColBERT.md: -------------------------------------------------------------------------------- 1 | # ColBERT 2 | 3 | ColBERT retriever. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | - **on** (*str | list[str]*) 12 | 13 | - **device** (*str*) 14 | 15 | 16 | 17 | 18 | ## Methods 19 | 20 | ???- note "__call__" 21 | 22 | Rank documents givent queries. 23 | 24 | **Parameters** 25 | 26 | - **queries_embeddings** (*dict[str, torch.Tensor]*) 27 | - **batch_size** (*int*) – defaults to `32` 28 | - **k** (*int*) – defaults to `None` 29 | - **tqdm_bar** (*bool*) – defaults to `False` 30 | 31 | ???- note "add" 32 | 33 | Add documents embeddings. 34 | 35 | **Parameters** 36 | 37 | - **documents_embeddings** (*dict*) 38 | 39 | ???- note "encode_documents" 40 | 41 | Encode documents. 42 | 43 | **Parameters** 44 | 45 | - **documents** (*list[str]*) 46 | - **batch_size** (*int*) – defaults to `32` 47 | - **tqdm_bar** (*bool*) – defaults to `True` 48 | - **query_mode** (*bool*) – defaults to `False` 49 | - **kwargs** 50 | 51 | ???- note "encode_queries" 52 | 53 | Encode queries. 54 | 55 | **Parameters** 56 | 57 | - **queries** (*list[str]*) 58 | - **batch_size** (*int*) – defaults to `32` 59 | - **tqdm_bar** (*bool*) – defaults to `True` 60 | - **query_mode** (*bool*) – defaults to `True` 61 | - **kwargs** 62 | 63 | -------------------------------------------------------------------------------- /docs/api/retrievers/SentenceTransformer.md: -------------------------------------------------------------------------------- 1 | # SentenceTransformer 2 | 3 | Sentence Transformer retriever. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | - **device** (*str*) – defaults to `cpu` 12 | 13 | 14 | 15 | ## Examples 16 | 17 | ```python 18 | >>> from neural_tree import retrievers 19 | >>> from sentence_transformers import SentenceTransformer 20 | >>> from pprint import pprint 21 | 22 | >>> model = SentenceTransformer("all-mpnet-base-v2") 23 | 24 | >>> retriever = retrievers.SentenceTransformer(key="id") 25 | 26 | >>> retriever = retriever.add( 27 | ... documents_embeddings={ 28 | ... 0: model.encode("Paris is the capital of France."), 29 | ... 1: model.encode("Berlin is the capital of Germany."), 30 | ... 2: model.encode("Paris and Berlin are European cities."), 31 | ... 3: model.encode("Paris and Berlin are beautiful cities."), 32 | ... } 33 | ... ) 34 | 35 | >>> queries_embeddings = { 36 | ... 0: model.encode("Paris"), 37 | ... 1: model.encode("Berlin"), 38 | ... } 39 | 40 | >>> candidates = retriever(queries_embeddings=queries_embeddings, k=2) 41 | >>> pprint(candidates) 42 | [[{'id': 0, 'similarity': 0.644777984318611}, 43 | {'id': 3, 'similarity': 0.52865785276988}], 44 | [{'id': 1, 'similarity': 0.6901492368348436}, 45 | {'id': 3, 'similarity': 0.5457692206973245}]] 46 | ``` 47 | 48 | ## Methods 49 | 50 | ???- note "__call__" 51 | 52 | Retrieve documents. 53 | 54 | **Parameters** 55 | 56 | - **queries_embeddings** (*dict[int, numpy.ndarray]*) 57 | - **k** (*int | None*) – defaults to `100` 58 | - **kwargs** 59 | 60 | ???- note "add" 61 | 62 | Add documents to the faiss index. 63 | 64 | **Parameters** 65 | 66 | - **documents_embeddings** (*dict[int, numpy.ndarray]*) 67 | 68 | -------------------------------------------------------------------------------- /docs/api/retrievers/TfIdf.md: -------------------------------------------------------------------------------- 1 | # TfIdf 2 | 3 | TfIdf retriever 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | - **on** (*list[str]*) 12 | 13 | 14 | 15 | 16 | ## Methods 17 | 18 | ???- note "__call__" 19 | 20 | Retrieve documents from batch of queries. 21 | 22 | **Parameters** 23 | 24 | - **queries_embeddings** (*dict[str, scipy.sparse._csr.csr_matrix]*) 25 | - **k** (*int*) – defaults to `None` 26 | - **batch_size** (*int*) – defaults to `2000` 27 | - **tqdm_bar** (*bool*) – defaults to `True` 28 | 29 | ???- note "add" 30 | 31 | Add new documents to the TFIDF retriever. The tfidf won't be refitted. 32 | 33 | **Parameters** 34 | 35 | - **documents_embeddings** (*dict[str, scipy.sparse._csr.csr_matrix]*) 36 | 37 | ???- note "encode_documents" 38 | 39 | Encode queries into sparse matrix. 40 | 41 | **Parameters** 42 | 43 | - **documents** (*list[dict]*) 44 | - **model** (*sklearn.feature_extraction.text.TfidfVectorizer*) 45 | 46 | ???- note "encode_queries" 47 | 48 | Encode queries into sparse matrix. 49 | 50 | **Parameters** 51 | 52 | - **queries** (*list[str]*) 53 | - **model** (*sklearn.feature_extraction.text.TfidfVectorizer*) 54 | 55 | ???- note "top_k" 56 | 57 | Return the top k documents for each query. 58 | 59 | **Parameters** 60 | 61 | - **similarities** (*scipy.sparse._csc.csc_matrix*) 62 | - **k** (*int*) 63 | 64 | -------------------------------------------------------------------------------- /docs/api/scoring/.pages: -------------------------------------------------------------------------------- 1 | title: scoring -------------------------------------------------------------------------------- /docs/api/scoring/BaseScore.md: -------------------------------------------------------------------------------- 1 | # BaseScore 2 | 3 | Base class for scoring functions. 4 | 5 | 6 | 7 | 8 | 9 | 10 | ## Methods 11 | 12 | ???- note "convert_to_tensor" 13 | 14 | Transform sparse matrix to tensor. 15 | 16 | **Parameters** 17 | 18 | - **embeddings** (*scipy.sparse._csr.csr_matrix | numpy.ndarray*) 19 | - **device** (*str*) 20 | 21 | ???- note "distinct_documents_encoder" 22 | 23 | Return True if the encoder is distinct for documents and nodes. 24 | 25 | 26 | ???- note "encode_queries_for_retrieval" 27 | 28 | Encode queries for retrieval. 29 | 30 | **Parameters** 31 | 32 | - **queries** (*list[str]*) 33 | 34 | ???- note "get_retriever" 35 | 36 | Create a retriever 37 | 38 | 39 | ???- note "leaf_scores" 40 | 41 | Return the scores of the embeddings. 42 | 43 | **Parameters** 44 | 45 | - **queries_embeddings** (*torch.Tensor*) 46 | - **leaf_embedding** (*torch.Tensor*) 47 | 48 | ???- note "nodes_scores" 49 | 50 | Score between queries and nodes embeddings. 51 | 52 | **Parameters** 53 | 54 | - **queries_embeddings** (*torch.Tensor*) 55 | - **nodes_embeddings** (*torch.Tensor*) 56 | 57 | ???- note "stack" 58 | 59 | Stack list of embeddings. 60 | 61 | - **embeddings** (*list[scipy.sparse._csr.csr_matrix | numpy.ndarray | dict]*) 62 | 63 | ???- note "transform_documents" 64 | 65 | Transform documents to embeddings. 66 | 67 | **Parameters** 68 | 69 | - **documents** (*list[dict]*) 70 | 71 | ???- note "transform_queries" 72 | 73 | Transform queries to embeddings. 74 | 75 | **Parameters** 76 | 77 | - **queries** (*list[str]*) 78 | 79 | -------------------------------------------------------------------------------- /docs/api/scoring/ColBERT.md: -------------------------------------------------------------------------------- 1 | # ColBERT 2 | 3 | TfIdf scoring function. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | - **on** (*list | str*) 12 | 13 | - **documents** (*list*) 14 | 15 | - **model** (*neural_cherche.models.colbert.ColBERT*) – defaults to `None` 16 | 17 | - **device** (*str*) – defaults to `cpu` 18 | 19 | - **kwargs** 20 | 21 | 22 | ## Attributes 23 | 24 | - **distinct_documents_encoder** 25 | 26 | Return True if the encoder is distinct for documents and nodes. 27 | 28 | 29 | ## Examples 30 | 31 | ```python 32 | >>> from neural_tree import trees, scoring 33 | >>> from neural_cherche import models 34 | >>> from sklearn.feature_extraction.text import TfidfVectorizer 35 | >>> from pprint import pprint 36 | >>> import torch 37 | 38 | >>> _ = torch.manual_seed(42) 39 | 40 | >>> documents = [ 41 | ... {"id": 0, "text": "Paris is the capital of France."}, 42 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 43 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 44 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 45 | ... ] 46 | 47 | >>> model = models.ColBERT( 48 | ... model_name_or_path="sentence-transformers/all-mpnet-base-v2", 49 | ... embedding_size=128, 50 | ... max_length_document=96, 51 | ... max_length_query=32, 52 | ... ) 53 | 54 | >>> tree = trees.ColBERTTree( 55 | ... key="id", 56 | ... on="text", 57 | ... model=model, 58 | ... documents=documents, 59 | ... leaf_balance_factor=1, 60 | ... branch_balance_factor=2, 61 | ... n_jobs=1, 62 | ... ) 63 | 64 | >>> print(tree) 65 | node 1 66 | node 10 67 | leaf 100 68 | leaf 101 69 | node 11 70 | leaf 110 71 | leaf 111 72 | 73 | >>> tree.leafs_to_documents 74 | {'100': [0], '101': [1], '110': [2], '111': [3]} 75 | 76 | >>> candidates = tree( 77 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."], 78 | ... k_leafs=2, 79 | ... k=2, 80 | ... ) 81 | 82 | >>> candidates["scores"] 83 | array([[28.12037659, 18.32332611], 84 | [29.28324509, 21.38923264]]) 85 | 86 | >>> candidates["leafs"] 87 | array([['100', '101'], 88 | ['110', '111']], dtype='>> pprint(candidates["tree_scores"]) 91 | [{'10': tensor(28.1204), 92 | '100': tensor(28.1204), 93 | '101': tensor(18.3233), 94 | '11': tensor(20.9327)}, 95 | {'10': tensor(21.6886), 96 | '11': tensor(29.2832), 97 | '110': tensor(29.2832), 98 | '111': tensor(21.3892)}] 99 | 100 | >>> pprint(candidates["documents"]) 101 | [[{'id': 0, 'leaf': '100', 'similarity': 28.120376586914062}, 102 | {'id': 1, 'leaf': '101', 'similarity': 18.323326110839844}], 103 | [{'id': 2, 'leaf': '110', 'similarity': 29.283245086669922}, 104 | {'id': 3, 'leaf': '111', 'similarity': 21.389232635498047}]] 105 | ``` 106 | 107 | ## Methods 108 | 109 | ???- note "average" 110 | 111 | Average embeddings. 112 | 113 | - **embeddings** (*torch.Tensor*) 114 | 115 | ???- note "convert_to_tensor" 116 | 117 | Transform sparse matrix to tensor. 118 | 119 | **Parameters** 120 | 121 | - **embeddings** (*numpy.ndarray | torch.Tensor*) 122 | - **device** (*str*) 123 | 124 | ???- note "encode_queries_for_retrieval" 125 | 126 | Encode queries for retrieval. 127 | 128 | **Parameters** 129 | 130 | - **queries** (*list[str]*) 131 | 132 | ???- note "get_retriever" 133 | 134 | Create a retriever 135 | 136 | 137 | ???- note "leaf_scores" 138 | 139 | Return the scores of the embeddings. 140 | 141 | **Parameters** 142 | 143 | - **queries_embeddings** (*torch.Tensor*) 144 | - **leaf_embedding** (*torch.Tensor*) 145 | 146 | ???- note "nodes_scores" 147 | 148 | Score between queries and nodes embeddings. 149 | 150 | **Parameters** 151 | 152 | - **queries_embeddings** (*torch.Tensor*) 153 | - **nodes_embeddings** (*torch.Tensor*) 154 | 155 | ???- note "stack" 156 | 157 | Stack list of embeddings. 158 | 159 | **Parameters** 160 | 161 | - **embeddings** (*list[torch.Tensor | numpy.ndarray]*) 162 | 163 | ???- note "transform_documents" 164 | 165 | Transform documents to embeddings. 166 | 167 | **Parameters** 168 | 169 | - **documents** (*list[dict]*) 170 | - **batch_size** (*int*) 171 | - **tqdm_bar** (*bool*) 172 | - **kwargs** 173 | 174 | ???- note "transform_queries" 175 | 176 | Transform queries to embeddings. 177 | 178 | **Parameters** 179 | 180 | - **queries** (*list[str]*) 181 | - **batch_size** (*int*) 182 | - **tqdm_bar** (*bool*) 183 | - **kwargs** 184 | 185 | -------------------------------------------------------------------------------- /docs/api/scoring/SentenceTransformer.md: -------------------------------------------------------------------------------- 1 | # SentenceTransformer 2 | 3 | Sentence Transformer scoring function. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | - **on** (*str | list*) 12 | 13 | - **model** (*sentence_transformers.SentenceTransformer.SentenceTransformer*) 14 | 15 | - **device** (*str*) – defaults to `cpu` 16 | 17 | - **faiss_device** (*str*) – defaults to `cpu` 18 | 19 | 20 | ## Attributes 21 | 22 | - **distinct_documents_encoder** 23 | 24 | Return True if the encoder is distinct for documents and nodes. 25 | 26 | 27 | ## Examples 28 | 29 | ```python 30 | >>> from neural_tree import trees, scoring 31 | >>> from sentence_transformers import SentenceTransformer 32 | >>> from pprint import pprint 33 | 34 | >>> documents = [ 35 | ... {"id": 0, "text": "Paris is the capital of France."}, 36 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 37 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 38 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 39 | ... ] 40 | 41 | >>> tree = trees.Tree( 42 | ... key="id", 43 | ... documents=documents, 44 | ... scoring=scoring.SentenceTransformer(key="id", on=["text"], model=SentenceTransformer("all-mpnet-base-v2")), 45 | ... leaf_balance_factor=1, 46 | ... branch_balance_factor=2, 47 | ... n_jobs=1, 48 | ... ) 49 | 50 | >>> print(tree) 51 | node 1 52 | node 11 53 | node 110 54 | leaf 1100 55 | leaf 1101 56 | leaf 111 57 | leaf 10 58 | 59 | >>> candidates = tree( 60 | ... queries=["paris", "berlin"], 61 | ... k_leafs=2, 62 | ... ) 63 | 64 | >>> candidates["scores"] 65 | array([[0.72453916, 0.60635257], 66 | [0.58386189, 0.57546711]]) 67 | 68 | >>> candidates["leafs"] 69 | array([['111', '10'], 70 | ['1101', '1100']], dtype='>> pprint(candidates["tree_scores"]) 73 | [{'10': tensor(0.6064), 74 | '11': tensor(0.7245), 75 | '110': tensor(0.5542), 76 | '1100': tensor(0.5403), 77 | '1101': tensor(0.5542), 78 | '111': tensor(0.7245)}, 79 | {'10': tensor(0.5206), 80 | '11': tensor(0.5797), 81 | '110': tensor(0.5839), 82 | '1100': tensor(0.5755), 83 | '1101': tensor(0.5839), 84 | '111': tensor(0.4026)}] 85 | 86 | >>> pprint(candidates["documents"]) 87 | [[{'id': 0, 'leaf': '111', 'similarity': 0.6447779347587058}, 88 | {'id': 1, 'leaf': '10', 'similarity': 0.43175890864117644}], 89 | [{'id': 3, 'leaf': '1101', 'similarity': 0.545769273959571}, 90 | {'id': 2, 'leaf': '1100', 'similarity': 0.54081365990618}]] 91 | ``` 92 | 93 | ## Methods 94 | 95 | ???- note "average" 96 | 97 | Average embeddings. 98 | 99 | - **embeddings** (*numpy.ndarray*) 100 | 101 | ???- note "convert_to_tensor" 102 | 103 | Convert numpy array to torch tensor. 104 | 105 | **Parameters** 106 | 107 | - **embeddings** (*numpy.ndarray*) 108 | - **device** (*str*) 109 | 110 | ???- note "encode_queries_for_retrieval" 111 | 112 | Encode queries for retrieval. 113 | 114 | - **queries** (*list[str]*) 115 | 116 | ???- note "get_retriever" 117 | 118 | Create a retriever 119 | 120 | 121 | ???- note "leaf_scores" 122 | 123 | Computes scores between query and leaf embedding. 124 | 125 | **Parameters** 126 | 127 | - **queries_embeddings** (*torch.Tensor*) 128 | - **leaf_embedding** (*torch.Tensor*) 129 | 130 | ???- note "nodes_scores" 131 | 132 | Score between queries and nodes embeddings. 133 | 134 | **Parameters** 135 | 136 | - **queries_embeddings** (*torch.Tensor*) 137 | - **nodes_embeddings** (*torch.Tensor*) 138 | 139 | ???- note "stack" 140 | 141 | Stack embeddings. 142 | 143 | - **embeddings** (*list[numpy.ndarray]*) 144 | 145 | ???- note "transform_documents" 146 | 147 | Transform documents to embeddings. 148 | 149 | **Parameters** 150 | 151 | - **documents** (*list[dict]*) 152 | - **batch_size** (*int*) 153 | - **kwargs** 154 | 155 | ???- note "transform_queries" 156 | 157 | Transform queries to embeddings. 158 | 159 | **Parameters** 160 | 161 | - **queries** (*list[str]*) 162 | - **batch_size** (*int*) 163 | - **kwargs** 164 | 165 | -------------------------------------------------------------------------------- /docs/api/scoring/TfIdf.md: -------------------------------------------------------------------------------- 1 | # TfIdf 2 | 3 | TfIdf scoring function. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | - **on** (*list | str*) 12 | 13 | - **documents** (*list*) 14 | 15 | - **tfidf_nodes** (*sklearn.feature_extraction.text.TfidfVectorizer | None*) – defaults to `None` 16 | 17 | - **tfidf_documents** (*sklearn.feature_extraction.text.TfidfVectorizer | None*) – defaults to `None` 18 | 19 | - **kwargs** 20 | 21 | 22 | ## Attributes 23 | 24 | - **distinct_documents_encoder** 25 | 26 | Return True if the encoder is distinct for documents and nodes. 27 | 28 | 29 | ## Examples 30 | 31 | ```python 32 | >>> from neural_tree import trees, scoring 33 | >>> from sklearn.feature_extraction.text import TfidfVectorizer 34 | >>> from pprint import pprint 35 | 36 | >>> documents = [ 37 | ... {"id": 0, "text": "Paris is the capital of France."}, 38 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 39 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 40 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 41 | ... ] 42 | 43 | >>> tree = trees.Tree( 44 | ... key="id", 45 | ... documents=documents, 46 | ... scoring=scoring.TfIdf(key="id", on=["text"], documents=documents), 47 | ... leaf_balance_factor=1, 48 | ... branch_balance_factor=2, 49 | ... ) 50 | 51 | >>> print(tree) 52 | node 1 53 | node 10 54 | leaf 100 55 | leaf 101 56 | node 11 57 | leaf 110 58 | leaf 111 59 | 60 | >>> tree.leafs_to_documents 61 | {'100': [0], '101': [1], '110': [2], '111': [3]} 62 | 63 | >>> candidates = tree( 64 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."], 65 | ... k_leafs=2, 66 | ... k=2, 67 | ... ) 68 | 69 | >>> candidates["scores"] 70 | array([[0.99999994, 0.63854915], 71 | [0.99999994, 0.72823119]]) 72 | 73 | >>> candidates["leafs"] 74 | array([['100', '101'], 75 | ['110', '111']], dtype='>> pprint(candidates["tree_scores"]) 78 | [{'10': tensor(1.0000), 79 | '100': tensor(1.0000), 80 | '101': tensor(0.6385), 81 | '11': tensor(0.1076)}, 82 | {'10': tensor(0.1076), 83 | '11': tensor(1.0000), 84 | '110': tensor(1.0000), 85 | '111': tensor(0.7282)}] 86 | 87 | >>> pprint(candidates["documents"]) 88 | [[{'id': 0, 'leaf': '100', 'similarity': 0.9999999999999978}, 89 | {'id': 1, 'leaf': '101', 'similarity': 0.39941742405759667}], 90 | [{'id': 2, 'leaf': '110', 'similarity': 0.9999999999999978}, 91 | {'id': 3, 'leaf': '111', 'similarity': 0.5385719658738707}]] 92 | ``` 93 | 94 | ## Methods 95 | 96 | ???- note "average" 97 | 98 | Average embeddings. 99 | 100 | - **embeddings** (*scipy.sparse._csr.csr_matrix*) 101 | 102 | ???- note "convert_to_tensor" 103 | 104 | Transform sparse matrix to tensor. 105 | 106 | **Parameters** 107 | 108 | - **embeddings** (*scipy.sparse._csr.csr_matrix*) 109 | - **device** (*str*) 110 | 111 | ???- note "encode_queries_for_retrieval" 112 | 113 | Encode queries for retrieval. 114 | 115 | **Parameters** 116 | 117 | - **queries** (*list[str]*) 118 | 119 | ???- note "get_retriever" 120 | 121 | Create a retriever 122 | 123 | 124 | ???- note "leaf_scores" 125 | 126 | Return the scores of the embeddings. 127 | 128 | **Parameters** 129 | 130 | - **queries_embeddings** (*torch.Tensor*) 131 | - **leaf_embedding** (*torch.Tensor*) 132 | 133 | ???- note "nodes_scores" 134 | 135 | Score between queries and nodes embeddings. 136 | 137 | **Parameters** 138 | 139 | - **queries_embeddings** (*torch.Tensor*) 140 | - **nodes_embeddings** (*torch.Tensor*) 141 | 142 | ???- note "stack" 143 | 144 | Stack list of embeddings. 145 | 146 | - **embeddings** (*list[scipy.sparse._csr.csr_matrix]*) 147 | 148 | ???- note "transform_documents" 149 | 150 | Transform documents to embeddings. 151 | 152 | **Parameters** 153 | 154 | - **documents** (*list[dict]*) 155 | - **kwargs** 156 | 157 | ???- note "transform_queries" 158 | 159 | Transform queries to embeddings. 160 | 161 | **Parameters** 162 | 163 | - **queries** (*list[str]*) 164 | - **kwargs** 165 | 166 | -------------------------------------------------------------------------------- /docs/api/trees/.pages: -------------------------------------------------------------------------------- 1 | title: trees -------------------------------------------------------------------------------- /docs/api/utils/.pages: -------------------------------------------------------------------------------- 1 | title: utils -------------------------------------------------------------------------------- /docs/api/utils/batchify.md: -------------------------------------------------------------------------------- 1 | # batchify 2 | 3 | 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **X** (*list[str]*) 10 | 11 | - **batch_size** (*int*) 12 | 13 | - **desc** (*str*) – defaults to `` 14 | 15 | - **tqdm_bar** (*bool*) – defaults to `True` 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /docs/api/utils/evaluate.md: -------------------------------------------------------------------------------- 1 | # evaluate 2 | 3 | Evaluate candidates matchs. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **scores** (*list[list[dict]]*) 10 | 11 | - **qrels** (*dict*) 12 | 13 | Qrels. 14 | 15 | - **queries_ids** (*list[str]*) 16 | 17 | - **metrics** (*list*) – defaults to `[]` 18 | 19 | Metrics to compute. 20 | 21 | - **key** (*str*) – defaults to `id` 22 | 23 | 24 | 25 | ## Examples 26 | 27 | ```python 28 | >>> from neural_cherche import models, retrieve, utils 29 | >>> import torch 30 | 31 | >>> _ = torch.manual_seed(42) 32 | 33 | >>> model = models.Splade( 34 | ... model_name_or_path="distilbert-base-uncased", 35 | ... device="cpu", 36 | ... ) 37 | 38 | >>> documents, queries_ids, queries, qrels = utils.load_beir( 39 | ... "scifact", 40 | ... split="test", 41 | ... ) 42 | 43 | >>> documents = documents[:10] 44 | 45 | >>> retriever = retrieve.Splade( 46 | ... key="id", 47 | ... on=["title", "text"], 48 | ... model=model 49 | ... ) 50 | 51 | >>> documents_embeddings = retriever.encode_documents( 52 | ... documents=documents, 53 | ... batch_size=1, 54 | ... ) 55 | 56 | >>> documents_embeddings = retriever.add( 57 | ... documents_embeddings=documents_embeddings, 58 | ... ) 59 | 60 | >>> queries_embeddings = retriever.encode_queries( 61 | ... queries=queries, 62 | ... batch_size=1, 63 | ... ) 64 | 65 | >>> scores = retriever( 66 | ... queries_embeddings=queries_embeddings, 67 | ... k=30, 68 | ... batch_size=1, 69 | ... ) 70 | 71 | >>> utils.evaluate( 72 | ... scores=scores, 73 | ... qrels=qrels, 74 | ... queries_ids=queries_ids, 75 | ... metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"] 76 | ... ) 77 | {'map': 0.0033333333333333335, 'ndcg@10': 0.0033333333333333335, 'ndcg@100': 0.0033333333333333335, 'recall@10': 0.0033333333333333335, 'recall@100': 0.0033333333333333335} 78 | ``` 79 | 80 | -------------------------------------------------------------------------------- /docs/api/utils/iter.md: -------------------------------------------------------------------------------- 1 | # iter 2 | 3 | Iterate over the dataset. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **queries** 10 | 11 | List of queries paired with documents. 12 | 13 | - **documents** 14 | 15 | List of documents paired with queries. 16 | 17 | - **batch_size** – defaults to `512` 18 | 19 | Size of the batch. 20 | 21 | - **epochs** (*int*) – defaults to `1` 22 | 23 | Number of epochs. 24 | 25 | - **shuffle** – defaults to `True` 26 | 27 | - **tqdm_bar** – defaults to `True` 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/api/utils/leafs-precision.md: -------------------------------------------------------------------------------- 1 | # leafs_precision 2 | 3 | Calculate the precision of the leafs. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | - **documents** (*list*) 12 | 13 | - **leafs** (*numpy.ndarray*) 14 | 15 | - **documents_to_leaf** (*dict*) 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /docs/api/utils/sanity-check.md: -------------------------------------------------------------------------------- 1 | # sanity_check 2 | 3 | Check if the input is valid. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **branch_balance_factor** (*int*) 10 | 11 | - **leaf_balance_factor** (*int*) 12 | 13 | - **graph** (*dict*) 14 | 15 | - **documents** (*list*) 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /docs/api/utils/set-env.md: -------------------------------------------------------------------------------- 1 | # set_env 2 | 3 | Set environment variables. 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/css/version-select.css: -------------------------------------------------------------------------------- 1 | @media only screen and (max-width:76.1875em) { 2 | #version-selector { 3 | padding: .6rem .8rem; 4 | } 5 | } -------------------------------------------------------------------------------- /docs/evaluate/.pages: -------------------------------------------------------------------------------- 1 | title: Evaluate 2 | nav: 3 | - evaluate.md 4 | 5 | -------------------------------------------------------------------------------- /docs/evaluate/evaluate.md: -------------------------------------------------------------------------------- 1 | # Evaluate 2 | 3 | Neural-tree evaluation is based on [RANX](https://github.com/AmenRa/ranx). We can also download datasets of [BEIR Benchmark](https://github.com/beir-cellar/beir) with the `utils.load_beir` function. 4 | 5 | 6 | ## Installation 7 | 8 | ```bash 9 | pip install "neural-tree[eval]" 10 | ``` 11 | 12 | ## Usage 13 | 14 | Here is an example of how to train a tree-based index using the `scifact` dataset and how to evaluate it. 15 | 16 | ```python 17 | import torch 18 | from neural_cherche import models 19 | from sentence_transformers import SentenceTransformer 20 | 21 | from neural_tree import clustering, datasets, trees, utils 22 | 23 | documents, train_queries, train_documents = datasets.load_beir_train( 24 | dataset_name="scifact", 25 | ) 26 | 27 | 28 | model = models.ColBERT( 29 | model_name_or_path="raphaelsty/neural-cherche-colbert", 30 | device="cuda", 31 | ) 32 | 33 | # We intialize a ColBERT index from a 34 | # SentenceTransformer-based hierarchical clustering. 35 | tree = trees.ColBERT( 36 | key="id", 37 | on=["title", "text"], 38 | model=model, 39 | sentence_transformer=SentenceTransformer(model_name_or_path="all-mpnet-base-v2"), 40 | documents=documents, 41 | leaf_balance_factor=100, 42 | branch_balance_factor=5, 43 | n_jobs=-1, 44 | device="cuda", 45 | faiss_device="cuda", 46 | ) 47 | 48 | optimizer = torch.optim.AdamW(lr=3e-3, params=list(tree.parameters())) 49 | 50 | 51 | for step, batch_queries, batch_documents in utils.iter( 52 | queries=train_queries, 53 | documents=train_documents, 54 | shuffle=True, 55 | epochs=50, 56 | batch_size=128, 57 | ): 58 | loss = tree.loss( 59 | queries=batch_queries, 60 | documents=batch_documents, 61 | ) 62 | 63 | loss.backward() 64 | optimizer.step() 65 | optimizer.zero_grad(set_to_none=True) 66 | 67 | 68 | documents, queries_ids, test_queries, qrels = datasets.load_beir_test( 69 | dataset_name="scifact", 70 | ) 71 | 72 | documents_to_leafs = clustering.optimize_leafs( 73 | tree=tree, 74 | queries=train_queries + test_queries, 75 | documents=documents, 76 | ) 77 | 78 | tree = tree.add( 79 | documents=documents, 80 | documents_to_leafs=documents_to_leafs, 81 | ) 82 | 83 | candidates = tree( 84 | queries=test_queries, 85 | k_leafs=2, # number of leafs to search 86 | k=10, # number of documents to retrieve 87 | ) 88 | 89 | documents, queries_ids, test_queries, qrels = datasets.load_beir_test( 90 | dataset_name="scifact", 91 | ) 92 | 93 | candidates = tree( 94 | queries=test_queries, 95 | k_leafs=2, 96 | k=10, 97 | ) 98 | 99 | 100 | scores = utils.evaluate( 101 | scores=candidates["documents"], 102 | qrels=qrels, 103 | queries_ids=queries_ids, 104 | ) 105 | 106 | print(scores) 107 | ``` 108 | 109 | ```python 110 | {"ndcg@10": 0.6957728027724698, "hits@1": 0.59, "hits@2": 0.69, "hits@3": 0.76, "hits@4": 0.8133333333333334, "hits@5": 0.8533333333333334, "hits@10": 0.91} 111 | ``` 112 | 113 | ## Evaluation dataset 114 | 115 | Here are what documents should looks like (an id with multiples fields): 116 | 117 | ```python 118 | [ 119 | { 120 | "id": "document_0", 121 | "title": "Bayesian measures of model complexity and fit", 122 | "text": "Summary. We consider the problem of comparing complex hierarchical models in which the number of parameters is not clearly defined. Using an information theoretic argument we derive a measure pD for the effective number of parameters in a model as the difference between the posterior mean of the deviance and the deviance at the posterior means of the parameters of interest. In general pD approximately corresponds to the trace of the product of Fisher's information and the posterior covariance, which in normal models is the trace of the ‘hat’ matrix projecting observations onto fitted values. Its properties in exponential families are explored. The posterior mean deviance is suggested as a Bayesian measure of fit or adequacy, and the contributions of individual observations to the fit and complexity can give rise to a diagnostic plot of deviance residuals against leverages. Adding pD to the posterior mean deviance gives a deviance information criterion for comparing models, which is related to other information criteria and has an approximate decision theoretic justification. The procedure is illustrated in some examples, and comparisons are drawn with alternative Bayesian and classical proposals. Throughout it is emphasized that the quantities required are trivial to compute in a Markov chain Monte Carlo analysis.", 123 | }, 124 | { 125 | "id": "document_1", 126 | "title": "Simplifying likelihood ratios", 127 | "text": "Likelihood ratios are one of the best measures of diagnostic accuracy, although they are seldom used, because interpreting them requires a calculator to convert back and forth between “probability” and “odds” of disease. This article describes a simpler method of interpreting likelihood ratios, one that avoids calculators, nomograms, and conversions to “odds” of disease. Several examples illustrate how the clinician can use this method to refine diagnostic decisions at the bedside.", 128 | }, 129 | ] 130 | ``` 131 | 132 | Queries is a list of strings: 133 | 134 | ```python 135 | [ 136 | "Varenicline monotherapy is more effective after 12 weeks of treatment compared to combination nicotine replacement therapies with varenicline or bupropion.", 137 | "Venules have a larger lumen diameter than arterioles.", 138 | "Venules have a thinner or absent smooth layer compared to arterioles.", 139 | "Vitamin D deficiency effects the term of delivery.", 140 | "Vitamin D deficiency is unrelated to birth weight.", 141 | "Women with a higher birth weight are more likely to develop breast cancer later in life.", 142 | ] 143 | ``` 144 | 145 | QueriesIds is a list of ids with respect to the order of queries: 146 | 147 | ```python 148 | [ 149 | "0", 150 | "1", 151 | "2", 152 | "3", 153 | "4", 154 | "5", 155 | ] 156 | ``` 157 | 158 | Qrels is the mapping between queries ids as key and dict of relevant documents with 1 as value: 159 | 160 | ```python 161 | { 162 | "1": {"document_0": 1}, 163 | "3": {"document_10": 1}, 164 | "5": {"document_5": 1}, 165 | "13": {"document_22": 1}, 166 | "36": {"document_23": 1, "document_0": 1}, 167 | "42": {"document_2": 1}, 168 | } 169 | ``` 170 | 171 | ## Metrics 172 | 173 | We can evaluate our model with various metrics detailed [here](https://amenra.github.io/ranx/metrics/). -------------------------------------------------------------------------------- /docs/existing_tree/.pages: -------------------------------------------------------------------------------- 1 | title: Existing tree 2 | nav: 3 | - existing_tree.md 4 | -------------------------------------------------------------------------------- /docs/existing_tree/existing_tree.md: -------------------------------------------------------------------------------- 1 | # Build an index from an existing tree 2 | 3 | Neural-Tree can build a tree from an existing graph. This is useful when we have a specific use case where we want to retrieve the right leaf for a query. 4 | 5 | The tree we want to pass should follow some rules: 6 | 7 | - We should avoid nodes with a lot of children. The more children a node has, the more time it will take to explore this node. 8 | 9 | - A node must have only one parent. This is a rule for the tree to be a tree. You can somehow duplicate a node to have it in multiple places in the tree. 10 | 11 | Let's create a tree which has one root node, two children nodes and two leafs nodes which contains up to 3 documents. 12 | 13 | ```python 14 | graph = { 15 | "root": { 16 | "science": { 17 | "machine learning": [ 18 | {"id": 0, "content": "bayern football team"}, 19 | {"id": 1, "content": "toulouse rugby team"}, 20 | ], 21 | "computer": [ 22 | {"id": 2, "content": "Apple Macintosh"}, 23 | {"id": 3, "content": "Microsoft Windows"}, 24 | {"id": 4, "content": "Linux Ubuntu"}, 25 | ], 26 | }, 27 | "history": { 28 | "france": [ 29 | {"id": 5, "content": "history of france"}, 30 | {"id": 6, "content": "french revolution"}, 31 | ], 32 | "italia": [ 33 | {"id": 7, "content": "history of rome"}, 34 | {"id": 8, "content": "history of venice"}, 35 | ], 36 | }, 37 | } 38 | } 39 | ``` 40 | 41 | We can now initialize either a TfIdf, a SentenceTransformer or a ColBERT tree using the graph we have created. 42 | 43 | ```python 44 | from neural_tree import trees 45 | from neural_cherche import models 46 | 47 | model = models.ColBERT( 48 | model_name_or_path="raphaelsty/neural-cherche-colbert", 49 | device="cuda" if torch.cuda.is_available() else "cpu", 50 | ) 51 | 52 | tree = trees.ColBERT( 53 | key="id", 54 | on=["content"], 55 | model=model, 56 | graph=graph, 57 | n_jobs=-1, 58 | ) 59 | 60 | print(tree) 61 | ``` 62 | 63 | This will output: 64 | 65 | ```python 66 | node root 67 | node science 68 | leaf computer 69 | leaf machine learning 70 | node history 71 | leaf france 72 | leaf italia 73 | ``` 74 | 75 | Once we have created our tree we can export it back to json using the `tree.to_json()`: 76 | 77 | ```python 78 | { 79 | "root": { 80 | "science": { 81 | "computer": [{"id": 2}, {"id": 3}, {"id": 4}], 82 | "machine learning": [{"id": 0}, {"id": 1}], 83 | }, 84 | "history": {"france": [{"id": 5}, {"id": 6}], "italia": [{"id": 7}, {"id": 8}]}, 85 | } 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /docs/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raphaelsty/neural-tree/c7fd61f6ae37931482b03ab2d569b0309bcd726b/docs/img/logo.png -------------------------------------------------------------------------------- /docs/img/neural_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raphaelsty/neural-tree/c7fd61f6ae37931482b03ab2d569b0309bcd726b/docs/img/neural_tree.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

Neural-Tree

4 |

Neural Search

5 |
6 | 7 |

8 | 9 |
10 | 11 | documentation 12 | 13 | license 14 |
15 | 16 | ## Installation 17 | 18 | We can install neural-tree using: 19 | 20 | ``` 21 | pip install neural-tree 22 | ``` 23 | 24 | If we plan to evaluate our model while training install: 25 | 26 | ``` 27 | pip install "neural-tree[eval]" 28 | ``` -------------------------------------------------------------------------------- /docs/javascripts/config.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; -------------------------------------------------------------------------------- /docs/js/version-select.js: -------------------------------------------------------------------------------- 1 | window.addEventListener("DOMContentLoaded", function () { 2 | // This is a bit hacky. Figure out the base URL from a known CSS file the 3 | // template refers to... 4 | var ex = new RegExp("/?css/version-select.css$"); 5 | var sheet = document.querySelector('link[href$="version-select.css"]'); 6 | 7 | var ABS_BASE_URL = sheet.href.replace(ex, ""); 8 | var CURRENT_VERSION = ABS_BASE_URL.split("/").pop(); 9 | 10 | function makeSelect(options, selected) { 11 | var select = document.createElement("select"); 12 | select.classList.add("form-control"); 13 | 14 | options.forEach(function (i) { 15 | var option = new Option(i.text, i.value, undefined, 16 | i.value === selected); 17 | select.add(option); 18 | }); 19 | 20 | return select; 21 | } 22 | 23 | var xhr = new XMLHttpRequest(); 24 | xhr.open("GET", ABS_BASE_URL + "/../versions.json"); 25 | xhr.onload = function () { 26 | var versions = JSON.parse(this.responseText); 27 | 28 | var realVersion = versions.find(function (i) { 29 | return i.version === CURRENT_VERSION || 30 | i.aliases.includes(CURRENT_VERSION); 31 | }).version; 32 | 33 | var select = makeSelect(versions.map(function (i) { 34 | return { text: i.title, value: i.version }; 35 | }), realVersion); 36 | select.addEventListener("change", function (event) { 37 | window.location.href = ABS_BASE_URL + "/../" + this.value; 38 | }); 39 | 40 | var container = document.createElement("div"); 41 | container.id = "version-selector"; 42 | container.className = "md-nav__item"; 43 | container.appendChild(select); 44 | 45 | var sidebar = document.querySelector(".md-nav--primary > .md-nav__list"); 46 | sidebar.parentNode.insertBefore(container, sidebar); 47 | }; 48 | xhr.send(); 49 | }); -------------------------------------------------------------------------------- /docs/scripts/index.py: -------------------------------------------------------------------------------- 1 | """This script is responsible for building the API reference. The API reference is located in 2 | docs/api. The script scans through all the modules, classes, and functions. It processes 3 | the __doc__ of each object and formats it so that MkDocs can process it in turn. 4 | """ 5 | import functools 6 | import importlib 7 | import inspect 8 | import os 9 | import pathlib 10 | import re 11 | import shutil 12 | 13 | from numpydoc.docscrape import ClassDoc, FunctionDoc 14 | 15 | package = "neural_tree" 16 | 17 | # shutil.copy("README.md", "docs/index.md") 18 | 19 | 20 | def paragraph(text): 21 | return f"{text}\n" 22 | 23 | 24 | def h1(text): 25 | return paragraph(f"# {text}") 26 | 27 | 28 | def h2(text): 29 | return paragraph(f"## {text}") 30 | 31 | 32 | def h3(text): 33 | return paragraph(f"### {text}") 34 | 35 | 36 | def h4(text): 37 | return paragraph(f"#### {text}") 38 | 39 | 40 | def link(caption, href): 41 | return f"[{caption}]({href})" 42 | 43 | 44 | def code(text): 45 | return f"`{text}`" 46 | 47 | 48 | def li(text): 49 | return f"- {text}\n" 50 | 51 | 52 | def snake_to_kebab(text): 53 | return text.replace("_", "-") 54 | 55 | 56 | def inherit_docstring(c, meth): 57 | """Since Python 3.5, inspect.getdoc is supposed to return the docstring from a parent class 58 | if a class has none. However this doesn't seem to work for Cython classes. 59 | """ 60 | 61 | doc = None 62 | 63 | for ancestor in inspect.getmro(c): 64 | try: 65 | ancestor_meth = getattr(ancestor, meth) 66 | except AttributeError: 67 | break 68 | doc = inspect.getdoc(ancestor_meth) 69 | if doc: 70 | break 71 | 72 | return doc 73 | 74 | 75 | def inherit_signature(c, method_name): 76 | m = getattr(c, method_name) 77 | sig = inspect.signature(m) 78 | 79 | params = [] 80 | 81 | for param in sig.parameters.values(): 82 | if param.name == "self" or param.annotation is not param.empty: 83 | params.append(param) 84 | continue 85 | 86 | for ancestor in inspect.getmro(c): 87 | try: 88 | ancestor_meth = inspect.signature(getattr(ancestor, m.__name__)) 89 | except AttributeError: 90 | break 91 | try: 92 | ancestor_param = ancestor_meth.parameters[param.name] 93 | except KeyError: 94 | break 95 | if ancestor_param.annotation is not param.empty: 96 | param = param.replace(annotation=ancestor_param.annotation) 97 | break 98 | 99 | params.append(param) 100 | 101 | return_annotation = sig.return_annotation 102 | if return_annotation is inspect._empty: 103 | for ancestor in inspect.getmro(c): 104 | try: 105 | ancestor_meth = inspect.signature(getattr(ancestor, m.__name__)) 106 | except AttributeError: 107 | break 108 | if ancestor_meth.return_annotation is not inspect._empty: 109 | return_annotation = ancestor_meth.return_annotation 110 | break 111 | 112 | return sig.replace(parameters=params, return_annotation=return_annotation) 113 | 114 | 115 | def snake_to_kebab(snake: str) -> str: 116 | return snake.replace("_", "-") 117 | 118 | 119 | def pascal_to_kebab(string): 120 | string = re.sub("(.)([A-Z][a-z]+)", r"\1-\2", string) 121 | string = re.sub("(.)([0-9]+)", r"\1-\2", string) 122 | return re.sub("([a-z0-9])([A-Z])", r"\1-\2", string).lower() 123 | 124 | 125 | class Linkifier: 126 | def __init__(self): 127 | path_index = {} 128 | name_index = {} 129 | 130 | modules = { 131 | module: importlib.import_module(f"{package}.{module}") 132 | for module in importlib.import_module(f"{package}").__all__ 133 | } 134 | 135 | def index_module(mod_name, mod, path): 136 | path = os.path.join(path, mod_name) 137 | dotted_path = path.replace("/", ".") 138 | 139 | for func_name, func in inspect.getmembers(mod, inspect.isfunction): 140 | for e in ( 141 | f"{mod_name}.{func_name}", 142 | f"{dotted_path}.{func_name}", 143 | f"{func.__module__}.{func_name}", 144 | ): 145 | path_index[e] = os.path.join(path, snake_to_kebab(func_name)) 146 | name_index[e] = f"{dotted_path}.{func_name}" 147 | 148 | for klass_name, klass in inspect.getmembers(mod, inspect.isclass): 149 | for e in ( 150 | f"{mod_name}.{klass_name}", 151 | f"{dotted_path}.{klass_name}", 152 | f"{klass.__module__}.{klass_name}", 153 | ): 154 | path_index[e] = os.path.join(path, klass_name) 155 | name_index[e] = f"{dotted_path}.{klass_name}" 156 | 157 | for submod_name, submod in inspect.getmembers(mod, inspect.ismodule): 158 | if submod_name not in mod.__all__ or submod_name == "typing": 159 | continue 160 | for e in (f"{mod_name}.{submod_name}", f"{dotted_path}.{submod_name}"): 161 | path_index[e] = os.path.join(path, snake_to_kebab(submod_name)) 162 | 163 | # Recurse 164 | index_module(submod_name, submod, path=path) 165 | 166 | for mod_name, mod in modules.items(): 167 | index_module(mod_name, mod, path="") 168 | 169 | # Prepend {package} to each index entry 170 | for k in list(path_index.keys()): 171 | path_index[f"{package}.{k}"] = path_index[k] 172 | for k in list(name_index.keys()): 173 | name_index[f"{package}.{k}"] = name_index[k] 174 | 175 | self.path_index = path_index 176 | self.name_index = name_index 177 | 178 | def linkify(self, text, use_fences, depth): 179 | path = self.path_index.get(text) 180 | name = self.name_index.get(text) 181 | if path and name: 182 | backwards = "../" * (depth + 1) 183 | if use_fences: 184 | return f"[`{name}`]({backwards}{path})" 185 | return f"[{name}]({backwards}{path})" 186 | return None 187 | 188 | def linkify_fences(self, text, depth): 189 | between_fences = re.compile("`[\w\.]+\.\w+`") 190 | return between_fences.sub( 191 | lambda x: self.linkify(x.group().strip("`"), True, depth) or x.group(), text 192 | ) 193 | 194 | def linkify_dotted(self, text, depth): 195 | dotted = re.compile("\w+\.[\.\w]+") 196 | return dotted.sub( 197 | lambda x: self.linkify(x.group(), False, depth) or x.group(), text 198 | ) 199 | 200 | 201 | def concat_lines(lines): 202 | return inspect.cleandoc(" ".join("\n\n" if line == "" else line for line in lines)) 203 | 204 | 205 | def print_docstring(obj, file, depth): 206 | """Prints a classes's docstring to a file.""" 207 | 208 | doc = ClassDoc(obj) if inspect.isclass(obj) else FunctionDoc(obj) 209 | 210 | printf = functools.partial(print, file=file) 211 | 212 | printf(h1(obj.__name__)) 213 | printf(linkifier.linkify_fences(paragraph(concat_lines(doc["Summary"])), depth)) 214 | printf( 215 | linkifier.linkify_fences( 216 | paragraph(concat_lines(doc["Extended Summary"])), depth 217 | ) 218 | ) 219 | 220 | # We infer the type annotations from the signatures, and therefore rely on the signature 221 | # instead of the docstring for documenting parameters 222 | try: 223 | signature = inspect.signature(obj) 224 | except ValueError: 225 | signature = ( 226 | inspect.Signature() 227 | ) # TODO: this is necessary for Cython classes, but it's not correct 228 | params_desc = {param.name: " ".join(param.desc) for param in doc["Parameters"]} 229 | 230 | # Parameters 231 | if signature.parameters: 232 | printf(h2("Parameters")) 233 | for param in signature.parameters.values(): 234 | # Name 235 | printf(f"- **{param.name}**", end="") 236 | # Type annotation 237 | if param.annotation is not param.empty: 238 | anno = inspect.formatannotation(param.annotation) 239 | anno = linkifier.linkify_dotted(anno, depth) 240 | printf(f" (*{anno}*)", end="") 241 | # Default value 242 | if param.default is not param.empty: 243 | printf(f" – defaults to `{param.default}`", end="") 244 | printf("\n", file=file) 245 | # Description 246 | if param.name in params_desc: 247 | desc = params_desc[param.name] 248 | if desc: 249 | printf(f" {desc}\n") 250 | printf("") 251 | 252 | # Attributes 253 | if doc["Attributes"]: 254 | printf(h2("Attributes")) 255 | for attr in doc["Attributes"]: 256 | # Name 257 | printf(f"- **{attr.name}**", end="") 258 | # Type annotation 259 | if attr.type: 260 | printf(f" (*{attr.type}*)", end="") 261 | printf("\n", file=file) 262 | # Description 263 | desc = " ".join(attr.desc) 264 | if desc: 265 | printf(f" {desc}\n") 266 | printf("") 267 | 268 | # Examples 269 | if doc["Examples"]: 270 | printf(h2("Examples")) 271 | 272 | in_code = False 273 | after_space = False 274 | 275 | for line in inspect.cleandoc("\n".join(doc["Examples"])).splitlines(): 276 | if ( 277 | in_code 278 | and after_space 279 | and line 280 | and not line.startswith(">>>") 281 | and not line.startswith("...") 282 | ): 283 | printf("```\n") 284 | in_code = False 285 | after_space = False 286 | 287 | if not in_code and line.startswith(">>>"): 288 | printf("```python") 289 | in_code = True 290 | 291 | after_space = False 292 | if not line: 293 | after_space = True 294 | 295 | printf(line) 296 | 297 | if in_code: 298 | printf("```") 299 | printf("") 300 | 301 | # Methods 302 | if inspect.isclass(obj) and doc["Methods"]: 303 | printf(h2("Methods")) 304 | printf_indent = lambda x, **kwargs: printf(f" {x}", **kwargs) 305 | 306 | for meth in doc["Methods"]: 307 | printf(paragraph(f'???- note "{meth.name}"')) 308 | 309 | # Parse method docstring 310 | docstring = inherit_docstring(c=obj, meth=meth.name) 311 | if not docstring: 312 | continue 313 | meth_doc = FunctionDoc(func=None, doc=docstring) 314 | 315 | printf_indent(paragraph(" ".join(meth_doc["Summary"]))) 316 | if meth_doc["Extended Summary"]: 317 | printf_indent(paragraph(" ".join(meth_doc["Extended Summary"]))) 318 | 319 | # We infer the type annotations from the signatures, and therefore rely on the signature 320 | # instead of the docstring for documenting parameters 321 | signature = inherit_signature(obj, meth.name) 322 | params_desc = { 323 | param.name: " ".join(param.desc) for param in doc["Parameters"] 324 | } 325 | 326 | # Parameters 327 | if ( 328 | len(signature.parameters) > 1 329 | ): # signature is never empty, but self doesn't count 330 | printf_indent("**Parameters**\n") 331 | for param in signature.parameters.values(): 332 | if param.name == "self": 333 | continue 334 | # Name 335 | printf_indent(f"- **{param.name}**", end="") 336 | # Type annotation 337 | if param.annotation is not param.empty: 338 | printf_indent( 339 | f" (*{inspect.formatannotation(param.annotation)}*)", end="" 340 | ) 341 | # Default value 342 | if param.default is not param.empty: 343 | printf_indent(f" – defaults to `{param.default}`", end="") 344 | printf_indent("", file=file) 345 | # Description 346 | desc = params_desc.get(param.name) 347 | if desc: 348 | printf_indent(f" {desc}") 349 | printf_indent("") 350 | 351 | # Returns 352 | if meth_doc["Returns"]: 353 | printf_indent("**Returns**\n") 354 | return_val = meth_doc["Returns"][0] 355 | if signature.return_annotation is not inspect._empty: 356 | if inspect.isclass(signature.return_annotation): 357 | printf_indent( 358 | f"*{signature.return_annotation.__name__}*: ", end="" 359 | ) 360 | else: 361 | printf_indent(f"*{signature.return_annotation}*: ", end="") 362 | printf_indent(return_val.type) 363 | printf_indent("") 364 | 365 | # Notes 366 | if doc["Notes"]: 367 | printf(h2("Notes")) 368 | printf(paragraph("\n".join(doc["Notes"]))) 369 | 370 | # References 371 | if doc["References"]: 372 | printf(h2("References")) 373 | printf(paragraph("\n".join(doc["References"]))) 374 | 375 | 376 | def print_module(mod, path, overview, is_submodule=False): 377 | mod_name = mod.__name__.split(".")[-1] 378 | 379 | # Create a directory for the module 380 | mod_slug = snake_to_kebab(mod_name) 381 | mod_path = path.joinpath(mod_slug) 382 | mod_short_path = str(mod_path).replace("docs/api/", "") 383 | os.makedirs(mod_path, exist_ok=True) 384 | with open(mod_path.joinpath(".pages"), "w") as f: 385 | f.write(f"title: {mod_name}") 386 | 387 | # Add the module to the overview 388 | if is_submodule: 389 | print(h3(mod_name), file=overview) 390 | else: 391 | print(h2(mod_name), file=overview) 392 | if mod.__doc__: 393 | print(paragraph(mod.__doc__), file=overview) 394 | 395 | # Extract all public classes and functions 396 | ispublic = lambda x: x.__name__ in mod.__all__ and not x.__name__.startswith("_") 397 | classes = inspect.getmembers(mod, lambda x: inspect.isclass(x) and ispublic(x)) 398 | funcs = inspect.getmembers(mod, lambda x: inspect.isfunction(x) and ispublic(x)) 399 | 400 | # Classes 401 | 402 | if classes and funcs: 403 | print("\n**Classes**\n", file=overview) 404 | 405 | for _, c in classes: 406 | print(f"{mod_name}.{c.__name__}") 407 | 408 | # Add the class to the overview 409 | slug = snake_to_kebab(c.__name__) 410 | print( 411 | li(link(c.__name__, f"../{mod_short_path}/{slug}")), end="", file=overview 412 | ) 413 | 414 | # Write down the class' docstring 415 | with open(mod_path.joinpath(slug).with_suffix(".md"), "w") as file: 416 | print_docstring(obj=c, file=file, depth=mod_short_path.count("/") + 1) 417 | 418 | # Functions 419 | 420 | if classes and funcs: 421 | print("\n**Functions**\n", file=overview) 422 | 423 | for _, f in funcs: 424 | print(f"{mod_name}.{f.__name__}") 425 | 426 | # Add the function to the overview 427 | slug = snake_to_kebab(f.__name__) 428 | print( 429 | li(link(f.__name__, f"../{mod_short_path}/{slug}")), end="", file=overview 430 | ) 431 | 432 | # Write down the function' docstring 433 | with open(mod_path.joinpath(slug).with_suffix(".md"), "w") as file: 434 | print_docstring(obj=f, file=file, depth=mod_short_path.count(".") + 1) 435 | 436 | # Sub-modules 437 | for name, submod in inspect.getmembers(mod, inspect.ismodule): 438 | # We only want to go through the public submodules, such as optim.schedulers 439 | if ( 440 | name in ("tags", "typing", "inspect", "skmultiflow_utils") 441 | or name not in mod.__all__ 442 | or name.startswith("_") 443 | ): 444 | continue 445 | print_module(mod=submod, path=mod_path, overview=overview, is_submodule=True) 446 | 447 | print("", file=overview) 448 | 449 | 450 | if __name__ == "__main__": 451 | api_path = pathlib.Path("docs/api") 452 | 453 | # Create a directory for the API reference 454 | shutil.rmtree(api_path, ignore_errors=True) 455 | os.makedirs(api_path, exist_ok=True) 456 | with open(api_path.joinpath(".pages"), "w") as f: 457 | f.write("title: API reference\narrange:\n - overview.md\n - ...\n") 458 | 459 | overview = open(api_path.joinpath("overview.md"), "w") 460 | print(h1("Overview"), file=overview) 461 | 462 | linkifier = Linkifier() 463 | 464 | for mod_name, mod in inspect.getmembers( 465 | importlib.import_module(f"{package}"), inspect.ismodule 466 | ): 467 | if mod_name.startswith("_"): 468 | continue 469 | print(mod_name) 470 | print_module(mod, path=api_path, overview=overview) 471 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | .md-typeset h2 { 2 | margin: 1.5em 0; 3 | padding-bottom: .4rem; 4 | border-bottom: .04rem solid var(--md-default-fg-color--lighter); 5 | } 6 | 7 | .md-footer { 8 | margin-top: 2em; 9 | } 10 | 11 | .md-typeset pre > code { 12 | border-radius: 0.5em; 13 | } -------------------------------------------------------------------------------- /docs/trees/.pages: -------------------------------------------------------------------------------- 1 | title: Fine-tune 2 | nav: 3 | - colbert.md 4 | - sentence_transformer.md 5 | - tfidf.md 6 | -------------------------------------------------------------------------------- /docs/trees/colbert.md: -------------------------------------------------------------------------------- 1 | # Colbert 2 | 3 | The ColBERT Tree utilizes hierarchical clustering with either TfIdf sparse vectors or SentenceTransformer embeddings to set up the tree's initial structure. 4 | 5 | Following the tree's initialization, documents are sorted into specific leaves according to the hierarchical clustering results. Subsequently, the average ColBERT embeddings of the documents are assigned to each node. 6 | 7 | During a query, the tree employs the ColBERT pre-trained model and its scoring function to identify the most relevant leaf or leaves for effective search results. 8 | 9 | ## Documents 10 | 11 | To create a tree-based index for ColBERT, we will need to: 12 | 13 | - Gather the whole set of documents we want to index. 14 | - Gather queries paired to documents. 15 | - Sample the training set in order to evaluate the index. 16 | 17 | 18 | ```python 19 | # Whole set of documents we want to index. 20 | documents = [ 21 | {"id": 0, "content": "paris"}, 22 | {"id": 1, "content": "london"}, 23 | {"id": 2, "content": "berlin"}, 24 | {"id": 3, "content": "rome"}, 25 | {"id": 4, "content": "bordeaux"}, 26 | {"id": 5, "content": "milan"}, 27 | ] 28 | 29 | # Paired training documents 30 | train_documents = [ 31 | {"id": 0, "content": "paris"}, 32 | {"id": 1, "content": "london"}, 33 | {"id": 2, "content": "berlin"}, 34 | {"id": 3, "content": "rome"}, 35 | ] 36 | 37 | # Paired training queries 38 | train_queries = [ 39 | "paris is the capital of france", 40 | "london is the capital of england", 41 | "berlin is the capital of germany", 42 | "rome is the capital of italy", 43 | ] 44 | ``` 45 | 46 | Let's train the index using the `documents`, `train_queries` and `train_documents` we have gathered. 47 | 48 | ```python 49 | import torch 50 | from neural_cherche import models 51 | 52 | from neural_tree import trees, utils 53 | 54 | model = models.ColBERT( 55 | model_name_or_path="raphaelsty/neural-cherche-colbert", 56 | device="cuda" if torch.cuda.is_available() else "cpu", 57 | ) 58 | 59 | tree = trees.ColBERT( 60 | key="id", # The field to use as a key for the documents. 61 | on=["content"], # The fields to use for the model. 62 | model=model, 63 | documents=documents, 64 | leaf_balance_factor=100, # Minimum number of documents per leaf. 65 | branch_balance_factor=5, # Number of childs per node. 66 | n_jobs=-1, # We want to set it to 1 when using Google Colab. 67 | device="cuda" if torch.cuda.is_available() else "cpu", 68 | ) 69 | 70 | optimizer = torch.optim.AdamW(lr=3e-3, params=list(tree.parameters())) 71 | 72 | for step, batch_queries, batch_documents in utils.iter( 73 | queries=train_queries, 74 | documents=train_documents, 75 | shuffle=True, 76 | epochs=50, 77 | batch_size=32, 78 | ): 79 | loss = tree.loss( 80 | queries=batch_queries, 81 | documents=batch_documents, 82 | ) 83 | 84 | loss.backward() 85 | optimizer.step() 86 | optimizer.zero_grad(set_to_none=True) 87 | ``` 88 | 89 | We can already use the `tree` to search for documents using the `tree` method. 90 | 91 | The `call` method of the tree outputs a dictionary containing several key pieces of information: the retrieved leaves under leafs, the score assigned to each leaf under scores, a record of the explored nodes and leaves along with their scores under tree_scores, and the documents retrieved for each query listed under documents. 92 | 93 | ```python 94 | tree( 95 | queries=["history"], 96 | k=10, # Number of documents to return for each query. 97 | k_leafs=1, # The number of leafs to return for each query. 98 | ) 99 | ``` 100 | 101 | ```python 102 | { 103 | "leafs": array([["10"]], dtype=" tuple[torch.Tensor, list, list]: 16 | """Replace KMeans clustering with average clustering when an existing graph is provided. 17 | 18 | Examples 19 | -------- 20 | >>> from neural_tree import clustering, scoring 21 | >>> import numpy as np 22 | 23 | >>> documents = [ 24 | ... {"id": 0, "text": "Paris is the capital of France."}, 25 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 26 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 27 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 28 | ... ] 29 | 30 | >>> documents_embeddings = np.array([ 31 | ... [1, 1], 32 | ... [1, 2], 33 | ... [10, 10], 34 | ... [1, 3], 35 | ... ]) 36 | 37 | >>> graph = {1: {11: {111: [{'id': 0}, {'id': 3}], 112: [{'id': 1}]}, 12: {121: [{'id': 2}], 122: [{'id': 3}]}}} 38 | 39 | >>> clustering.average( 40 | ... key="id", 41 | ... documents_embeddings=documents_embeddings, 42 | ... documents=documents, 43 | ... graph=graph[1], 44 | ... scoring=scoring.SentenceTransformer(key="id", on=["text"], model=None), 45 | ... ) 46 | 47 | """ 48 | mapping_documents_embeddings = { 49 | document[key]: embedding 50 | for document, embedding in zip(documents, documents_embeddings) 51 | } 52 | 53 | mapping_nodes_documents = { 54 | node: get_mapping_nodes_documents(graph=graph[node]) for node in graph.keys() 55 | } 56 | 57 | mappings_nodes_embeddings = { 58 | node: scoring.average( 59 | scoring.stack( 60 | [ 61 | mapping_documents_embeddings[document[key]] 62 | for document in node_documents 63 | ] 64 | ) 65 | ) 66 | for node, node_documents in mapping_nodes_documents.items() 67 | } 68 | 69 | mapping_documents_ids = {document[key]: document for document in documents} 70 | 71 | mappings_nodes_embeddings = list(mappings_nodes_embeddings.values()) 72 | 73 | if isinstance(mappings_nodes_embeddings[0], np.ndarray): 74 | node_embeddings = torch.tensor( 75 | data=np.stack(arrays=mappings_nodes_embeddings), 76 | device=device, 77 | dtype=torch.float32, 78 | requires_grad=True, 79 | ) 80 | else: 81 | node_embeddings = torch.stack(tensors=mappings_nodes_embeddings, dim=0) 82 | node_embeddings = node_embeddings.to(device=device) 83 | node_embeddings.requires_grad = True 84 | 85 | extended_documents, extended_documents_embeddings, labels = [], [], [] 86 | for node, node_documents in mapping_nodes_documents.items(): 87 | extended_documents.extend( 88 | [mapping_documents_ids[document[key]] for document in node_documents] 89 | ) 90 | extended_documents_embeddings.extend( 91 | [mapping_documents_embeddings[document[key]] for document in node_documents] 92 | ) 93 | labels.extend([node] * len(node_documents)) 94 | 95 | return ( 96 | node_embeddings, 97 | labels, 98 | extended_documents, 99 | scoring.stack(extended_documents_embeddings), 100 | ) 101 | 102 | 103 | def get_mapping_nodes_documents(graph: dict | list, documents: list | None = None): 104 | """Get documents from specific node.""" 105 | if documents is None: 106 | documents = [] 107 | 108 | if isinstance(graph, list): 109 | documents.extend(graph) 110 | return documents 111 | 112 | for node, child in graph.items(): 113 | if isinstance(child, dict): 114 | documents = get_mapping_nodes_documents(graph=child, documents=documents) 115 | else: 116 | documents.extend(child) 117 | return documents 118 | -------------------------------------------------------------------------------- /neural_tree/clustering/kmeans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn import cluster 3 | 4 | __all__ = ["KMeans"] 5 | 6 | 7 | def KMeans( 8 | documents_embeddings: torch.Tensor, 9 | n_clusters: int, 10 | max_iter: int, 11 | n_init: int, 12 | seed: int, 13 | device: str, 14 | ) -> tuple[torch.Tensor, list]: 15 | """KMeans clustering.""" 16 | kmeans: cluster.KMeans = cluster.KMeans( 17 | n_clusters=n_clusters, 18 | max_iter=max_iter, 19 | n_init=n_init, 20 | random_state=seed, 21 | ).fit(X=documents_embeddings) 22 | 23 | node_embeddings = torch.tensor( 24 | data=kmeans.cluster_centers_, 25 | device=device, 26 | dtype=torch.float32, 27 | requires_grad=True, 28 | ) 29 | 30 | return ( 31 | node_embeddings, 32 | kmeans.labels_, 33 | ) 34 | -------------------------------------------------------------------------------- /neural_tree/clustering/optimize.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | from cherche import retrieve 5 | from scipy import sparse 6 | from scipy.sparse import csr_matrix, dok_matrix 7 | 8 | __all__ = ["optimize_leafs"] 9 | 10 | 11 | def create_sparse_matrix_retriever( 12 | candidates: list[list], mapping_documents: dict, key: str 13 | ) -> csr_matrix: 14 | """Build a sparse matrix (queries, documents) with 1 when document is relevant to a 15 | query.""" 16 | query_documents_mapping = collections.defaultdict(list) 17 | for query, query_candidates in enumerate(iterable=candidates): 18 | for document in query_candidates: 19 | query_documents_mapping[query].append(mapping_documents[document[key]]) 20 | 21 | query_documents_matrix = dok_matrix( 22 | arg1=(len(candidates), len(mapping_documents)), dtype=np.int8 23 | ) 24 | 25 | for query, query_documents in query_documents_mapping.items(): 26 | query_documents_matrix[query, query_documents] = 1 27 | 28 | return query_documents_matrix.tocsr() 29 | 30 | 31 | def create_sparse_matrix_tree( 32 | candidates_tree: np.ndarray, 33 | ) -> tuple[sparse.csr_matrix, dict[int, list[int]]]: 34 | """Build a sparse matrix (queries, leafs) with 1 when leaf is relevant to a query.""" 35 | leafs_to_queries = collections.defaultdict(list) 36 | 37 | for query, query_leafs in enumerate(iterable=candidates_tree): 38 | for leaf in query_leafs.tolist(): 39 | leafs_to_queries[leaf].append(query) 40 | 41 | query_leafs_matrix = dok_matrix( 42 | arg1=(len(candidates_tree), len(leafs_to_queries)), dtype=np.int8 43 | ) 44 | 45 | mapping_leafs = { 46 | leaf: index for index, leaf in enumerate(iterable=leafs_to_queries) 47 | } 48 | 49 | for leaf, leaf_queries in leafs_to_queries.items(): 50 | for query in leaf_queries: 51 | query_leafs_matrix[query, mapping_leafs[leaf]] = 1 52 | 53 | return query_leafs_matrix.tocsr(), { 54 | index: leaf for leaf, index in mapping_leafs.items() 55 | } 56 | 57 | 58 | def top_k(similarities, k: int): 59 | """Return the top k documents for each query.""" 60 | similarities *= -1 61 | matchs = [] 62 | for row in similarities: 63 | _k = min(row.data.shape[0] - 1, k) 64 | ind = np.argpartition(a=row.data, kth=_k, axis=0)[:k] 65 | similarity = np.take_along_axis(arr=row.data, indices=ind, axis=0) 66 | indices = np.take_along_axis(arr=row.indices, indices=ind, axis=0) 67 | ind = np.argsort(a=similarity, axis=0) 68 | matchs.append(np.take_along_axis(arr=indices, indices=ind, axis=0)) 69 | return matchs 70 | 71 | 72 | def optimize_leafs( 73 | tree, 74 | documents: list[dict], 75 | queries: list[str], 76 | k_tree: int = 2, 77 | k_retriever: int = 10, 78 | k_leafs: int = 2, 79 | **kwargs, 80 | ) -> dict: 81 | """Optimize the clusters.""" 82 | mapping_documents = { 83 | document[tree.key]: index for index, document in enumerate(iterable=documents) 84 | } 85 | 86 | retriever = retrieve.TfIdf(key=tree.key, on=tree.scoring.on, documents=documents) 87 | query_documents_matrix = create_sparse_matrix_retriever( 88 | candidates=retriever(q=queries, k=k_retriever, batch_size=512, tqdm_bar=False), 89 | mapping_documents=mapping_documents, 90 | key=tree.key, 91 | ) 92 | 93 | inverse_mapping_document = { 94 | index: document for document, index in mapping_documents.items() 95 | } 96 | 97 | query_leafs_matrix, inverse_mapping_leafs = create_sparse_matrix_tree( 98 | candidates_tree=tree( 99 | queries=queries, 100 | k=k_tree, 101 | score_documents=False, 102 | **kwargs, 103 | )["leafs"] 104 | ) 105 | 106 | documents_to_leafs = collections.defaultdict(list) 107 | for document, leafs in enumerate( 108 | iterable=top_k( 109 | similarities=query_documents_matrix.T @ query_leafs_matrix, 110 | k=k_leafs, 111 | ) 112 | ): 113 | for leaf in leafs.tolist(): 114 | documents_to_leafs[inverse_mapping_document[document]].append( 115 | inverse_mapping_leafs[leaf] 116 | ) 117 | 118 | return documents_to_leafs 119 | -------------------------------------------------------------------------------- /neural_tree/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .beir import load_beir, load_beir_test, load_beir_train 2 | 3 | __all__ = ["load_beir", "load_beir_train", "load_beir_test"] 4 | -------------------------------------------------------------------------------- /neural_tree/datasets/beir.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __all__ = ["load_beir", "load_beir_train", "load_beir_test"] 4 | 5 | 6 | def _make_pairs(queries: dict, qrels: dict) -> tuple[list, list]: 7 | """Make pairs of queries and documents for training.""" 8 | test_queries, test_documents = [], [] 9 | for query, (_, documents_queries) in zip(queries, qrels.items()): 10 | for document_id in documents_queries: 11 | test_queries.append(query) 12 | test_documents.append({"id": document_id}) 13 | return test_queries, test_documents 14 | 15 | 16 | def load_beir(dataset_name: str, split: str) -> tuple: 17 | """Load BEIR dataset.""" 18 | from beir import util 19 | from beir.datasets.data_loader import GenericDataLoader 20 | 21 | path = f"./beir_datasets/{dataset_name}" 22 | if not os.path.isdir(s=path): 23 | path = util.download_and_unzip( 24 | url=f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip", 25 | out_dir="./beir_datasets/", 26 | ) 27 | 28 | documents, queries, qrels = GenericDataLoader(data_folder=path).load(split=split) 29 | 30 | documents = [ 31 | { 32 | "id": document_id, 33 | "title": document["title"], 34 | "text": document["text"], 35 | } 36 | for document_id, document in documents.items() 37 | ] 38 | 39 | return documents, queries, qrels 40 | 41 | 42 | def load_beir_train(dataset_name: str) -> tuple[list, list, list]: 43 | """Load training dataset. 44 | 45 | Parameters 46 | ---------- 47 | dataset_name 48 | Dataset name 49 | 50 | Examples 51 | -------- 52 | >>> from neural_tree import datasets 53 | 54 | >>> documents, train_queries, train_documents = datasets.load_beir_train( 55 | ... dataset_name="scifact", 56 | ... ) 57 | 58 | >>> len(documents) 59 | 5183 60 | 61 | >>> assert len(train_queries) == len(train_documents) 62 | 63 | """ 64 | documents, queries, qrels = load_beir(dataset_name=dataset_name, split="train") 65 | 66 | train_queries, train_documents = _make_pairs( 67 | queries=list(queries.values()), qrels=qrels 68 | ) 69 | 70 | return documents, train_queries, train_documents 71 | 72 | 73 | def load_beir_test(dataset_name: str) -> tuple[list, list, dict]: 74 | """Load BEIR testing dataset. 75 | 76 | Parameters 77 | ---------- 78 | dataset_name 79 | Dataset name. 80 | 81 | Examples 82 | -------- 83 | >>> from neural_tree import datasets 84 | 85 | >>> documents, queries_ids, queries, qrels = datasets.load_beir_test( 86 | ... dataset_name="scifact", 87 | ... ) 88 | 89 | >>> len(documents) 90 | 5183 91 | 92 | >>> assert len(queries_ids) == len(queries) == len(qrels) 93 | """ 94 | documents, queries, qrels = load_beir(dataset_name=dataset_name, split="test") 95 | return documents, list(queries.keys()), list(queries.values()), qrels 96 | -------------------------------------------------------------------------------- /neural_tree/leafs/__init__.py: -------------------------------------------------------------------------------- 1 | from .leaf import Leaf 2 | 3 | __all__ = ["Leaf"] 4 | -------------------------------------------------------------------------------- /neural_tree/leafs/leaf.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Leaf"] 2 | 3 | import collections 4 | 5 | import torch 6 | 7 | from ..scoring import SentenceTransformer, TfIdf 8 | 9 | 10 | class Leaf(torch.nn.Module): 11 | """Leaf class.""" 12 | 13 | def __init__( 14 | self, 15 | key: str, 16 | level: int, 17 | documents: list, 18 | documents_embeddings: list, 19 | node_name: int, 20 | scoring: TfIdf | SentenceTransformer, 21 | parent: int = 0, 22 | create_retrievers: bool = True, 23 | **kwargs, 24 | ) -> None: 25 | super(Leaf, self).__init__() 26 | self.key = key 27 | self.level = level 28 | self.node_name = node_name 29 | self.parent = parent 30 | self.create_retrievers = create_retrievers 31 | 32 | self.documents = {} 33 | 34 | if self.create_retrievers: 35 | self.retriever = scoring.get_retriever() 36 | 37 | if scoring.distinct_documents_encoder: 38 | documents_embeddings = None 39 | elif self.create_retrievers: 40 | documents_embeddings = { 41 | document[self.key]: embedding 42 | for document, embedding in zip(documents, documents_embeddings) 43 | } 44 | 45 | self.add( 46 | scoring=scoring, 47 | documents=documents, 48 | documents_embeddings=documents_embeddings, 49 | ) 50 | 51 | def __str__(self) -> str: 52 | """String representation of a leaf.""" 53 | sep = "\t" 54 | return f"{self.level * sep} leaf {self.node_name}" 55 | 56 | def add( 57 | self, 58 | scoring: SentenceTransformer | TfIdf, 59 | documents: list, 60 | documents_embeddings: dict | None = None, 61 | ) -> "Leaf": 62 | """Add document to the leaf.""" 63 | if not self.create_retrievers: 64 | # If we don't want to create retrievers for the leaves 65 | for document in documents: 66 | self.documents[document[self.key]] = True 67 | return self 68 | 69 | if documents_embeddings is None: 70 | documents_embeddings = self.retriever.encode_documents( 71 | documents=documents, 72 | model=scoring.model, 73 | ) 74 | 75 | documents_embeddings = { 76 | document: embedding 77 | for document, embedding in documents_embeddings.items() 78 | if document not in self.documents 79 | } 80 | 81 | if not documents_embeddings: 82 | return self 83 | 84 | self.retriever.add( 85 | documents_embeddings=documents_embeddings, 86 | ) 87 | 88 | for document in documents: 89 | self.documents[document[self.key]] = True 90 | 91 | return self 92 | 93 | def nodes_scores( 94 | self, 95 | scoring: SentenceTransformer | TfIdf, 96 | queries_embeddings: torch.Tensor, 97 | node_embedding: torch.Tensor, 98 | ) -> torch.Tensor: 99 | """Compute the scores between the queries and the leaf.""" 100 | return scoring.leaf_scores( 101 | queries_embeddings=queries_embeddings, leaf_embedding=node_embedding 102 | ) 103 | 104 | def __call__( 105 | self, 106 | queries_embeddings, 107 | k: int, 108 | ) -> torch.Tensor: 109 | """Return scores between query and leaf documents.""" 110 | if not self.documents: 111 | return [[] for _ in range(len(queries_embeddings))] 112 | 113 | candidates = self.retriever( 114 | queries_embeddings=queries_embeddings, 115 | tqdm_bar=False, 116 | k=k, 117 | ) 118 | 119 | return [ 120 | [{**document, "leaf": self.node_name} for document in query_documents] 121 | for query_documents in candidates 122 | ] 123 | 124 | def search( 125 | self, 126 | tree_scores: collections.defaultdict, 127 | **kwargs, 128 | ) -> tuple[torch.Tensor, list]: 129 | """Return the documents in the leaf.""" 130 | return tree_scores 131 | 132 | def to_json(self) -> dict: 133 | """Return the leaf as a json.""" 134 | return [{self.key: document} for document in self.documents] 135 | -------------------------------------------------------------------------------- /neural_tree/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | __all__ = ["Node"] 4 | -------------------------------------------------------------------------------- /neural_tree/nodes/node.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import torch 4 | from joblib import Parallel, delayed 5 | 6 | from ..clustering import KMeans, average 7 | from ..leafs import Leaf 8 | from ..scoring import SentenceTransformer, TfIdf 9 | 10 | __all__ = ["Node"] 11 | 12 | 13 | class Node(torch.nn.Module): 14 | """Node of the tree.""" 15 | 16 | def __init__( 17 | self, 18 | level: int, 19 | key: str, 20 | documents_embeddings: torch.Tensor, 21 | documents: list, 22 | leaf_balance_factor: int, 23 | branch_balance_factor: int, 24 | device: str, 25 | node_name: int | str, 26 | scoring: SentenceTransformer | TfIdf, 27 | seed: int, 28 | max_iter: int, 29 | n_init: int, 30 | parent: int, 31 | n_jobs: int, 32 | create_retrievers: bool, 33 | graph: dict | None, 34 | ) -> None: 35 | super(Node, self).__init__() 36 | self.level = level 37 | self.leaf_balance_factor = leaf_balance_factor 38 | self.branch_balance_factor = branch_balance_factor 39 | self.device = device 40 | self.seed = seed 41 | self.parent = parent 42 | self.node_name = node_name 43 | 44 | if graph is not None: 45 | self.nodes_embeddings, labels, documents, documents_embeddings = average( 46 | key=key, 47 | documents=documents, 48 | documents_embeddings=documents_embeddings, 49 | graph=graph, 50 | scoring=scoring, 51 | device=self.device, 52 | ) 53 | else: 54 | self.nodes_embeddings, labels = KMeans( 55 | documents_embeddings=documents_embeddings, 56 | n_clusters=self.branch_balance_factor, 57 | max_iter=max_iter, 58 | n_init=n_init, 59 | seed=self.seed, 60 | device=self.device, 61 | ) 62 | 63 | clusters = collections.defaultdict(list) 64 | for document, embedding, group in zip(documents, documents_embeddings, labels): 65 | clusters[group].append((document, embedding)) 66 | 67 | if n_jobs == 1: 68 | self.childs = [ 69 | self.create_child( 70 | level=self.level + 1, 71 | node_name=f"{self.node_name}{group}" if graph is None else group, 72 | key=key, 73 | documents=[document for document, _ in clusters[group]], 74 | documents_embeddings=[ 75 | embedding for _, embedding in clusters[group] 76 | ], 77 | scoring=scoring, 78 | max_iter=max_iter, 79 | n_init=n_init, 80 | create_retrievers=create_retrievers, 81 | graph=graph[group] if graph is not None else None, 82 | n_jobs=n_jobs, 83 | seed=self.seed, 84 | ) 85 | for group in sorted( 86 | clusters, key=lambda key: len(clusters[key]), reverse=True 87 | ) 88 | ] 89 | else: 90 | self.childs = Parallel(n_jobs=n_jobs)( 91 | delayed(function=self.create_child)( 92 | level=self.level + 1, 93 | node_name=f"{self.node_name}{group}" if graph is None else group, 94 | key=key, 95 | documents=[document for document, _ in clusters[group]], 96 | documents_embeddings=[ 97 | embedding for _, embedding in clusters[group] 98 | ], 99 | scoring=scoring, 100 | max_iter=max_iter, 101 | n_init=n_init, 102 | create_retrievers=create_retrievers, 103 | graph=graph[group] if graph is not None else None, 104 | n_jobs=n_jobs, 105 | seed=self.seed, 106 | ) 107 | for group in sorted( 108 | clusters, key=lambda key: len(clusters[key]), reverse=True 109 | ) 110 | ) 111 | 112 | def create_child( 113 | self, 114 | level: int, 115 | node_name: str, 116 | key: str, 117 | documents: list, 118 | documents_embeddings: list, 119 | scoring: SentenceTransformer | TfIdf, 120 | max_iter: int, 121 | n_init: int, 122 | create_retrievers: bool, 123 | graph: dict | list | None, 124 | n_jobs: int, 125 | seed: int, 126 | ) -> None: 127 | """Create a child.""" 128 | child = Leaf if len(documents) <= self.leaf_balance_factor else Node 129 | if graph is not None and isinstance(graph, list): 130 | child = Leaf 131 | 132 | child = child( 133 | level=level, 134 | node_name=node_name, 135 | key=key, 136 | scoring=scoring, 137 | documents=documents, 138 | documents_embeddings=scoring.stack(embeddings=documents_embeddings), 139 | leaf_balance_factor=self.leaf_balance_factor, 140 | branch_balance_factor=self.branch_balance_factor, 141 | device=self.device, 142 | seed=seed, 143 | max_iter=max_iter, 144 | n_init=n_init, 145 | parent=self.node_name, 146 | create_retrievers=create_retrievers, 147 | graph=graph, 148 | n_jobs=n_jobs, 149 | ) 150 | return child 151 | 152 | def __str__(self) -> str: 153 | """String representation of a""" 154 | sep = "\t" 155 | return f"{self.level * sep} node {self.node_name}" 156 | 157 | def nodes_scores( 158 | self, 159 | scoring: SentenceTransformer | TfIdf, 160 | queries_embeddings: torch.Tensor, 161 | **kwargs, 162 | ) -> torch.Tensor: 163 | """Return the scores of the embeddings.""" 164 | return scoring.nodes_scores( 165 | queries_embeddings=queries_embeddings, 166 | nodes_embeddings=self.nodes_embeddings, 167 | ) 168 | 169 | def get_childs_and_scores( 170 | self, 171 | queries: list, 172 | scores: torch.Tensor, 173 | tree_scores: collections.defaultdict, 174 | paths: list | None, 175 | k: int, 176 | ) -> tuple[torch.Tensor, torch.Tensor]: 177 | """Return the childs and scores given matrix of scores.""" 178 | if paths is None: 179 | scores = torch.stack(tensors=scores, dim=1) 180 | scores, childs = torch.topk(input=scores, k=min(k, scores.shape[1]), dim=1) 181 | return childs, scores 182 | 183 | # If paths is not None, we go through the choosen path. 184 | path = [query_path.pop(0) if query_path else leaf for leaf, query_path in paths] 185 | 186 | child_node_names = [] 187 | for node_name in path: 188 | for index, child in enumerate(iterable=self.childs): 189 | if node_name == child.node_name: 190 | child_node_names.append(index) 191 | break 192 | 193 | childs = torch.tensor( 194 | data=child_node_names, 195 | dtype=torch.long, 196 | device=self.device, 197 | ).unsqueeze(dim=1) 198 | 199 | scores = torch.stack( 200 | tensors=[tree_scores[query][node] for query, node in zip(queries, path)], 201 | dim=0, 202 | ).unsqueeze(dim=1) 203 | 204 | return childs, scores 205 | 206 | def search( 207 | self, 208 | queries: list[str], 209 | queries_embeddings: torch.Tensor, 210 | scoring: SentenceTransformer | TfIdf, 211 | k: int, 212 | beam_search_depth: int, 213 | paths: dict[list] | None = None, 214 | tree_scores: collections.defaultdict | None = None, 215 | **kwargs, 216 | ) -> tuple[torch.Tensor, list]: 217 | """Search for the closest embedding. 218 | 219 | Parameters 220 | ---------- 221 | queries: 222 | Queries to search for. 223 | embeddings: 224 | Embeddings to search for. 225 | tree_scores: 226 | Dictionnary of already computed scores in the tree. 227 | documents: 228 | Documents add to the leafs. 229 | k: 230 | node_name of closest embeddings to return. 231 | paths: 232 | Paths to explore. 233 | """ 234 | # We go through the choosen path and we do not explore the tree if paths. 235 | if paths is not None: 236 | k = 1 237 | 238 | # Store childs scores: 239 | if tree_scores is None: 240 | tree_scores = collections.defaultdict(dict) 241 | 242 | scores = [] 243 | for index, node in enumerate(iterable=self.childs): 244 | score = node.nodes_scores( 245 | scoring=scoring, 246 | queries_embeddings=queries_embeddings, 247 | node_embedding=self.nodes_embeddings[index], 248 | ) 249 | 250 | scores.append(score) 251 | for query, query_score in zip(queries, score): 252 | tree_scores[query][node.node_name] = query_score 253 | 254 | childs, scores = self.get_childs_and_scores( 255 | queries=queries, 256 | scores=scores, 257 | tree_scores=tree_scores, 258 | paths=paths, 259 | k=k if self.level == beam_search_depth else 1, 260 | ) 261 | 262 | # Aggregate embeddings by child. 263 | index_embeddings, index_queries, index_paths = ( 264 | collections.defaultdict(list), 265 | collections.defaultdict(list), 266 | collections.defaultdict(list), 267 | ) 268 | 269 | for index, query_childs in enumerate(iterable=childs): 270 | for child in query_childs.tolist(): 271 | index_embeddings[child].append(queries_embeddings[index]) 272 | index_queries[child].append(queries[index]) 273 | 274 | if paths is not None: 275 | index_paths[child].append(paths[index]) 276 | 277 | index_paths = dict(index_paths) 278 | 279 | for (child, embeddings), (_, queries_child) in zip( 280 | index_embeddings.items(), 281 | index_queries.items(), 282 | ): 283 | tree_scores = self.childs[child].search( 284 | queries=queries_child, 285 | queries_embeddings=torch.stack(tensors=embeddings, axis=0), 286 | scoring=scoring, 287 | tree_scores=tree_scores, 288 | paths=index_paths[child] if child in index_paths else None, 289 | k=k, 290 | beam_search_depth=beam_search_depth, 291 | ) 292 | 293 | return tree_scores 294 | 295 | def to_json(self) -> dict: 296 | return {child.node_name: child.to_json() for child in self.childs} 297 | -------------------------------------------------------------------------------- /neural_tree/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | from .colbert import ColBERT 2 | from .sentence_transformer import SentenceTransformer 3 | from .tfidf import TfIdf 4 | 5 | __all__ = ["ColBERT", "SentenceTransformer", "TfIdf"] 6 | -------------------------------------------------------------------------------- /neural_tree/retrievers/colbert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from neural_cherche import rank 3 | 4 | from .. import utils 5 | 6 | __all__ = ["ColBERT"] 7 | 8 | 9 | class ColBERT(rank.ColBERT): 10 | """ColBERT retriever.""" 11 | 12 | def __init__( 13 | self, 14 | key: str, 15 | on: str | list[str], 16 | device: str, 17 | ) -> None: 18 | self.key = key 19 | self.on = on if isinstance(on, list) else [on] 20 | self.embeddings = {} 21 | self.documents = [] 22 | self.device = device 23 | 24 | def add(self, documents_embeddings: dict) -> "ColBERT": 25 | """Add documents embeddings.""" 26 | documents_embeddings = { 27 | document_id: embedding 28 | for document_id, embedding in documents_embeddings.items() 29 | if document_id not in self.embeddings 30 | } 31 | 32 | self.embeddings.update(documents_embeddings) 33 | self.documents.extend( 34 | [{self.key: document_id} for document_id in documents_embeddings.keys()] 35 | ) 36 | 37 | return self 38 | 39 | def __call__( 40 | self, 41 | queries_embeddings: dict[str, torch.Tensor], 42 | batch_size: int = 32, 43 | k: int = None, 44 | tqdm_bar: bool = False, 45 | ) -> list[list[str]]: 46 | """Rank documents givent queries. 47 | 48 | Parameters 49 | ---------- 50 | queries 51 | Queries. 52 | documents 53 | Documents. 54 | queries_embeddings 55 | Queries embeddings. 56 | batch_size 57 | Batch size. 58 | tqdm_bar 59 | Show tqdm bar. 60 | k 61 | Number of documents to retrieve. 62 | """ 63 | scores = [] 64 | 65 | for query, embedding_query in queries_embeddings.items(): 66 | query_scores = [] 67 | 68 | embedding_query = embedding_query.to(device=self.device) 69 | 70 | for batch_documents in utils.batchify( 71 | X=self.documents, 72 | batch_size=batch_size, 73 | tqdm_bar=tqdm_bar, 74 | ): 75 | embeddings_batch_documents = torch.stack( 76 | tensors=[ 77 | self.embeddings[document[self.key]] 78 | for document in batch_documents 79 | ], 80 | dim=0, 81 | ) 82 | 83 | query_documents_scores = torch.einsum( 84 | "sh,bth->bst", 85 | embedding_query, 86 | embeddings_batch_documents, 87 | ) 88 | 89 | query_scores.append( 90 | query_documents_scores.max(dim=2).values.sum(axis=1) 91 | ) 92 | 93 | scores.append(torch.cat(tensors=query_scores, dim=0)) 94 | 95 | return self._rank(scores=scores, documents=self.documents, k=k) 96 | 97 | def _rank( 98 | self, scores: torch.Tensor, documents: list[dict], k: int 99 | ) -> list[list[dict]]: 100 | """Rank documents by scores. 101 | 102 | Parameters 103 | ---------- 104 | scores 105 | Scores. 106 | documents 107 | Documents. 108 | k 109 | Number of documents to retrieve. 110 | """ 111 | ranked = [] 112 | 113 | for query_scores in scores: 114 | top_k = torch.topk( 115 | input=query_scores, 116 | k=min(k, len(self.documents)) if k is not None else len(self.documents), 117 | dim=-1, 118 | ) 119 | 120 | ranked.append( 121 | [ 122 | {**self.documents[indice], "similarity": similarity} 123 | for indice, similarity in zip(top_k.indices, top_k.values.tolist()) 124 | ] 125 | ) 126 | 127 | return ranked 128 | -------------------------------------------------------------------------------- /neural_tree/retrievers/sentence_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | 5 | __all__ = ["SentenceTransformer"] 6 | 7 | 8 | class SentenceTransformer: 9 | """Sentence Transformer retriever. 10 | 11 | Examples 12 | -------- 13 | >>> from neural_tree import retrievers 14 | >>> from sentence_transformers import SentenceTransformer 15 | >>> from pprint import pprint 16 | 17 | >>> model = SentenceTransformer("all-mpnet-base-v2") 18 | 19 | >>> retriever = retrievers.SentenceTransformer(key="id") 20 | 21 | >>> retriever = retriever.add( 22 | ... documents_embeddings={ 23 | ... 0: model.encode("Paris is the capital of France."), 24 | ... 1: model.encode("Berlin is the capital of Germany."), 25 | ... 2: model.encode("Paris and Berlin are European cities."), 26 | ... 3: model.encode("Paris and Berlin are beautiful cities."), 27 | ... } 28 | ... ) 29 | 30 | >>> queries_embeddings = { 31 | ... 0: model.encode("Paris"), 32 | ... 1: model.encode("Berlin"), 33 | ... } 34 | 35 | >>> candidates = retriever(queries_embeddings=queries_embeddings, k=2) 36 | >>> pprint(candidates) 37 | [[{'id': 0, 'similarity': 0.644777984318611}, 38 | {'id': 3, 'similarity': 0.52865785276988}], 39 | [{'id': 1, 'similarity': 0.6901492368348436}, 40 | {'id': 3, 'similarity': 0.5457692206973245}]] 41 | 42 | """ 43 | 44 | def __init__(self, key: str, device: str = "cpu") -> None: 45 | self.key = key 46 | self.device = device 47 | self.index = None 48 | self.documents = [] 49 | 50 | def _build(self, embeddings: np.ndarray) -> Any: 51 | """Build faiss index. 52 | 53 | Parameters 54 | ---------- 55 | index 56 | faiss index. 57 | embeddings 58 | Embeddings of the documents. 59 | 60 | """ 61 | if self.index is None: 62 | try: 63 | import faiss 64 | except: 65 | raise ImportError( 66 | 'Run pip install "neural-tree[cpu]" or pip install "neural-tree[gpu]" to run faiss on cpu / gpu.' 67 | ) 68 | self.index = faiss.IndexFlatL2(embeddings.shape[1]) 69 | if self.device == "cuda": 70 | try: 71 | self.index = faiss.index_cpu_to_gpu( 72 | faiss.StandardGpuResources(), 0, self.index 73 | ) 74 | except: 75 | raise ImportError( 76 | 'Run pip install "neural-tree[gpu]" to run faiss on gpu.' 77 | ) 78 | 79 | if not self.index.is_trained and embeddings: 80 | self.index.train(embeddings) 81 | 82 | self.index.add(embeddings) 83 | return self.index 84 | 85 | def add(self, documents_embeddings: dict[int, np.ndarray]) -> "SentenceTransformer": 86 | """Add documents to the faiss index.""" 87 | self.documents.extend(list(documents_embeddings.keys())) 88 | self.index = self._build( 89 | embeddings=np.array(object=list(documents_embeddings.values())) 90 | ) 91 | return self 92 | 93 | def __call__( 94 | self, 95 | queries_embeddings: dict[int, np.ndarray], 96 | k: int | None = 100, 97 | **kwargs, 98 | ) -> list: 99 | """Retrieve documents.""" 100 | if k is None: 101 | k = len(self.documents) 102 | 103 | k = min(k, len(self.documents)) 104 | queries_embeddings = np.array(object=list(queries_embeddings.values())) 105 | distances, indexes = self.index.search(queries_embeddings, k) 106 | matchs = np.take(a=self.documents, indices=np.where(indexes < 0, 0, indexes)) 107 | rank = [] 108 | for distance, index, match in zip(distances, indexes, matchs): 109 | rank.append( 110 | [ 111 | { 112 | self.key: m, 113 | "similarity": 1 / (1 + d), 114 | } 115 | for d, idx, m in zip(distance, index, match) 116 | if idx > -1 117 | ] 118 | ) 119 | 120 | return rank 121 | -------------------------------------------------------------------------------- /neural_tree/retrievers/tfidf.py: -------------------------------------------------------------------------------- 1 | from neural_cherche import retrieve 2 | from scipy.sparse import csr_matrix 3 | from sklearn.feature_extraction.text import TfidfVectorizer 4 | 5 | __all__ = ["TfIdf"] 6 | 7 | 8 | class TfIdf(retrieve.TfIdf): 9 | """TfIdf retriever""" 10 | 11 | def __init__( 12 | self, 13 | key: str, 14 | on: list[str], 15 | ) -> None: 16 | super().__init__(key=key, on=on, fit=False) 17 | self.tfidf = None 18 | 19 | def encode_documents( 20 | self, documents: list[dict], model: TfidfVectorizer 21 | ) -> dict[str, csr_matrix]: 22 | """Encode queries into sparse matrix.""" 23 | content = [ 24 | " ".join([doc.get(field, "") for field in self.on]) for doc in documents 25 | ] 26 | 27 | # matrix is a csr matrix of shape (n_documents, n_features) 28 | matrix = model.transform(raw_documents=content) 29 | return {document[self.key]: row for document, row in zip(documents, matrix)} 30 | 31 | def encode_queries( 32 | self, queries: list[str], model: TfidfVectorizer 33 | ) -> dict[str, csr_matrix]: 34 | """Encode queries into sparse matrix.""" 35 | matrix = model.transform(raw_documents=queries) 36 | return {query: row for query, row in zip(queries, matrix)} 37 | -------------------------------------------------------------------------------- /neural_tree/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseScore 2 | from .colbert import ColBERT 3 | from .sentence_transformer import SentenceTransformer 4 | from .tfidf import TfIdf 5 | 6 | __all__ = ["BaseScore", "ColBERT", "SentenceTransformer", "TfIdf"] 7 | -------------------------------------------------------------------------------- /neural_tree/scoring/base.py: -------------------------------------------------------------------------------- 1 | """Base class for scoring functions.""" 2 | from abc import ABC, abstractmethod 3 | from typing import Any 4 | 5 | import numpy as np 6 | import torch 7 | from scipy import sparse 8 | 9 | __all__ = ["BaseScore"] 10 | 11 | 12 | class BaseScore(ABC): 13 | """Base class for scoring functions.""" 14 | 15 | def __init__(self) -> None: 16 | pass 17 | 18 | @abstractmethod 19 | def distinct_documents_encoder(self) -> bool: 20 | """Return True if the encoder is distinct for documents and nodes.""" 21 | 22 | @abstractmethod 23 | def transform_queries( 24 | self, queries: list[str] 25 | ) -> sparse.csr_matrix | np.ndarray | dict: 26 | """Transform queries to embeddings.""" 27 | 28 | @abstractmethod 29 | def transform_documents( 30 | self, documents: list[dict] 31 | ) -> sparse.csr_matrix | np.ndarray | dict: 32 | """Transform documents to embeddings.""" 33 | 34 | @abstractmethod 35 | def get_retriever(self) -> Any: 36 | """Create a retriever""" 37 | 38 | @abstractmethod 39 | def encode_queries_for_retrieval( 40 | self, queries: list[str] 41 | ) -> sparse.csr_matrix | np.ndarray | dict: 42 | """Encode queries for retrieval.""" 43 | 44 | @abstractmethod 45 | def convert_to_tensor( 46 | embeddings: sparse.csr_matrix | np.ndarray, device: str 47 | ) -> torch.Tensor: 48 | """Transform sparse matrix to tensor.""" 49 | 50 | @abstractmethod 51 | def nodes_scores( 52 | queries_embeddings: torch.Tensor, nodes_embeddings: torch.Tensor 53 | ) -> torch.Tensor: 54 | """Score between queries and nodes embeddings.""" 55 | 56 | @abstractmethod 57 | def leaf_scores( 58 | queries_embeddings: torch.Tensor, leaf_embedding: torch.Tensor 59 | ) -> torch.Tensor: 60 | """Return the scores of the embeddings.""" 61 | 62 | @abstractmethod 63 | def stack( 64 | embeddings: list[sparse.csr_matrix | np.ndarray | dict], 65 | ) -> sparse.csr_matrix | np.ndarray | dict: 66 | """Stack list of embeddings.""" 67 | -------------------------------------------------------------------------------- /neural_tree/scoring/colbert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from neural_cherche import models 4 | 5 | from ..retrievers import ColBERT as colbert_retriever 6 | from ..utils import batchify, set_env 7 | from .base import BaseScore 8 | 9 | __all__ = ["ColBERT"] 10 | 11 | 12 | class ColBERT(BaseScore): 13 | """TfIdf scoring function. 14 | 15 | Examples 16 | -------- 17 | >>> from neural_tree import trees, scoring 18 | >>> from neural_cherche import models 19 | >>> from sklearn.feature_extraction.text import TfidfVectorizer 20 | >>> from pprint import pprint 21 | >>> import torch 22 | 23 | >>> _ = torch.manual_seed(42) 24 | 25 | >>> documents = [ 26 | ... {"id": 0, "text": "Paris is the capital of France."}, 27 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 28 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 29 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 30 | ... ] 31 | 32 | >>> model = models.ColBERT( 33 | ... model_name_or_path="sentence-transformers/all-mpnet-base-v2", 34 | ... embedding_size=128, 35 | ... max_length_document=96, 36 | ... max_length_query=32, 37 | ... ) 38 | 39 | >>> tree = trees.ColBERTTree( 40 | ... key="id", 41 | ... on="text", 42 | ... model=model, 43 | ... documents=documents, 44 | ... leaf_balance_factor=1, 45 | ... branch_balance_factor=2, 46 | ... n_jobs=1, 47 | ... ) 48 | 49 | >>> print(tree) 50 | node 1 51 | node 10 52 | leaf 100 53 | leaf 101 54 | node 11 55 | leaf 110 56 | leaf 111 57 | 58 | >>> tree.leafs_to_documents 59 | {'100': [0], '101': [1], '110': [2], '111': [3]} 60 | 61 | >>> candidates = tree( 62 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."], 63 | ... k_leafs=2, 64 | ... k=2, 65 | ... ) 66 | 67 | >>> candidates["scores"] 68 | array([[28.12037659, 18.32332611], 69 | [29.28324509, 21.38923264]]) 70 | 71 | >>> candidates["leafs"] 72 | array([['100', '101'], 73 | ['110', '111']], dtype='>> pprint(candidates["tree_scores"]) 76 | [{'10': tensor(28.1204), 77 | '100': tensor(28.1204), 78 | '101': tensor(18.3233), 79 | '11': tensor(20.9327)}, 80 | {'10': tensor(21.6886), 81 | '11': tensor(29.2832), 82 | '110': tensor(29.2832), 83 | '111': tensor(21.3892)}] 84 | 85 | >>> pprint(candidates["documents"]) 86 | [[{'id': 0, 'leaf': '100', 'similarity': 28.120376586914062}, 87 | {'id': 1, 'leaf': '101', 'similarity': 18.323326110839844}], 88 | [{'id': 2, 'leaf': '110', 'similarity': 29.283245086669922}, 89 | {'id': 3, 'leaf': '111', 'similarity': 21.389232635498047}]] 90 | 91 | """ 92 | 93 | def __init__( 94 | self, 95 | key: str, 96 | on: list | str, 97 | documents: list, 98 | model: models.ColBERT = None, 99 | device: str = "cpu", 100 | **kwargs, 101 | ) -> None: 102 | """Initialize the scoring function.""" 103 | set_env() 104 | self.key = key 105 | self.on = [on] if isinstance(on, str) else on 106 | self.model = model 107 | self.device = device 108 | 109 | @property 110 | def distinct_documents_encoder(self) -> bool: 111 | """Return True if the encoder is distinct for documents and nodes.""" 112 | return False 113 | 114 | def transform_queries( 115 | self, queries: list[str], batch_size: int, tqdm_bar: bool, *kwargs 116 | ) -> torch.Tensor: 117 | """Transform queries to embeddings.""" 118 | queries_embeddings = [] 119 | 120 | for batch in batchify(X=queries, batch_size=batch_size, tqdm_bar=tqdm_bar): 121 | queries_embeddings.append( 122 | self.model.encode(texts=batch, query_mode=True)["embeddings"] 123 | ) 124 | 125 | return ( 126 | queries_embeddings[0].to(self.device) 127 | if len(queries_embeddings) == 1 128 | else torch.cat(tensors=queries_embeddings, dim=0).to(device=self.device) 129 | ) 130 | 131 | def transform_documents( 132 | self, documents: list[dict], batch_size: int, tqdm_bar: bool, **kwargs 133 | ) -> torch.Tensor: 134 | """Transform documents to embeddings.""" 135 | documents_embeddings = [] 136 | 137 | for batch in batchify( 138 | X=[ 139 | " ".join([document[field] for field in self.on]) 140 | for document in documents 141 | ], 142 | batch_size=batch_size, 143 | tqdm_bar=tqdm_bar, 144 | ): 145 | documents_embeddings.append( 146 | self.model.encode(texts=batch, query_mode=False)["embeddings"] 147 | ) 148 | 149 | return ( 150 | documents_embeddings[0].to(self.device) 151 | if len(documents_embeddings) == 1 152 | else torch.cat(tensors=documents_embeddings, dim=0).to(device=self.device) 153 | ) 154 | 155 | def get_retriever(self) -> None: 156 | """Create a retriever""" 157 | return colbert_retriever(key=self.key, on=self.on, device=self.device) 158 | 159 | def encode_queries_for_retrieval(self, queries: list[str]) -> None: 160 | """Encode queries for retrieval.""" 161 | pass 162 | 163 | @staticmethod 164 | def convert_to_tensor( 165 | embeddings: np.ndarray | torch.Tensor, device: str 166 | ) -> torch.Tensor: 167 | """Transform sparse matrix to tensor.""" 168 | if isinstance(embeddings, np.ndarray): 169 | return torch.tensor(data=embeddings, device=device, dtype=torch.float32) 170 | return embeddings.to(device=device) 171 | 172 | @staticmethod 173 | def nodes_scores( 174 | queries_embeddings: torch.Tensor, nodes_embeddings: torch.Tensor 175 | ) -> torch.Tensor: 176 | """Score between queries and nodes embeddings.""" 177 | return torch.stack( 178 | tensors=[ 179 | torch.einsum( 180 | "sh,bth->bst", 181 | query_embedding, 182 | nodes_embeddings, 183 | ) 184 | .max(dim=2) 185 | .values.sum(axis=1) 186 | .max(dim=0) 187 | .values 188 | for query_embedding in queries_embeddings 189 | ], 190 | dim=0, 191 | ) 192 | 193 | @staticmethod 194 | def leaf_scores( 195 | queries_embeddings: torch.Tensor, leaf_embedding: torch.Tensor 196 | ) -> torch.Tensor: 197 | """Return the scores of the embeddings.""" 198 | return torch.stack( 199 | tensors=[ 200 | torch.einsum( 201 | "sh,th->st", 202 | query_embedding, 203 | leaf_embedding, 204 | ) 205 | .max(dim=1) 206 | .values.sum() 207 | for query_embedding in queries_embeddings 208 | ], 209 | dim=0, 210 | ) 211 | 212 | def stack(self, embeddings: list[torch.Tensor | np.ndarray]) -> torch.Tensor: 213 | """Stack list of embeddings.""" 214 | if isinstance(embeddings, np.ndarray): 215 | return self.convert_to_tensor( 216 | embeddings=embeddings, device=self.model.device 217 | ) 218 | return torch.stack(tensors=embeddings, dim=0) 219 | 220 | @staticmethod 221 | def average(embeddings: torch.Tensor) -> torch.Tensor: 222 | """Average embeddings.""" 223 | return embeddings.mean(axis=0) 224 | -------------------------------------------------------------------------------- /neural_tree/scoring/sentence_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sentence_transformers import SentenceTransformer 4 | 5 | from ..retrievers import SentenceTransformer as SentenceTransformerRetriever 6 | from ..utils import set_env 7 | 8 | __all__ = ["SentenceTransformer"] 9 | 10 | 11 | class SentenceTransformer: 12 | """Sentence Transformer scoring function. 13 | 14 | Examples 15 | -------- 16 | >>> from neural_tree import trees, scoring 17 | >>> from sentence_transformers import SentenceTransformer 18 | >>> from pprint import pprint 19 | 20 | >>> documents = [ 21 | ... {"id": 0, "text": "Paris is the capital of France."}, 22 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 23 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 24 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 25 | ... ] 26 | 27 | >>> tree = trees.Tree( 28 | ... key="id", 29 | ... documents=documents, 30 | ... scoring=scoring.SentenceTransformer(key="id", on=["text"], model=SentenceTransformer("all-mpnet-base-v2")), 31 | ... leaf_balance_factor=1, 32 | ... branch_balance_factor=2, 33 | ... n_jobs=1, 34 | ... ) 35 | 36 | >>> print(tree) 37 | node 1 38 | node 11 39 | node 110 40 | leaf 1100 41 | leaf 1101 42 | leaf 111 43 | leaf 10 44 | 45 | >>> candidates = tree( 46 | ... queries=["paris", "berlin"], 47 | ... k_leafs=2, 48 | ... ) 49 | 50 | >>> candidates["scores"] 51 | array([[0.72453916, 0.60635257], 52 | [0.58386189, 0.57546711]]) 53 | 54 | >>> candidates["leafs"] 55 | array([['111', '10'], 56 | ['1101', '1100']], dtype='>> pprint(candidates["tree_scores"]) 59 | [{'10': tensor(0.6064), 60 | '11': tensor(0.7245), 61 | '110': tensor(0.5542), 62 | '1100': tensor(0.5403), 63 | '1101': tensor(0.5542), 64 | '111': tensor(0.7245)}, 65 | {'10': tensor(0.5206), 66 | '11': tensor(0.5797), 67 | '110': tensor(0.5839), 68 | '1100': tensor(0.5755), 69 | '1101': tensor(0.5839), 70 | '111': tensor(0.4026)}] 71 | 72 | >>> pprint(candidates["documents"]) 73 | [[{'id': 0, 'leaf': '111', 'similarity': 0.6447779347587058}, 74 | {'id': 1, 'leaf': '10', 'similarity': 0.43175890864117644}], 75 | [{'id': 3, 'leaf': '1101', 'similarity': 0.545769273959571}, 76 | {'id': 2, 'leaf': '1100', 'similarity': 0.54081365990618}]] 77 | 78 | """ 79 | 80 | def __init__( 81 | self, 82 | key: str, 83 | on: str | list, 84 | model: SentenceTransformer, 85 | device: str = "cpu", 86 | faiss_device: str = "cpu", 87 | ) -> None: 88 | set_env() 89 | self.key = key 90 | self.on = [on] if isinstance(on, str) else on 91 | self.model = model 92 | self.device = device 93 | self.faiss_device = faiss_device 94 | 95 | @property 96 | def distinct_documents_encoder(self) -> bool: 97 | """Return True if the encoder is distinct for documents and nodes.""" 98 | return False 99 | 100 | def transform_queries( 101 | self, queries: list[str], batch_size: int, **kwargs 102 | ) -> np.ndarray: 103 | """Transform queries to embeddings.""" 104 | return self.model.encode(queries, batch_size=batch_size) 105 | 106 | def transform_documents( 107 | self, documents: list[dict], batch_size: int, **kwargs 108 | ) -> np.ndarray: 109 | """Transform documents to embeddings.""" 110 | return self.model.encode( 111 | [ 112 | " ".join([document[field] for field in self.on]) 113 | for document in documents 114 | ], 115 | batch_size=batch_size, 116 | ) 117 | 118 | def get_retriever(self) -> None: 119 | """Create a retriever""" 120 | return SentenceTransformerRetriever(key=self.key, device=self.faiss_device) 121 | 122 | @staticmethod 123 | def encode_queries_for_retrieval(queries: list[str]) -> None: 124 | """Encode queries for retrieval.""" 125 | pass 126 | 127 | @staticmethod 128 | def convert_to_tensor(embeddings: np.ndarray, device: str) -> torch.Tensor: 129 | """Convert numpy array to torch tensor.""" 130 | return torch.tensor(data=embeddings, device=device, dtype=torch.float32) 131 | 132 | @staticmethod 133 | def nodes_scores( 134 | queries_embeddings: torch.Tensor, nodes_embeddings: torch.Tensor 135 | ) -> torch.Tensor: 136 | """Score between queries and nodes embeddings.""" 137 | return torch.max( 138 | input=torch.mm(input=queries_embeddings, mat2=nodes_embeddings.T), 139 | dim=1, 140 | ).values 141 | 142 | @staticmethod 143 | def leaf_scores( 144 | queries_embeddings: torch.Tensor, leaf_embedding: torch.Tensor 145 | ) -> torch.Tensor: 146 | """Computes scores between query and leaf embedding.""" 147 | return queries_embeddings @ leaf_embedding.T 148 | 149 | @staticmethod 150 | def stack(embeddings: list[np.ndarray]) -> np.ndarray: 151 | """Stack embeddings.""" 152 | return ( 153 | np.vstack(tup=embeddings) 154 | if len(embeddings) > 1 155 | else embeddings[0].reshape(1, -1) 156 | ) 157 | 158 | @staticmethod 159 | def average(embeddings: np.ndarray) -> np.ndarray: 160 | """Average embeddings.""" 161 | return np.mean(a=embeddings, axis=0) 162 | -------------------------------------------------------------------------------- /neural_tree/scoring/tfidf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import sparse 4 | from sklearn.feature_extraction.text import TfidfVectorizer 5 | 6 | from ..retrievers import TfIdf as tfidf_retriever 7 | from ..utils import set_env 8 | from .base import BaseScore 9 | 10 | __all__ = ["TfIdf"] 11 | 12 | 13 | class TfIdf(BaseScore): 14 | """TfIdf scoring function. 15 | 16 | Examples 17 | -------- 18 | >>> from neural_tree import trees, scoring 19 | >>> from sklearn.feature_extraction.text import TfidfVectorizer 20 | >>> from pprint import pprint 21 | 22 | >>> documents = [ 23 | ... {"id": 0, "text": "Paris is the capital of France."}, 24 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 25 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 26 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 27 | ... ] 28 | 29 | >>> tree = trees.Tree( 30 | ... key="id", 31 | ... documents=documents, 32 | ... scoring=scoring.TfIdf(key="id", on=["text"], documents=documents), 33 | ... leaf_balance_factor=1, 34 | ... branch_balance_factor=2, 35 | ... ) 36 | 37 | >>> print(tree) 38 | node 1 39 | node 10 40 | leaf 100 41 | leaf 101 42 | node 11 43 | leaf 110 44 | leaf 111 45 | 46 | >>> tree.leafs_to_documents 47 | {'100': [0], '101': [1], '110': [2], '111': [3]} 48 | 49 | >>> candidates = tree( 50 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."], 51 | ... k_leafs=2, 52 | ... k=2, 53 | ... ) 54 | 55 | >>> candidates["scores"] 56 | array([[0.99999994, 0.63854915], 57 | [0.99999994, 0.72823119]]) 58 | 59 | >>> candidates["leafs"] 60 | array([['100', '101'], 61 | ['110', '111']], dtype='>> pprint(candidates["tree_scores"]) 64 | [{'10': tensor(1.0000), 65 | '100': tensor(1.0000), 66 | '101': tensor(0.6385), 67 | '11': tensor(0.1076)}, 68 | {'10': tensor(0.1076), 69 | '11': tensor(1.0000), 70 | '110': tensor(1.0000), 71 | '111': tensor(0.7282)}] 72 | 73 | >>> pprint(candidates["documents"]) 74 | [[{'id': 0, 'leaf': '100', 'similarity': 0.9999999999999978}, 75 | {'id': 1, 'leaf': '101', 'similarity': 0.39941742405759667}], 76 | [{'id': 2, 'leaf': '110', 'similarity': 0.9999999999999978}, 77 | {'id': 3, 'leaf': '111', 'similarity': 0.5385719658738707}]] 78 | 79 | """ 80 | 81 | def __init__( 82 | self, 83 | key: str, 84 | on: list | str, 85 | documents: list, 86 | tfidf_nodes: TfidfVectorizer | None = None, 87 | tfidf_documents: TfidfVectorizer | None = None, 88 | **kwargs, 89 | ) -> None: 90 | """Initialize the scoring function.""" 91 | set_env() 92 | 93 | self.key = key 94 | self.on = [on] if isinstance(on, str) else on 95 | 96 | if tfidf_nodes is None: 97 | tfidf_nodes = TfidfVectorizer() 98 | 99 | if tfidf_documents is None: 100 | tfidf_documents = TfidfVectorizer( 101 | lowercase=True, ngram_range=(3, 7), analyzer="char_wb" 102 | ) 103 | 104 | self.tfidf_nodes = tfidf_nodes.fit( 105 | raw_documents=[ 106 | " ".join([document[field] for field in self.on]) 107 | for document in documents 108 | ], 109 | ) 110 | 111 | self.model = tfidf_documents.fit( 112 | raw_documents=[ 113 | " ".join([document[field] for field in self.on]) 114 | for document in documents 115 | ], 116 | ) 117 | 118 | @property 119 | def distinct_documents_encoder(self) -> bool: 120 | """Return True if the encoder is distinct for documents and nodes.""" 121 | return True 122 | 123 | def transform_queries(self, queries: list[str], **kwargs) -> sparse.csr_matrix: 124 | """Transform queries to embeddings.""" 125 | return self.tfidf_nodes.transform(raw_documents=queries) 126 | 127 | def transform_documents(self, documents: list[dict], **kwargs) -> sparse.csr_matrix: 128 | """Transform documents to embeddings.""" 129 | return self.tfidf_nodes.transform( 130 | raw_documents=[ 131 | " ".join([document[field] for field in self.on]) 132 | for document in documents 133 | ], 134 | ) 135 | 136 | def get_retriever(self) -> tfidf_retriever: 137 | """Create a retriever""" 138 | return tfidf_retriever( 139 | key=self.key, 140 | on=self.on, 141 | ) 142 | 143 | def encode_queries_for_retrieval(self, queries: list[str]) -> sparse.csr_matrix: 144 | """Encode queries for retrieval.""" 145 | return self.model.transform(raw_documents=queries) 146 | 147 | @staticmethod 148 | def convert_to_tensor(embeddings: sparse.csr_matrix, device: str) -> torch.Tensor: 149 | """Transform sparse matrix to tensor.""" 150 | embeddings = embeddings.tocoo() 151 | return torch.sparse_coo_tensor( 152 | indices=torch.tensor( 153 | data=np.vstack(tup=(embeddings.row, embeddings.col)), 154 | dtype=torch.long, 155 | device=device, 156 | ), 157 | values=torch.tensor( 158 | data=embeddings.data, dtype=torch.float32, device=device 159 | ), 160 | size=torch.Size(embeddings.shape), 161 | ).to(device=device) 162 | 163 | @staticmethod 164 | def nodes_scores( 165 | queries_embeddings: torch.Tensor, nodes_embeddings: torch.Tensor 166 | ) -> torch.Tensor: 167 | """Score between queries and nodes embeddings.""" 168 | return torch.max( 169 | input=torch.mm( 170 | input=queries_embeddings, 171 | mat2=nodes_embeddings.T, 172 | ).to_dense(), 173 | dim=1, 174 | ).values 175 | 176 | @staticmethod 177 | def leaf_scores( 178 | queries_embeddings: torch.Tensor, leaf_embedding: torch.Tensor 179 | ) -> torch.Tensor: 180 | """Return the scores of the embeddings.""" 181 | return ( 182 | torch.mm(input=queries_embeddings, mat2=leaf_embedding.unsqueeze(dim=0).T) 183 | .to_dense() 184 | .flatten() 185 | ) 186 | 187 | @staticmethod 188 | def stack(embeddings: list[sparse.csr_matrix]) -> sparse.csr_matrix: 189 | """Stack list of embeddings.""" 190 | return ( 191 | sparse.vstack(blocks=embeddings) if len(embeddings) > 1 else embeddings[0] 192 | ) 193 | 194 | @staticmethod 195 | def average(embeddings: sparse.csr_matrix) -> sparse.csr_matrix: 196 | """Average embeddings.""" 197 | return embeddings.mean(axis=0) 198 | -------------------------------------------------------------------------------- /neural_tree/trees/__init__.py: -------------------------------------------------------------------------------- 1 | from .colbert import ColBERT 2 | from .sentence_transformer import SentenceTransformer 3 | from .tfidf import TfIdf 4 | from .tree import Tree 5 | 6 | __all__ = ["ColBERT", "SentenceTransformer", "TfIdf", "Tree"] 7 | -------------------------------------------------------------------------------- /neural_tree/trees/colbert.py: -------------------------------------------------------------------------------- 1 | from neural_cherche import models 2 | from sentence_transformers import SentenceTransformer as SentenceTransformerModel 3 | 4 | from ..clustering import get_mapping_nodes_documents 5 | from ..scoring import ColBERT as scoring_ColBERT 6 | from .sentence_transformer import SentenceTransformer 7 | from .tfidf import TfIdf 8 | from .tree import Tree 9 | 10 | __all__ = ["ColBERT"] 11 | 12 | 13 | class ColBERT(Tree): 14 | """ColBERT retriever. 15 | 16 | Parameters 17 | ---------- 18 | key 19 | Key to identify the documents. 20 | on 21 | List of columns to use for the retrieval. 22 | model 23 | ColBERT model. 24 | sentence_transformer 25 | SentenceTransformer model in order to perform the hierarchical clustering. If 26 | None, the hierarchical clustering is performed with a TfIdf model. 27 | documents 28 | List of documents to index. 29 | graph 30 | Existing graph to initialize the tree. 31 | leaf_balance_factor 32 | Balance factor for the leafs. Once there is less than `leaf_balance_factor` 33 | documents in a node, the node becomes a leaf. 34 | branch_balance_factor 35 | Balance factor for the branches. The number of children of a node is limited to 36 | `branch_balance_factor`. 37 | device 38 | Device to use for the retrieval. 39 | n_jobs 40 | Number of jobs to use when creating the tree. If -1, all CPUs are used. 41 | batch_size 42 | Batch size to use when creating the tree. 43 | max_iter 44 | Maximum number of iterations to perform with Kmeans algorithm when creating the 45 | tree. 46 | n_init 47 | Number of time the KMeans algorithm will be run with different centroid seeds. 48 | create_retrievers 49 | Whether to create the retrievers or not. If False, the tree is only created and 50 | the __call__ method will only output relevant leafs and scores rather than 51 | ranked documents. 52 | tqdm_bar 53 | Whether to show the tqdm bar when creating the tree. 54 | seed 55 | Random seed. 56 | 57 | """ 58 | 59 | def __init__( 60 | self, 61 | key: str, 62 | on: str | list[str], 63 | model: models.ColBERT, 64 | sentence_transformer: SentenceTransformerModel | None = None, 65 | documents: list[dict] | None = None, 66 | graph: dict | None = None, 67 | leaf_balance_factor: int = 100, 68 | branch_balance_factor: int = 5, 69 | device: str = "cpu", 70 | n_jobs: int = -1, 71 | batch_size: int = 32, 72 | max_iter: int = 3000, 73 | n_init: int = 100, 74 | create_retrievers: bool = True, 75 | tqdm_bar: bool = True, 76 | seed: int = 42, 77 | ) -> None: 78 | """Create a tree with the TfIdf scoring.""" 79 | if graph is not None: 80 | documents = get_mapping_nodes_documents(graph=graph) 81 | elif graph is None and sentence_transformer is None: 82 | index = TfIdf( 83 | key=key, 84 | on=on, 85 | documents=documents, 86 | leaf_balance_factor=leaf_balance_factor, 87 | branch_balance_factor=branch_balance_factor, 88 | create_retrievers=False, 89 | n_jobs=n_jobs, 90 | max_iter=max_iter, 91 | n_init=n_init, 92 | seed=seed, 93 | ) 94 | graph = index.to_json() 95 | else: 96 | index = SentenceTransformer( 97 | key=key, 98 | on=on, 99 | documents=documents, 100 | model=sentence_transformer, 101 | leaf_balance_factor=leaf_balance_factor, 102 | branch_balance_factor=branch_balance_factor, 103 | n_jobs=n_jobs, 104 | create_retrievers=False, 105 | max_iter=max_iter, 106 | n_init=n_init, 107 | batch_size=batch_size, 108 | seed=seed, 109 | ) 110 | graph = index.to_json() 111 | 112 | scoring = scoring_ColBERT( 113 | key=key, 114 | on=on, 115 | documents=documents, 116 | model=model, 117 | device=device, 118 | ) 119 | 120 | # We computes embeddings here because we need documents contents. 121 | documents_embeddings = scoring.transform_documents( 122 | documents=documents, 123 | model=model, 124 | device=device, 125 | batch_size=batch_size, 126 | tqdm_bar=tqdm_bar, 127 | ) 128 | 129 | documents_embeddings = { 130 | document[key]: embedding 131 | for document, embedding in zip(documents, documents_embeddings) 132 | } 133 | 134 | super().__init__( 135 | key=key, 136 | graph=graph, 137 | documents=documents, 138 | scoring=scoring, 139 | documents_embeddings=documents_embeddings, 140 | leaf_balance_factor=leaf_balance_factor, 141 | branch_balance_factor=branch_balance_factor, 142 | device=device, 143 | n_jobs=1, 144 | create_retrievers=create_retrievers, 145 | max_iter=max_iter, 146 | n_init=n_init, 147 | seed=seed, 148 | ) 149 | -------------------------------------------------------------------------------- /neural_tree/trees/sentence_transformer.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | 3 | from ..clustering import get_mapping_nodes_documents 4 | from ..scoring import SentenceTransformer as scoring_SenrenceTransformer 5 | from .tree import Tree 6 | 7 | __all__ = ["SentenceTransformer"] 8 | 9 | 10 | class SentenceTransformer(Tree): 11 | """Tree with Sentence Transformer scoring function. 12 | 13 | Parameters 14 | ---------- 15 | key 16 | Key to identify the documents. 17 | on 18 | List of columns to use for the retrieval. 19 | model 20 | Sentence Transformer model. 21 | documents 22 | List of documents to index. 23 | graph 24 | Existing graph to initialize the tree. 25 | leaf_balance_factor 26 | Balance factor for the leafs. Once there is less than `leaf_balance_factor` 27 | documents in a node, the node becomes a leaf. 28 | branch_balance_factor 29 | Balance factor for the branches. The number of children of a node is limited to 30 | `branch_balance_factor`. 31 | device 32 | Device to use for the retrieval. 33 | n_jobs 34 | Number of jobs to use when creating the tree. If -1, all CPUs are used. 35 | batch_size 36 | Batch size to use when creating the tree. 37 | max_iter 38 | Maximum number of iterations to perform with Kmeans algorithm when creating the 39 | tree. 40 | n_init 41 | Number of time the KMeans algorithm will be run with different centroid seeds. 42 | create_retrievers 43 | Whether to create the retrievers or not. If False, the tree is only created and 44 | the __call__ method will only output relevant leafs and scores rather than 45 | ranked documents. 46 | tqdm_bar 47 | Whether to show the tqdm bar when creating the tree. 48 | seed 49 | Random seed. 50 | 51 | Examples 52 | -------- 53 | >>> from neural_tree import trees 54 | >>> from sentence_transformers import SentenceTransformer 55 | >>> from pprint import pprint 56 | 57 | >>> documents = [ 58 | ... {"id": 0, "text": "Paris is the capital of France."}, 59 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 60 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 61 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 62 | ... ] 63 | 64 | >>> tree = trees.SentenceTransformer( 65 | ... key="id", 66 | ... on=["text"], 67 | ... documents=documents, 68 | ... model=SentenceTransformer("all-mpnet-base-v2"), 69 | ... leaf_balance_factor=2, 70 | ... branch_balance_factor=2, 71 | ... device="cpu", 72 | ... ) 73 | 74 | >>> tree = tree.add(documents=documents) 75 | 76 | >>> print(tree) 77 | node 1 78 | node 11 79 | leaf 110 80 | leaf 111 81 | leaf 10 82 | 83 | >>> tree.leafs_to_documents 84 | {'110': [2, 3, 1], '111': [0], '10': [1]} 85 | 86 | >>> candidates = tree( 87 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."], 88 | ... k_leafs=2, 89 | ... k=1, 90 | ... ) 91 | 92 | >>> candidates["scores"] 93 | array([[1. , 0.76908004], 94 | [0.88792843, 0.82272887]]) 95 | 96 | >>> candidates["leafs"] 97 | array([['111', '10'], 98 | ['110', '10']], dtype='>> pprint(candidates["tree_scores"]) 101 | [{'10': tensor(0.7691, device='mps:0'), 102 | '11': tensor(1., device='mps:0'), 103 | '110': tensor(0.6536, device='mps:0'), 104 | '111': tensor(1., device='mps:0')}, 105 | {'10': tensor(0.8227, device='mps:0'), 106 | '11': tensor(0.8879, device='mps:0'), 107 | '110': tensor(0.8879, device='mps:0'), 108 | '111': tensor(0.6923, device='mps:0')}] 109 | 110 | >>> pprint(candidates["documents"]) 111 | [[{'id': 0, 'leaf': '111', 'similarity': 1.0}], 112 | [{'id': 2, 'leaf': '110', 'similarity': 1.0}]] 113 | 114 | """ 115 | 116 | def __init__( 117 | self, 118 | key: str, 119 | on: str | list[str], 120 | model: SentenceTransformer, 121 | documents: list[dict] | None = None, 122 | documents_embeddings: dict | None = None, 123 | graph: dict | None = None, 124 | leaf_balance_factor: int = 100, 125 | branch_balance_factor: int = 5, 126 | device: str = "cpu", 127 | faiss_device: str = "cpu", 128 | batch_size: int = 32, 129 | n_jobs: int = -1, 130 | max_iter: int = 3000, 131 | n_init: int = 100, 132 | create_retrievers: bool = True, 133 | seed: int = 42, 134 | ) -> None: 135 | """Create a tree with the TfIdf scoring.""" 136 | if graph is not None: 137 | documents = get_mapping_nodes_documents(graph=graph) 138 | 139 | super().__init__( 140 | key=key, 141 | documents=documents, 142 | graph=graph, 143 | documents_embeddings=documents_embeddings, 144 | scoring=scoring_SenrenceTransformer( 145 | key=key, 146 | on=on, 147 | model=model, 148 | device=device, 149 | faiss_device=faiss_device, 150 | ), 151 | leaf_balance_factor=leaf_balance_factor, 152 | branch_balance_factor=branch_balance_factor, 153 | device=device, 154 | batch_size=batch_size, 155 | n_jobs=n_jobs, 156 | max_iter=max_iter, 157 | n_init=n_init, 158 | create_retrievers=create_retrievers, 159 | seed=seed, 160 | ) 161 | -------------------------------------------------------------------------------- /neural_tree/trees/tfidf.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import TfidfVectorizer 2 | 3 | from ..clustering import get_mapping_nodes_documents 4 | from ..scoring import TfIdf as scoring_TfIdf 5 | from .tree import Tree 6 | 7 | __all__ = ["TfIdf"] 8 | 9 | 10 | class TfIdf(Tree): 11 | """Tree with tfidf scoring function. 12 | 13 | Parameters 14 | ---------- 15 | key 16 | Key to identify the documents. 17 | on 18 | List of columns to use for the retrieval. 19 | tfidf_nodes 20 | TfidfVectorizer for the nodes. 21 | tfidf_documents 22 | TfidfVectorizer for the documents. 23 | documents 24 | List of documents to index. 25 | graph 26 | Existing graph to initialize the tree. 27 | leaf_balance_factor 28 | Balance factor for the leafs. Once there is less than `leaf_balance_factor` 29 | documents in a node, the node becomes a leaf. 30 | branch_balance_factor 31 | Balance factor for the branches. The number of children of a node is limited to 32 | `branch_balance_factor`. 33 | device 34 | Device to use for the retrieval. 35 | n_jobs 36 | Number of jobs to use when creating the tree. If -1, all CPUs are used. 37 | max_iter 38 | Maximum number of iterations to perform with Kmeans algorithm when creating the 39 | tree. 40 | n_init 41 | Number of time the KMeans algorithm will be run with different centroid seeds. 42 | create_retrievers 43 | Whether to create the retrievers or not. If False, the tree is only created and 44 | the __call__ method will only output relevant leafs and scores rather than 45 | ranked documents. 46 | tqdm_bar 47 | Whether to show the tqdm bar when creating the tree. 48 | seed 49 | Random seed. 50 | 51 | Examples 52 | -------- 53 | >>> from neural_tree import trees 54 | >>> from pprint import pprint 55 | 56 | >>> documents = [ 57 | ... {"id": 0, "text": "Paris is the capital of France."}, 58 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 59 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 60 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 61 | ... ] 62 | 63 | >>> tree = trees.TfIdf( 64 | ... key="id", 65 | ... on="text", 66 | ... documents=documents, 67 | ... leaf_balance_factor=2, 68 | ... branch_balance_factor=2, 69 | ... ) 70 | 71 | >>> tree = tree.add(documents=documents) 72 | 73 | >>> print(tree) 74 | node 1 75 | leaf 10 76 | leaf 11 77 | 78 | >>> tree.leafs_to_documents 79 | {'10': [0, 1], '11': [2, 3]} 80 | 81 | >>> candidates = tree( 82 | ... queries=["Paris is the capital of France.", "Paris and Berlin are European cities."], 83 | ... k_leafs=2, 84 | ... k=2, 85 | ... ) 86 | 87 | >>> candidates["scores"] 88 | array([[0.81927449, 0.10763316], 89 | [0.8641156 , 0.10763316]]) 90 | 91 | >>> candidates["leafs"] 92 | array([['10', '11'], 93 | ['11', '10']], dtype='>> pprint(candidates["tree_scores"]) 96 | [{'10': tensor(0.8193), '11': tensor(0.1076)}, 97 | {'10': tensor(0.1076), '11': tensor(0.8641)}] 98 | 99 | >>> pprint(candidates["documents"]) 100 | [[{'id': 0, 'leaf': '10', 'similarity': 0.9999999999999978}, 101 | {'id': 1, 'leaf': '10', 'similarity': 0.39941742405759667}], 102 | [{'id': 2, 'leaf': '11', 'similarity': 0.9999999999999978}, 103 | {'id': 3, 'leaf': '11', 'similarity': 0.5385719658738707}]] 104 | 105 | """ 106 | 107 | def __init__( 108 | self, 109 | key: str, 110 | on: str | list[str], 111 | documents: list[dict] | None = None, 112 | graph: dict | None = None, 113 | leaf_balance_factor: int = 100, 114 | branch_balance_factor: int = 5, 115 | tfidf_nodes: TfidfVectorizer | None = None, 116 | tfidf_documents: TfidfVectorizer | None = None, 117 | device: str = "cpu", 118 | n_jobs: int = -1, 119 | max_iter: int = 3000, 120 | n_init: int = 100, 121 | create_retrievers: bool = True, 122 | seed: int = 42, 123 | ) -> None: 124 | """Create a tree with the TfIdf scoring.""" 125 | if graph is not None: 126 | documents = get_mapping_nodes_documents(graph=graph) 127 | 128 | super().__init__( 129 | key=key, 130 | documents=documents, 131 | graph=graph, 132 | scoring=scoring_TfIdf( 133 | key=key, 134 | on=on, 135 | documents=documents, 136 | tfidf_nodes=tfidf_nodes, 137 | tfidf_documents=tfidf_documents, 138 | device=device, 139 | ), 140 | leaf_balance_factor=leaf_balance_factor, 141 | branch_balance_factor=branch_balance_factor, 142 | device=device, 143 | n_jobs=n_jobs, 144 | create_retrievers=create_retrievers, 145 | max_iter=max_iter, 146 | n_init=n_init, 147 | seed=seed, 148 | ) 149 | -------------------------------------------------------------------------------- /neural_tree/trees/tree.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import random 4 | from functools import lru_cache 5 | from typing import Generator 6 | 7 | import numpy as np 8 | import torch 9 | from scipy import sparse 10 | 11 | from ..leafs import Leaf 12 | from ..nodes import Node 13 | from ..scoring import SentenceTransformer, TfIdf 14 | from ..utils import sanity_check 15 | 16 | __all__ = ["Tree"] 17 | 18 | 19 | class Tree(torch.nn.Module): 20 | """Tree based index for information retrieval. 21 | 22 | Examples 23 | -------- 24 | >>> from neural_tree import trees, scoring, clustering 25 | >>> from pprint import pprint 26 | 27 | >>> device = "cpu" 28 | 29 | >>> queries = [ 30 | ... "Paris is the capital of France.", 31 | ... "Berlin", 32 | ... "Berlin", 33 | ... "Paris is the capital of France." 34 | ... ] 35 | 36 | >>> documents = [ 37 | ... {"id": 0, "text": "Paris is the capital of France."}, 38 | ... {"id": 1, "text": "Berlin is the capital of Germany."}, 39 | ... {"id": 2, "text": "Paris and Berlin are European cities."}, 40 | ... {"id": 3, "text": "Paris and Berlin are beautiful cities."}, 41 | ... ] 42 | 43 | >>> tree = trees.Tree( 44 | ... key="id", 45 | ... documents=documents, 46 | ... scoring=scoring.TfIdf(key="id", on=["text"], documents=documents), 47 | ... leaf_balance_factor=1, 48 | ... branch_balance_factor=2, 49 | ... device=device, 50 | ... n_jobs=1, 51 | ... ) 52 | 53 | >>> print(tree) 54 | node 1 55 | node 10 56 | leaf 100 57 | leaf 101 58 | node 11 59 | leaf 110 60 | leaf 111 61 | 62 | >>> tree.documents_to_leafs 63 | {0: ['100'], 1: ['101'], 2: ['110'], 3: ['111']} 64 | 65 | >>> tree.leafs_to_documents 66 | {'100': [0], '101': [1], '110': [2], '111': [3]} 67 | 68 | >>> candidates = tree( 69 | ... queries=queries, 70 | ... k=2, 71 | ... k_leafs=2, 72 | ... ) 73 | 74 | >>> pprint(candidates["documents"]) 75 | [[{'id': 0, 'leaf': '100', 'similarity': 0.9999999999999978}, 76 | {'id': 1, 'leaf': '101', 'similarity': 0.39941742405759667}], 77 | [{'id': 3, 'leaf': '111', 'similarity': 0.3523828592933607}, 78 | {'id': 2, 'leaf': '110', 'similarity': 0.348413283355546}], 79 | [{'id': 3, 'leaf': '111', 'similarity': 0.3523828592933607}, 80 | {'id': 2, 'leaf': '110', 'similarity': 0.348413283355546}], 81 | [{'id': 0, 'leaf': '100', 'similarity': 0.9999999999999978}, 82 | {'id': 1, 'leaf': '101', 'similarity': 0.39941742405759667}]] 83 | 84 | >>> pprint(candidates["tree_scores"]) 85 | [{'10': tensor(1.0000), 86 | '100': tensor(1.0000), 87 | '101': tensor(0.6385), 88 | '11': tensor(0.1076)}, 89 | {'10': tensor(0.3235), 90 | '11': tensor(0.3327), 91 | '110': tensor(0.3327), 92 | '111': tensor(0.3327)}, 93 | {'10': tensor(0.3235), 94 | '11': tensor(0.3327), 95 | '110': tensor(0.3327), 96 | '111': tensor(0.3327)}, 97 | {'10': tensor(1.0000), 98 | '100': tensor(1.0000), 99 | '101': tensor(0.6385), 100 | '11': tensor(0.1076)}] 101 | 102 | 103 | >>> candidates = tree( 104 | ... queries=queries, 105 | ... leafs=["110", "111", "111", "111"], 106 | ... ) 107 | 108 | >>> pprint(candidates["documents"]) 109 | [[{'id': 2, 'leaf': '110', 'similarity': 0.1036216271728989}], 110 | [{'id': 3, 'leaf': '111', 'similarity': 0.3523828592933607}], 111 | [{'id': 3, 'leaf': '111', 'similarity': 0.3523828592933607}], 112 | [{'id': 3, 'leaf': '111', 'similarity': 0.09981163726061484}]] 113 | 114 | >>> optimizer = torch.optim.AdamW(lr=3e-5, params=list(tree.parameters())) 115 | 116 | >>> loss = tree.loss( 117 | ... queries=queries, 118 | ... documents=documents, 119 | ... ) 120 | 121 | >>> loss.backward() 122 | >>> optimizer.step() 123 | >>> assert loss.item() > 0 124 | 125 | >>> graph = tree.to_json() 126 | >>> pprint(graph) 127 | {1: {'10': {'100': [{'id': 0}], '101': [{'id': 1}]}, 128 | '11': {'110': [{'id': 2}], '111': [{'id': 3}]}}} 129 | 130 | >>> graph = {'sport': {'football': {'bayern': [{'id': 2, 'text': 'bayern football team'}], 131 | ... 'psg': [{'id': 1, 'text': 'psg football team'}]}, 132 | ... 'rugby': {'toulouse': [{'id': 3, 'text': 'toulouse rugby team'}], 133 | ... 'ville rose': [{'id': 3, 'text': 'toulouse rugby team'}, 134 | ... {'id': 4, 'text': 'tfc football team'}]}}} 135 | 136 | >>> documents = clustering.get_mapping_nodes_documents(graph=graph) 137 | 138 | >>> tree = trees.Tree( 139 | ... key="id", 140 | ... documents=documents, 141 | ... scoring=scoring.TfIdf(key="id", on=["text"], documents=documents), 142 | ... leaf_balance_factor=1, 143 | ... branch_balance_factor=2, 144 | ... device=device, 145 | ... graph=graph, 146 | ... n_jobs=1, 147 | ... ) 148 | 149 | >>> tree.documents_to_leafs 150 | {3: ['ville rose', 'toulouse'], 4: ['ville rose'], 2: ['bayern'], 1: ['psg']} 151 | 152 | >>> tree.leafs_to_documents 153 | {'ville rose': [3, 4], 'toulouse': [3], 'bayern': [2], 'psg': [1]} 154 | 155 | >>> print(tree) 156 | node sport 157 | node rugby 158 | leaf ville rose 159 | leaf toulouse 160 | node football 161 | leaf bayern 162 | leaf psg 163 | 164 | >>> candidates = tree( 165 | ... queries=["psg", "toulouse"], 166 | ... k=2, 167 | ... k_leafs=2, 168 | ... ) 169 | 170 | >>> pprint(candidates["documents"]) 171 | [[{'id': 1, 'leaf': 'psg', 'similarity': 0.5255159378077358}], 172 | [{'id': 3, 'leaf': 'ville rose', 'similarity': 0.7865788511708137}, 173 | {'id': 3, 'leaf': 'toulouse', 'similarity': 0.7865788511708137}]] 174 | 175 | References 176 | ---------- 177 | [Li et al., 2023](https://arxiv.org/pdf/2206.02743.pdf) 178 | 179 | """ 180 | 181 | def __init__( 182 | self, 183 | key: str, 184 | scoring: SentenceTransformer | TfIdf, 185 | documents: list, 186 | leaf_balance_factor: int, 187 | branch_balance_factor: int, 188 | device, 189 | seed: int, 190 | max_iter: int, 191 | n_init: int, 192 | n_jobs: int, 193 | batch_size: int = None, 194 | create_retrievers: bool = True, 195 | graph: dict | None = None, 196 | documents_embeddings: dict | None = None, 197 | ) -> None: 198 | super(Tree, self).__init__() 199 | self.key = key 200 | self.device = device 201 | self.seed = seed 202 | self.scoring = scoring 203 | self.create_retrievers = create_retrievers 204 | self.node_name = 1 205 | 206 | # Sanity check over input parameters 207 | sanity_check( 208 | branch_balance_factor=branch_balance_factor, 209 | leaf_balance_factor=leaf_balance_factor, 210 | graph=graph, 211 | documents=documents, 212 | ) 213 | 214 | if graph is not None: 215 | for node_name in graph.keys(): 216 | self.node_name = node_name 217 | break 218 | 219 | if documents_embeddings is None: 220 | documents_embeddings = self.scoring.transform_documents( 221 | documents=documents, batch_size=batch_size 222 | ) 223 | else: 224 | documents_embeddings = self.scoring.stack( 225 | embeddings=[ 226 | documents_embeddings[document[self.key]] for document in documents 227 | ] 228 | ) 229 | 230 | self.tree = Node( 231 | level=0, 232 | node_name=self.node_name, 233 | key=self.key, 234 | documents=documents, 235 | documents_embeddings=documents_embeddings, 236 | scoring=scoring, 237 | leaf_balance_factor=leaf_balance_factor, 238 | branch_balance_factor=branch_balance_factor, 239 | device=self.device, 240 | seed=self.seed, 241 | n_jobs=n_jobs, 242 | create_retrievers=create_retrievers, 243 | graph=graph[self.node_name] if graph is not None else None, 244 | parent=0, 245 | max_iter=max_iter, 246 | n_init=n_init, 247 | ) 248 | 249 | self.documents_to_leafs, self.leafs_to_documents = self.get_documents_leafs() 250 | self.negative_samples = self.get_negative_samples() 251 | self._paths = self.get_paths() 252 | self.mapping_leafs = self.get_mapping_leafs() 253 | 254 | def __str__(self) -> str: 255 | """Return the tree as string.""" 256 | repr = "" 257 | for node in self.nodes(): 258 | repr += f"{node}\n" 259 | return repr[:-1] 260 | 261 | def get_mapping_leafs(self) -> dict: 262 | """Returns mapping between leafs and their number.""" 263 | mapping_leafs = {} 264 | for leaf in self.nodes(): 265 | if isinstance(leaf, Leaf): 266 | mapping_leafs[leaf.node_name] = leaf 267 | return mapping_leafs 268 | 269 | def get_documents_leafs(self) -> dict: 270 | """Returns mapping between documents ids and leafs and vice versa.""" 271 | documents_to_leafs, leafs_to_documents = ( 272 | collections.defaultdict(list), 273 | collections.defaultdict(list), 274 | ) 275 | for node in self.nodes(): 276 | if isinstance(node, Leaf): 277 | for document in node.documents: 278 | documents_to_leafs[document].append(node.node_name) 279 | leafs_to_documents[node.node_name].append(document) 280 | 281 | return dict(documents_to_leafs), dict(leafs_to_documents) 282 | 283 | def get_paths(self) -> list[torch.Tensor]: 284 | """Map leafs to their nodes.""" 285 | self.paths.cache_clear() 286 | paths = collections.defaultdict(list) 287 | for node in self.nodes(): 288 | if isinstance(node, Leaf): 289 | leaf = node 290 | for _ in range(leaf.level): 291 | node = self.get_parent(node_name=node.node_name) 292 | if node.level != 0: 293 | paths[leaf.node_name].append(node.node_name) 294 | paths[leaf.node_name].reverse() 295 | return dict(paths) 296 | 297 | @lru_cache(maxsize=1000) 298 | def paths(self, leaf: int) -> dict: 299 | return copy.deepcopy(x=self._paths[leaf]) 300 | 301 | def get_negative_samples(self) -> dict: 302 | """Return negative samples build from the tree.""" 303 | levels = collections.defaultdict(list) 304 | for node in self.nodes(): 305 | if node.node_name != self.node_name: 306 | levels[node.level].append((node.node_name, node.parent)) 307 | negatives = {} 308 | for _, nodes in levels.items(): 309 | for node, node_parent in nodes: 310 | negatives[node] = [ 311 | negative_node 312 | for (negative_node, negative_node_parent) in nodes 313 | if (negative_node != node) and (negative_node_parent == node_parent) 314 | ] 315 | return negatives 316 | 317 | def parameters(self) -> Generator: 318 | """Return the parameters of the tree.""" 319 | for node in self.nodes(): 320 | if isinstance(node, Leaf): 321 | continue 322 | yield node.nodes_embeddings 323 | 324 | def nodes( 325 | self, 326 | node: Node | Leaf = None, 327 | ) -> Generator: 328 | """Iterate over the nodes of the tree.""" 329 | if node is None: 330 | node = self.tree 331 | yield node 332 | 333 | for node in node.childs: 334 | yield node 335 | 336 | if not isinstance(node, Leaf): 337 | yield from self.nodes(node=node) 338 | 339 | def get_parent(self, node_name: int | str) -> int | str: 340 | """Get parent nodes of a specifc node. 341 | 342 | Parameters 343 | ---------- 344 | number: 345 | Number of the node. 346 | """ 347 | for node in self.nodes(): 348 | if isinstance(node, Leaf): 349 | continue 350 | 351 | for child in node.childs: 352 | if child.node_name == node_name: 353 | return node 354 | 355 | return None 356 | 357 | @torch.no_grad() 358 | def __call__( 359 | self, 360 | queries: list[str], 361 | k: int = 100, 362 | k_leafs: int = 1, 363 | leafs: list[int] | None = None, 364 | score_documents: bool = True, 365 | beam_search_depth: int = 1, 366 | queries_embeddings: torch.Tensor | np.ndarray | dict = None, 367 | batch_size: int = 32, 368 | tqdm_bar: bool = True, 369 | ) -> tuple[torch.Tensor, list, list]: 370 | """Search for the closest embedding. 371 | 372 | Parameters 373 | ---------- 374 | queries: 375 | Queries to search for. 376 | embeddings: 377 | Embeddings to search for. 378 | k: 379 | Number of leafs to search for. 380 | leafs: 381 | Leaf to search for. 382 | score_documents: 383 | Weather to score documents or not. 384 | """ 385 | if queries_embeddings is None: 386 | queries_embeddings = self.scoring.transform_queries( 387 | queries=queries, 388 | batch_size=batch_size, 389 | tqdm_bar=tqdm_bar, 390 | ) 391 | 392 | tree_scores = self._search( 393 | queries=queries, 394 | queries_embeddings=queries_embeddings, 395 | k_leafs=k_leafs, 396 | leafs=leafs, 397 | beam_search_depth=beam_search_depth, 398 | batch_size=batch_size, 399 | tqdm_bar=tqdm_bar, 400 | ) 401 | 402 | if not leafs: 403 | leafs_scores = [ 404 | { 405 | leaf: score.item() 406 | for leaf, score in sorted( 407 | query_scores.items(), 408 | key=lambda item: item[1].item(), 409 | reverse=True, 410 | ) 411 | if leaf in self.mapping_leafs 412 | } 413 | for query_scores in tree_scores 414 | ] 415 | else: 416 | leafs_scores = [ 417 | {leaf: query_scores[leaf].item()} 418 | for leaf, query_scores in zip(leafs, tree_scores) 419 | ] 420 | 421 | # We may not have k leafs for each query, so we take the minimum. 422 | if k_leafs > 1: 423 | k_leafs = min( 424 | min([len(query_leafs_scores) for query_leafs_scores in leafs_scores]), 425 | k_leafs, 426 | ) 427 | 428 | candidates = { 429 | "leafs": np.array( 430 | object=[ 431 | list(query_leafs_scores.keys())[:k_leafs] 432 | for query_leafs_scores in leafs_scores 433 | ] 434 | ), 435 | "scores": np.array( 436 | object=[ 437 | list(query_leafs_scores.values())[:k_leafs] 438 | for query_leafs_scores in leafs_scores 439 | ] 440 | ), 441 | "tree_scores": tree_scores, 442 | } 443 | 444 | if not score_documents or not self.create_retrievers: 445 | return candidates 446 | 447 | if self.scoring.distinct_documents_encoder: 448 | queries_embeddings = self.scoring.encode_queries_for_retrieval( 449 | queries=queries, 450 | ) 451 | 452 | leafs_queries, leafs_embeddings = ( 453 | collections.defaultdict(list), 454 | collections.defaultdict(list), 455 | ) 456 | for query, (query_leafs, embedding) in enumerate( 457 | iterable=zip(candidates["leafs"], queries_embeddings) 458 | ): 459 | for leaf in query_leafs: 460 | leafs_queries[leaf].append(query) 461 | leafs_embeddings[leaf].append(embedding) 462 | 463 | documents = collections.defaultdict(list) 464 | for leaf in leafs_queries: 465 | leaf_documents = self.mapping_leafs[leaf]( 466 | queries_embeddings={ 467 | query: embedding 468 | for query, embedding in enumerate(iterable=leafs_embeddings[leaf]) 469 | }, 470 | k=k, 471 | ) 472 | for query, query_documents in zip(leafs_queries[leaf], leaf_documents): 473 | documents[query].extend(query_documents) 474 | 475 | # Sort documents if k > 1 476 | candidates["documents"] = [ 477 | sorted( 478 | documents[query], 479 | key=lambda document: document["similarity"], 480 | reverse=True, 481 | )[: min(k, len(documents[query]))] 482 | if k_leafs > 1 483 | else documents[query] 484 | for query in range(len(queries)) 485 | ] 486 | 487 | return candidates 488 | 489 | def empty(self) -> "Tree": 490 | """Empty the tree.""" 491 | for node in self.nodes(): 492 | if isinstance(node, Leaf): 493 | node.empty() 494 | return self 495 | 496 | def _search( 497 | self, 498 | queries: list[str], 499 | k_leafs: int = 1, 500 | leafs: list[int] | None = None, 501 | beam_search_depth: int = 1, 502 | queries_embeddings: torch.Tensor | np.ndarray | dict = None, 503 | batch_size: int = 32, 504 | tqdm_bar: bool = True, 505 | ) -> tuple[torch.Tensor, list, list]: 506 | """Search for the closest embedding with gradient. 507 | 508 | Parameters 509 | ---------- 510 | queries: 511 | Queries to search for. 512 | embeddings: 513 | Embeddings to search for. 514 | """ 515 | if queries_embeddings is None: 516 | queries_embeddings = self.scoring.transform_queries( 517 | queries=queries, 518 | batch_size=batch_size, 519 | tqdm_bar=tqdm_bar, 520 | ) 521 | 522 | queries_embeddings = self.scoring.convert_to_tensor( 523 | embeddings=queries_embeddings, device=self.device 524 | ) 525 | 526 | paths = ( 527 | [(leaf, copy.copy(self.paths(leaf=leaf))) for leaf in leafs] 528 | if leafs is not None 529 | else None 530 | ) 531 | 532 | tree_scores = self.tree.search( 533 | queries=[index for index, _ in enumerate(iterable=queries)], 534 | queries_embeddings=queries_embeddings, 535 | scoring=self.scoring, 536 | k=k_leafs, 537 | paths=paths, 538 | beam_search_depth=beam_search_depth, 539 | ) 540 | 541 | return list(tree_scores.values()) 542 | 543 | @torch.no_grad() 544 | def add( 545 | self, 546 | documents: list, 547 | documents_embeddings: np.ndarray | sparse.csr_matrix | dict = None, 548 | k: int = 1, 549 | documents_to_leafs: dict = None, 550 | batch_size: int = 32, 551 | tqdm_bar: bool = True, 552 | ) -> "Tree": 553 | """Add documents to the tree. 554 | 555 | Parameters 556 | ---------- 557 | documents: 558 | Documents to add to the tree. 559 | embeddings: 560 | Embeddings of the documents. 561 | k: 562 | Number of leafs to add the documents to. 563 | """ 564 | if documents_embeddings is None: 565 | documents_embeddings = self.scoring.transform_documents( 566 | documents=documents, 567 | batch_size=batch_size, 568 | tqdm_bar=tqdm_bar, 569 | ) 570 | 571 | if documents_to_leafs is None: 572 | leafs = self( 573 | queries=[document[self.key] for document in documents], 574 | queries_embeddings=documents_embeddings, 575 | k=k, 576 | score_documents=False, 577 | tqdm_bar=False, 578 | )["leafs"].tolist() 579 | else: 580 | leafs = [documents_to_leafs[document[self.key]] for document in documents] 581 | 582 | documents_leafs, embeddings_leafs = ( 583 | collections.defaultdict(list), 584 | collections.defaultdict(dict), 585 | ) 586 | 587 | for document, embedding, document_leafs in zip( 588 | documents, documents_embeddings, leafs 589 | ): 590 | for leaf in document_leafs: 591 | documents_leafs[leaf].append(document) 592 | embeddings_leafs[leaf][document[self.key]] = embedding 593 | 594 | for leaf, embeddings in embeddings_leafs.items(): 595 | self.mapping_leafs[leaf].add( 596 | documents=documents_leafs[leaf], 597 | documents_embeddings=None 598 | if self.scoring.distinct_documents_encoder 599 | else embeddings_leafs[leaf], 600 | scoring=self.scoring, 601 | ) 602 | 603 | self.documents_to_leafs, self.leafs_to_documents = self.get_documents_leafs() 604 | self.negative_samples = self.get_negative_samples() 605 | return self 606 | 607 | def loss( 608 | self, 609 | queries: list[str], 610 | documents: list[dict], 611 | batch_size: int = 32, 612 | ) -> None: 613 | """Computes the loss of the tree given the input batch. 614 | 615 | Parameters 616 | ---------- 617 | queries_embeddings: 618 | Embeddings of the queries. 619 | documents: 620 | Documents ids that where added to the tree. 621 | """ 622 | leafs = [ 623 | random.choice(seq=list(self.documents_to_leafs[document[self.key]])) 624 | for document in documents 625 | ] 626 | 627 | tree_scores = self._search( 628 | queries=queries, 629 | k_leafs=1, 630 | leafs=leafs, 631 | batch_size=batch_size, 632 | tqdm_bar=False, 633 | ) 634 | 635 | loss, size = 0, 0 636 | cross_entropy = torch.nn.CrossEntropyLoss() 637 | for leaf, query_scores in zip(leafs, tree_scores): 638 | query_level_scores = [query_scores[leaf]] 639 | 640 | for negative_node in self.negative_samples[leaf]: 641 | query_level_scores.append(query_scores[negative_node]) 642 | 643 | query_level_scores = torch.stack( 644 | tensors=query_level_scores, dim=0 645 | ).unsqueeze(dim=0) 646 | 647 | size += 1 648 | loss += cross_entropy( 649 | query_level_scores, 650 | torch.zeros( 651 | query_level_scores.shape[0], 652 | device=self.device, 653 | dtype=torch.long, 654 | ), 655 | ) 656 | 657 | for node in copy.copy(self.paths(leaf=leaf)): 658 | query_level_scores = [query_scores[node]] 659 | for negative_node in self.negative_samples[node]: 660 | query_level_scores.append(query_scores[negative_node]) 661 | 662 | query_level_scores = torch.stack( 663 | tensors=query_level_scores, dim=0 664 | ).unsqueeze(dim=0) 665 | 666 | size += 1 667 | loss += cross_entropy( 668 | query_level_scores, 669 | torch.zeros( 670 | query_level_scores.shape[0], 671 | device=self.device, 672 | dtype=torch.long, 673 | ), 674 | ) 675 | 676 | return loss / size 677 | 678 | def to_json(self) -> dict: 679 | """Return the tree as a graph.""" 680 | return {self.node_name: self.tree.to_json()} 681 | -------------------------------------------------------------------------------- /neural_tree/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .batchify import batchify 2 | from .evaluate import evaluate, leafs_precision 3 | from .iter import iter 4 | from .sanity_check import sanity_check 5 | from .set_env import set_env 6 | 7 | __all__ = [ 8 | "batchify", 9 | "evaluate", 10 | "leafs_precision", 11 | "iter", 12 | "sanity_check", 13 | "set_env", 14 | ] 15 | -------------------------------------------------------------------------------- /neural_tree/utils/batchify.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | __all__ = ["batchify"] 4 | 5 | 6 | def batchify( 7 | X: list[str], batch_size: int, desc: str = "", tqdm_bar: bool = True 8 | ) -> list: 9 | batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)] 10 | 11 | if tqdm_bar: 12 | for batch in tqdm.tqdm( 13 | batchs, 14 | position=0, 15 | total=1 + len(X) // batch_size, 16 | desc=desc, 17 | ): 18 | yield batch 19 | else: 20 | yield from batchs 21 | -------------------------------------------------------------------------------- /neural_tree/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | __all__ = ["evaluate", "leafs_precision"] 4 | 5 | 6 | def leafs_precision( 7 | key: str, 8 | documents: list, 9 | leafs: np.ndarray, 10 | documents_to_leaf: dict, 11 | ) -> float: 12 | """Calculate the precision of the leafs.""" 13 | recall = 0 14 | for leafs_query, document in zip(leafs.tolist(), documents): 15 | for leaf_document in documents_to_leaf[document[key]]: 16 | if leafs_query[0] == leaf_document: 17 | recall += 1 18 | break 19 | return recall / len(leafs) 20 | 21 | 22 | def evaluate( 23 | scores: list[list[dict]], 24 | qrels: dict, 25 | queries_ids: list[str], 26 | metrics: list = [], 27 | key: str = "id", 28 | ) -> dict[str, float]: 29 | """Evaluate candidates matchs. 30 | 31 | Parameters 32 | ---------- 33 | matchs 34 | Matchs. 35 | qrels 36 | Qrels. 37 | queries 38 | index of queries of qrels. 39 | k 40 | Number of documents to retrieve. 41 | metrics 42 | Metrics to compute. 43 | 44 | Examples 45 | -------- 46 | >>> from neural_cherche import models, retrieve, utils 47 | >>> import torch 48 | 49 | >>> _ = torch.manual_seed(42) 50 | 51 | >>> model = models.Splade( 52 | ... model_name_or_path="distilbert-base-uncased", 53 | ... device="cpu", 54 | ... ) 55 | 56 | >>> documents, queries_ids, queries, qrels = utils.load_beir( 57 | ... "scifact", 58 | ... split="test", 59 | ... ) 60 | 61 | >>> documents = documents[:10] 62 | 63 | >>> retriever = retrieve.Splade( 64 | ... key="id", 65 | ... on=["title", "text"], 66 | ... model=model 67 | ... ) 68 | 69 | >>> documents_embeddings = retriever.encode_documents( 70 | ... documents=documents, 71 | ... batch_size=1, 72 | ... ) 73 | 74 | >>> documents_embeddings = retriever.add( 75 | ... documents_embeddings=documents_embeddings, 76 | ... ) 77 | 78 | >>> queries_embeddings = retriever.encode_queries( 79 | ... queries=queries, 80 | ... batch_size=1, 81 | ... ) 82 | 83 | >>> scores = retriever( 84 | ... queries_embeddings=queries_embeddings, 85 | ... k=30, 86 | ... batch_size=1, 87 | ... ) 88 | 89 | >>> utils.evaluate( 90 | ... scores=scores, 91 | ... qrels=qrels, 92 | ... queries_ids=queries_ids, 93 | ... metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"] 94 | ... ) 95 | {'map': 0.0033333333333333335, 'ndcg@10': 0.0033333333333333335, 'ndcg@100': 0.0033333333333333335, 'recall@10': 0.0033333333333333335, 'recall@100': 0.0033333333333333335} 96 | 97 | """ 98 | from ranx import Qrels, Run 99 | from ranx import evaluate as ranx_evaluate 100 | 101 | qrels = Qrels(qrels=qrels) 102 | 103 | run_dict = { 104 | id_query: { 105 | match[key]: 1 - (rank / len(query_matchs)) 106 | for rank, match in enumerate(iterable=query_matchs) 107 | } 108 | for id_query, query_matchs in zip(queries_ids, scores) 109 | } 110 | 111 | run = Run(run=run_dict) 112 | 113 | if not metrics: 114 | metrics = ["ndcg@10"] + [f"hits@{k}" for k in [1, 2, 3, 4, 5, 10]] 115 | 116 | return ranx_evaluate( 117 | qrels=qrels, 118 | run=run, 119 | metrics=metrics, 120 | make_comparable=True, 121 | ) 122 | -------------------------------------------------------------------------------- /neural_tree/utils/iter.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import typing 4 | 5 | import tqdm 6 | 7 | from .batchify import batchify 8 | 9 | __all__ = ["iter"] 10 | 11 | 12 | def iter( 13 | queries, documents, batch_size=512, epochs: int = 1, shuffle=True, tqdm_bar=True 14 | ) -> typing.Generator: 15 | """Iterate over the dataset. 16 | 17 | Parameters 18 | ---------- 19 | queries 20 | List of queries paired with documents. 21 | documents 22 | List of documents paired with queries. 23 | batch_size 24 | Size of the batch. 25 | epochs 26 | Number of epochs. 27 | """ 28 | step = 0 29 | queries = copy.deepcopy(x=queries) 30 | documents = copy.deepcopy(x=documents) 31 | 32 | bar = tqdm.tqdm(iterable=range(epochs), position=0) if tqdm_bar else range(epochs) 33 | 34 | for _ in bar: 35 | if shuffle: 36 | queries_documents = list(zip(queries, documents)) 37 | random.shuffle(x=queries_documents) 38 | queries, documents = zip(*queries_documents) 39 | 40 | for batch_queries, batch_documents in zip( 41 | batchify(X=queries, batch_size=batch_size, tqdm_bar=False), 42 | batchify(X=documents, batch_size=batch_size, tqdm_bar=False), 43 | ): 44 | yield step, batch_queries, batch_documents 45 | step += 1 46 | -------------------------------------------------------------------------------- /neural_tree/utils/sanity_check.py: -------------------------------------------------------------------------------- 1 | __all__ = ["sanity_check"] 2 | 3 | 4 | def sanity_check( 5 | branch_balance_factor: int, leaf_balance_factor: int, graph: dict, documents: list 6 | ) -> None: 7 | """Check if the input is valid.""" 8 | if branch_balance_factor < 2: 9 | raise ValueError("Branch balance factor must be greater than 1.") 10 | 11 | if leaf_balance_factor < 1: 12 | raise ValueError("Leaf balance factor must be greater than 0.") 13 | 14 | if graph is not None: 15 | if len(graph.keys()) > 1: 16 | raise ValueError("Graph must have only one root node.") 17 | 18 | if documents is None and graph is None: 19 | raise ValueError("You must provide either documents or an existing graph.") 20 | -------------------------------------------------------------------------------- /neural_tree/utils/set_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __all__ = ["set_env"] 4 | 5 | 6 | def set_env() -> None: 7 | """Set environment variables.""" 8 | os.environ["HF_HOME"] = os.environ["HOME"] + "/.cache/huggingface" 9 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 10 | -------------------------------------------------------------------------------- /neural_tree/version.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::DeprecationWarning 4 | ignore::RuntimeWarning 5 | ignore::UserWarning 6 | addopts = 7 | --doctest-modules 8 | --verbose 9 | -ra 10 | --cov-config=.coveragerc 11 | -m "not web and not slow" 12 | doctest_optionflags = NORMALIZE_WHITESPACE NUMBER 13 | norecursedirs = 14 | build 15 | docs 16 | node_modules 17 | markers = 18 | web: tests that require using the Internet 19 | slow: tests that take a long time to run -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Inside of setup.cfg 2 | [metadata] 3 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | from neural_tree.__version__ import __version__ 4 | 5 | with open(file="README.md", mode="r", encoding="utf-8") as fh: 6 | long_description = fh.read() 7 | 8 | base_packages = [ 9 | "torch >= 1.13", 10 | "tqdm >= 4.66", 11 | "transformers >= 4.34.0", 12 | "sentence-transformers >= 2.2.2", 13 | "neural-cherche >= 1.1.0", 14 | "scikit-learn >= 1.4.0", 15 | ] 16 | 17 | eval = ["ranx >= 0.3.16", "beir >= 2.0.0"] 18 | 19 | dev = ["mkdocs-material == 9.2.8"] 20 | 21 | setuptools.setup( 22 | name="neural-tree", 23 | version=f"{__version__}", 24 | license="MIT", 25 | author="Raphael Sourty", 26 | author_email="raphael.sourty@gmail.com", 27 | description="Neural-Tree", 28 | long_description=long_description, 29 | long_description_content_type="text/markdown", 30 | url="https://github.com/raphaelsty/neural-tree", 31 | download_url="https://github.com/user/neural-tree/archive/v_01.tar.gz", 32 | keywords=[ 33 | "tree search", 34 | "neural search", 35 | "information retrieval", 36 | "semantic search", 37 | "colbert", 38 | "tree", 39 | ], 40 | packages=setuptools.find_packages(), 41 | install_requires=base_packages, 42 | extras_require={"eval": base_packages + eval, "dev": base_packages + eval + dev}, 43 | classifiers=[ 44 | "Programming Language :: Python :: 3", 45 | "License :: OSI Approved :: MIT License", 46 | "Operating System :: OS Independent", 47 | ], 48 | python_requires=">=3.6", 49 | ) 50 | --------------------------------------------------------------------------------