├── .gitignore ├── CITATION.cff ├── LICENSE ├── Makefile ├── README.md ├── _typos.toml ├── changelog.md ├── conftest.py ├── dev_env.yml ├── docs ├── advanced_retriever.md ├── dense_retriever.md ├── faq.md ├── filters.md ├── hybrid_retriever.md ├── sparse_retriever.md ├── speed.md └── text_preprocessing.md ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── retriv ├── __init__.py ├── autotune │ ├── __init__.py │ ├── bm25_autotune.py │ └── merger_autotune.py ├── base_retriever.py ├── dense_retriever │ ├── __init__.py │ ├── ann_searcher.py │ ├── dense_retriever.py │ └── encoder.py ├── experimental │ ├── __init__.py │ └── advanced_retriever.py ├── hybrid_retriever.py ├── merger │ ├── __init__.py │ ├── merger.py │ └── normalization.py ├── paths.py ├── sparse_retriever │ ├── __init__.py │ ├── build_inverted_index.py │ ├── preprocessing │ │ ├── __init__.py │ │ ├── normalization.py │ │ ├── stemmer.py │ │ ├── stopwords.py │ │ ├── tokenizer.py │ │ └── utils.py │ ├── sparse_retrieval_models │ │ ├── __init__.py │ │ ├── bm25.py │ │ └── tf_idf.py │ └── sparse_retriever.py └── utils │ ├── __init__.py │ └── numba_utils.py ├── setup.py ├── test_env.yml └── tests ├── advanced_retriever └── advanced_retriever_test.py ├── dense_retriever └── encoder_test.py ├── merger ├── merger_test.py └── score_normalization_test.py ├── numba_utils_test.py └── sparse_retriever ├── preprocessing_test.py ├── search_engine_test.py ├── stemmer_test.py ├── stopwords_test.py ├── text_normalization_test.py └── tokenizer_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | /*.ipynb 2 | /beir 3 | /dictionaries 4 | /collections 5 | /datasets 6 | /bkp 7 | *.npy 8 | *.tar.gz 9 | *.csv 10 | *.tsv 11 | *.jsonl 12 | /compute_runs.py 13 | .vscode 14 | 15 | raw_collection.jsonl 16 | .DS_Store 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Bassani" 5 | given-names: "Elias" 6 | orcid: "https://orcid.org/0000-0001-7922-2578" 7 | title: "retriv: A Python Search Engine for the Common Man" 8 | version: 0.2.1 9 | doi: 10.5281/zenodo.7978820 10 | date-released: 2023-05-28 11 | url: "https://github.com/AmenRa/retriv" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Elias Bassani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # from https://github.com/alexander-beedie/polars/blob/a083a26ca092b3821bf4044aed74d91ee127bad1/py-polars/Makefile 2 | .DEFAULT_GOAL := help 3 | 4 | SHELL := /bin/bash 5 | .SHELLFLAGS := -eu -o pipefail -c 6 | PYTHONPATH= 7 | VENV = .venv 8 | MAKE = make 9 | 10 | VENV_BIN=$(VENV)/bin 11 | 12 | .venv: ## Set up virtual environment and install requirements 13 | python3 -m venv $(VENV) 14 | $(MAKE) requirements 15 | 16 | .PHONY: requirements 17 | requirements: .venv ## Install/refresh all project requirements 18 | $(VENV_BIN)/python -m pip install --upgrade pip 19 | $(VENV_BIN)/pip install -r requirements-dev.txt 20 | $(VENV_BIN)/pip install -r requirements.txt 21 | 22 | .PHONY: build 23 | build: .venv ## Compile and install retriv 24 | . $(VENV_BIN)/activate 25 | 26 | 27 | .PHONY: fix-lint 28 | fix-lint: .venv ## Fix linting 29 | . $(VENV_BIN)/activate 30 | $(VENV_BIN)/black . 31 | $(VENV_BIN)/isort . 32 | 33 | .PHONY: lint 34 | lint: .venv ## Check linting 35 | . $(VENV_BIN)/activate 36 | $(VENV_BIN)/isort --check . 37 | $(VENV_BIN)/black --check . 38 | $(VENV_BIN)/blackdoc . 39 | $(VENV_BIN)/ruff . 40 | $(VENV_BIN)/typos . 41 | # $(VENV_BIN)/mypy retriv 42 | 43 | .PHONY: test 44 | test: .venv build ## Run unittest 45 | . $(VENV_BIN)/activate 46 | $(VENV_BIN)/pytest tests 47 | 48 | .PHONY: coverage 49 | coverage: .venv build ## Run tests and report coverage 50 | . $(VENV_BIN)/activate 51 | $(VENV_BIN)/pytest tests --cov -n auto --dist worksteal -m "not benchmark" 52 | 53 | .PHONY: release 54 | release: .venv build ## Release a new version 55 | . $(VENV_BIN)/activate 56 | $(VENV_BIN)/python setup.py sdist bdist_wheel 57 | $(VENV_BIN)/twine upload --repository-url https://upload.pypi.org/legacy/ dist/* 58 | 59 | .PHONY: clean 60 | clean: ## Clean up caches and build artifacts 61 | @rm -rf .venv/ 62 | @rm -rf target/ 63 | @rm -rf .pytest_cache/ 64 | @rm -rf .coverage 65 | @rm -rf retriv.egg-info 66 | @rm -rf dist 67 | @rm -rf build 68 | 69 | .PHONY: help 70 | help: ## Display this help screen 71 | @echo -e "\033[1mAvailable commands:\033[0m\n" 72 | @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 |

6 | 7 | 8 | 9 | 10 | 11 | PyPI version 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | License: MIT 20 | 21 | 22 | 23 | 24 |

25 | 26 | ## 🔥 News 27 | - [August 23, 2023] `retriv` 0.2.2 is out! 28 | This release adds _experimental_ support for multi-field documents and filters. 29 | Please, refer to [Advanced Retriever](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md) documentation. 30 | 31 | - [February 18, 2023] `retriv` 0.2.0 is out! 32 | This release adds support for Dense and Hybrid Retrieval. 33 | Dense Retrieval leverages the semantic similarity of the queries' and documents' vector representations, which can be computed directly by `retriv` or imported from other sources. 34 | Hybrid Retrieval mix traditional retrieval, informally called Sparse Retrieval, and Dense Retrieval results to further improve retrieval effectiveness. 35 | As the library was almost completely redone, indices built with previous versions are no longer supported. 36 | 37 | ## ⚡️ Introduction 38 | 39 | [retriv](https://github.com/AmenRa/retriv) is a user-friendly and efficient [search engine](https://en.wikipedia.org/wiki/Search_engine) implemented in [Python](https://en.wikipedia.org/wiki/Python_(programming_language)) supporting Sparse (traditional search with [BM25](https://en.wikipedia.org/wiki/Okapi_BM25), [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf)), Dense ([semantic search](https://en.wikipedia.org/wiki/Semantic_search)) and Hybrid retrieval (a mix of Sparse and Dense Retrieval). 40 | It allows you to build a search engine in a __single line of code__. 41 | 42 | [retriv](https://github.com/AmenRa/retriv) is built upon [Numba](https://github.com/numba/numba) for high-speed [vector operations](https://en.wikipedia.org/wiki/Automatic_vectorization) and [automatic parallelization](https://en.wikipedia.org/wiki/Automatic_parallelization), [PyTorch](https://pytorch.org) and [Transformers](https://huggingface.co/docs/transformers/index) for easy access and usage of [Transformer-based Language Models](https://web.stanford.edu/~jurafsky/slp3/10.pdf), and [Faiss](https://github.com/facebookresearch/faiss) for approximate [nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search). 43 | In addition, it provides automatic tuning functionalities to allow you to tune its internal components with minimal intervention. 44 | 45 | 46 | ## ✨ Main Features 47 | 48 | ### Retrievers 49 | - [Sparse Retriever](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md): standard searcher based on lexical matching. 50 | [retriv](https://github.com/AmenRa/retriv) implements [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) as its main retrieval model. 51 | [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf) is also supported for educational purposes. 52 | The sparse retriever comes armed with multiple [stemmers](https://en.wikipedia.org/wiki/Stemming), [tokenizers](https://en.wikipedia.org/wiki/Lexical_analysis#Tokenization), and [stop-word](https://en.wikipedia.org/wiki/Stop_word) lists, for multiple languages. 53 | Click [here](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md) to learn more. 54 | - [Dense Retriever](https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md): a dense retriever is a retrieval model that performs [semantic search](https://en.wikipedia.org/wiki/Semantic_search). 55 | Click [here](https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md) to learn more. 56 | - [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md): an hybrid retriever is a retrieval model built on top of a sparse and a dense retriever. 57 | Click [here](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) to learn more. 58 | - [Advanced Retriever](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md): an advanced sparse retriever supporting filters. This is and experimental feature. 59 | Click [here](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md) to learn more. 60 | 61 | ### Unified Search Interface 62 | All the supported retrievers share the same search interface: 63 | - [search](#search): standard search functionality, what you expect by a search engine. 64 | - [msearch](#multi-search): computes the results for multiple queries at once. 65 | It leverages [automatic parallelization](https://en.wikipedia.org/wiki/Automatic_parallelization) whenever possible. 66 | - [bsearch](#batch-search): similar to [msearch](#multi-search) but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. [bsearch](#batch-search) is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM. Pre-computed results can be leveraged for negative sampling during the training of [Neural Models](https://en.wikipedia.org/wiki/Artificial_neural_network) for [Information Retrieval](https://en.wikipedia.org/wiki/Information_retrieval). 67 | 68 | ### AutoTune 69 | [retriv](https://github.com/AmenRa/retriv) automatically tunes [Faiss](https://github.com/facebookresearch/faiss) configuration for approximate nearest neighbors search by leveraging [AutoFaiss](https://github.com/criteo/autofaiss) to guarantee 10ms response time based on your available hardware. 70 | Moreover, it offers an automatic tuning functionality for [BM25](https://en.wikipedia.org/wiki/Okapi_BM25)'s parameters, which require minimal user intervention. 71 | Under the hood, [retriv](https://github.com/AmenRa/retriv) leverages [Optuna](https://optuna.org), a [hyperparameter optimization](https://en.wikipedia.org/wiki/Hyperparameter_optimization) framework, and [ranx](https://github.com/AmenRa/ranx), an [Information Retrieval](https://en.wikipedia.org/wiki/Information_retrieval) evaluation library, to test several parameter configurations for [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) and choose the best one. 72 | Finally, it can automatically balance the importance of lexical and semantic relevance scores computed by the [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) to maximize retrieval effectiveness. 73 | 74 | ## 📚 Documentation 75 | 76 | - [Sparse Retriever](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md) 77 | - [Dense Retriever](https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md) 78 | - [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) 79 | - [Text Pre-Processing](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) 80 | - [FAQ](https://github.com/AmenRa/retriv/blob/main/docs/faq.md) 81 | 82 | ## 🔌 Requirements 83 | ``` 84 | python>=3.8 85 | ``` 86 | 87 | ## 💾 Installation 88 | ```bash 89 | pip install retriv 90 | ``` 91 | 92 | ## 💡 Minimal Working Example 93 | 94 | ```python 95 | # Note: SearchEngine is an alias for the SparseRetriever 96 | from retriv import SearchEngine 97 | 98 | collection = [ 99 | {"id": "doc_1", "text": "Generals gathered in their masses"}, 100 | {"id": "doc_2", "text": "Just like witches at black masses"}, 101 | {"id": "doc_3", "text": "Evil minds that plot destruction"}, 102 | {"id": "doc_4", "text": "Sorcerer of death's construction"}, 103 | ] 104 | 105 | se = SearchEngine("new-index").index(collection) 106 | 107 | se.search("witches masses") 108 | ``` 109 | Output: 110 | ```json 111 | [ 112 | { 113 | "id": "doc_2", 114 | "text": "Just like witches at black masses", 115 | "score": 1.7536403 116 | }, 117 | { 118 | "id": "doc_1", 119 | "text": "Generals gathered in their masses", 120 | "score": 0.6931472 121 | } 122 | ] 123 | ``` 124 | 125 | 126 | 127 | 128 | 129 | 130 | ## 🎁 Feature Requests 131 | Would you like to see other features implemented? Please, open a [feature request](https://github.com/AmenRa/retriv/issues/new?assignees=&labels=enhancement&template=feature_request.md&title=%5BFeature+Request%5D+title). 132 | 133 | 134 | ## 🤘 Want to contribute? 135 | Would you like to contribute? Please, drop me an [e-mail](mailto:elias.bssn@gmail.com?subject=[GitHub]%20retriv). 136 | 137 | 138 | ## 📄 License 139 | [retriv](https://github.com/AmenRa/retriv) is an open-sourced software licensed under the [MIT license](LICENSE). 140 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | [default] 2 | extend-ignore-identifiers-re = [ 3 | ".*Hsi", 4 | ] 5 | -------------------------------------------------------------------------------- /changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [0.2.2] - 2023-08-23 8 | ### Added 9 | - Added `advanced_retriever.py`. 10 | 11 | ### Changed 12 | - Text preprocessing refactoring. 13 | 14 | ## [0.2.1] - 2023-05-16 15 | ### Added 16 | - Added doc strings to `sparse_retriver.py`. 17 | - Added doc strings to `dense_retriver.py`. 18 | - Added doc strings to `hybrid_retriver.py`. 19 | 20 | ## [0.2.0] - 2023-02-19 21 | ### Added 22 | - Added `sparse_retriver.py`. 23 | - Added `dense_retriver.py`. 24 | - Added `hybrid_retriver.py`. 25 | - Added `encoder.py`. 26 | - Added `ann_searcher.py`. 27 | - Added `merger.py`. 28 | 29 | ### Changed 30 | - Almost everything. 31 | 32 | ### Removed 33 | - Removed dependance from `cyhunspell`. 34 | 35 | ## [0.1.5] - 2023-01-26 36 | ### Changed 37 | - Search efficiency improvements. -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # NEEDED BY PYTETST 2 | -------------------------------------------------------------------------------- /dev_env.yml: -------------------------------------------------------------------------------- 1 | name: retriv 2 | dependencies: 3 | - python=3.8 4 | - black 5 | - conda-forge::icecream 6 | - isort 7 | - ipykernel 8 | - ipywidgets 9 | - conda-forge::mkdocs-material 10 | - conda-forge::mkdocs-autorefs 11 | - conda-forge::mkdocstrings 12 | - conda-forge::mkdocstrings-python 13 | - pygments>=2.12 14 | - notebook 15 | - pytest 16 | # Pip 17 | - pip 18 | - pip: 19 | - ranx 20 | - krovetzstemmer 21 | - twine 22 | - oneliner_utils 23 | - indxr 24 | - numpy 25 | - nltk 26 | - numba>=0.54.1 27 | - tqdm 28 | - optuna 29 | - pystemmer==2.0.1 30 | - unidecode 31 | - scikit-learn 32 | - torch 33 | - torchvision 34 | - torchaudio 35 | - transformers[torch] 36 | - faiss-cpu 37 | - autofaiss 38 | - multipipe 39 | -------------------------------------------------------------------------------- /docs/advanced_retriever.md: -------------------------------------------------------------------------------- 1 | # Advanced Retriever 2 | 3 | ⚠️ This is an experimental feature. 4 | 5 | The Advanced Retriever is a searcher based on lexical matching and search filters. 6 | It supports [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) and [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf) as the [Sparse Retriever](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md) and provides the same resources for multi-lingual [text pre-processing](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md). In addition, it supports search filters, i.e., a set of rules that can be used to filter out documents from the search results. 7 | 8 | In the following, we show how to build a search engine employing an advanced retriever, index a document collection, and search it. 9 | 10 | ## Schema 11 | 12 | The first step to building an Advanced Retriever is to define the `schema` of document collection. 13 | The `schema` is a dictionary describing the documents' `fields` and their `data types`. 14 | Based on the `data types`, search `filters` can be defined and applied to the search results. 15 | 16 | [retriv](https://github.com/AmenRa/retriv) supports the following data types: 17 | - __id:__ field used for the document IDs. 18 | - __text:__ text field used for lexical matching. 19 | - __number:__ numeric value. 20 | - __bool:__ boolean value (True or False). 21 | - __keyword:__ string or number representing a keyword or a category. 22 | - __keywords:__ list of keywords. 23 | 24 | An example of `schema` for a collection of books is shown below. 25 | NB: At the time of writing, [retriv](https://github.com/AmenRa/retriv) supports only one text field per schema. 26 | Therefore, the `content` field is used for both the title and the abstract of the books. 27 | 28 | ```json 29 | schema = { 30 | "isbn": "id", 31 | "content": "text", 32 | "year": "number", 33 | "is_english": "bool", 34 | "author": "keyword", 35 | "genres": "keywords", 36 | } 37 | ``` 38 | 39 | ## Build 40 | 41 | The Advanced Retriever provides several options to tailor its functioning to you preferences, as shown below. 42 | 43 | ```python 44 | from retriv.experimental import AdvancedRetriever 45 | 46 | ar = AdvancedRetriever( 47 | schema=schema, 48 | index_name="new-index", 49 | model="bm25", 50 | min_df=1, 51 | tokenizer="whitespace", 52 | stemmer="english", 53 | stopwords="english", 54 | do_lowercasing=True, 55 | do_ampersand_normalization=True, 56 | do_special_chars_normalization=True, 57 | do_acronyms_normalization=True, 58 | do_punctuation_removal=True, 59 | ) 60 | ``` 61 | 62 | - `schema`: the documents' schema. 63 | - `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. 64 | - `model`: defines the retrieval model to use for searching (`bm25` or `tf-idf`). 65 | - `min_df`: terms that appear in less than `min_df` documents will be ignored. 66 | If integer, the parameter indicates the absolute count. 67 | If float, it represents a proportion of documents. 68 | - `tokenizer`: [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization by setting the parameter to `None`. 69 | - `stemmer`: [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`. 70 | - `stopwords`: [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`. 71 | - `do_lowercasing`: whether to lowercase texts. 72 | - `do_ampersand_normalization`: whether to convert `&` in `and` during pre-processing. 73 | - `do_special_chars_normalization`: whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`. 74 | - `do_acronyms_normalization`: whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`. 75 | - `do_punctuation_removal`: whether to remove punctuation. 76 | 77 | __Note:__ text pre-processing is equally applied to documents during indexing and to queries at search time. 78 | 79 | ## Index 80 | 81 | ### Create 82 | You can index a document collection from JSONl, CSV, or TSV files. 83 | CSV and TSV files must have a header. 84 | [retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it. 85 | Use the `callback` parameter to pass a function for converting your documents in the format defined by your `schema` on the fly. 86 | Indexes are automatically persisted on disk at the end of the process. 87 | 88 | ```python 89 | ar = ar.index_file( 90 | path="path/to/collection", # File kind is automatically inferred 91 | show_progress=True, # Default value 92 | callback=lambda doc: { # Callback defaults to None. 93 | "id": doc["id"], 94 | "text": doc["title"] + ". " + doc["text"], 95 | ... 96 | ) 97 | ``` 98 | 99 | ### Load 100 | ```python 101 | ar = AdvancedRetriever.load("index-name") 102 | ``` 103 | 104 | ### Delete 105 | ```python 106 | AdvancedRetriever.delete("index-name") 107 | ``` 108 | 109 | ## Search 110 | 111 | ### Query & Filters 112 | 113 | Advanced Retriever search query can be either a string or a dictionary. 114 | In the former case, the string is used as the query text and no filters are applied. 115 | In the latter case, the dictionary defines the query text and the filters to apply to the search results. If the query text is omitted from the dictionary, documents matching the filters will be returned. 116 | 117 | [retriv](https://github.com/AmenRa/retriv) supports two way of filtering the search results (`where` and `where_not`) and several type-specific operators. 118 | 119 | - `where` means that only the documents matching the filter will be considered during search. 120 | - `where_not` means that the documents matching the filter will be ignored during search. 121 | 122 | Below we describe the effects of the supported operators for each data type and way of filtering. 123 | 124 | #### Where 125 | 126 | | Field Type | Operator | Value | Meaning | 127 | | ---------- | --------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | 128 | | number | `eq` | number | Only the documents whose field value is **equal to** the provided value will be considered during search. | 129 | | number | `gt` | number | Only the documents whose field value is **greater than** the provided value will be considered during search. | 130 | | number | `gte` | number | Only the documents whose field value is **greater or equal to** the provided value will be considered during search. | 131 | | number | `lt` | number | Only the documents whose field value is **less than** the provided value will be considered during search. | 132 | | number | `lte` | number | Only the documents whose field value is **less or equal to** the provided value will be considered during search. | 133 | | number | `between` | number | Only the documents whose field value is **between** the provided values (included) will be considered during search. | 134 | | bool | | True / False | Only the documents whose field value is **equal to** the provided value will be considered during search. | 135 | | keyword | | any value / list of values | Only the documents whose field value is **equal to** the provided value or **among** the provided values will be considered during search. | 136 | | keywords | `or` | any value / list of values | Only the documents whose field value is **contains** the provided value or **contains one of** the provided values will be considered during search. | 137 | | keywords | `and` | any value / list of values | Only the documents whose field value **contains all** the provided values will be considered during search. | 138 | 139 | Query example: 140 | ```python 141 | query = { 142 | "text": "search terms", 143 | "where": { 144 | "numeric_field_name": ("gte", 1970), 145 | "boolean_field_name": True, 146 | "keyword_field_name": "kw_1", 147 | "keywords_field_name": ("or", ["kws_23", "kws_666"]), 148 | } 149 | } 150 | ``` 151 | 152 | Alternatively, you can omit the `where` key and use the following syntax: 153 | ```python 154 | query = { 155 | "text": "search terms", 156 | "numeric_field_name": ("gte", 1970), 157 | "boolean_field_name": True, 158 | "keyword_field_name": "kw_1", 159 | "keywords_field_name": ("or", ["kws_23", "kws_666"]), 160 | } 161 | ``` 162 | 163 | 164 | #### Where not 165 | 166 | | Field Type | Operator | Value | Meaning | 167 | | ---------- | --------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------ | 168 | | number | `eq` | number | The documents whose field value is **equal to** the provided value will be ignored. | 169 | | number | `gt` | number | The documents whose field value is **greater than** the provided value will be ignored. | 170 | | number | `gte` | number | The documents whose field value is **greater or equal to** the provided value will be ignored. | 171 | | number | `lt` | number | The documents whose field value is **less than** the provided value will be ignored. | 172 | | number | `lte` | number | The documents whose field value is **less or equal to** the provided value will be ignored. | 173 | | number | `between` | number | The documents whose field value is **between** the provided values (included) will be ignored. | 174 | | bool | | True / False | The documents whose field value is **equal to** the provided value will be ignored. | 175 | | keyword | | any value / list of values | The documents whose field value is **equal to** the provided value or **among** the provided values will be ignored. | 176 | | keywords | `or` | any value / list of values | The documents whose field value is **contains** the provided value or **contains one of** the provided values will be ignored. | 177 | | keywords | `and` | any value / list of values | The documents whose field value **contains all** the provided values will be ignored. | 178 | 179 | Query example: 180 | ```python 181 | query = { 182 | "text": "search terms", 183 | "where": { 184 | "numeric_field_name": ("gte", 1970), 185 | "boolean_field_name": True, 186 | "keyword_field_name": "kw_1", 187 | "keywords_field_name": ("or", ["kws_23", "kws_666"]), 188 | } 189 | } 190 | ``` 191 | 192 | ### Search 193 | 194 | ```python 195 | ar.search( 196 | query: ... 197 | return_docs=True # Default value. 198 | cutoff=100 # Default value. 199 | operator="OR" # Default value. 200 | subset_doc_ids=None # Default value. 201 | ) 202 | ``` 203 | 204 | - `query`: what to search for and which filters to apply. See the section [Query & Filters](#query--filters) for more details. 205 | - `return_docs`: whether to return documents or only their IDs. 206 | - `cutoff`: number of results to return. 207 | - `operator`: whether to perform conjunctive (`AND`) or disjunctive (`OR`) search. Conjunctive search retrieves documents that contain **all** the query terms. Disjunctive search retrieves documents that contain **at least one** of the query terms. 208 | - `subset_doc_ids`: restrict the search to the subset of documents having the provided IDs. 209 | 210 | Sample output: 211 | ```json 212 | [ 213 | { 214 | "id": "doc_2", 215 | "text": "Just like witches at black masses", 216 | "score": 1.7536403 217 | }, 218 | { 219 | "id": "doc_1", 220 | "text": "Generals gathered in their masses", 221 | "score": 0.6931472 222 | } 223 | ] 224 | ``` 225 | 226 | 227 | 228 | 229 | 265 | 266 | 282 | -------------------------------------------------------------------------------- /docs/dense_retriever.md: -------------------------------------------------------------------------------- 1 | # Dense Retriever 2 | 3 | The Dense Retriever performs [semantic search](https://en.wikipedia.org/wiki/Semantic_search), i.e., it compares vector representations of queries and documents to compute the relevance scores of the latter. 4 | The Dense Retriever comprises two components: the `Encoder` and the `ANN Searcher`, described below. 5 | 6 | In the following, we show how to build a search engine for [semantic search](https://en.wikipedia.org/wiki/Semantic_search) based on dense retrieval, index a document collection, and search it. 7 | 8 | ## Build 9 | 10 | Building a Dense Retriever is as simple as shown below. 11 | Default parameter values are shown. 12 | 13 | ```python 14 | from retriv import DenseRetriever 15 | 16 | dr = DenseRetriever( 17 | index_name="new-index", 18 | model="sentence-transformers/all-MiniLM-L6-v2", 19 | normalize=True, 20 | max_length=128, 21 | use_ann=True, 22 | ) 23 | ``` 24 | 25 | - `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. 26 | - `model`: defines the encoder model to encode queries and documents into vectors. You can use an [HuggingFace's Transformers](https://huggingface.co/models) pre-trained model by providing its ID or load a local model by providing its path. 27 | In the case of local models, the path must point to the directory containing the data saved with the [`PreTrainedModel.save_pretrained`](https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/model#transformers.PreTrainedModel.save_pretrained) method. 28 | Note that the representations are computed with `mean pooling` over the `last_hidden_state`. 29 | - `normalize`: whether to L2 normalize the vector representations. 30 | - `max_length`: texts longer than `max_length` will be automatically truncated. Choose this parameter based on how the employed model was trained or is generally used. 31 | - `use_ann`: whether to use approximate nearest neighbors search. Set it to `False` to use nearest neighbors search without approximation. If you have less than 20k documents in your collection, you probably want to disable approximation. 32 | 33 | ## Index 34 | 35 | ### Create 36 | You can index a document collection from JSONl, CSV, or TSV files. 37 | CSV and TSV files must have a header. 38 | [retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it. 39 | Use the `callback` parameter to pass a function for converting your documents in the format supported by [retriv](https://github.com/AmenRa/retriv) on the fly. 40 | Documents must have a single `text` field and an `id`. 41 | Indexes are automatically persisted on disk at the end of the process. 42 | To speed up the indexing process, you can activate GPU-based encoding. 43 | 44 | ```python 45 | dr = dr.index_file( 46 | path="path/to/collection", # File kind is automatically inferred 47 | embeddings_path=None, # Default value 48 | use_gpu=False, # Default value 49 | batch_size=512, # Default value 50 | show_progress=True, # Default value 51 | callback=lambda doc: { # Callback defaults to None. 52 | "id": doc["id"], 53 | "text": doc["title"] + ". " + doc["text"], 54 | ), 55 | ) 56 | ``` 57 | 58 | - `embeddings_path`: in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. 59 | - `use_gpu`: whether to use the GPU for document encoding. 60 | - `batch_size`: how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. 61 | 62 | 63 | ### Load 64 | ```python 65 | dr = DenseRetriever.load("index-name") 66 | ``` 67 | 68 | ### Delete 69 | ```python 70 | DenseRetriever.delete("index-name") 71 | ``` 72 | 73 | ## Search 74 | 75 | ### Search 76 | 77 | Standard search functionality. 78 | 79 | ```python 80 | dr.search( 81 | query="witches masses", # What to search for 82 | return_docs=True, # Default value, return the text of the documents 83 | cutoff=100, # Default value, number of results to return 84 | ) 85 | ``` 86 | Output: 87 | ```json 88 | [ 89 | { 90 | "id": "doc_2", 91 | "text": "Just like witches at black masses", 92 | "score": 0.9536403 93 | }, 94 | { 95 | "id": "doc_1", 96 | "text": "Generals gathered in their masses", 97 | "score": 0.6931472 98 | } 99 | ] 100 | ``` 101 | 102 | ### Multi-Search 103 | 104 | Compute results for multiple queries at once. 105 | 106 | ```python 107 | dr.msearch( 108 | queries=[{"id": "q_1", "text": "witches masses"}, ...], 109 | cutoff=100, # Default value, number of results 110 | batch_size=32, # Default value. 111 | ) 112 | ``` 113 | Output: 114 | ```json 115 | { 116 | "q_1": { 117 | "doc_2": 1.7536403, 118 | "doc_1": 0.6931472 119 | }, 120 | ... 121 | } 122 | ``` 123 | 124 | - `batch_size`: how many searches to perform at once. Regulate it if you ran into memory usage issues or want to maximize throughput. 125 | 126 | ### Batch-Search 127 | 128 | Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. 129 | [bsearch](#batch-search) is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM. 130 | 131 | ```python 132 | dr.bsearch( 133 | queries=[{"id": "q_1", "text": "witches masses"}, ...], 134 | cutoff=100, 135 | batch_size=32, 136 | show_progress=True, 137 | qrels=None, # To add relevance information to the saved files 138 | path=None, # Where to save the results, if you want to save them 139 | ) 140 | ``` 141 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | ## How do I change `retriv` working directory? 4 | ```python 5 | import retriv 6 | 7 | retriv.set_base_path("new/working/path") 8 | ``` 9 | 10 | or 11 | 12 | ```python 13 | import os 14 | 15 | os.environ["RETRIV_BASE_PATH"] = "new/working/path" 16 | ``` 17 | -------------------------------------------------------------------------------- /docs/filters.md: -------------------------------------------------------------------------------- 1 | ## Filtering Search Results 2 | 3 | [retriv](https://github.com/AmenRa/retriv) supports two way of filtering the search results (`where` and `where_not`) and several type-specific operators. 4 | 5 | - `where` means that only the documents matching the filter will be considered during search. 6 | - `where_not` means that the documents matching the filter will be ignored during search. 7 | 8 | Below we describe the effects of the supported operators for each data type and way of filtering. 9 | 10 | ### Where 11 | 12 | | Field Type | Operator | Value | Meaning | 13 | | ---------- | --------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | 14 | | number | `eq` | number | Only the documents whose field value is **equal to** the provided value will be considered during search. | 15 | | number | `gt` | number | Only the documents whose field value is **greater than** the provided value will be considered during search. | 16 | | number | `gte` | number | Only the documents whose field value is **greater or equal to** the provided value will be considered during search. | 17 | | number | `lt` | number | Only the documents whose field value is **less than** the provided value will be considered during search. | 18 | | number | `lte` | number | Only the documents whose field value is **less or equal to** the provided value will be considered during search. | 19 | | number | `between` | number | Only the documents whose field value is **between** the provided values (included) will be considered during search. | 20 | | bool | | True / False | Only the documents whose field value is **equal to** the provided value will be considered during search. | 21 | | keyword | | any value / list of values | Only the documents whose field value is **equal to** the provided value or **among** the provided values will be considered during search. | 22 | | keywords | `or` | any value / list of values | Only the documents whose field value is **contains** the provided value or **contains one of** the provided values will be considered during search. | 23 | | keywords | `and` | any value / list of values | Only the documents whose field value **contains all** the provided values will be considered during search. | 24 | 25 | Query example: 26 | ```python 27 | query = { 28 | "text": "search terms", 29 | "where": { 30 | "numeric_field_name": ("gte", 1970), 31 | "boolean_field_name": True, 32 | "keyword_field_name": "kw_1", 33 | "keywords_field_name": ("or", ["kws_23", "kws_666"]), 34 | } 35 | } 36 | ``` 37 | 38 | Alternatively, you can omit the `where` key and use the following syntax: 39 | ```python 40 | query = { 41 | "text": "search terms", 42 | "numeric_field_name": ("gte", 1970), 43 | "boolean_field_name": True, 44 | "keyword_field_name": "kw_1", 45 | "keywords_field_name": ("or", ["kws_23", "kws_666"]), 46 | } 47 | ``` 48 | 49 | 50 | ### Where not 51 | 52 | | Field Type | Operator | Value | Meaning | 53 | | ---------- | --------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------ | 54 | | number | `eq` | number | The documents whose field value is **equal to** the provided value will be ignored. | 55 | | number | `gt` | number | The documents whose field value is **greater than** the provided value will be ignored. | 56 | | number | `gte` | number | The documents whose field value is **greater or equal to** the provided value will be ignored. | 57 | | number | `lt` | number | The documents whose field value is **less than** the provided value will be ignored. | 58 | | number | `lte` | number | The documents whose field value is **less or equal to** the provided value will be ignored. | 59 | | number | `between` | number | The documents whose field value is **between** the provided values (included) will be ignored. | 60 | | bool | | True / False | The documents whose field value is **equal to** the provided value will be ignored. | 61 | | keyword | | any value / list of values | The documents whose field value is **equal to** the provided value or **among** the provided values will be ignored. | 62 | | keywords | `or` | any value / list of values | The documents whose field value is **contains** the provided value or **contains one of** the provided values will be ignored. | 63 | | keywords | `and` | any value / list of values | The documents whose field value **contains all** the provided values will be ignored. | 64 | 65 | Query example: 66 | ```python 67 | query = { 68 | "text": "search terms", 69 | "where": { 70 | "numeric_field_name": ("gte", 1970), 71 | "boolean_field_name": True, 72 | "keyword_field_name": "kw_1", 73 | "keywords_field_name": ("or", ["kws_23", "kws_666"]), 74 | } 75 | } 76 | ``` -------------------------------------------------------------------------------- /docs/hybrid_retriever.md: -------------------------------------------------------------------------------- 1 | # Sparse Retriever 2 | 3 | The [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) is searcher based on both lexical and semantic matching. 4 | It comprises three components: the [Sparse Retriever]((https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md)), the [Dense Retriever]((https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md)), and the Merger. 5 | The Merger fuses the results of the Sparse and Dense Retrievers to compute the _hybrid_ results. 6 | 7 | In the following, we show how to build an hybrid search engine, index a document collection, and search it. 8 | 9 | ## Build 10 | 11 | You can instantiate and Hybrid Retriever as follows. 12 | Default parameter values are shown. 13 | 14 | ```python 15 | from retriv import HybridRetriever 16 | 17 | hr = HybridRetriever( 18 | # Shared params ------------------------------------------------------------ 19 | index_name="new-index", 20 | # Sparse retriever params -------------------------------------------------- 21 | sr_model="bm25", 22 | min_df=1, 23 | tokenizer="whitespace", 24 | stemmer="english", 25 | stopwords="english", 26 | do_lowercasing=True, 27 | do_ampersand_normalization=True, 28 | do_special_chars_normalization=True, 29 | do_acronyms_normalization=True, 30 | do_punctuation_removal=True, 31 | # Dense retriever params --------------------------------------------------- 32 | dr_model="sentence-transformers/all-MiniLM-L6-v2", 33 | normalize=True, 34 | max_length=128, 35 | use_ann=True, 36 | ) 37 | ``` 38 | 39 | - Shared params: 40 | - `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. 41 | - Sparse Retriever params: 42 | - `sr_model`: defines the model to use for sparse retrieval (`bm25` or `tf-idf`). 43 | - `min_df`: terms that appear in less than `min_df` documents will be ignored. 44 | If integer, the parameter indicates the absolute count. 45 | If float, it represents a proportion of documents. 46 | - `tokenizer`: [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization setting the parameter to `None`. 47 | - `stemmer`: [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`. 48 | - `stopwords`: [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`. 49 | - `do_lowercasing`: whether to lowercase texts. 50 | - `do_ampersand_normalization`: whether to convert `&` in `and` during pre-processing. 51 | - `do_special_chars_normalization`: whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`. 52 | - `do_acronyms_normalization`: whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`. 53 | - `do_punctuation_removal`: whether to remove punctuation. 54 | - Dense Retriever params: 55 | - `dr_model`: defines the model to use for encoding queries and documents into vectors. You can use an [HuggingFace's Transformers](https://huggingface.co/models) pre-trained model by providing its ID or load a local model by providing its path. 56 | In the case of local models, the path must point to the directory containing the data saved with the [`PreTrainedModel.save_pretrained`](https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/model#transformers.PreTrainedModel.save_pretrained) method. 57 | Note that the representations are computed with `mean pooling` over the `last_hidden_state`. 58 | - `normalize`: whether to L2 normalize the vector representations. 59 | - `max_length`: texts longer than `max_length` will be automatically truncated. Choose this parameter based on how the employed model was trained or is generally used. 60 | - `use_ann`: whether to use approximate nearest neighbors search. Set it to `False` to use nearest neighbors search without approximation. If you have less than 20k documents in your collection, you probably want to disable approximation. 61 | 62 | __Note:__ text pre-processing is equally applied to documents during indexing and to queries at search time. 63 | 64 | ## Index 65 | 66 | ### Create 67 | You can index a document collection from JSONl, CSV, or TSV files. 68 | CSV and TSV files must have a header. 69 | [retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it. 70 | Use the `callback` parameter to pass a function for converting your documents in the format supported by [retriv](https://github.com/AmenRa/retriv) on the fly. 71 | Documents must have a single `text` field and an `id`. 72 | The Hybrid Retriever sequentially build the indices for the Sparse and Dense Retrievers. 73 | Indexes are automatically persisted on disk at the end of the process. 74 | To speed up the indexing process of the Dense Retriever, you can activate GPU-based encoding. 75 | 76 | ```python 77 | hr = hr.index_file( 78 | path="path/to/collection", # File kind is automatically inferred 79 | embeddings_path=None, # Default value 80 | use_gpu=False, # Default value 81 | batch_size=512, # Default value 82 | show_progress=True, # Default value 83 | callback=lambda doc: { # Callback defaults to None. 84 | "id": doc["id"], 85 | "text": doc["title"] + ". " + doc["text"], 86 | ), 87 | ) 88 | ``` 89 | 90 | ### Load 91 | ```python 92 | hr = HybridRetriever.load("index-name") 93 | ``` 94 | 95 | ### Delete 96 | ```python 97 | HybridRetriever.delete("index-name") 98 | ``` 99 | 100 | ## Search 101 | 102 | During search, the Hybrid Retriever fuses the top 1000 results of the Sparse and Dense Retrievers. 103 | 104 | ### Search 105 | 106 | Standard search functionality. 107 | 108 | ```python 109 | hr.search( 110 | query="witches masses", # What to search for 111 | return_docs=True, # Default value, return the text of the documents 112 | cutoff=100, # Default value, number of results to return 113 | ) 114 | ``` 115 | Output: 116 | ```json 117 | [ 118 | { 119 | "id": "doc_2", 120 | "text": "Just like witches at black masses", 121 | "score": 0.9536403 122 | }, 123 | { 124 | "id": "doc_1", 125 | "text": "Generals gathered in their masses", 126 | "score": 0.6931472 127 | } 128 | ] 129 | ``` 130 | 131 | ### Multi-Search 132 | 133 | Compute results for multiple queries at once. 134 | 135 | ```python 136 | hr.msearch( 137 | queries=[{"id": "q_1", "text": "witches masses"}, ...], 138 | cutoff=100, # Default value, number of results 139 | batch_size=32, # Default value. 140 | ) 141 | ``` 142 | Output: 143 | ```json 144 | { 145 | "q_1": { 146 | "doc_2": 1.7536403, 147 | "doc_1": 0.6931472 148 | }, 149 | ... 150 | } 151 | ``` 152 | 153 | - `batch_size`: how many searches to perform at once. Regulate it if you ran into memory usage issues or want to maximize throughput. 154 | 155 | ### Batch-Search 156 | 157 | Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. 158 | [bsearch](#batch-search) is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM. 159 | 160 | ```python 161 | hr.bsearch( 162 | queries=[{"id": "q_1", "text": "witches masses"}, ...], 163 | cutoff=100, 164 | batch_size=32, 165 | show_progress=True, 166 | qrels=None, # To add relevance information to the saved files 167 | path=None, # Where to save the results, if you want to save them 168 | ) 169 | ``` 170 | 171 | ## AutoTune 172 | 173 | Use the AutoTune function to tune the Sparse Retriever's model [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) parameters and the importance given to the lexical and semantic relevance scores computed by the Sparse and Dense Retrievers, respectively. 174 | All metrics supported by [ranx](https://github.com/AmenRa/ranx) are supported by the `autotune` function. 175 | At the of the process, the best parameter configurations are automatically applied and saved to disk. 176 | You can inspect the best configurations found by printing `hr.sparse_retriever.hyperparams`, `hr.merger.norm` and `hr.merger.params`. 177 | 178 | ```python 179 | sr.autotune( 180 | queries=[{ "q_id": "q_1", "text": "...", ... }], # Train queries 181 | qrels={ "q_1": { "doc_1": 1, ... }, ... }, # Train qrels 182 | metric="ndcg", # Default value, metric to maximize 183 | n_trials=100, # Default value, number of trials 184 | cutoff=100, # Default value, number of results 185 | batch_size=32, # Default value 186 | ) 187 | ``` 188 | -------------------------------------------------------------------------------- /docs/sparse_retriever.md: -------------------------------------------------------------------------------- 1 | # Sparse Retriever 2 | 3 | The Sparse Retriever is a traditional searcher based on lexical matching. 4 | It supports [BM25](https://en.wikipedia.org/wiki/Okapi_BM25), the retrieval model used by major search engines libraries, such as [Lucene](https://en.wikipedia.org/wiki/Apache_Lucene) and [Elasticsearch](https://en.wikipedia.org/wiki/Elasticsearch). 5 | [retriv](https://github.com/AmenRa/retriv) also implements the classic relevance model [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf) for educational purposes. 6 | 7 | The Sparse Retriever also provides several resources for multi-lingual [text pre-processing](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md), aiming to maximize its retrieval effectiveness. 8 | 9 | In the following, we show how to build a search engine employing a sparse retriever, index a document collection, and search it. 10 | 11 | ## Build 12 | 13 | The Sparse Retriever provides several options to tailor its functioning to you preferences, as shown below. 14 | Default parameter values are shown. 15 | 16 | ```python 17 | # Note: the SparseRetriever has an alias called SearchEngine, if you prefer 18 | from retriv import SparseRetriever 19 | 20 | sr = SparseRetriever( 21 | index_name="new-index", 22 | model="bm25", 23 | min_df=1, 24 | tokenizer="whitespace", 25 | stemmer="english", 26 | stopwords="english", 27 | do_lowercasing=True, 28 | do_ampersand_normalization=True, 29 | do_special_chars_normalization=True, 30 | do_acronyms_normalization=True, 31 | do_punctuation_removal=True, 32 | ) 33 | ``` 34 | 35 | - `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. 36 | - `model`: defines the retrieval model to use for searching (`bm25` or `tf-idf`). 37 | - `min_df`: terms that appear in less than `min_df` documents will be ignored. 38 | If integer, the parameter indicates the absolute count. 39 | If float, it represents a proportion of documents. 40 | - `tokenizer`: [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization by setting the parameter to `None`. 41 | - `stemmer`: [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`. 42 | - `stopwords`: [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`. 43 | - `do_lowercasing`: whether to lowercase texts. 44 | - `do_ampersand_normalization`: whether to convert `&` in `and` during pre-processing. 45 | - `do_special_chars_normalization`: whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`. 46 | - `do_acronyms_normalization`: whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`. 47 | - `do_punctuation_removal`: whether to remove punctuation. 48 | 49 | __Note:__ text pre-processing is equally applied to documents during indexing and to queries at search time. 50 | 51 | ## Index 52 | 53 | ### Create 54 | You can index a document collection from JSONl, CSV, or TSV files. 55 | CSV and TSV files must have a header. 56 | [retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it. 57 | Use the `callback` parameter to pass a function for converting your documents in the format supported by [retriv](https://github.com/AmenRa/retriv) on the fly. 58 | Documents must have a single `text` field and an `id`. 59 | Indexes are automatically persisted on disk at the end of the process. 60 | Indexing functionalities are built to have minimal memory footprint while leveraging multi-processing for efficiency. 61 | Indexing 10M documents takes from 5 to 10 minutes on a [AMD Ryzen™ 9 5950X](https://www.amd.com/en/products/cpu/amd-ryzen-9-5950x), depending on the length of the documents. 62 | 63 | ```python 64 | sr = sr.index_file( 65 | path="path/to/collection", # File kind is automatically inferred 66 | show_progress=True, # Default value 67 | callback=lambda doc: { # Callback defaults to None. 68 | "id": doc["id"], 69 | "text": doc["title"] + ". " + doc["text"], 70 | ) 71 | ``` 72 | 73 | ### Load 74 | ```python 75 | sr = SparseRetriever.load("index-name") 76 | ``` 77 | 78 | ### Delete 79 | ```python 80 | SparseRetriever.delete("index-name") 81 | ``` 82 | 83 | ## Search 84 | 85 | ### Search 86 | 87 | Standard search functionality. 88 | 89 | ```python 90 | sr.search( 91 | query="witches masses", # What to search for 92 | return_docs=True, # Default value, return the text of the documents 93 | cutoff=100, # Default value, number of results to return 94 | ) 95 | ``` 96 | Output: 97 | ```json 98 | [ 99 | { 100 | "id": "doc_2", 101 | "text": "Just like witches at black masses", 102 | "score": 1.7536403 103 | }, 104 | { 105 | "id": "doc_1", 106 | "text": "Generals gathered in their masses", 107 | "score": 0.6931472 108 | } 109 | ] 110 | ``` 111 | 112 | ### Multi-Search 113 | 114 | Compute results for multiple queries at once. 115 | 116 | ```python 117 | sr.msearch( 118 | queries=[{"id": "q_1", "text": "witches masses"}, ...], 119 | cutoff=100, # Default value, number of results 120 | ) 121 | ``` 122 | Output: 123 | ```json 124 | { 125 | "q_1": { 126 | "doc_2": 1.7536403, 127 | "doc_1": 0.6931472 128 | }, 129 | ... 130 | } 131 | ``` 132 | 133 | ### Batch-Search 134 | 135 | Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. 136 | [bsearch](#batch-search) is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM. 137 | 138 | ```python 139 | sr.bsearch( 140 | queries=[{"id": "q_1", "text": "witches masses"}, ...], 141 | cutoff=100, 142 | batch_size=1_000, 143 | show_progress=True, 144 | qrels=None, # To add relevance information to the saved files 145 | path=None, # Where to save the results, if you want to save them 146 | ) 147 | ``` 148 | 149 | ## AutoTune 150 | 151 | Use the AutoTune function to tune [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) parameters w.r.t. your document collection and queries. 152 | All metrics supported by [ranx](https://github.com/AmenRa/ranx) are supported by the `autotune` function. 153 | At the of the process, the best parameter configuration is automatically applied to the `SparseRetriever` instance and saved to disk. 154 | You can inspect the current configuration by printing `sr.hyperparams`. 155 | 156 | ```python 157 | sr.autotune( 158 | queries=[{ "q_id": "q_1", "text": "...", ... }], # Train queries 159 | qrels={ "q_1": { "doc_1": 1, ... }, ... }, # Train qrels 160 | metric="ndcg", # Default value, metric to maximize 161 | n_trials=100, # Default value, number of trials 162 | cutoff=100, # Default value, number of results 163 | ) 164 | ``` 165 | -------------------------------------------------------------------------------- /docs/speed.md: -------------------------------------------------------------------------------- 1 | ## Speed Comparison 2 | 3 | TO BE UPDATED 4 | 5 | We performed a speed test, comparing [retriv](https://github.com/AmenRa/retriv) to [rank_bm25](https://github.com/dorianbrown/rank_bm25), a popular [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) implementation in [Python](https://en.wikipedia.org/wiki/Python_(programming_language)), and [pyserini](https://github.com/castorini/pyserini), a [Python](https://en.wikipedia.org/wiki/Python_(programming_language)) binding to the [Lucene](https://en.wikipedia.org/wiki/Apache_Lucene) search engine. 6 | 7 | We relied on the [MSMARCO Passage](https://microsoft.github.io/msmarco) dataset to collect documents and queries. 8 | Specifically, we used the original document collection and three sub-samples of it, accounting for 1k, 100k, and 1M documents, respectively, and sampled 1k queries from the original ones. 9 | We computed the top-100 results with each library (if possible). 10 | Results are reported below. Best results are highlighted in boldface. 11 | 12 | | Library | Collection Size | Elapsed Time | Avg. Query Time | Throughput (q/s) | 13 | | --------------------------------------------------------- | --------------: | -----------: | --------------: | ---------------: | 14 | | [rank_bm25](https://github.com/dorianbrown/rank_bm25) | 1,000 | 646ms | 6.5ms | 1548/s | 15 | | [pyserini](https://github.com/castorini/pyserini) | 1,000 | 1,438ms | 1.4ms | 695/s | 16 | | [retriv](https://github.com/AmenRa/retriv) | 1,000 | 140ms | 0.1ms | 7143/s | 17 | | [retriv](https://github.com/AmenRa/retriv) (multi-search) | 1,000 | __134ms__ | __0.1ms__ | __7463/s__ | 18 | | [rank_bm25](https://github.com/dorianbrown/rank_bm25) | 100,000 | 106,000ms | 1060ms | 1/s | 19 | | [pyserini](https://github.com/castorini/pyserini) | 100,000 | 2,532ms | 2.5ms | 395/s | 20 | | [retriv](https://github.com/AmenRa/retriv) | 100,000 | 314ms | 0.3ms | 3185/s | 21 | | [retriv](https://github.com/AmenRa/retriv) (multi-search) | 100,000 | __256ms__ | __0.3ms__ | __3906__/s | 22 | | [rank_bm25](https://github.com/dorianbrown/rank_bm25) | 1,000,000 | N/A | N/A | N/A | 23 | | [pyserini](https://github.com/castorini/pyserini) | 1,000,000 | 4,060ms | 4.1ms | 246/s | 24 | | [retriv](https://github.com/AmenRa/retriv) | 1,000,000 | 1,018ms | 1.0ms | 982/s | 25 | | [retriv](https://github.com/AmenRa/retriv) (multi-search) | 1,000,000 | __503ms__ | __0.5ms__ | __1988/s__ | 26 | | [rank_bm25](https://github.com/dorianbrown/rank_bm25) | 8,841,823 | N/A | N/A | N/A | 27 | | [pyserini](https://github.com/castorini/pyserini) | 8,841,823 | 12,245ms | 12.2ms | 82/s | 28 | | [retriv](https://github.com/AmenRa/retriv) | 8,841,823 | 10,763ms | 10.8ms | 93/s | 29 | | [retriv](https://github.com/AmenRa/retriv) (multi-search) | 8,841,823 | __4,476ms__ | __4.4ms__ | __227/s__ | -------------------------------------------------------------------------------- /docs/text_preprocessing.md: -------------------------------------------------------------------------------- 1 | # Text Pre-Processing 2 | 3 | [retriv](https://github.com/AmenRa/retriv) provides several resources for multi-lingual text pre-processing, aiming to maximize its retrieval effectiveness. 4 | 5 | ## Stemmers 6 | [Stemmers](https://en.wikipedia.org/wiki/Stemming) reduce words to their word stem, base or root form. 7 | [retriv](https://github.com/AmenRa/retriv) supports the following stemmers: 8 | - [snowball](https://snowballstem.org) (default) 9 | The following languages are supported by Snowball Stemmer: 10 | Arabic, Basque, Catalan, Danish, Dutch, English, Finnish, French, German, Greek, Hindi, Hungarian, Indonesian, Irish, Italian, Lithuanian, Nepali, Norwegian, Portuguese, Romanian, Russian, Spanish, Swedish, Tamil, Turkish. 11 | To select your preferred language simply use `` . 12 | - [arlstem](https://www.nltk.org/api/nltk.stem.arlstem.html) (Arabic) 13 | - [arlstem2](https://www.nltk.org/api/nltk.stem.arlstem2.html) (Arabic) 14 | - [cistem](https://www.nltk.org/api/nltk.stem.cistem.html) (German) 15 | - [isri](https://www.nltk.org/api/nltk.stem.isri.html) (Arabic) 16 | - [krovetz](https://dl.acm.org/doi/10.1145/160688.160718) (English) 17 | - [lancaster](https://www.nltk.org/api/nltk.stem.lancaster.html) (English) 18 | - [porter](https://www.nltk.org/api/nltk.stem.porter.html) (English) 19 | 20 | 21 | ## Tokenizers 22 | [Tokenizers](https://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) divide a string into smaller units, such as words. 23 | [retriv](https://github.com/AmenRa/retriv) supports the following tokenizers: 24 | - [whitespace](https://www.nltk.org/api/nltk.tokenize.html) 25 | - [word](https://www.nltk.org/api/nltk.tokenize.html) 26 | - [wordpunct](https://www.nltk.org/api/nltk.tokenize.html) 27 | - [sent](https://www.nltk.org/api/nltk.tokenize.html) 28 | 29 | 30 | ## Stop-word Lists 31 | [retriv](https://github.com/AmenRa/retriv) supports [stop-word](https://en.wikipedia.org/wiki/Stop_word) lists for the following languages: Arabic, Azerbaijani, Basque, Bengali, Catalan, Chinese, Danish, Dutch, English, Finnish, French, German, Greek, Hebrew, Hinglish, Hungarian, Indonesian, Italian, Kazakh, Nepali, Norwegian, Portuguese, Romanian, Russian, Slovene, Spanish, Swedish, Tajik, and Turkish. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # E501 ignore long comment lines (black should take care of that) 3 | # E712 ignore == False, needed for numba checks. 4 | ignore = ["E501", "E712"] 5 | 6 | [tool.isort] 7 | profile = "black" -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | black[jupyter] 3 | blackdoc 4 | isort 5 | pytest 6 | pytest-cov 7 | pytest-xdist 8 | ruff 9 | twine 10 | typos 11 | wheel -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autofaiss 2 | faiss-cpu 3 | indxr 4 | krovetzstemmer 5 | multipipe 6 | nltk 7 | numba>=0.54.1 8 | numpy 9 | oneliner_utils 10 | optuna 11 | pystemmer==2.0.1 12 | ranx 13 | scikit-learn 14 | torch 15 | torchaudio 16 | torchvision 17 | tqdm 18 | transformers[torch] 19 | unidecode -------------------------------------------------------------------------------- /retriv/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "ANN_Searcher", 3 | "DenseRetriever", 4 | "Encoder", 5 | "SearchEngine", 6 | "SparseRetriever", 7 | "HybridRetriever", 8 | "Merger", 9 | ] 10 | 11 | import os 12 | from pathlib import Path 13 | 14 | from .dense_retriever.ann_searcher import ANN_Searcher 15 | from .dense_retriever.dense_retriever import DenseRetriever 16 | from .dense_retriever.encoder import Encoder 17 | from .hybrid_retriever import HybridRetriever 18 | from .merger.merger import Merger 19 | from .sparse_retriever.sparse_retriever import SparseRetriever 20 | from .sparse_retriever.sparse_retriever import SparseRetriever as SearchEngine 21 | 22 | # Set environment variables ---------------------------------------------------- 23 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 24 | if "RETRIV_BASE_PATH" not in os.environ: # allow user to set a different path in .bash_profile 25 | os.environ["RETRIV_BASE_PATH"] = str(Path.home() / ".retriv") 26 | 27 | def set_base_path(path: str): 28 | os.environ["RETRIV_BASE_PATH"] = path 29 | -------------------------------------------------------------------------------- /retriv/autotune/__init__.py: -------------------------------------------------------------------------------- 1 | from .bm25_autotune import tune_bm25 2 | from .merger_autotune import tune_merger 3 | 4 | __all__ = [ 5 | "tune_bm25", 6 | "tune_merger", 7 | ] 8 | -------------------------------------------------------------------------------- /retriv/autotune/bm25_autotune.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import optuna 4 | from optuna.exceptions import ExperimentalWarning 5 | from ranx import Qrels, Run, evaluate 6 | 7 | warnings.filterwarnings("ignore", category=ExperimentalWarning) 8 | 9 | 10 | def bm25_objective(trial, queries, qrels, se, metric, cutoff): 11 | b = trial.suggest_float("b", 0.0, 1.0, step=0.01) 12 | k1 = trial.suggest_float("k1", 0.0, 10.0, step=0.1) 13 | 14 | se.hyperparams = dict(b=b, k1=k1) 15 | run = Run(se.bsearch(queries=queries, cutoff=cutoff, show_progress=False)) 16 | 17 | return evaluate(qrels, run, metric) 18 | 19 | 20 | def tune_bm25(queries, qrels, se, metric, n_trials, cutoff): 21 | qrels = Qrels(qrels) 22 | 23 | optuna.logging.set_verbosity(optuna.logging.WARNING) 24 | 25 | sampler = optuna.samplers.TPESampler(seed=42) 26 | study = optuna.create_study(direction="maximize", sampler=sampler) 27 | study.optimize( 28 | lambda trial: bm25_objective(trial, queries, qrels, se, metric, cutoff), 29 | n_trials=n_trials, 30 | show_progress_bar=True, 31 | ) 32 | 33 | optuna.logging.set_verbosity(optuna.logging.INFO) 34 | 35 | # Set best params 36 | se.hyperparams = study.best_params 37 | 38 | return study.best_params 39 | -------------------------------------------------------------------------------- /retriv/autotune/merger_autotune.py: -------------------------------------------------------------------------------- 1 | from ranx import Qrels, Run, evaluate, fuse, optimize_fusion 2 | 3 | 4 | def tune_merger(qrels, runs, metric): 5 | ranx_qrels = Qrels(qrels) 6 | ranx_runs = [Run(run) for run in runs] 7 | 8 | best_score = 0.0 9 | best_config = None 10 | 11 | for norm in ["min-max", "max", "sum"]: 12 | best_params = optimize_fusion( 13 | qrels=ranx_qrels, 14 | runs=ranx_runs, 15 | norm=norm, 16 | method="wsum", 17 | metric=metric, 18 | show_progress=False, 19 | ) 20 | 21 | combined_run = fuse( 22 | runs=ranx_runs, norm=norm, method="wsum", params=best_params 23 | ) 24 | 25 | score = evaluate(ranx_qrels, combined_run, metric) 26 | if score > best_score: 27 | best_score = score 28 | best_config = { 29 | "norm": norm, 30 | "params": best_params, 31 | } 32 | 33 | return best_config 34 | -------------------------------------------------------------------------------- /retriv/base_retriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from typing import Iterable, List 4 | 5 | import numpy as np 6 | import orjson 7 | from indxr import Indxr 8 | from oneliner_utils import read_csv, read_jsonl 9 | 10 | from .paths import docs_path, index_path 11 | 12 | 13 | class BaseRetriever: 14 | def __init__(self, index_name: str = "new-index"): 15 | self.index_name = index_name 16 | self.id_mapping = None 17 | self.doc_count = None 18 | self.doc_index = None 19 | 20 | @staticmethod 21 | def delete(index_name="new-index"): 22 | try: 23 | shutil.rmtree(index_path(index_name)) 24 | print(f"{index_name} successfully removed.") 25 | except FileNotFoundError: 26 | print(f"{index_name} not found.") 27 | 28 | def collection_generator(self, path: str, callback: callable = None): 29 | kind = os.path.splitext(path)[1][1:] 30 | assert kind in { 31 | "jsonl", 32 | "csv", 33 | "tsv", 34 | }, "Only JSONl, CSV, and TSV are currently supported." 35 | 36 | if kind == "jsonl": 37 | collection = read_jsonl(path, generator=True, callback=callback) 38 | elif kind == "csv": 39 | collection = read_csv(path, generator=True, callback=callback) 40 | elif kind == "tsv": 41 | collection = read_csv( 42 | path, delimiter="\t", generator=True, callback=callback 43 | ) 44 | 45 | return collection 46 | 47 | def save_collection(self, collection: Iterable, callback: callable = None): 48 | with open(docs_path(self.index_name), "wb") as f: 49 | for doc in collection: 50 | x = callback(doc) if callback is not None else doc 51 | f.write(orjson.dumps(x) + "\n".encode()) 52 | 53 | def initialize_doc_index(self): 54 | self.doc_index = Indxr(docs_path(self.index_name)) 55 | 56 | def initialize_id_mapping(self): 57 | ids = read_jsonl( 58 | docs_path(self.index_name), 59 | generator=True, 60 | callback=lambda x: x["id"], 61 | ) 62 | self.id_mapping = dict(enumerate(ids)) 63 | 64 | def get_doc(self, doc_id: str) -> dict: 65 | return self.doc_index.get(doc_id) 66 | 67 | def get_docs(self, doc_ids: List[str]) -> List[dict]: 68 | return self.doc_index.mget(doc_ids) 69 | 70 | def prepare_results(self, doc_ids: List[str], scores: np.ndarray) -> List[dict]: 71 | docs = self.get_docs(doc_ids) 72 | results = [] 73 | for doc, score in zip(docs, scores): 74 | doc["score"] = score 75 | results.append(doc) 76 | 77 | return results 78 | 79 | def map_internal_ids_to_original_ids(self, doc_ids: Iterable) -> List[str]: 80 | return [self.id_mapping[doc_id] for doc_id in doc_ids] 81 | 82 | def save(self): 83 | raise NotImplementedError() 84 | 85 | @staticmethod 86 | def load(): 87 | raise NotImplementedError() 88 | 89 | def index(self): 90 | raise NotImplementedError() 91 | 92 | def index_file(self): 93 | raise NotImplementedError() 94 | 95 | def search(self): 96 | raise NotImplementedError() 97 | 98 | def msearch(self): 99 | raise NotImplementedError() 100 | 101 | def autotune(self): 102 | raise NotImplementedError() 103 | -------------------------------------------------------------------------------- /retriv/dense_retriever/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/dense_retriever/__init__.py -------------------------------------------------------------------------------- /retriv/dense_retriever/ann_searcher.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import faiss 4 | import numpy as np 5 | import psutil 6 | from autofaiss import build_index 7 | from oneliner_utils import read_json 8 | 9 | from ..paths import embeddings_folder_path, faiss_index_infos_path, faiss_index_path 10 | 11 | 12 | def get_ram(): 13 | size_bytes = psutil.virtual_memory().total 14 | i = int(math.floor(math.log(size_bytes, 1024))) 15 | p = math.pow(1024, i) 16 | s = round(size_bytes / p, 2) 17 | return f"{s}GB" 18 | 19 | 20 | class ANN_Searcher: 21 | def __init__(self, index_name: str = "new-index"): 22 | self.index_name = index_name 23 | self.faiss_index = None 24 | self.faiss_index_infos = None 25 | 26 | def build(self, use_gpu=False): 27 | index, index_infos = build_index( 28 | embeddings=str(embeddings_folder_path(self.index_name)), 29 | index_path=str(faiss_index_path(self.index_name)), 30 | index_infos_path=str(faiss_index_infos_path(self.index_name)), 31 | save_on_disk=True, 32 | metric_type="ip", 33 | # max_index_memory_usage="32GB", 34 | current_memory_available=get_ram(), 35 | max_index_query_time_ms=10, 36 | min_nearest_neighbors_to_retrieve=20, 37 | index_key=None, 38 | index_param=None, 39 | use_gpu=use_gpu, 40 | nb_cores=None, 41 | make_direct_map=False, 42 | should_be_memory_mappable=False, 43 | distributed=None, 44 | verbose=40, 45 | ) 46 | 47 | self.faiss_index = index 48 | self.faiss_index_infos = index_infos 49 | 50 | @staticmethod 51 | def load(index_name: str = "new-index"): 52 | ann_searcher = ANN_Searcher(index_name) 53 | ann_searcher.faiss_index = faiss.read_index(str(faiss_index_path(index_name))) 54 | ann_searcher.faiss_index_infos = read_json(faiss_index_infos_path(index_name)) 55 | return ann_searcher 56 | 57 | def search(self, query: np.ndarray, cutoff: int = 100): 58 | query = query.reshape(1, len(query)) 59 | ids, scores = self.msearch(query, cutoff) 60 | return ids[0], scores[0] 61 | 62 | def msearch(self, queries: np.ndarray, cutoff: int = 100): 63 | scores, ids = self.faiss_index.search(queries, cutoff) 64 | return ids, scores 65 | -------------------------------------------------------------------------------- /retriv/dense_retriever/dense_retriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from typing import Dict, Iterable, List 4 | 5 | import numpy as np 6 | import orjson 7 | from numba import njit, prange 8 | from numba.typed import List as TypedList 9 | from oneliner_utils import create_path 10 | from tqdm import tqdm 11 | 12 | from ..base_retriever import BaseRetriever 13 | from ..paths import docs_path, dr_state_path, embeddings_folder_path 14 | from .ann_searcher import ANN_Searcher 15 | from .encoder import Encoder 16 | 17 | 18 | class DenseRetriever(BaseRetriever): 19 | def __init__( 20 | self, 21 | index_name: str = "new-index", 22 | model: str = "sentence-transformers/all-MiniLM-L6-v2", 23 | normalize: bool = True, 24 | max_length: int = 128, 25 | use_ann: bool = True, 26 | ): 27 | """The Dense Retriever performs [semantic search](https://en.wikipedia.org/wiki/Semantic_search), i.e., it compares vector representations of queries and documents to compute the relevance scores of the latter. 28 | 29 | Args: 30 | index_name (str, optional): [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. Defaults to "new-index". 31 | 32 | model (str, optional): defines the encoder model to encode queries and documents into vectors. You can use an [HuggingFace's Transformers](https://huggingface.co/models) pre-trained model by providing its ID or load a local model by providing its path. In the case of local models, the path must point to the directory containing the data saved with the [`PreTrainedModel.save_pretrained`](https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/model#transformers.PreTrainedModel.save_pretrained) method. Note that the representations are computed with `mean pooling` over the `last_hidden_state`. Defaults to "sentence-transformers/all-MiniLM-L6-v2". 33 | 34 | normalize (bool, optional): whether to L2 normalize the vector representations. Defaults to True. 35 | 36 | max_length (int, optional): texts longer than `max_length` will be automatically truncated. Choose this parameter based on how the employed model was trained or is generally used. Defaults to 128. 37 | 38 | use_ann (bool, optional): whether to use approximate nearest neighbors search. Set it to `False` to use nearest neighbors search without approximation. If you have less than 20k documents in your collection, you probably want to disable approximation. Defaults to True. 39 | """ 40 | 41 | self.index_name = index_name 42 | self.model = model 43 | self.normalize = normalize 44 | self.max_length = max_length 45 | self.use_ann = use_ann 46 | 47 | self.encoder = Encoder( 48 | index_name=index_name, 49 | model=model, 50 | normalize=normalize, 51 | max_length=max_length, 52 | ) 53 | 54 | self.ann_searcher = ANN_Searcher(index_name=index_name) 55 | 56 | self.id_mapping = None 57 | self.doc_count = None 58 | self.doc_index = None 59 | 60 | self.embeddings = None 61 | 62 | def save(self): 63 | """Save the state of the retriever to be able to restore it later.""" 64 | 65 | state = dict( 66 | init_args=dict( 67 | index_name=self.index_name, 68 | model=self.model, 69 | normalize=self.normalize, 70 | max_length=self.max_length, 71 | use_ann=self.use_ann, 72 | ), 73 | id_mapping=self.id_mapping, 74 | doc_count=self.doc_count, 75 | embeddings=True if self.embeddings is not None else None, 76 | ) 77 | np.savez_compressed(dr_state_path(self.index_name), state=state) 78 | 79 | @staticmethod 80 | def load(index_name: str = "new-index"): 81 | """Load a retriever and its index. 82 | 83 | Args: 84 | index_name (str, optional): Name of the index. Defaults to "new-index". 85 | 86 | Returns: 87 | DenseRetriever: Dense Retriever. 88 | """ 89 | 90 | state = np.load(dr_state_path(index_name), allow_pickle=True)["state"][()] 91 | dr = DenseRetriever(**state["init_args"]) 92 | dr.initialize_doc_index() 93 | dr.id_mapping = state["id_mapping"] 94 | dr.doc_count = state["doc_count"] 95 | if state["embeddings"]: 96 | dr.load_embeddings() 97 | if dr.use_ann: 98 | dr.ann_searcher = ANN_Searcher.load(index_name) 99 | return dr 100 | 101 | def load_embeddings(self): 102 | """Internal usage.""" 103 | path = embeddings_folder_path(self.index_name) 104 | npy_file_paths = sorted(os.listdir(path)) 105 | self.embeddings = np.concatenate( 106 | [np.load(path / npy_file_path) for npy_file_path in npy_file_paths] 107 | ) 108 | 109 | def import_embeddings(self, path: str): 110 | """Internal usage.""" 111 | shutil.copyfile(path, embeddings_folder_path(self.index_name) / "chunk_0.npy") 112 | 113 | def index_aux( 114 | self, 115 | embeddings_path: str = None, 116 | use_gpu: bool = False, 117 | batch_size: int = 512, 118 | callback: callable = None, 119 | show_progress: bool = True, 120 | ): 121 | """Internal usage.""" 122 | if embeddings_path is not None: 123 | self.import_embeddings(embeddings_path) 124 | else: 125 | self.encoder.change_device("cuda" if use_gpu else "cpu") 126 | self.encoder.encode_collection( 127 | path=docs_path(self.index_name), 128 | batch_size=batch_size, 129 | callback=callback, 130 | show_progress=show_progress, 131 | ) 132 | self.encoder.change_device("cpu") 133 | 134 | if self.use_ann: 135 | if show_progress: 136 | print("Building ANN Searcher") 137 | self.ann_searcher.build() 138 | else: 139 | if show_progress: 140 | print("Loading embeddings...") 141 | self.load_embeddings() 142 | 143 | def index( 144 | self, 145 | collection: Iterable, 146 | embeddings_path: str = None, 147 | use_gpu: bool = False, 148 | batch_size: int = 512, 149 | callback: callable = None, 150 | show_progress: bool = True, 151 | ): 152 | """Index a given collection of documents. 153 | 154 | Args: 155 | collection (Iterable): collection of documents to index. 156 | 157 | embeddings_path (str, optional): in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. Defaults to None. 158 | 159 | use_gpu (bool, optional): whether to use the GPU for document encoding. Defaults to False. 160 | 161 | batch_size (int, optional): how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 512. 162 | 163 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None. 164 | 165 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True. 166 | 167 | Returns: 168 | DenseRetriever: Dense Retriever 169 | """ 170 | 171 | self.save_collection(collection, callback) 172 | self.initialize_doc_index() 173 | self.initialize_id_mapping() 174 | self.doc_count = len(self.id_mapping) 175 | self.index_aux( 176 | embeddings_path=embeddings_path, 177 | use_gpu=use_gpu, 178 | batch_size=batch_size, 179 | callback=callback, 180 | show_progress=show_progress, 181 | ) 182 | self.save() 183 | return self 184 | 185 | def index_file( 186 | self, 187 | path: str, 188 | embeddings_path: str = None, 189 | use_gpu: bool = False, 190 | batch_size: int = 512, 191 | callback: callable = None, 192 | show_progress: bool = True, 193 | ): 194 | """Index the collection contained in a given file. 195 | 196 | Args: 197 | path (str): path of file containing the collection to index. 198 | 199 | embeddings_path (str, optional): in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. Defaults to None. 200 | 201 | use_gpu (bool, optional): whether to use the GPU for document encoding. Defaults to False. 202 | 203 | batch_size (int, optional): how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 512. 204 | 205 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None. 206 | 207 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True. 208 | 209 | Returns: 210 | DenseRetriever: Dense Retriever. 211 | """ 212 | 213 | collection = self.collection_generator(path, callback) 214 | return self.index( 215 | collection, 216 | embeddings_path, 217 | use_gpu, 218 | batch_size, 219 | None, 220 | show_progress, 221 | ) 222 | 223 | def search( 224 | self, 225 | query: str, 226 | return_docs: bool = True, 227 | cutoff: int = 100, 228 | ) -> List: 229 | """Standard search functionality. 230 | 231 | Args: 232 | query (str): what to search for. 233 | 234 | return_docs (bool, optional): whether to return the texts of the documents. Defaults to True. 235 | 236 | cutoff (int, optional): number of results to return. Defaults to 100. 237 | 238 | Returns: 239 | List: results. 240 | """ 241 | 242 | encoded_query = self.encoder(query) 243 | 244 | if self.use_ann: 245 | doc_ids, scores = self.ann_searcher.search(encoded_query, cutoff) 246 | else: 247 | if self.embeddings is None: 248 | self.load_embeddings() 249 | doc_ids, scores = compute_scores(encoded_query, self.embeddings, cutoff) 250 | 251 | doc_ids = self.map_internal_ids_to_original_ids(doc_ids) 252 | 253 | return ( 254 | self.prepare_results(doc_ids, scores) 255 | if return_docs 256 | else dict(zip(doc_ids, scores)) 257 | ) 258 | 259 | def msearch( 260 | self, 261 | queries: List[Dict[str, str]], 262 | cutoff: int = 100, 263 | batch_size: int = 32, 264 | ) -> Dict: 265 | """Compute results for multiple queries at once. 266 | 267 | Args: 268 | queries (List[Dict[str, str]]): what to search for. 269 | 270 | cutoff (int, optional): number of results to return. Defaults to 100. 271 | 272 | batch_size (int, optional): how many queries to search at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 32. 273 | 274 | Returns: 275 | Dict: results. 276 | """ 277 | 278 | q_ids = [x["id"] for x in queries] 279 | q_texts = [x["text"] for x in queries] 280 | encoded_queries = self.encoder(q_texts, batch_size, show_progress=False) 281 | 282 | if self.use_ann: 283 | doc_ids, scores = self.ann_searcher.msearch(encoded_queries, cutoff) 284 | else: 285 | if self.embeddings is None: 286 | self.load_embeddings() 287 | doc_ids, scores = compute_scores_multi( 288 | encoded_queries, self.embeddings, cutoff 289 | ) 290 | 291 | doc_ids = [ 292 | self.map_internal_ids_to_original_ids(_doc_ids) for _doc_ids in doc_ids 293 | ] 294 | 295 | results = {q: dict(zip(doc_ids[i], scores[i])) for i, q in enumerate(q_ids)} 296 | 297 | return {q_id: results[q_id] for q_id in q_ids} 298 | 299 | def bsearch( 300 | self, 301 | queries: List[Dict[str, str]], 302 | cutoff: int = 100, 303 | batch_size: int = 32, 304 | show_progress: bool = True, 305 | qrels: Dict[str, Dict[str, float]] = None, 306 | path: str = None, 307 | ): 308 | """Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. bsearch is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM. 309 | 310 | Args: 311 | queries (List[Dict[str, str]]): what to search for. 312 | 313 | cutoff (int, optional): number of results to return. Defaults to 100. 314 | 315 | batch_size (int, optional): how many queries to search at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 32. 316 | 317 | show_progress (bool, optional): whether to show a progress bar for the search process. Defaults to True. 318 | 319 | qrels (Dict[str, Dict[str, float]], optional): query relevance judgements for the queries. Defaults to None. 320 | 321 | path (str, optional): where to save the results. Defaults to None. 322 | 323 | Returns: 324 | Dict: results. 325 | """ 326 | 327 | batches = [ 328 | queries[i : i + batch_size] for i in range(0, len(queries), batch_size) 329 | ] 330 | 331 | results = {} 332 | 333 | pbar = tqdm( 334 | total=len(queries), 335 | disable=not show_progress, 336 | desc="Batch search", 337 | dynamic_ncols=True, 338 | mininterval=0.5, 339 | ) 340 | 341 | if path is None: 342 | for batch in batches: 343 | new_results = self.msearch( 344 | queries=batch, cutoff=cutoff, batch_size=len(batch) 345 | ) 346 | results = {**results, **new_results} 347 | pbar.update(min(batch_size, len(batch))) 348 | else: 349 | path = create_path(path) 350 | path.parent.mkdir(parents=True, exist_ok=True) 351 | 352 | with open(path, "wb") as f: 353 | for batch in batches: 354 | new_results = self.msearch(queries=batch, cutoff=cutoff) 355 | 356 | for i, (k, v) in enumerate(new_results.items()): 357 | x = { 358 | "id": k, 359 | "text": batch[i]["text"], 360 | "dense_doc_ids": list(v.keys()), 361 | "dense_scores": [float(s) for s in list(v.values())], 362 | } 363 | if qrels is not None: 364 | x["rel_doc_ids"] = list(qrels[k].keys()) 365 | x["rel_scores"] = list(qrels[k].values()) 366 | f.write(orjson.dumps(x) + "\n".encode()) 367 | 368 | pbar.update(min(batch_size, len(batch))) 369 | 370 | return results 371 | 372 | 373 | @njit(cache=True) 374 | def compute_scores(query: np.ndarray, docs: np.ndarray, cutoff: int): 375 | """Internal usage.""" 376 | 377 | scores = docs @ query 378 | indices = np.argsort(-scores)[:cutoff] 379 | 380 | return indices, scores[indices] 381 | 382 | 383 | @njit(cache=True, parallel=True) 384 | def compute_scores_multi(queries: np.ndarray, docs: np.ndarray, cutoff: int): 385 | """Internal usage.""" 386 | 387 | n = len(queries) 388 | ids = TypedList([np.empty(1, dtype=np.int64) for _ in range(n)]) 389 | scores = TypedList([np.empty(1, dtype=np.float32) for _ in range(n)]) 390 | 391 | for i in prange(len(queries)): 392 | ids[i], scores[i] = compute_scores(queries[i], docs, cutoff) 393 | 394 | return ids, scores 395 | -------------------------------------------------------------------------------- /retriv/dense_retriever/encoder.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from typing import Generator, List, Union 3 | 4 | import numpy as np 5 | import torch 6 | from oneliner_utils import read_jsonl 7 | from torch import Tensor, einsum 8 | from torch.nn.functional import normalize 9 | from tqdm import tqdm 10 | from transformers import AutoConfig, AutoModel, AutoTokenizer 11 | 12 | from ..paths import embeddings_folder_path, encoder_state_path 13 | 14 | pbar_kwargs = dict(position=0, dynamic_ncols=True, mininterval=1.0) 15 | 16 | 17 | def count_lines(path: str): 18 | """Counts the number of lines in a file.""" 19 | return sum(1 for _ in open(path)) 20 | 21 | 22 | def generate_batch(docs: Union[List, Generator], batch_size: int) -> Generator: 23 | texts = [] 24 | 25 | for doc in docs: 26 | texts.append(doc["text"]) 27 | 28 | if len(texts) == batch_size: 29 | yield texts 30 | texts = [] 31 | 32 | if texts: 33 | yield texts 34 | 35 | 36 | class Encoder: 37 | def __init__( 38 | self, 39 | index_name: str = "new-index", 40 | model: str = "sentence-transformers/all-MiniLM-L6-v2", 41 | normalize: bool = True, 42 | return_numpy: bool = True, 43 | max_length: int = 128, 44 | device: str = "cpu", 45 | ): 46 | self.index_name = index_name 47 | self.model = model 48 | self.tokenizer = AutoTokenizer.from_pretrained(model) 49 | self.encoder = AutoModel.from_pretrained(model).to(device).eval() 50 | self.embedding_dim = AutoConfig.from_pretrained(model).hidden_size 51 | self.max_length = max_length 52 | self.normalize = normalize 53 | self.return_numpy = return_numpy 54 | self.device = device 55 | self.tokenizer_kwargs = { 56 | "padding": True, 57 | "truncation": True, 58 | "max_length": self.max_length, 59 | "return_tensors": "pt", 60 | } 61 | 62 | def save(self): 63 | state = dict( 64 | index_name=self.index_name, 65 | model=self.model, 66 | normalize=self.normalize, 67 | return_numpy=self.return_numpy, 68 | max_length=self.max_length, 69 | device=self.device, 70 | ) 71 | np.save(encoder_state_path(self.index_name), state) 72 | 73 | @staticmethod 74 | def load(index_name: str, device: str = None): 75 | state = np.load(encoder_state_path(index_name), allow_pickle=True)[()] 76 | if device is not None: 77 | state["device"] = device 78 | return Encoder(**state) 79 | 80 | def change_device(self, device: str = "cpu"): 81 | self.device = device 82 | self.encoder.to(device) 83 | 84 | def tokenize(self, texts: List[str]): 85 | tokens = self.tokenizer(texts, **self.tokenizer_kwargs) 86 | return {k: v.to(self.device) for k, v in tokens.items()} 87 | 88 | def mean_pooling(self, embeddings: Tensor, mask: Tensor) -> Tensor: 89 | numerators = einsum("xyz,xy->xyz", embeddings, mask).sum(dim=1) 90 | denominators = torch.clamp(mask.sum(dim=1), min=1e-9) 91 | return einsum("xz,x->xz", numerators, 1 / denominators) 92 | 93 | def __call__( 94 | self, 95 | x: Union[str, List[str]], 96 | batch_size: int = 32, 97 | show_progress: bool = True, 98 | ): 99 | if isinstance(x, str): 100 | return self.encode(x) 101 | else: 102 | return self.bencode(x, batch_size=batch_size, show_progress=show_progress) 103 | 104 | def encode(self, text: str): 105 | return self.bencode([text], batch_size=1, show_progress=False)[0] 106 | 107 | def bencode( 108 | self, 109 | texts: List[str], 110 | batch_size: int = 32, 111 | show_progress: bool = True, 112 | ): 113 | embeddings = [] 114 | 115 | pbar = tqdm( 116 | total=len(texts), 117 | desc="Generating embeddings", 118 | disable=not show_progress, 119 | **pbar_kwargs, 120 | ) 121 | 122 | for i in range(ceil(len(texts) / batch_size)): 123 | start, stop = i * batch_size, (i + 1) * batch_size 124 | tokens = self.tokenize(texts[start:stop]) 125 | 126 | with torch.no_grad(): 127 | emb = self.encoder(**tokens).last_hidden_state 128 | emb = self.mean_pooling(emb, tokens["attention_mask"]) 129 | if self.normalize: 130 | emb = normalize(emb, dim=-1) 131 | 132 | embeddings.append(emb) 133 | 134 | pbar.update(stop - start) 135 | pbar.close() 136 | 137 | embeddings = torch.cat(embeddings) 138 | 139 | if self.return_numpy: 140 | embeddings = embeddings.detach().cpu().numpy() 141 | 142 | return embeddings 143 | 144 | def encode_collection( 145 | self, 146 | path: str, 147 | batch_size: int = 512, 148 | callback: callable = None, 149 | show_progress: bool = True, 150 | ): 151 | n_docs = count_lines(path) 152 | collection = read_jsonl(path, callback=callback, generator=True) 153 | 154 | reservoir = np.empty((1_000_000, self.embedding_dim), dtype=np.float32) 155 | reservoir_n = 0 156 | offset = 0 157 | 158 | pbar = tqdm( 159 | total=n_docs, 160 | desc="Embedding documents", 161 | disable=not show_progress, 162 | **pbar_kwargs, 163 | ) 164 | 165 | for texts in generate_batch(collection, batch_size): 166 | # Compute embeddings ----------------------------------------------- 167 | embeddings = self.bencode(texts, batch_size=len(texts), show_progress=False) 168 | 169 | # Compute new offset ----------------------------------------------- 170 | new_offset = offset + len(embeddings) 171 | 172 | if new_offset >= len(reservoir): 173 | np.save( 174 | embeddings_folder_path(self.index_name) 175 | / f"chunk_{reservoir_n}.npy", 176 | reservoir[:offset], 177 | ) 178 | reservoir = np.empty((1_000_000, self.embedding_dim), dtype=np.float32) 179 | reservoir_n += 1 180 | offset = 0 181 | new_offset = len(embeddings) 182 | 183 | # Save embeddings in the reservoir --------------------------------- 184 | reservoir[offset:new_offset] = embeddings 185 | 186 | # Update offeset --------------------------------------------------- 187 | offset = new_offset 188 | 189 | pbar.update(len(embeddings)) 190 | 191 | if offset < len(reservoir): 192 | np.save( 193 | embeddings_folder_path(self.index_name) / f"chunk_{reservoir_n}.npy", 194 | reservoir[:offset], 195 | ) 196 | reservoir = [] 197 | 198 | assert len(reservoir) == 0, "Reservoir is not empty." 199 | -------------------------------------------------------------------------------- /retriv/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["AdvancedRetriever"] 2 | 3 | 4 | from .advanced_retriever import AdvancedRetriever 5 | -------------------------------------------------------------------------------- /retriv/hybrid_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, List, Set, Union 2 | 3 | import numpy as np 4 | import orjson 5 | from oneliner_utils import create_path 6 | from tqdm import tqdm 7 | 8 | from .base_retriever import BaseRetriever 9 | from .dense_retriever.dense_retriever import DenseRetriever 10 | from .merger.merger import Merger 11 | from .paths import hr_state_path 12 | from .sparse_retriever.sparse_retriever import SparseRetriever 13 | 14 | 15 | class HybridRetriever(BaseRetriever): 16 | def __init__( 17 | self, 18 | # Global params 19 | index_name: str = "new-index", 20 | # Sparse retriever params 21 | sr_model: str = "bm25", 22 | min_df: int = 1, 23 | tokenizer: Union[str, callable] = "whitespace", 24 | stemmer: Union[str, callable] = "english", 25 | stopwords: Union[str, List[str], Set[str]] = "english", 26 | do_lowercasing: bool = True, 27 | do_ampersand_normalization: bool = True, 28 | do_special_chars_normalization: bool = True, 29 | do_acronyms_normalization: bool = True, 30 | do_punctuation_removal: bool = True, 31 | # Dense retriever params 32 | dr_model: str = "sentence-transformers/all-MiniLM-L6-v2", 33 | normalize: bool = True, 34 | max_length: int = 128, 35 | use_ann: bool = True, 36 | # For already instantiated modules 37 | sparse_retriever: SparseRetriever = None, 38 | dense_retriever: DenseRetriever = None, 39 | merger: Merger = None, 40 | ): 41 | """The [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) is searcher based on both lexical and semantic matching. It comprises three components: the [Sparse Retriever]((https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md)), the [Dense Retriever]((https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md)), and the Merger. The Merger fuses the results of the Sparse and Dense Retrievers to compute the _hybrid_ results. 42 | 43 | Args: 44 | index_name (str, optional): [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. Defaults to "new-index". 45 | 46 | sr_model (str, optional): defines the model to use for sparse retrieval (`bm25` or `tf-idf`). Defaults to "bm25". 47 | 48 | min_df (int, optional): terms that appear in less than `min_df` documents will be ignored. If integer, the parameter indicates the absolute count. If float, it represents a proportion of documents. Defaults to 1. 49 | 50 | tokenizer (Union[str, callable], optional): [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization setting the parameter to `None`. Defaults to "whitespace". 51 | 52 | stemmer (Union[str, callable], optional): [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`. Defaults to "english". 53 | 54 | stopwords (Union[str, List[str], Set[str]], optional): [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`. Defaults to "english". 55 | 56 | do_lowercasing (bool, optional): whether to lowercase texts. Defaults to True. 57 | 58 | do_ampersand_normalization (bool, optional): whether to convert `&` in `and` during pre-processing. Defaults to True. 59 | 60 | do_special_chars_normalization (bool, optional): whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`. Defaults to True. 61 | 62 | do_acronyms_normalization (bool, optional): whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`. Defaults to True. 63 | 64 | do_punctuation_removal (bool, optional): whether to remove punctuation. Defaults to True. 65 | 66 | dr_model (str, optional): defines the model to use for encoding queries and documents into vectors. You can use an [HuggingFace's Transformers](https://huggingface.co/models) pre-trained model by providing its ID or load a local model by providing its path. In the case of local models, the path must point to the directory containing the data saved with the [`PreTrainedModel.save_pretrained`](https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/model#transformers.PreTrainedModel.save_pretrained) method. Note that the representations are computed with `mean pooling` over the `last_hidden_state`. Defaults to "sentence-transformers/all-MiniLM-L6-v2". 67 | 68 | normalize (bool, optional): whether to L2 normalize the vector representations. Defaults to True. 69 | 70 | max_length (int, optional): texts longer than `max_length` will be automatically truncated. Choose this parameter based on how the employed model was trained or is generally used. Defaults to 128. 71 | 72 | use_ann (bool, optional): whether to use approximate nearest neighbors search. Set it to `False` to use nearest neighbors search without approximation. If you have less than 20k documents in your collection, you probably want to disable approximation. Defaults to True. 73 | """ 74 | 75 | self.index_name = index_name 76 | 77 | self.sparse_retriever = ( 78 | sparse_retriever 79 | if sparse_retriever is not None 80 | else SparseRetriever( 81 | index_name=index_name, 82 | model=sr_model, 83 | min_df=min_df, 84 | tokenizer=tokenizer, 85 | stemmer=stemmer, 86 | stopwords=stopwords, 87 | do_lowercasing=do_lowercasing, 88 | do_ampersand_normalization=do_ampersand_normalization, 89 | do_special_chars_normalization=do_special_chars_normalization, 90 | do_acronyms_normalization=do_acronyms_normalization, 91 | do_punctuation_removal=do_punctuation_removal, 92 | ) 93 | ) 94 | 95 | self.dense_retriever = ( 96 | dense_retriever 97 | if dense_retriever is not None 98 | else DenseRetriever( 99 | index_name=index_name, 100 | model=dr_model, 101 | normalize=normalize, 102 | max_length=max_length, 103 | use_ann=use_ann, 104 | ) 105 | ) 106 | 107 | self.merger = merger if merger is not None else Merger(index_name=index_name) 108 | 109 | def index( 110 | self, 111 | collection: Iterable, 112 | embeddings_path: str = None, 113 | use_gpu: bool = False, 114 | batch_size: int = 512, 115 | callback: callable = None, 116 | show_progress: bool = True, 117 | ): 118 | """Index a given collection of documents. 119 | 120 | Args: 121 | collection (Iterable): collection of documents to index. 122 | 123 | embeddings_path (str, optional): in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. Defaults to None. 124 | 125 | use_gpu (bool, optional): whether to use the GPU for document encoding. Defaults to False. 126 | 127 | batch_size (int, optional): how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 512. 128 | 129 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None. 130 | 131 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True. 132 | 133 | Returns: 134 | HybridRetriever: Hybrid Retriever 135 | """ 136 | 137 | self.save_collection(collection, callback) 138 | 139 | self.initialize_doc_index() 140 | self.initialize_id_mapping() 141 | self.doc_count = len(self.id_mapping) 142 | 143 | # Sparse --------------------------------------------------------------- 144 | self.sparse_retriever.doc_index = self.doc_index 145 | self.sparse_retriever.id_mapping = self.id_mapping 146 | self.sparse_retriever.doc_count = self.doc_count 147 | self.sparse_retriever.index_aux(show_progress) 148 | 149 | # Dense ---------------------------------------------------------------- 150 | self.dense_retriever.doc_index = self.doc_index 151 | self.dense_retriever.id_mapping = self.id_mapping 152 | self.dense_retriever.doc_count = self.doc_count 153 | self.dense_retriever.index_aux( 154 | embeddings_path, use_gpu, batch_size, callback, show_progress 155 | ) 156 | 157 | self.save() 158 | 159 | return self 160 | 161 | def index_file( 162 | self, 163 | path: str, 164 | embeddings_path: str = None, 165 | use_gpu: bool = False, 166 | batch_size: int = 512, 167 | callback: callable = None, 168 | show_progress: bool = True, 169 | ): 170 | """Index the collection contained in a given file. 171 | 172 | Args: 173 | path (str): path of file containing the collection to index. 174 | 175 | embeddings_path (str, optional): in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. Defaults to None. 176 | 177 | use_gpu (bool, optional): whether to use the GPU for document encoding. Defaults to False. 178 | 179 | batch_size (int, optional): how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 512. 180 | 181 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None. 182 | 183 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True. 184 | 185 | Returns: 186 | HybridRetriever: Hybrid Retriever. 187 | """ 188 | 189 | collection = self.collection_generator(path, callback) 190 | return self.index( 191 | collection, 192 | embeddings_path, 193 | use_gpu, 194 | batch_size, 195 | None, 196 | show_progress, 197 | ) 198 | 199 | def save(self): 200 | """Save the state of the retriever to be able to restore it later.""" 201 | 202 | state = dict( 203 | id_mapping=self.id_mapping, 204 | doc_count=self.doc_count, 205 | ) 206 | np.savez_compressed(hr_state_path(self.index_name), state=state) 207 | 208 | self.sparse_retriever.save() 209 | self.dense_retriever.save() 210 | self.merger.save() 211 | 212 | @staticmethod 213 | def load(index_name: str = "new-index"): 214 | """Load a retriever and its index. 215 | 216 | Args: 217 | index_name (str, optional): Name of the index. Defaults to "new-index". 218 | 219 | Returns: 220 | HybridRetriever: Hybrid Retriever. 221 | """ 222 | 223 | state = np.load(hr_state_path(index_name), allow_pickle=True)["state"][()] 224 | 225 | hr = HybridRetriever(index_name) 226 | hr.initialize_doc_index() 227 | hr.id_mapping = state["id_mapping"] 228 | hr.doc_count = state["doc_count"] 229 | 230 | hr.sparse_retriever = SparseRetriever.load(index_name) 231 | hr.dense_retriever = DenseRetriever.load(index_name) 232 | hr.merger = Merger.load(index_name) 233 | return hr 234 | 235 | def search( 236 | self, 237 | query: str, 238 | return_docs: bool = True, 239 | cutoff: int = 100, 240 | ) -> List: 241 | """Standard search functionality. 242 | 243 | Args: 244 | query (str): what to search for. 245 | 246 | return_docs (bool, optional): whether to return the texts of the documents. Defaults to True. 247 | 248 | cutoff (int, optional): number of results to return. Defaults to 100. 249 | 250 | Returns: 251 | List: results. 252 | """ 253 | 254 | sparse_results = self.sparse_retriever.search(query, False, 1_000) 255 | dense_results = self.dense_retriever.search(query, False, 1_000) 256 | hybrid_results = self.merger.fuse([sparse_results, dense_results]) 257 | return ( 258 | self.prepare_results( 259 | list(hybrid_results.keys())[:cutoff], 260 | list(hybrid_results.values())[:cutoff], 261 | ) 262 | if return_docs 263 | else hybrid_results 264 | ) 265 | 266 | def msearch( 267 | self, 268 | queries: List[Dict[str, str]], 269 | cutoff: int = 100, 270 | batch_size: int = 32, 271 | ) -> Dict: 272 | """Compute results for multiple queries at once. 273 | 274 | Args: 275 | queries (List[Dict[str, str]]): what to search for. 276 | 277 | cutoff (int, optional): number of results to return. Defaults to 100. 278 | 279 | batch_size (int, optional): how many queries to search at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 32. 280 | 281 | Returns: 282 | Dict: results. 283 | """ 284 | 285 | sparse_results = self.sparse_retriever.msearch(queries, 1_000) 286 | dense_results = self.dense_retriever.msearch(queries, 1_000, batch_size) 287 | return self.merger.mfuse([sparse_results, dense_results], cutoff) 288 | 289 | def bsearch( 290 | self, 291 | queries: List[Dict[str, str]], 292 | cutoff: int = 100, 293 | batch_size: int = 32, 294 | show_progress: bool = True, 295 | qrels: Dict[str, Dict[str, float]] = None, 296 | path: str = None, 297 | ): 298 | """Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. bsearch is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM. 299 | 300 | Args: 301 | queries (List[Dict[str, str]]): what to search for. 302 | 303 | cutoff (int, optional): number of results to return. Defaults to 100. 304 | 305 | batch_size (int, optional): how many queries to search at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 32. 306 | 307 | show_progress (bool, optional): whether to show a progress bar for the search process. Defaults to True. 308 | 309 | qrels (Dict[str, Dict[str, float]], optional): query relevance judgements for the queries. Defaults to None. 310 | 311 | path (str, optional): where to save the results. Defaults to None. 312 | 313 | Returns: 314 | Dict: results. 315 | """ 316 | 317 | batches = [ 318 | queries[i : i + batch_size] for i in range(0, len(queries), batch_size) 319 | ] 320 | 321 | results = {} 322 | 323 | pbar = tqdm( 324 | total=len(queries), 325 | disable=not show_progress, 326 | desc="Batch search", 327 | dynamic_ncols=True, 328 | mininterval=0.5, 329 | ) 330 | 331 | if path is None: 332 | for batch in batches: 333 | new_results = self.msearch( 334 | queries=batch, cutoff=cutoff, batch_size=len(batch) 335 | ) 336 | results = {**results, **new_results} 337 | pbar.update(min(batch_size, len(batch))) 338 | else: 339 | path = create_path(path) 340 | path.parent.mkdir(parents=True, exist_ok=True) 341 | 342 | with open(path, "wb") as f: 343 | for batch in batches: 344 | new_results = self.msearch(queries=batch, cutoff=cutoff) 345 | 346 | for i, (k, v) in enumerate(new_results.items()): 347 | x = { 348 | "id": k, 349 | "text": batch[i]["text"], 350 | "hybrid_doc_ids": list(v.keys()), 351 | "hybrid_scores": [float(s) for s in list(v.values())], 352 | } 353 | if qrels is not None: 354 | x["rel_doc_ids"] = list(qrels[k].keys()) 355 | x["rel_scores"] = list(qrels[k].values()) 356 | f.write(orjson.dumps(x) + "\n".encode()) 357 | 358 | pbar.update(min(batch_size, len(batch))) 359 | 360 | return results 361 | 362 | def autotune( 363 | self, 364 | queries: List[Dict[str, str]], 365 | qrels: Dict[str, Dict[str, float]], 366 | metric: str = "ndcg", 367 | n_trials: int = 100, 368 | cutoff: int = 100, 369 | batch_size: int = 32, 370 | ): 371 | """Use the AutoTune function to tune the Sparse Retriever's model [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) parameters and the importance given to the lexical and semantic relevance scores computed by the Sparse and Dense Retrievers, respectively. All metrics supported by [ranx](https://github.com/AmenRa/ranx) are supported by the `autotune` function. At the of the process, the best parameter configurations are automatically applied and saved to disk. You can inspect the best configurations found by printing `hr.sparse_retriever.hyperparams`, `hr.merger.norm` and `hr.merger.params`. 372 | 373 | Args: 374 | queries (List[Dict[str, str]]): queries to use for the optimization process. 375 | 376 | qrels (Dict[str, Dict[str, float]]): query relevance judgements for the queries. 377 | 378 | metric (str, optional): metric to optimize for. Defaults to "ndcg". 379 | 380 | n_trials (int, optional): number of configuration to evaluate. Defaults to 100. 381 | 382 | cutoff (int, optional): number of results to consider for the optimization process. Defaults to 100. 383 | """ 384 | 385 | # Tune sparse ---------------------------------------------------------- 386 | self.sparse_retriever.autotune( 387 | queries=queries, 388 | qrels=qrels, 389 | metric=metric, 390 | n_trials=n_trials, 391 | cutoff=cutoff, 392 | ) 393 | 394 | # Tune merger ---------------------------------------------------------- 395 | sparse_results = self.sparse_retriever.msearch(queries, 1_000) 396 | dense_results = self.dense_retriever.msearch(queries, 1_000, batch_size) 397 | self.merger.autotune(qrels, [sparse_results, dense_results], metric) 398 | -------------------------------------------------------------------------------- /retriv/merger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/merger/__init__.py -------------------------------------------------------------------------------- /retriv/merger/merger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | 6 | from ..autotune import tune_merger 7 | from ..paths import merger_state_path 8 | from .normalization import max_norm_multi, min_max_norm_multi, sum_norm_multi 9 | 10 | 11 | class Merger: 12 | def __init__(self, index_name: str = "new-index"): 13 | self.index_name = index_name 14 | self.norm = "min-max" 15 | self.params = None 16 | 17 | def fuse( 18 | self, results: List[Dict[str, float]], cutoff: int = 100 19 | ) -> Dict[str, float]: 20 | return self.mfuse([{"q_0": res} for res in results], cutoff)["q_0"] 21 | 22 | def mfuse( 23 | self, runs: List[Dict[str, Dict[str, float]]], cutoff: int = 100 24 | ) -> Dict[str, Dict[str, float]]: 25 | if self.norm == "min-max": 26 | normalized_runs = min_max_norm_multi(runs) 27 | elif self.norm == "max": 28 | normalized_runs = max_norm_multi(runs) 29 | elif self.norm == "sum": 30 | normalized_runs = sum_norm_multi(runs) 31 | else: 32 | raise NotImplementedError 33 | 34 | weights = [1.0 for _ in runs] if self.params is None else self.params["weights"] 35 | 36 | fused_run = defaultdict(lambda: defaultdict(float)) 37 | for i, run in enumerate(normalized_runs): 38 | for q_id in run: 39 | for doc_id in run[q_id]: 40 | fused_run[q_id][doc_id] += weights[i] * run[q_id][doc_id] 41 | 42 | # Sort results by descending value and ascending key 43 | for q_id, results in list(fused_run.items()): 44 | fused_run[q_id] = dict(sorted(results.items(), key=lambda x: (-x[1], x[0]))) 45 | 46 | # Apply cutoff 47 | for q_id, results in list(fused_run.items()): 48 | fused_run[q_id] = dict(list(results.items())[:cutoff]) 49 | 50 | return dict(fused_run) 51 | 52 | def save(self): 53 | state = dict( 54 | init_args=dict(index_name=self.index_name), 55 | norm=self.norm, 56 | params=self.params, 57 | ) 58 | np.savez_compressed(merger_state_path(self.index_name), state=state) 59 | 60 | @staticmethod 61 | def load(index_name: str = "new-index"): 62 | state = np.load(merger_state_path(index_name), allow_pickle=True)["state"][()] 63 | merger = Merger(**state["init_args"]) 64 | merger.norm = state["norm"] 65 | merger.params = state["params"] 66 | return merger 67 | 68 | def autotune( 69 | self, 70 | qrels: Dict[str, Dict[str, float]], 71 | runs: List[Dict[str, Dict[str, float]]], 72 | metric: str = "ndcg", 73 | ): 74 | config = tune_merger(qrels, runs, metric) 75 | self.norm = config["norm"] 76 | self.params = config["params"] 77 | self.save() 78 | -------------------------------------------------------------------------------- /retriv/merger/normalization.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | 4 | def extract_scores(results): 5 | """Extract the scores from a given results dictionary.""" 6 | scores = [None] * len(results) 7 | for i, v in enumerate(results.values()): 8 | scores[i] = v 9 | return scores 10 | 11 | 12 | def safe_max(x): 13 | return max(x) if len(x) != 0 else 0 14 | 15 | 16 | def safe_min(x): 17 | return min(x) if len(x) != 0 else 0 18 | 19 | 20 | def min_max_norm(run: Dict[str, float]): 21 | """Apply `min-max normalization` to a given run.""" 22 | normalized_run = {} 23 | 24 | for q_id, results in run.items(): 25 | scores = extract_scores(results) 26 | min_score = safe_min(scores) 27 | max_score = safe_max(scores) 28 | denominator = max(max_score - min_score, 1e-9) 29 | 30 | normalized_results = { 31 | doc_id: (results[doc_id] - min_score) / (denominator) for doc_id in results 32 | } 33 | 34 | normalized_run[q_id] = normalized_results 35 | 36 | return normalized_run 37 | 38 | 39 | def max_norm(run: Dict[str, float]): 40 | """Apply `max normalization` to a given run.""" 41 | normalized_run = {} 42 | 43 | for q_id, results in run.items(): 44 | scores = extract_scores(results) 45 | max_score = safe_max(scores) 46 | denominator = max(max_score, 1e-9) 47 | 48 | normalized_results = { 49 | doc_id: results[doc_id] / denominator for doc_id in results 50 | } 51 | 52 | normalized_run[q_id] = normalized_results 53 | 54 | return normalized_run 55 | 56 | 57 | def sum_norm(run: Dict[str, float]): 58 | """Apply `sum normalization` to a given run.""" 59 | normalized_run = {} 60 | 61 | for q_id, results in run.items(): 62 | scores = extract_scores(results) 63 | min_score = safe_min(scores) 64 | sum_score = sum(scores) 65 | denominator = sum_score - min_score * len(results) 66 | denominator = max(denominator, 1e-9) 67 | 68 | normalized_results = { 69 | doc_id: (results[doc_id] - min_score) / (denominator) for doc_id in results 70 | } 71 | 72 | normalized_run[q_id] = normalized_results 73 | 74 | return normalized_run 75 | 76 | 77 | def min_max_norm_multi(runs: List[Dict[str, float]]): 78 | """Apply `min-max normalization` to a list of given runs.""" 79 | return [min_max_norm(run) for run in runs] 80 | 81 | 82 | def max_norm_multi(runs: List[Dict[str, float]]): 83 | """Apply `max normalization` to a list of given runs.""" 84 | return [max_norm(run) for run in runs] 85 | 86 | 87 | def sum_norm_multi(runs: List[Dict[str, float]]): 88 | """Apply `sum normalization` to a list of given runs.""" 89 | return [sum_norm(run) for run in runs] 90 | -------------------------------------------------------------------------------- /retriv/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | def base_path(): 6 | p = Path(os.environ.get("RETRIV_BASE_PATH")) 7 | p.mkdir(parents=True, exist_ok=True) 8 | return p 9 | 10 | 11 | def collections_path(): 12 | p = base_path() / "collections" 13 | p.mkdir(parents=True, exist_ok=True) 14 | return p 15 | 16 | 17 | def index_path(index_name: str): 18 | path = collections_path() / index_name 19 | path.mkdir(parents=True, exist_ok=True) 20 | return path 21 | 22 | 23 | def docs_path(index_name: str): 24 | return index_path(index_name) / "docs.jsonl" 25 | 26 | 27 | def sr_state_path(index_name: str): 28 | return index_path(index_name) / "sr_state.npz" 29 | 30 | 31 | def fr_state_path(index_name: str): 32 | return index_path(index_name) / "fr_state.npz" 33 | 34 | 35 | def embeddings_path(index_name: str): 36 | return index_path(index_name) / "embeddings.h5" 37 | 38 | 39 | def embeddings_folder_path(index_name: str): 40 | path = index_path(index_name) / "embeddings" 41 | path.mkdir(parents=True, exist_ok=True) 42 | return path 43 | 44 | 45 | def faiss_index_path(index_name: str): 46 | return index_path(index_name) / "faiss.index" 47 | 48 | 49 | def faiss_index_infos_path(index_name: str): 50 | return index_path(index_name) / "faiss_index_infos.json" 51 | 52 | 53 | def dr_state_path(index_name: str): 54 | return index_path(index_name) / "dr_state.npz" 55 | 56 | 57 | def hr_state_path(index_name: str): 58 | return index_path(index_name) / "hr_state.npz" 59 | 60 | 61 | def encoder_state_path(index_name: str): 62 | return index_path(index_name) / "encoder_state.json" 63 | 64 | 65 | def merger_state_path(index_name: str): 66 | return index_path(index_name) / "merger_state.npz" 67 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/sparse_retriever/__init__.py -------------------------------------------------------------------------------- /retriv/sparse_retriever/build_inverted_index.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Iterable 3 | 4 | import numpy as np 5 | from sklearn.feature_extraction.text import CountVectorizer 6 | from tqdm import tqdm 7 | 8 | 9 | def convert_df_matrix_into_inverted_index( 10 | df_matrix, vocabulary, show_progress: bool = True 11 | ): 12 | inverted_index = defaultdict(dict) 13 | 14 | for i, term in enumerate( 15 | tqdm( 16 | vocabulary, 17 | disable=not show_progress, 18 | desc="Building inverted index", 19 | dynamic_ncols=True, 20 | mininterval=0.5, 21 | ) 22 | ): 23 | inverted_index[term]["doc_ids"] = df_matrix[i].indices 24 | inverted_index[term]["tfs"] = df_matrix[i].data 25 | 26 | return inverted_index 27 | 28 | 29 | def build_inverted_index( 30 | collection: Iterable, 31 | n_docs: int, 32 | min_df: int = 1, 33 | show_progress: bool = True, 34 | ) -> Dict: 35 | vectorizer = CountVectorizer( 36 | tokenizer=lambda x: x, 37 | preprocessor=lambda x: x, 38 | min_df=min_df, 39 | dtype=np.int16, 40 | token_pattern=None, 41 | ) 42 | 43 | # [n_docs x n_terms] 44 | df_matrix = vectorizer.fit_transform( 45 | tqdm( 46 | collection, 47 | total=n_docs, 48 | disable=not show_progress, 49 | desc="Building TDF matrix", 50 | dynamic_ncols=True, 51 | mininterval=0.5, 52 | ) 53 | ) 54 | # [n_terms x n_docs] 55 | df_matrix = df_matrix.transpose().tocsr() 56 | vocabulary = vectorizer.get_feature_names_out() 57 | inverted_index = convert_df_matrix_into_inverted_index( 58 | df_matrix=df_matrix, 59 | vocabulary=vocabulary, 60 | show_progress=show_progress, 61 | ) 62 | 63 | doc_lens = np.squeeze(np.asarray(df_matrix.sum(axis=0), dtype=np.float32)) 64 | relative_doc_lens = doc_lens / np.mean(doc_lens, dtype=np.float32) 65 | 66 | return dict(inverted_index), doc_lens, relative_doc_lens 67 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "lowercasing", 3 | "normalize_acronyms", 4 | "normalize_ampersand", 5 | "normalize_special_chars", 6 | "remove_punctuation", 7 | "strip_whitespaces", 8 | "get_stemmer", 9 | "get_tokenizer", 10 | "get_stopwords", 11 | ] 12 | 13 | 14 | from typing import Callable, List, Set 15 | 16 | from multipipe import Multipipe 17 | 18 | from .normalization import ( 19 | lowercasing, 20 | normalize_acronyms, 21 | normalize_ampersand, 22 | normalize_special_chars, 23 | remove_punctuation, 24 | strip_whitespaces, 25 | ) 26 | from .stemmer import get_stemmer 27 | from .stopwords import get_stopwords 28 | from .tokenizer import get_tokenizer 29 | 30 | 31 | def preprocessing( 32 | x: str, 33 | tokenizer: Callable, 34 | stopwords: Set[str], 35 | stemmer: Callable, 36 | do_lowercasing: bool, 37 | do_ampersand_normalization: bool, 38 | do_special_chars_normalization: bool, 39 | do_acronyms_normalization: bool, 40 | do_punctuation_removal: bool, 41 | ) -> List[str]: 42 | if do_lowercasing: 43 | x = lowercasing(x) 44 | if do_ampersand_normalization: 45 | x = normalize_ampersand(x) 46 | if do_special_chars_normalization: 47 | x = normalize_special_chars(x) 48 | if do_acronyms_normalization: 49 | x = normalize_acronyms(x) 50 | 51 | if tokenizer == str.split and do_punctuation_removal: 52 | x = remove_punctuation(x) 53 | x = strip_whitespaces(x) 54 | 55 | x = tokenizer(x) 56 | 57 | if tokenizer != str.split and do_punctuation_removal: 58 | x = [remove_punctuation(t) for t in x] 59 | x = [t for t in x if t] 60 | 61 | x = [t for t in x if t not in stopwords] 62 | 63 | return [stemmer(t) for t in x] 64 | 65 | 66 | def preprocessing_multi( 67 | tokenizer: callable, 68 | stopwords: List[str], 69 | stemmer: callable, 70 | do_lowercasing: bool, 71 | do_ampersand_normalization: bool, 72 | do_special_chars_normalization: bool, 73 | do_acronyms_normalization: bool, 74 | do_punctuation_removal: bool, 75 | ): 76 | callables = [] 77 | 78 | if do_lowercasing: 79 | callables.append(lowercasing) 80 | if do_ampersand_normalization: 81 | callables.append(normalize_ampersand) 82 | if do_special_chars_normalization: 83 | callables.append(normalize_special_chars) 84 | if do_acronyms_normalization: 85 | callables.append(normalize_acronyms) 86 | if tokenizer == str.split and do_punctuation_removal: 87 | callables.append(remove_punctuation) 88 | callables.append(strip_whitespaces) 89 | 90 | callables.append(tokenizer) 91 | 92 | if tokenizer != str.split and do_punctuation_removal: 93 | 94 | def rp(x): 95 | x = [remove_punctuation(t) for t in x] 96 | return [t for t in x if t] 97 | 98 | callables.append(rp) 99 | 100 | def sw(x): 101 | return [t for t in x if t not in stopwords] 102 | 103 | callables.append(sw) 104 | 105 | def stem(x): 106 | return [stemmer(t) for t in x] 107 | 108 | callables.append(stem) 109 | 110 | return Multipipe(callables) 111 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/preprocessing/normalization.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | from unidecode import unidecode 5 | 6 | 7 | def lowercasing(x: str) -> str: 8 | return x.lower() 9 | 10 | 11 | def normalize_ampersand(x: str) -> str: 12 | return x.replace("&", " and ") 13 | 14 | 15 | def normalize_diacritics(x: str) -> str: 16 | return unidecode(x) 17 | 18 | 19 | def normalize_special_chars(x: str) -> str: 20 | special_chars_trans = dict( 21 | [(ord(x), ord(y)) for x, y in zip("‘’´“”–-", "'''\"\"--")] 22 | ) 23 | return x.translate(special_chars_trans) 24 | 25 | 26 | def normalize_acronyms(x: str) -> str: 27 | return re.sub(r"\.(?!(\S[^. ])|\d)", "", x) 28 | 29 | 30 | def remove_punctuation(x: str) -> str: 31 | translator = str.maketrans(string.punctuation, " " * len(string.punctuation)) 32 | return x.translate(translator) 33 | 34 | 35 | def strip_whitespaces(x: str) -> str: 36 | x = x.strip() 37 | 38 | while " " in x: 39 | x = x.replace(" ", " ") 40 | 41 | return x 42 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/preprocessing/stemmer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Union 3 | 4 | import nltk 5 | from krovetzstemmer import Stemmer as KrovetzStemmer 6 | from Stemmer import Stemmer as SnowballStemmer 7 | 8 | from .utils import identity_function 9 | 10 | stemmers_dict = { 11 | "krovetz": partial(KrovetzStemmer()), 12 | "porter": partial(nltk.stem.PorterStemmer().stem), 13 | "lancaster": partial(nltk.stem.LancasterStemmer().stem), 14 | "arlstem": partial(nltk.stem.ARLSTem().stem), # Arabic 15 | "arlstem2": partial(nltk.stem.ARLSTem2().stem), # Arabic 16 | "cistem": partial(nltk.stem.Cistem().stem), # German 17 | "isri": partial(nltk.stem.ISRIStemmer().stem), # Arabic 18 | "arabic": partial(SnowballStemmer("arabic").stemWord), 19 | "basque": partial(SnowballStemmer("basque").stemWord), 20 | "catalan": partial(SnowballStemmer("catalan").stemWord), 21 | "danish": partial(SnowballStemmer("danish").stemWord), 22 | "dutch": partial(SnowballStemmer("dutch").stemWord), 23 | "english": partial(nltk.stem.SnowballStemmer("english").stem), 24 | "finnish": partial(SnowballStemmer("finnish").stemWord), 25 | "french": partial(SnowballStemmer("french").stemWord), 26 | "german": partial(SnowballStemmer("german").stemWord), 27 | "greek": partial(SnowballStemmer("greek").stemWord), 28 | "hindi": partial(SnowballStemmer("hindi").stemWord), 29 | "hungarian": partial(SnowballStemmer("hungarian").stemWord), 30 | "indonesian": partial(SnowballStemmer("indonesian").stemWord), 31 | "irish": partial(SnowballStemmer("irish").stemWord), 32 | "italian": partial(SnowballStemmer("italian").stemWord), 33 | "lithuanian": partial(SnowballStemmer("lithuanian").stemWord), 34 | "nepali": partial(SnowballStemmer("nepali").stemWord), 35 | "norwegian": partial(SnowballStemmer("norwegian").stemWord), 36 | "portuguese": partial(SnowballStemmer("portuguese").stemWord), 37 | "romanian": partial(SnowballStemmer("romanian").stemWord), 38 | "russian": partial(SnowballStemmer("russian").stemWord), 39 | "spanish": partial(SnowballStemmer("spanish").stemWord), 40 | "swedish": partial(SnowballStemmer("swedish").stemWord), 41 | "tamil": partial(SnowballStemmer("tamil").stemWord), 42 | "turkish": partial(SnowballStemmer("turkish").stemWord), 43 | } 44 | 45 | 46 | def krovetz_f(x: str) -> str: 47 | return stemmers_dict["krovetz"](x) 48 | 49 | 50 | def porter_f(x: str) -> str: 51 | return stemmers_dict["porter"](x) 52 | 53 | 54 | def lancaster_f(x: str) -> str: 55 | return stemmers_dict["lancaster"](x) 56 | 57 | 58 | def arlstem_f(x: str) -> str: 59 | return stemmers_dict["arlstem"](x) 60 | 61 | 62 | def arlstem2_f(x: str) -> str: 63 | return stemmers_dict["arlstem2"](x) 64 | 65 | 66 | def cistem_f(x: str) -> str: 67 | return stemmers_dict["cistem"](x) 68 | 69 | 70 | def isri_f(x: str) -> str: 71 | return stemmers_dict["isri"](x) 72 | 73 | 74 | def arabic_f(x: str) -> str: 75 | return stemmers_dict["arabic"](x) 76 | 77 | 78 | def basque_f(x: str) -> str: 79 | return stemmers_dict["basque"](x) 80 | 81 | 82 | def catalan_f(x: str) -> str: 83 | return stemmers_dict["catalan"](x) 84 | 85 | 86 | def danish_f(x: str) -> str: 87 | return stemmers_dict["danish"](x) 88 | 89 | 90 | def dutch_f(x: str) -> str: 91 | return stemmers_dict["dutch"](x) 92 | 93 | 94 | def english_f(x: str) -> str: 95 | return stemmers_dict["english"](x) 96 | 97 | 98 | def finnish_f(x: str) -> str: 99 | return stemmers_dict["finnish"](x) 100 | 101 | 102 | def french_f(x: str) -> str: 103 | return stemmers_dict["french"](x) 104 | 105 | 106 | def german_f(x: str) -> str: 107 | return stemmers_dict["german"](x) 108 | 109 | 110 | def greek_f(x: str) -> str: 111 | return stemmers_dict["greek"](x) 112 | 113 | 114 | def hindi_f(x: str) -> str: 115 | return stemmers_dict["hindi"](x) 116 | 117 | 118 | def hungarian_f(x: str) -> str: 119 | return stemmers_dict["hungarian"](x) 120 | 121 | 122 | def indonesian_f(x: str) -> str: 123 | return stemmers_dict["indonesian"](x) 124 | 125 | 126 | def irish_f(x: str) -> str: 127 | return stemmers_dict["irish"](x) 128 | 129 | 130 | def italian_f(x: str) -> str: 131 | return stemmers_dict["italian"](x) 132 | 133 | 134 | def lithuanian_f(x: str) -> str: 135 | return stemmers_dict["lithuanian"](x) 136 | 137 | 138 | def nepali_f(x: str) -> str: 139 | return stemmers_dict["nepali"](x) 140 | 141 | 142 | def norwegian_f(x: str) -> str: 143 | return stemmers_dict["norwegian"](x) 144 | 145 | 146 | def portuguese_f(x: str) -> str: 147 | return stemmers_dict["portuguese"](x) 148 | 149 | 150 | def romanian_f(x: str) -> str: 151 | return stemmers_dict["romanian"](x) 152 | 153 | 154 | def russian_f(x: str) -> str: 155 | return stemmers_dict["russian"](x) 156 | 157 | 158 | def spanish_f(x: str) -> str: 159 | return stemmers_dict["spanish"](x) 160 | 161 | 162 | def swedish_f(x: str) -> str: 163 | return stemmers_dict["swedish"](x) 164 | 165 | 166 | def tamil_f(x: str) -> str: 167 | return stemmers_dict["tamil"](x) 168 | 169 | 170 | def turkish_f(x: str) -> str: 171 | return stemmers_dict["turkish"](x) 172 | 173 | 174 | stemmers_f_dict = { 175 | "krovetz": krovetz_f, 176 | "porter": porter_f, 177 | "lancaster": lancaster_f, 178 | "arlstem": arlstem_f, 179 | "arlstem2": arlstem2_f, 180 | "cistem": cistem_f, 181 | "isri": isri_f, 182 | "arabic": arabic_f, 183 | "basque": basque_f, 184 | "catalan": catalan_f, 185 | "danish": danish_f, 186 | "dutch": dutch_f, 187 | "english": english_f, 188 | "finnish": finnish_f, 189 | "french": french_f, 190 | "german": german_f, 191 | "greek": greek_f, 192 | "hindi": hindi_f, 193 | "hungarian": hungarian_f, 194 | "indonesian": indonesian_f, 195 | "irish": irish_f, 196 | "italian": italian_f, 197 | "lithuanian": lithuanian_f, 198 | "nepali": nepali_f, 199 | "norwegian": norwegian_f, 200 | "portuguese": portuguese_f, 201 | "romanian": romanian_f, 202 | "russian": russian_f, 203 | "spanish": spanish_f, 204 | "swedish": swedish_f, 205 | "tamil": tamil_f, 206 | "turkish": turkish_f, 207 | } 208 | 209 | 210 | def _get_stemmer(stemmer: str) -> callable: 211 | assert stemmer.lower() in stemmers_f_dict, f"Stemmer {stemmer} not supported." 212 | return stemmers_f_dict[stemmer.lower()] 213 | 214 | 215 | def get_stemmer(stemmer: Union[str, callable, bool]) -> callable: 216 | if isinstance(stemmer, str): 217 | return _get_stemmer(stemmer) 218 | elif callable(stemmer): 219 | return stemmer 220 | elif stemmer is None: 221 | return identity_function 222 | else: 223 | raise (NotImplementedError) 224 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/preprocessing/stopwords.py: -------------------------------------------------------------------------------- 1 | from typing import List, Set, Union 2 | 3 | import nltk 4 | 5 | supported_languages = { 6 | "arabic", 7 | "azerbaijani", 8 | "basque", 9 | "bengali", 10 | "catalan", 11 | "chinese", 12 | "danish", 13 | "dutch", 14 | "english", 15 | "finnish", 16 | "french", 17 | "german", 18 | "greek", 19 | "hebrew", 20 | "hinglish", 21 | "hungarian", 22 | "indonesian", 23 | "italian", 24 | "kazakh", 25 | "nepali", 26 | "norwegian", 27 | "portuguese", 28 | "romanian", 29 | "russian", 30 | "slovene", 31 | "spanish", 32 | "swedish", 33 | "tajik", 34 | "turkish", 35 | } 36 | 37 | 38 | def _get_stopwords(lang: str) -> List[str]: 39 | nltk.download("stopwords", quiet=True) 40 | assert ( 41 | lang.lower() in supported_languages 42 | ), f"Stop-words for {lang.capitalize()} are not available." 43 | return nltk.corpus.stopwords.words(lang) 44 | 45 | 46 | def get_stopwords(sw_list: Union[str, List[str], Set[str], bool]) -> List[str]: 47 | if isinstance(sw_list, str): 48 | return _get_stopwords(sw_list) 49 | elif type(sw_list) is list and all(isinstance(x, str) for x in sw_list): 50 | return sw_list 51 | elif type(sw_list) is set: 52 | return list(sw_list) 53 | elif sw_list is None: 54 | return [] 55 | else: 56 | raise (NotImplementedError) 57 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/preprocessing/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import nltk 4 | 5 | from .utils import identity_function 6 | 7 | tokenizers_dict = { 8 | "whitespace": str.split, 9 | "word": nltk.tokenize.word_tokenize, 10 | "wordpunct": nltk.tokenize.wordpunct_tokenize, 11 | "sent": nltk.tokenize.sent_tokenize, 12 | } 13 | 14 | 15 | def _get_tokenizer(tokenizer: str) -> callable: 16 | assert tokenizer.lower() in tokenizers_dict, f"Tokenizer {tokenizer} not supported." 17 | if tokenizer == "punkt": 18 | nltk.download("punkt", quiet=True) 19 | return tokenizers_dict[tokenizer.lower()] 20 | 21 | 22 | def get_tokenizer(tokenizer: Union[str, callable, bool]) -> callable: 23 | if isinstance(tokenizer, str): 24 | return _get_tokenizer(tokenizer) 25 | elif callable(tokenizer): 26 | return tokenizer 27 | elif tokenizer is None: 28 | return identity_function 29 | else: 30 | raise (NotImplementedError) 31 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | def identity_function(x): 2 | return x 3 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/sparse_retrieval_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/sparse_retriever/sparse_retrieval_models/__init__.py -------------------------------------------------------------------------------- /retriv/sparse_retriever/sparse_retrieval_models/bm25.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numba as nb 4 | import numpy as np 5 | from numba import njit, prange 6 | from numba.typed import List as TypedList 7 | 8 | from ...utils.numba_utils import ( 9 | intersect_sorted, 10 | intersect_sorted_multi, 11 | union_sorted_multi, 12 | unsorted_top_k, 13 | ) 14 | 15 | 16 | @njit(cache=True) 17 | def bm25( 18 | b: float, 19 | k1: float, 20 | term_doc_freqs: nb.typed.List[np.ndarray], 21 | doc_ids: nb.typed.List[np.ndarray], 22 | relative_doc_lens: nb.typed.List[np.ndarray], 23 | doc_count: int, 24 | cutoff: int, 25 | operator: str = "OR", 26 | subset_doc_ids: np.ndarray = None, 27 | ) -> Tuple[np.ndarray]: 28 | if operator == "AND": 29 | unique_doc_ids = intersect_sorted_multi(doc_ids) 30 | elif operator == "OR": 31 | unique_doc_ids = union_sorted_multi(doc_ids) 32 | 33 | if subset_doc_ids is not None: 34 | unique_doc_ids = intersect_sorted(unique_doc_ids, subset_doc_ids) 35 | 36 | scores = np.empty(doc_count, dtype=np.float32) 37 | scores[unique_doc_ids] = 0.0 # Initialize scores 38 | 39 | for i in range(len(term_doc_freqs)): 40 | indices = doc_ids[i] 41 | freqs = term_doc_freqs[i] 42 | 43 | df = np.float32(len(indices)) 44 | idf = np.float32(np.log(1.0 + (((doc_count - df) + 0.5) / (df + 0.5)))) 45 | 46 | scores[indices] += idf * ( 47 | (freqs * (k1 + 1.0)) 48 | / (freqs + k1 * (1.0 - b + (b * relative_doc_lens[indices]))) 49 | ) 50 | 51 | scores = scores[unique_doc_ids] 52 | 53 | if cutoff < len(scores): 54 | scores, indices = unsorted_top_k(scores, cutoff) 55 | unique_doc_ids = unique_doc_ids[indices] 56 | 57 | indices = np.argsort(-scores) 58 | 59 | return unique_doc_ids[indices], scores[indices] 60 | 61 | 62 | @njit(cache=True, parallel=True) 63 | def bm25_multi( 64 | b: float, 65 | k1: float, 66 | term_doc_freqs: nb.typed.List[nb.typed.List[np.ndarray]], 67 | doc_ids: nb.typed.List[nb.typed.List[np.ndarray]], 68 | relative_doc_lens: nb.typed.List[np.ndarray], 69 | doc_count: int, 70 | cutoff: int, 71 | ) -> Tuple[nb.typed.List[np.ndarray]]: 72 | unique_doc_ids = TypedList([np.empty(1, dtype=np.int32) for _ in doc_ids]) 73 | scores = TypedList([np.empty(1, dtype=np.float32) for _ in doc_ids]) 74 | 75 | for i in prange(len(term_doc_freqs)): 76 | _term_doc_freqs = term_doc_freqs[i] 77 | _doc_ids = doc_ids[i] 78 | 79 | _unique_doc_ids = union_sorted_multi(_doc_ids) 80 | 81 | _scores = np.empty(doc_count, dtype=np.float32) 82 | _scores[_unique_doc_ids] = 0.0 # Initialize _scores 83 | 84 | for j in range(len(_term_doc_freqs)): 85 | indices = _doc_ids[j] 86 | freqs = _term_doc_freqs[j] 87 | 88 | df = np.float32(len(indices)) 89 | idf = np.float32(np.log(1.0 + (((doc_count - df) + 0.5) / (df + 0.5)))) 90 | 91 | _scores[indices] += idf * ( 92 | (freqs * (k1 + 1.0)) 93 | / (freqs + k1 * (1.0 - b + (b * relative_doc_lens[indices]))) 94 | ) 95 | 96 | _scores = _scores[_unique_doc_ids] 97 | 98 | if cutoff < len(_scores): 99 | _scores, indices = unsorted_top_k(_scores, cutoff) 100 | _unique_doc_ids = _unique_doc_ids[indices] 101 | 102 | indices = np.argsort(_scores)[::-1] 103 | 104 | unique_doc_ids[i] = _unique_doc_ids[indices] 105 | scores[i] = _scores[indices] 106 | 107 | return unique_doc_ids, scores 108 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/sparse_retrieval_models/tf_idf.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numba as nb 4 | import numpy as np 5 | from numba import njit, prange 6 | from numba.typed import List as TypedList 7 | 8 | from ...utils.numba_utils import ( 9 | intersect_sorted, 10 | intersect_sorted_multi, 11 | union_sorted_multi, 12 | unsorted_top_k, 13 | ) 14 | 15 | 16 | @njit(cache=True) 17 | def tf_idf( 18 | term_doc_freqs: nb.typed.List[np.ndarray], 19 | doc_ids: nb.typed.List[np.ndarray], 20 | doc_lens: nb.typed.List[np.ndarray], 21 | cutoff: int, 22 | operator: str = "OR", 23 | subset_doc_ids: np.ndarray = None, 24 | ) -> Tuple[np.ndarray]: 25 | if operator == "AND": 26 | unique_doc_ids = intersect_sorted_multi(doc_ids) 27 | elif operator == "OR": 28 | unique_doc_ids = union_sorted_multi(doc_ids) 29 | 30 | if subset_doc_ids is not None: 31 | unique_doc_ids = intersect_sorted(unique_doc_ids, subset_doc_ids) 32 | 33 | doc_count = len(doc_lens) 34 | scores = np.empty(doc_count, dtype=np.float32) 35 | scores[unique_doc_ids] = 0.0 # Initialize scores 36 | 37 | for i in range(len(term_doc_freqs)): 38 | indices = doc_ids[i] 39 | freqs = term_doc_freqs[i] 40 | 41 | tf = freqs / doc_lens[indices] 42 | 43 | df = np.float32(len(indices)) 44 | idf = np.float32(np.log((1.0 + doc_count) / (1.0 + df)) + 1.0) 45 | 46 | scores[indices] += tf * idf 47 | 48 | scores = scores[unique_doc_ids] 49 | 50 | if cutoff < len(scores): 51 | scores, indices = unsorted_top_k(scores, cutoff) 52 | unique_doc_ids = unique_doc_ids[indices] 53 | 54 | indices = np.argsort(-scores) 55 | 56 | return unique_doc_ids[indices], scores[indices] 57 | 58 | 59 | @njit(cache=True, parallel=True) 60 | def tf_idf_multi( 61 | term_doc_freqs: nb.typed.List[nb.typed.List[np.ndarray]], 62 | doc_ids: nb.typed.List[nb.typed.List[np.ndarray]], 63 | doc_lens: nb.typed.List[nb.typed.List[np.ndarray]], 64 | cutoff: int, 65 | ) -> Tuple[nb.typed.List[np.ndarray]]: 66 | unique_doc_ids = TypedList([np.empty(1, dtype=np.int32) for _ in doc_ids]) 67 | scores = TypedList([np.empty(1, dtype=np.float32) for _ in doc_ids]) 68 | 69 | for i in prange(len(term_doc_freqs)): 70 | _term_doc_freqs = term_doc_freqs[i] 71 | _doc_ids = doc_ids[i] 72 | 73 | _unique_doc_ids = union_sorted_multi(_doc_ids) 74 | 75 | doc_count = len(doc_lens) 76 | _scores = np.empty(doc_count, dtype=np.float32) 77 | _scores[_unique_doc_ids] = 0.0 # Initialize _scores 78 | 79 | for j in range(len(_term_doc_freqs)): 80 | indices = _doc_ids[j] 81 | freqs = _term_doc_freqs[j] 82 | 83 | tf = freqs / doc_lens[indices] 84 | 85 | df = np.float32(len(indices)) 86 | idf = np.float32(np.log((1.0 + doc_count) / (1.0 + df)) + 1.0) 87 | 88 | _scores[indices] += tf * idf 89 | 90 | _scores = _scores[_unique_doc_ids] 91 | 92 | if cutoff < len(_scores): 93 | _scores, indices = unsorted_top_k(_scores, cutoff) 94 | _unique_doc_ids = _unique_doc_ids[indices] 95 | 96 | indices = np.argsort(_scores)[::-1] 97 | 98 | unique_doc_ids[i] = _unique_doc_ids[indices] 99 | scores[i] = _scores[indices] 100 | 101 | return unique_doc_ids, scores 102 | -------------------------------------------------------------------------------- /retriv/sparse_retriever/sparse_retriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Iterable, List, Set, Union 3 | 4 | import numba as nb 5 | import numpy as np 6 | import orjson 7 | from numba.typed import List as TypedList 8 | from oneliner_utils import create_path, read_jsonl 9 | from tqdm import tqdm 10 | 11 | from ..autotune import tune_bm25 12 | from ..base_retriever import BaseRetriever 13 | from ..paths import docs_path, sr_state_path 14 | from .build_inverted_index import build_inverted_index 15 | from .preprocessing import ( 16 | get_stemmer, 17 | get_stopwords, 18 | get_tokenizer, 19 | preprocessing, 20 | preprocessing_multi, 21 | ) 22 | from .sparse_retrieval_models.bm25 import bm25, bm25_multi 23 | from .sparse_retrieval_models.tf_idf import tf_idf, tf_idf_multi 24 | 25 | 26 | class SparseRetriever(BaseRetriever): 27 | def __init__( 28 | self, 29 | index_name: str = "new-index", 30 | model: str = "bm25", 31 | min_df: int = 1, 32 | tokenizer: Union[str, callable] = "whitespace", 33 | stemmer: Union[str, callable] = "english", 34 | stopwords: Union[str, List[str], Set[str]] = "english", 35 | do_lowercasing: bool = True, 36 | do_ampersand_normalization: bool = True, 37 | do_special_chars_normalization: bool = True, 38 | do_acronyms_normalization: bool = True, 39 | do_punctuation_removal: bool = True, 40 | hyperparams: dict = None, 41 | ): 42 | """The Sparse Retriever is a traditional searcher based on lexical matching. It supports BM25, the retrieval model used by major search engines libraries, such as Lucene and Elasticsearch. retriv also implements the classic relevance model TF-IDF for educational purposes. 43 | 44 | Args: 45 | index_name (str, optional): [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. Defaults to "new-index". 46 | 47 | model (str, optional): defines the retrieval model to use for searching (`bm25` or `tf-idf`). Defaults to "bm25". 48 | 49 | min_df (int, optional): terms that appear in less than `min_df` documents will be ignored. If integer, the parameter indicates the absolute count. If float, it represents a proportion of documents. Defaults to 1. 50 | 51 | tokenizer (Union[str, callable], optional): [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization by setting the parameter to `None`. Defaults to "whitespace". 52 | 53 | stemmer (Union[str, callable], optional): [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`. Defaults to "english". 54 | 55 | stopwords (Union[str, List[str], Set[str]], optional): [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`. Defaults to "english". 56 | 57 | do_lowercasing (bool, optional): whether or not to lowercase texts. Defaults to True. 58 | 59 | do_ampersand_normalization (bool, optional): whether to convert `&` in `and` during pre-processing. Defaults to True. 60 | 61 | do_special_chars_normalization (bool, optional): whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`. Defaults to True. 62 | 63 | do_acronyms_normalization (bool, optional): whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`. Defaults to True. 64 | 65 | do_punctuation_removal (bool, optional): whether to remove punctuation. Defaults to True. 66 | 67 | hyperparams (dict, optional): Retrieval model hyperparams. If `None`, it is automatically set to `{b: 0.75, k1: 1.2}`. Defaults to None. 68 | """ 69 | 70 | assert model.lower() in {"bm25", "tf-idf"} 71 | assert min_df > 0, "`min_df` must be greater than zero." 72 | self.init_args = { 73 | "model": model.lower(), 74 | "min_df": min_df, 75 | "index_name": index_name, 76 | "do_lowercasing": do_lowercasing, 77 | "do_ampersand_normalization": do_ampersand_normalization, 78 | "do_special_chars_normalization": do_special_chars_normalization, 79 | "do_acronyms_normalization": do_acronyms_normalization, 80 | "do_punctuation_removal": do_punctuation_removal, 81 | "tokenizer": tokenizer, 82 | "stemmer": stemmer, 83 | "stopwords": stopwords, 84 | } 85 | 86 | self.model = model.lower() 87 | self.min_df = min_df 88 | self.index_name = index_name 89 | 90 | self.do_lowercasing = do_lowercasing 91 | self.do_ampersand_normalization = do_ampersand_normalization 92 | self.do_special_chars_normalization = do_special_chars_normalization 93 | self.do_acronyms_normalization = do_acronyms_normalization 94 | self.do_punctuation_removal = do_punctuation_removal 95 | 96 | self.tokenizer = get_tokenizer(tokenizer) 97 | self.stemmer = get_stemmer(stemmer) 98 | self.stopwords = [self.stemmer(sw) for sw in get_stopwords(stopwords)] 99 | 100 | self.id_mapping = None 101 | self.inverted_index = None 102 | self.vocabulary = None 103 | self.doc_count = None 104 | self.doc_lens = None 105 | self.avg_doc_len = None 106 | self.relative_doc_lens = None 107 | self.doc_index = None 108 | 109 | self.preprocessing_kwargs = { 110 | "tokenizer": self.tokenizer, 111 | "stemmer": self.stemmer, 112 | "stopwords": self.stopwords, 113 | "do_lowercasing": self.do_lowercasing, 114 | "do_ampersand_normalization": self.do_ampersand_normalization, 115 | "do_special_chars_normalization": self.do_special_chars_normalization, 116 | "do_acronyms_normalization": self.do_acronyms_normalization, 117 | "do_punctuation_removal": self.do_punctuation_removal, 118 | } 119 | 120 | self.preprocessing_pipe = preprocessing_multi(**self.preprocessing_kwargs) 121 | 122 | self.hyperparams = dict(b=0.75, k1=1.2) if hyperparams is None else hyperparams 123 | 124 | def save(self) -> None: 125 | """Save the state of the retriever to be able to restore it later.""" 126 | 127 | state = { 128 | "init_args": self.init_args, 129 | "id_mapping": self.id_mapping, 130 | "doc_count": self.doc_count, 131 | "inverted_index": self.inverted_index, 132 | "vocabulary": self.vocabulary, 133 | "doc_lens": self.doc_lens, 134 | "relative_doc_lens": self.relative_doc_lens, 135 | "hyperparams": self.hyperparams, 136 | } 137 | 138 | np.savez_compressed(sr_state_path(self.index_name), state=state) 139 | 140 | @staticmethod 141 | def load(index_name: str = "new-index"): 142 | """Load a retriever and its index. 143 | 144 | Args: 145 | index_name (str, optional): Name of the index. Defaults to "new-index". 146 | 147 | Returns: 148 | SparseRetriever: Sparse Retriever. 149 | """ 150 | 151 | state = np.load(sr_state_path(index_name), allow_pickle=True)["state"][()] 152 | 153 | se = SparseRetriever(**state["init_args"]) 154 | se.initialize_doc_index() 155 | se.id_mapping = state["id_mapping"] 156 | se.doc_count = state["doc_count"] 157 | se.inverted_index = state["inverted_index"] 158 | se.vocabulary = set(se.inverted_index) 159 | se.doc_lens = state["doc_lens"] 160 | se.relative_doc_lens = state["relative_doc_lens"] 161 | se.hyperparams = state["hyperparams"] 162 | 163 | state = { 164 | "init_args": se.init_args, 165 | "id_mapping": se.id_mapping, 166 | "doc_count": se.doc_count, 167 | "inverted_index": se.inverted_index, 168 | "vocabulary": se.vocabulary, 169 | "doc_lens": se.doc_lens, 170 | "relative_doc_lens": se.relative_doc_lens, 171 | "hyperparams": se.hyperparams, 172 | } 173 | 174 | return se 175 | 176 | def index_aux(self, show_progress: bool = True): 177 | """Internal usage.""" 178 | collection = read_jsonl( 179 | docs_path(self.index_name), 180 | generator=True, 181 | callback=lambda x: x["text"], 182 | ) 183 | 184 | # Preprocessing -------------------------------------------------------- 185 | collection = self.preprocessing_pipe(collection, generator=True) 186 | 187 | # Inverted index ------------------------------------------------------- 188 | ( 189 | self.inverted_index, 190 | self.doc_lens, 191 | self.relative_doc_lens, 192 | ) = build_inverted_index( 193 | collection=collection, 194 | n_docs=self.doc_count, 195 | min_df=self.min_df, 196 | show_progress=show_progress, 197 | ) 198 | self.avg_doc_len = np.mean(self.doc_lens, dtype=np.float32) 199 | self.vocabulary = set(self.inverted_index) 200 | 201 | def index( 202 | self, 203 | collection: Iterable, 204 | callback: callable = None, 205 | show_progress: bool = True, 206 | ): 207 | """Index a given collection of documents. 208 | 209 | Args: 210 | collection (Iterable): collection of documents to index. 211 | 212 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None. 213 | 214 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True. 215 | 216 | Returns: 217 | SparseRetriever: Sparse Retriever. 218 | """ 219 | 220 | self.save_collection(collection, callback) 221 | self.initialize_doc_index() 222 | self.initialize_id_mapping() 223 | self.doc_count = len(self.id_mapping) 224 | self.index_aux(show_progress) 225 | self.save() 226 | return self 227 | 228 | def index_file( 229 | self, path: str, callback: callable = None, show_progress: bool = True 230 | ): 231 | """Index the collection contained in a given file. 232 | 233 | Args: 234 | path (str): path of file containing the collection to index. 235 | 236 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None. 237 | 238 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True. 239 | 240 | Returns: 241 | SparseRetriever: Sparse Retriever 242 | """ 243 | 244 | collection = self.collection_generator(path=path, callback=callback) 245 | return self.index(collection=collection, show_progress=show_progress) 246 | 247 | # SEARCH =================================================================== 248 | def query_preprocessing(self, query: str) -> List[str]: 249 | """Internal usage.""" 250 | return preprocessing(query, **self.preprocessing_kwargs) 251 | 252 | def get_term_doc_freqs(self, query_terms: List[str]) -> nb.types.List: 253 | """Internal usage.""" 254 | return TypedList([self.inverted_index[t]["tfs"] for t in query_terms]) 255 | 256 | def get_doc_ids(self, query_terms: List[str]) -> nb.types.List: 257 | """Internal usage.""" 258 | return TypedList([self.inverted_index[t]["doc_ids"] for t in query_terms]) 259 | 260 | def search(self, query: str, return_docs: bool = True, cutoff: int = 100) -> List: 261 | """Standard search functionality. 262 | 263 | Args: 264 | query (str): what to search for. 265 | 266 | return_docs (bool, optional): wether to return the texts of the documents. Defaults to True. 267 | 268 | cutoff (int, optional): number of results to return. Defaults to 100. 269 | 270 | Returns: 271 | List: results. 272 | """ 273 | 274 | query_terms = self.query_preprocessing(query) 275 | if not query_terms: 276 | return {} 277 | query_terms = [t for t in query_terms if t in self.vocabulary] 278 | if not query_terms: 279 | return {} 280 | 281 | doc_ids = self.get_doc_ids(query_terms) 282 | term_doc_freqs = self.get_term_doc_freqs(query_terms) 283 | 284 | if self.model == "bm25": 285 | unique_doc_ids, scores = bm25( 286 | term_doc_freqs=term_doc_freqs, 287 | doc_ids=doc_ids, 288 | relative_doc_lens=self.relative_doc_lens, 289 | doc_count=self.doc_count, 290 | cutoff=cutoff, 291 | **self.hyperparams, 292 | ) 293 | elif self.model == "tf-idf": 294 | unique_doc_ids, scores = tf_idf( 295 | term_doc_freqs=term_doc_freqs, 296 | doc_ids=doc_ids, 297 | doc_lens=self.doc_lens, 298 | cutoff=cutoff, 299 | ) 300 | else: 301 | raise NotImplementedError() 302 | 303 | unique_doc_ids = self.map_internal_ids_to_original_ids(unique_doc_ids) 304 | 305 | if not return_docs: 306 | return dict(zip(unique_doc_ids, scores)) 307 | 308 | return self.prepare_results(unique_doc_ids, scores) 309 | 310 | def msearch(self, queries: List[Dict[str, str]], cutoff: int = 100) -> Dict: 311 | """Compute results for multiple queries at once. 312 | 313 | Args: 314 | queries (List[Dict[str, str]]): what to search for. 315 | 316 | cutoff (int, optional): number of results to return. Defaults to 100. 317 | 318 | Returns: 319 | Dict: results. 320 | """ 321 | 322 | term_doc_freqs = TypedList() 323 | doc_ids = TypedList() 324 | q_ids = [] 325 | no_results_q_ids = [] 326 | 327 | for q in queries: 328 | q_id, query = q["id"], q["text"] 329 | query_terms = self.query_preprocessing(query) 330 | query_terms = [t for t in query_terms if t in self.vocabulary] 331 | if not query_terms: 332 | no_results_q_ids.append(q_id) 333 | continue 334 | 335 | if all(t not in self.inverted_index for t in query_terms): 336 | no_results_q_ids.append(q_id) 337 | continue 338 | 339 | q_ids.append(q_id) 340 | term_doc_freqs.append(self.get_term_doc_freqs(query_terms)) 341 | doc_ids.append(self.get_doc_ids(query_terms)) 342 | 343 | if not q_ids: 344 | return {q_id: {} for q_id in [q["id"] for q in queries]} 345 | 346 | if self.model == "bm25": 347 | unique_doc_ids, scores = bm25_multi( 348 | term_doc_freqs=term_doc_freqs, 349 | doc_ids=doc_ids, 350 | relative_doc_lens=self.relative_doc_lens, 351 | doc_count=self.doc_count, 352 | cutoff=cutoff, 353 | **self.hyperparams, 354 | ) 355 | elif self.model == "tf-idf": 356 | unique_doc_ids, scores = tf_idf_multi( 357 | term_doc_freqs=term_doc_freqs, 358 | doc_ids=doc_ids, 359 | doc_lens=self.doc_lens, 360 | cutoff=cutoff, 361 | ) 362 | else: 363 | raise NotImplementedError() 364 | 365 | unique_doc_ids = [ 366 | self.map_internal_ids_to_original_ids(_unique_doc_ids) 367 | for _unique_doc_ids in unique_doc_ids 368 | ] 369 | 370 | results = { 371 | q: dict(zip(unique_doc_ids[i], scores[i])) for i, q in enumerate(q_ids) 372 | } 373 | 374 | for q_id in no_results_q_ids: 375 | results[q_id] = {} 376 | 377 | # Order as queries 378 | return {q_id: results[q_id] for q_id in [q["id"] for q in queries]} 379 | 380 | def bsearch( 381 | self, 382 | queries: List[Dict[str, str]], 383 | cutoff: int = 100, 384 | batch_size: int = 1_000, 385 | show_progress: bool = True, 386 | qrels: Dict[str, Dict[str, float]] = None, 387 | path: str = None, 388 | ) -> Dict: 389 | """Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. bsearch is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM. 390 | 391 | Args: 392 | queries (List[Dict[str, str]]): what to search for. 393 | 394 | cutoff (int, optional): number of results to return. Defaults to 100. 395 | 396 | batch_size (int, optional): number of query to perform simultaneously. Defaults to 1_000. 397 | 398 | show_progress (bool, optional): whether to show a progress bar for the search process. Defaults to True. 399 | 400 | qrels (Dict[str, Dict[str, float]], optional): query relevance judgements for the queries. Defaults to None. 401 | 402 | path (str, optional): where to save the results. Defaults to None. 403 | 404 | Returns: 405 | Dict: results. 406 | """ 407 | 408 | batches = [ 409 | queries[i : i + batch_size] for i in range(0, len(queries), batch_size) 410 | ] 411 | 412 | results = {} 413 | 414 | pbar = tqdm( 415 | total=len(queries), 416 | disable=not show_progress, 417 | desc="Batch search", 418 | dynamic_ncols=True, 419 | mininterval=0.5, 420 | ) 421 | 422 | if path is None: 423 | for batch in batches: 424 | new_results = self.msearch(queries=batch, cutoff=cutoff) 425 | results = {**results, **new_results} 426 | pbar.update(min(batch_size, len(batch))) 427 | else: 428 | path = create_path(path) 429 | path.parent.mkdir(parents=True, exist_ok=True) 430 | 431 | with open(path, "wb") as f: 432 | for batch in batches: 433 | new_results = self.msearch(queries=batch, cutoff=cutoff) 434 | 435 | for i, (k, v) in enumerate(new_results.items()): 436 | x = { 437 | "id": k, 438 | "text": batch[i]["text"], 439 | f"{self.model}_doc_ids": list(v.keys()), 440 | f"{self.model}_scores": [ 441 | float(s) for s in list(v.values()) 442 | ], 443 | } 444 | if qrels is not None: 445 | x["rel_doc_ids"] = list(qrels[k].keys()) 446 | x["rel_scores"] = list(qrels[k].values()) 447 | f.write(orjson.dumps(x) + "\n".encode()) 448 | 449 | pbar.update(min(batch_size, len(batch))) 450 | 451 | return results 452 | 453 | def autotune( 454 | self, 455 | queries: List[Dict[str, str]], 456 | qrels: Dict[str, Dict[str, float]], 457 | metric: str = "ndcg", 458 | n_trials: int = 100, 459 | cutoff: int = 100, 460 | ): 461 | """Use the AutoTune function to tune [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) parameters w.r.t. your document collection and queries. 462 | All metrics supported by [ranx](https://github.com/AmenRa/ranx) are supported by the `autotune` function. At the of the process, the best parameter configuration is automatically applied to the `SparseRetriever` instance and saved to disk. You can inspect the current configuration by printing `sr.hyperparams`. 463 | 464 | Args: 465 | queries (List[Dict[str, str]]): queries to use for the optimization process. 466 | 467 | qrels (Dict[str, Dict[str, float]]): query relevance judgements for the queries. 468 | 469 | metric (str, optional): metric to optimize for. Defaults to "ndcg". 470 | 471 | n_trials (int, optional): number of configuration to evaluate. Defaults to 100. 472 | 473 | cutoff (int, optional): number of results to consider for the optimization process. Defaults to 100. 474 | """ 475 | 476 | hyperparams = tune_bm25( 477 | queries=queries, 478 | qrels=qrels, 479 | se=self, 480 | metric=metric, 481 | n_trials=n_trials, 482 | cutoff=cutoff, 483 | ) 484 | self.hyperparams = hyperparams 485 | self.save() 486 | -------------------------------------------------------------------------------- /retriv/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/utils/__init__.py -------------------------------------------------------------------------------- /retriv/utils/numba_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numba import njit 3 | 4 | 5 | # UNION ------------------------------------------------------------------------ 6 | @njit(cache=True) 7 | def union_sorted(a1: np.array, a2: np.array): 8 | result = np.empty(len(a1) + len(a2), dtype=np.int32) 9 | i = 0 10 | j = 0 11 | k = 0 12 | 13 | while i < len(a1) and j < len(a2): 14 | if a1[i] < a2[j]: 15 | result[k] = a1[i] 16 | i += 1 17 | elif a1[i] > a2[j]: 18 | result[k] = a2[j] 19 | j += 1 20 | else: # a1[i] == a2[j] 21 | result[k] = a1[i] 22 | i += 1 23 | j += 1 24 | k += 1 25 | 26 | result = result[:k] 27 | 28 | if i < len(a1): 29 | result = np.concatenate((result, a1[i:])) 30 | elif j < len(a2): 31 | result = np.concatenate((result, a2[j:])) 32 | 33 | return result 34 | 35 | 36 | @njit(cache=True) 37 | def union_sorted_multi(arrays): 38 | if len(arrays) == 1: 39 | return arrays[0] 40 | elif len(arrays) == 2: 41 | return union_sorted(arrays[0], arrays[1]) 42 | else: 43 | return union_sorted( 44 | union_sorted_multi(arrays[:2]), union_sorted_multi(arrays[2:]) 45 | ) 46 | 47 | 48 | # INTERSECTION ----------------------------------------------------------------- 49 | @njit(cache=True) 50 | def intersect_sorted(a1: np.array, a2: np.array): 51 | result = np.empty(min(len(a1), len(a2)), dtype=np.int32) 52 | i = 0 53 | j = 0 54 | k = 0 55 | 56 | while i < len(a1) and j < len(a2): 57 | if a1[i] < a2[j]: 58 | i += 1 59 | elif a1[i] > a2[j]: 60 | j += 1 61 | else: # a1[i] == a2[j] 62 | result[k] = a1[i] 63 | i += 1 64 | j += 1 65 | k += 1 66 | 67 | return result[:k] 68 | 69 | 70 | @njit(cache=True) 71 | def intersect_sorted_multi(arrays): 72 | a = arrays[0] 73 | 74 | for i in range(1, len(arrays)): 75 | a = intersect_sorted(a, arrays[i]) 76 | 77 | return a 78 | 79 | 80 | # DIFFERENCE ------------------------------------------------------------------- 81 | @njit(cache=True) 82 | def diff_sorted(a1: np.array, a2: np.array): 83 | result = np.empty(len(a1), dtype=np.int32) 84 | i = 0 85 | j = 0 86 | k = 0 87 | 88 | while i < len(a1) and j < len(a2): 89 | if a1[i] < a2[j]: 90 | result[k] = a1[i] 91 | i += 1 92 | k += 1 93 | elif a1[i] > a2[j]: 94 | j += 1 95 | else: # a1[i] == a2[j] 96 | i += 1 97 | j += 1 98 | 99 | result = result[:k] 100 | 101 | if i < len(a1): 102 | result = np.concatenate((result, a1[i:])) 103 | 104 | return result 105 | 106 | 107 | # ----------------------------------------------------------------------------- 108 | @njit(cache=True) 109 | def concat1d(X): 110 | out = np.empty(sum([len(x) for x in X]), dtype=X[0].dtype) 111 | 112 | i = 0 113 | for x in X: 114 | for j in range(len(x)): 115 | out[i] = x[j] 116 | i = i + 1 117 | 118 | return out 119 | 120 | 121 | @njit(cache=True) 122 | def get_indices(array, scores): 123 | n_scores = len(scores) 124 | min_score = min(scores) 125 | max_score = max(scores) 126 | indices = np.full(n_scores, -1, dtype=np.int64) 127 | counter = 0 128 | 129 | for i in range(len(array)): 130 | if array[i] >= min_score and array[i] <= max_score: 131 | for j in range(len(scores)): 132 | if indices[j] == -1: 133 | if scores[j] == array[i]: 134 | indices[j] = i 135 | counter += 1 136 | if len(indices) == counter: 137 | return indices 138 | break 139 | 140 | return indices 141 | 142 | 143 | @njit(cache=True) 144 | def unsorted_top_k(array: np.ndarray, k: int): 145 | top_k_values = np.zeros(k, dtype=np.float32) 146 | top_k_indices = np.zeros(k, dtype=np.int32) 147 | 148 | min_value = 0.0 149 | min_value_idx = 0 150 | 151 | for i, value in enumerate(array): 152 | if value > min_value: 153 | top_k_values[min_value_idx] = value 154 | top_k_indices[min_value_idx] = i 155 | min_value_idx = top_k_values.argmin() 156 | min_value = top_k_values[min_value_idx] 157 | 158 | return top_k_values, top_k_indices 159 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="retriv", 8 | version="0.2.3", 9 | author="Elias Bassani", 10 | author_email="elias.bssn@gmail.com", 11 | description="retriv: A Python Search Engine for Humans.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/AmenRa/retriv", 15 | packages=setuptools.find_packages(), 16 | install_requires=[ 17 | "numpy", 18 | "nltk", 19 | "numba>=0.54.1", 20 | "tqdm", 21 | "optuna", 22 | "krovetzstemmer", 23 | "pystemmer==2.0.1", 24 | "unidecode", 25 | "scikit-learn", 26 | "ranx", 27 | "indxr", 28 | "oneliner_utils", 29 | "torch", 30 | "torchvision", 31 | "torchaudio", 32 | "transformers[torch]", 33 | "faiss-cpu", 34 | "autofaiss", 35 | "multipipe", 36 | ], 37 | classifiers=[ 38 | "Programming Language :: Python :: 3", 39 | "License :: OSI Approved :: MIT License", 40 | "Intended Audience :: Science/Research", 41 | "Operating System :: OS Independent", 42 | "Topic :: Text Processing :: General", 43 | ], 44 | keywords=[ 45 | "information retrieval", 46 | "search engine", 47 | "bm25", 48 | "numba", 49 | "sparse retrieval", 50 | "dense retrieval", 51 | "hybrid retrieval", 52 | "neural information retrieval", 53 | ], 54 | python_requires=">=3.8", 55 | ) 56 | -------------------------------------------------------------------------------- /test_env.yml: -------------------------------------------------------------------------------- 1 | name: asd 2 | dependencies: 3 | - python=3.8 4 | - black 5 | - conda-forge::icecream 6 | - isort 7 | - ipykernel 8 | - ipywidgets 9 | - conda-forge::mkdocs-material 10 | - conda-forge::mkdocs-autorefs 11 | - conda-forge::mkdocstrings 12 | - conda-forge::mkdocstrings-python 13 | - pygments>=2.12 14 | - notebook 15 | - pytest 16 | # Pip 17 | - pip 18 | - pip: 19 | - retriv 20 | 21 | -------------------------------------------------------------------------------- /tests/dense_retriever/encoder_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from retriv.dense_retriever.encoder import Encoder 4 | 5 | 6 | # FIXTURES ===================================================================== 7 | @pytest.fixture 8 | def encoder(): 9 | return Encoder(model="sentence-transformers/all-MiniLM-L6-v2") 10 | 11 | 12 | @pytest.fixture 13 | def texts(): 14 | return [ 15 | "Generals gathered in their masses", 16 | "Just like witches at black masses", 17 | "Evil minds that plot destruction", 18 | "Sorcerer of death's construction", 19 | ] 20 | 21 | 22 | # TESTS ======================================================================== 23 | def test_call(encoder, texts): 24 | embeddings = encoder(texts) 25 | assert embeddings.shape[0] == 4 26 | assert embeddings.shape[1] == encoder.embedding_dim 27 | -------------------------------------------------------------------------------- /tests/merger/merger_test.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | 3 | import pytest 4 | 5 | from retriv.merger.merger import Merger 6 | from retriv.merger.normalization import min_max_norm 7 | 8 | 9 | # FIXTURES ===================================================================== 10 | @pytest.fixture 11 | def merger(): 12 | return Merger() 13 | 14 | 15 | @pytest.fixture 16 | def run_a(): 17 | return { 18 | "q1": { 19 | "d1": 2.0, 20 | "d2": 0.7, 21 | "d3": 0.5, 22 | }, 23 | "q2": { 24 | "d1": 1.0, 25 | "d2": 0.7, 26 | "d3": 0.5, 27 | }, 28 | } 29 | 30 | 31 | @pytest.fixture 32 | def run_b(): 33 | return { 34 | "q1": { 35 | "d3": 2.0, 36 | "d1": 0.7, 37 | }, 38 | "q2": { 39 | "d1": 1.0, 40 | "d2": 0.7, 41 | "d3": 0.5, 42 | }, 43 | } 44 | 45 | 46 | # TESTS ======================================================================== 47 | def test_fuse(merger, run_a, run_b): 48 | fused_results = merger.fuse([run_a["q1"], run_b["q1"]]) 49 | 50 | norm_run_a = min_max_norm(run_a) 51 | norm_run_b = min_max_norm(run_b) 52 | 53 | assert isclose(fused_results["d1"], norm_run_a["q1"]["d1"] + norm_run_b["q1"]["d1"]) 54 | assert isclose(fused_results["d2"], norm_run_a["q1"]["d2"]) 55 | assert isclose(fused_results["d3"], norm_run_a["q1"]["d3"] + norm_run_b["q1"]["d3"]) 56 | 57 | 58 | def test_mfuse(merger, run_a, run_b): 59 | fused_run = merger.mfuse([run_a, run_b]) 60 | 61 | norm_run_a = min_max_norm(run_a) 62 | norm_run_b = min_max_norm(run_b) 63 | 64 | assert isclose( 65 | fused_run["q1"]["d1"], norm_run_a["q1"]["d1"] + norm_run_b["q1"]["d1"] 66 | ) 67 | assert isclose(fused_run["q1"]["d2"], norm_run_a["q1"]["d2"]) 68 | assert isclose( 69 | fused_run["q1"]["d3"], norm_run_a["q1"]["d3"] + norm_run_b["q1"]["d3"] 70 | ) 71 | 72 | assert isclose( 73 | fused_run["q2"]["d1"], norm_run_a["q2"]["d1"] + norm_run_b["q2"]["d1"] 74 | ) 75 | assert isclose( 76 | fused_run["q2"]["d2"], norm_run_a["q2"]["d2"] + norm_run_b["q2"]["d2"] 77 | ) 78 | assert isclose( 79 | fused_run["q2"]["d3"], norm_run_a["q2"]["d3"] + norm_run_b["q2"]["d3"] 80 | ) 81 | -------------------------------------------------------------------------------- /tests/merger/score_normalization_test.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | 3 | import pytest 4 | 5 | from retriv.merger.normalization import ( 6 | extract_scores, 7 | max_norm, 8 | max_norm_multi, 9 | min_max_norm, 10 | min_max_norm_multi, 11 | safe_max, 12 | safe_min, 13 | sum_norm, 14 | sum_norm_multi, 15 | ) 16 | 17 | 18 | # FIXTURES ===================================================================== 19 | @pytest.fixture 20 | def run_a(): 21 | return { 22 | "q1": { 23 | "d1": 2.0, 24 | "d2": 0.7, 25 | "d3": 0.5, 26 | }, 27 | "q2": { 28 | "d1": 1.0, 29 | "d2": 0.7, 30 | "d3": 0.5, 31 | }, 32 | } 33 | 34 | 35 | @pytest.fixture 36 | def run_b(): 37 | return { 38 | "q1": { 39 | "d3": 2.0, 40 | "d1": 0.7, 41 | }, 42 | "q2": { 43 | "d1": 1.0, 44 | "d2": 0.7, 45 | "d3": 0.5, 46 | }, 47 | } 48 | 49 | 50 | # TESTS ======================================================================== 51 | def test_extract_scores(run_a, run_b): 52 | assert extract_scores(run_a["q1"]) == [2.0, 0.7, 0.5] 53 | assert extract_scores(run_a["q2"]) == [1.0, 0.7, 0.5] 54 | assert extract_scores(run_b["q1"]) == [2.0, 0.7] 55 | assert extract_scores(run_b["q2"]) == [1.0, 0.7, 0.5] 56 | 57 | 58 | def test_safe_max(): 59 | assert safe_max([1.0, 0.7, 0.5]) == 1.0 60 | assert safe_max([]) == 0.0 61 | 62 | 63 | def test_safe_min(): 64 | assert safe_min([1.0, 0.7, 0.5]) == 0.5 65 | assert safe_min([]) == 0.0 66 | 67 | 68 | def test_max_norm(run_a, run_b): 69 | normalized_run = max_norm(run_a) 70 | 71 | assert isclose(normalized_run["q1"]["d1"], 1.0) 72 | assert isclose(normalized_run["q1"]["d2"], 0.7 / 2.0) 73 | assert isclose(normalized_run["q1"]["d3"], 0.5 / 2.0) 74 | 75 | assert isclose(normalized_run["q2"]["d1"], 1.0) 76 | assert isclose(normalized_run["q2"]["d2"], 0.7) 77 | assert isclose(normalized_run["q2"]["d3"], 0.5) 78 | 79 | normalized_run = max_norm(run_b) 80 | 81 | assert isclose(normalized_run["q1"]["d3"], 1.0) 82 | assert isclose(normalized_run["q1"]["d1"], 0.7 / 2.0) 83 | 84 | assert isclose(normalized_run["q2"]["d1"], 1.0) 85 | assert isclose(normalized_run["q2"]["d2"], 0.7) 86 | assert isclose(normalized_run["q2"]["d3"], 0.5) 87 | 88 | 89 | def test_min_max_norm(run_a, run_b): 90 | normalized_run = min_max_norm(run_a) 91 | 92 | assert isclose(normalized_run["q1"]["d1"], (2.0 - 0.5) / (2.0 - 0.5)) 93 | assert isclose(normalized_run["q1"]["d2"], (0.7 - 0.5) / (2.0 - 0.5)) 94 | assert isclose(normalized_run["q1"]["d3"], (0.5 - 0.5) / (2.0 - 0.5)) 95 | 96 | assert isclose(normalized_run["q2"]["d1"], (1.0 - 0.5) / (1.0 - 0.5)) 97 | assert isclose(normalized_run["q2"]["d2"], (0.7 - 0.5) / (1.0 - 0.5)) 98 | assert isclose(normalized_run["q2"]["d3"], (0.5 - 0.5) / (1.0 - 0.5)) 99 | 100 | normalized_run = min_max_norm(run_b) 101 | 102 | assert isclose(normalized_run["q1"]["d3"], (2.0 - 0.7) / (2.0 - 0.7)) 103 | assert isclose(normalized_run["q1"]["d1"], (0.7 - 0.7) / (2.0 - 0.7)) 104 | 105 | assert isclose(normalized_run["q2"]["d1"], (1.0 - 0.5) / (1.0 - 0.5)) 106 | assert isclose(normalized_run["q2"]["d2"], (0.7 - 0.5) / (1.0 - 0.5)) 107 | assert isclose(normalized_run["q2"]["d3"], (0.5 - 0.5) / (1.0 - 0.5)) 108 | 109 | 110 | def test_sum_norm(run_a, run_b): 111 | normalized_run = sum_norm(run_a) 112 | 113 | denominator = (2.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5) 114 | assert isclose(normalized_run["q1"]["d1"], (2.0 - 0.5) / denominator) 115 | assert isclose(normalized_run["q1"]["d2"], (0.7 - 0.5) / denominator) 116 | assert isclose(normalized_run["q1"]["d3"], (0.5 - 0.5) / denominator) 117 | 118 | denominator = (1.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5) 119 | assert isclose(normalized_run["q2"]["d1"], (1.0 - 0.5) / denominator) 120 | assert isclose(normalized_run["q2"]["d2"], (0.7 - 0.5) / denominator) 121 | assert isclose(normalized_run["q2"]["d3"], (0.5 - 0.5) / denominator) 122 | 123 | normalized_run = sum_norm(run_b) 124 | 125 | denominator = (1.0 - 0.7) + (0.7 - 0.7) 126 | assert isclose(normalized_run["q1"]["d3"], (1.0 - 0.7) / denominator) 127 | assert isclose(normalized_run["q1"]["d1"], (0.7 - 0.7) / denominator) 128 | 129 | denominator = (1.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5) 130 | assert isclose(normalized_run["q2"]["d1"], (1.0 - 0.5) / denominator) 131 | assert isclose(normalized_run["q2"]["d2"], (0.7 - 0.5) / denominator) 132 | assert isclose(normalized_run["q2"]["d3"], (0.5 - 0.5) / denominator) 133 | 134 | 135 | def test_max_norm_multi(run_a, run_b): 136 | runs = [run_a, run_b] 137 | normalized_runs = max_norm_multi(runs) 138 | 139 | assert isclose(normalized_runs[0]["q1"]["d1"], 1.0) 140 | assert isclose(normalized_runs[0]["q1"]["d2"], 0.7 / 2.0) 141 | assert isclose(normalized_runs[0]["q1"]["d3"], 0.5 / 2.0) 142 | assert isclose(normalized_runs[0]["q2"]["d1"], 1.0) 143 | assert isclose(normalized_runs[0]["q2"]["d2"], 0.7) 144 | assert isclose(normalized_runs[0]["q2"]["d3"], 0.5) 145 | 146 | assert isclose(normalized_runs[1]["q1"]["d3"], 1.0) 147 | assert isclose(normalized_runs[1]["q1"]["d1"], 0.7 / 2.0) 148 | assert isclose(normalized_runs[1]["q2"]["d1"], 1.0) 149 | assert isclose(normalized_runs[1]["q2"]["d2"], 0.7) 150 | assert isclose(normalized_runs[1]["q2"]["d3"], 0.5) 151 | 152 | 153 | def test_min_max_norm_multi(run_a, run_b): 154 | runs = [run_a, run_b] 155 | normalized_runs = min_max_norm_multi(runs) 156 | 157 | assert isclose(normalized_runs[0]["q1"]["d1"], (2.0 - 0.5) / (2.0 - 0.5)) 158 | assert isclose(normalized_runs[0]["q1"]["d2"], (0.7 - 0.5) / (2.0 - 0.5)) 159 | assert isclose(normalized_runs[0]["q1"]["d3"], (0.5 - 0.5) / (2.0 - 0.5)) 160 | assert isclose(normalized_runs[0]["q2"]["d1"], (1.0 - 0.5) / (1.0 - 0.5)) 161 | assert isclose(normalized_runs[0]["q2"]["d2"], (0.7 - 0.5) / (1.0 - 0.5)) 162 | assert isclose(normalized_runs[0]["q2"]["d3"], (0.5 - 0.5) / (1.0 - 0.5)) 163 | 164 | assert isclose(normalized_runs[1]["q1"]["d3"], (2.0 - 0.7) / (2.0 - 0.7)) 165 | assert isclose(normalized_runs[1]["q1"]["d1"], (0.7 - 0.7) / (2.0 - 0.7)) 166 | assert isclose(normalized_runs[1]["q2"]["d1"], (1.0 - 0.5) / (1.0 - 0.5)) 167 | assert isclose(normalized_runs[1]["q2"]["d2"], (0.7 - 0.5) / (1.0 - 0.5)) 168 | assert isclose(normalized_runs[1]["q2"]["d3"], (0.5 - 0.5) / (1.0 - 0.5)) 169 | 170 | 171 | def test_sum_norm_multi(run_a, run_b): 172 | runs = [run_a, run_b] 173 | normalized_runs = sum_norm_multi(runs) 174 | 175 | denominator = (2.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5) 176 | assert isclose(normalized_runs[0]["q1"]["d1"], (2.0 - 0.5) / denominator) 177 | assert isclose(normalized_runs[0]["q1"]["d2"], (0.7 - 0.5) / denominator) 178 | assert isclose(normalized_runs[0]["q1"]["d3"], (0.5 - 0.5) / denominator) 179 | 180 | denominator = (1.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5) 181 | assert isclose(normalized_runs[0]["q2"]["d1"], (1.0 - 0.5) / denominator) 182 | assert isclose(normalized_runs[0]["q2"]["d2"], (0.7 - 0.5) / denominator) 183 | assert isclose(normalized_runs[0]["q2"]["d3"], (0.5 - 0.5) / denominator) 184 | 185 | denominator = (1.0 - 0.7) + (0.7 - 0.7) 186 | assert isclose(normalized_runs[1]["q1"]["d3"], (1.0 - 0.7) / denominator) 187 | assert isclose(normalized_runs[1]["q1"]["d1"], (0.7 - 0.7) / denominator) 188 | denominator = (1.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5) 189 | assert isclose(normalized_runs[1]["q2"]["d1"], (1.0 - 0.5) / denominator) 190 | assert isclose(normalized_runs[1]["q2"]["d2"], (0.7 - 0.5) / denominator) 191 | assert isclose(normalized_runs[1]["q2"]["d3"], (0.5 - 0.5) / denominator) 192 | -------------------------------------------------------------------------------- /tests/numba_utils_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from numba.typed import List as TypedList 4 | 5 | from retriv.utils.numba_utils import ( 6 | concat1d, 7 | diff_sorted, 8 | get_indices, 9 | intersect_sorted, 10 | intersect_sorted_multi, 11 | union_sorted, 12 | union_sorted_multi, 13 | unsorted_top_k, 14 | ) 15 | 16 | 17 | # TESTS ======================================================================== 18 | def test_union_sorted(): 19 | a1 = np.array([1, 3, 4, 7], dtype=np.int32) 20 | a2 = np.array([1, 4, 7, 9], dtype=np.int32) 21 | result = union_sorted(a1, a2) 22 | expected = np.array([1, 3, 4, 7, 9], dtype=np.int32) 23 | 24 | assert np.array_equal(result, expected) 25 | 26 | 27 | def test_union_sorted_multi(): 28 | a1 = np.array([1, 3, 4, 7], dtype=np.int32) 29 | a2 = np.array([1, 4, 7, 9], dtype=np.int32) 30 | a3 = np.array([10, 11], dtype=np.int32) 31 | a4 = np.array([11, 12, 13], dtype=np.int32) 32 | 33 | arrays = TypedList([a1, a2, a3, a4]) 34 | 35 | result = union_sorted_multi(arrays) 36 | expected = np.array([1, 3, 4, 7, 9, 10, 11, 12, 13], dtype=np.int32) 37 | 38 | assert np.array_equal(result, expected) 39 | 40 | 41 | def test_intersect_sorted(): 42 | a1 = np.array([1, 3, 4, 7], dtype=np.int32) 43 | a2 = np.array([1, 4, 7, 9], dtype=np.int32) 44 | result = intersect_sorted(a1, a2) 45 | expected = np.array([1, 4, 7], dtype=np.int32) 46 | 47 | assert np.array_equal(result, expected) 48 | 49 | 50 | def test_intersect_sorted_multi(): 51 | a1 = np.array([1, 3, 4, 7], dtype=np.int32) 52 | a2 = np.array([1, 4, 7, 9], dtype=np.int32) 53 | a3 = np.array([4, 7], dtype=np.int32) 54 | a4 = np.array([3, 7, 9], dtype=np.int32) 55 | 56 | arrays = TypedList([a1, a2, a3, a4]) 57 | 58 | result = intersect_sorted_multi(arrays) 59 | expected = np.array([7], dtype=np.int32) 60 | 61 | print(result) 62 | 63 | assert np.array_equal(result, expected) 64 | 65 | 66 | def test_diff_sorted(): 67 | a1 = np.array([1, 3, 4, 7], dtype=np.int32) 68 | a2 = np.array([1, 4, 7, 9], dtype=np.int32) 69 | result = diff_sorted(a1, a2) 70 | expected = np.array([3], dtype=np.int32) 71 | 72 | assert np.array_equal(result, expected) 73 | 74 | a1 = np.array([1, 3, 4, 7, 11], dtype=np.int32) 75 | a2 = np.array([1, 4, 7, 9], dtype=np.int32) 76 | result = diff_sorted(a1, a2) 77 | expected = np.array([3, 11], dtype=np.int32) 78 | 79 | assert np.array_equal(result, expected) 80 | 81 | 82 | def test_concat1d(): 83 | a1 = np.array([1, 3, 4, 7], dtype=np.int32) 84 | a2 = np.array([1, 4, 7, 9], dtype=np.int32) 85 | a3 = np.array([10, 11], dtype=np.int32) 86 | a4 = np.array([11, 12, 13], dtype=np.int32) 87 | 88 | arrays = TypedList([a1, a2, a3, a4]) 89 | 90 | result = concat1d(arrays) 91 | expected = np.array([1, 3, 4, 7, 1, 4, 7, 9, 10, 11, 11, 12, 13], dtype=np.int32) 92 | 93 | assert np.array_equal(result, expected) 94 | 95 | 96 | def test_get_indices(): 97 | array = np.array([0.1, 0.3, 0.2, 0.4], dtype=np.float32) 98 | scores = np.array([0.4, 0.3, 0.2, 0.1], dtype=np.float32) 99 | 100 | result = get_indices(array, scores) 101 | expected = np.array([3, 1, 2, 0], dtype=np.int64) 102 | 103 | assert np.array_equal(result, expected) 104 | 105 | 106 | def test_unsorted_top_k(): 107 | array = np.array([0.1, 0.3, 0.2, 0.4], dtype=np.float32) 108 | k = 2 109 | 110 | top_k_values, top_k_indices = unsorted_top_k(array, k) 111 | 112 | assert len(top_k_values) == 2 113 | assert 0.3 in top_k_values 114 | assert 0.4 in top_k_values 115 | assert len(top_k_indices) == 2 116 | assert 1 in top_k_indices 117 | assert 3 in top_k_indices 118 | -------------------------------------------------------------------------------- /tests/sparse_retriever/preprocessing_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from retriv.sparse_retriever.preprocessing import preprocessing, preprocessing_multi 4 | from retriv.sparse_retriever.preprocessing.stemmer import get_stemmer 5 | from retriv.sparse_retriever.preprocessing.stopwords import get_stopwords 6 | from retriv.sparse_retriever.preprocessing.tokenizer import get_tokenizer 7 | 8 | 9 | # FIXTURES ===================================================================== 10 | @pytest.fixture 11 | def stemmer(): 12 | return get_stemmer("english") 13 | 14 | 15 | @pytest.fixture 16 | def stopwords(): 17 | return get_stopwords("english") 18 | 19 | 20 | @pytest.fixture 21 | def tokenizer(): 22 | return get_tokenizer("whitespace") 23 | 24 | 25 | @pytest.fixture 26 | def docs(): 27 | return """Black Sabbath were an English rock band formed in Birmingham in 1968 by guitarist Tony Iommi, drummer Bill Ward, bassist Geezer Butler and vocalist Ozzy Osbourne. They are often cited as pioneers of heavy metal music. The band helped define the genre with releases such as Black Sabbath (1970), Paranoid (1970) and Master of Reality (1971). The band had multiple line-up changes following Osbourne's departure in 1979 and Iommi is the only constant member throughout their history. 28 | After previous iterations of the group - the Polka Tulk Blues Band and Earth - the band settled on the name Black Sabbath in 1969. They distinguished themselves through occult themes with horror-inspired lyrics and down-tuned guitars. Signing to Philips Records in November 1969, they released their first single, "Evil Woman", in January 1970, and their debut album, Black Sabbath, was released the following month. Though it received a negative critical response, the album was a commercial success, leading to a follow-up record, Paranoid, later that year. The band's popularity grew, and by 1973's Sabbath Bloody Sabbath, critics were starting to respond favourably. 29 | Osbourne's excessive substance abuse led to his firing in 1979. He was replaced by former Rainbow vocalist Ronnie James Dio. Following two albums with Dio, Black Sabbath endured many personnel changes in the 1980s and 1990s that included vocalists Ian Gillan, Glenn Hughes, Ray Gillen and Tony Martin, as well as several drummers and bassists. Martin, who replaced Gillen in 1987, was the second-longest serving vocalist and recorded three albums with Black Sabbath before his dismissal in 1991. That same year, Iommi and Butler were rejoined by Dio and drummer Vinny Appice to record Dehumanizer (1992). After two more studio albums with Martin, who replaced Dio in 1993, the band's original line-up reunited in 1997 and released a live album, Reunion, the following year; they continued to tour occasionally until 2005. Other than various back catalogue reissues and compilation albums, as well as the Mob Rules-era line-up reunited as Heaven & Hell, there was no further activity under the Black Sabbath name for six years. They reunited in 2011 and released their final studio album and 19th overall, 13, in 2013, which features all of the original members except Ward. During their farewell tour, the band played their final concert in their home city of Birmingham on 4 February 2017. Occasional partial reunions have happened since, most recently when Osbourne and Iommi performed together at the closing ceremony of the 2022 Commonwealth Games in Birmingham. 30 | Black Sabbath have sold over 70 million records worldwide as of 2013, making them one of the most commercially successful heavy metal bands. Black Sabbath, together with Deep Purple and Led Zeppelin, have been referred to as the "unholy trinity of British hard rock and heavy metal in the early to mid-seventies". They were ranked by MTV as the "Greatest Metal Band of All Time" and placed second on VH1's "100 Greatest Artists of Hard Rock" list. Rolling Stone magazine ranked them number 85 on their "100 Greatest Artists of All Time". Black Sabbath were inducted into the UK Music Hall of Fame in 2005 and the Rock and Roll Hall of Fame in 2006. They have also won two Grammy Awards for Best Metal Performance, and in 2019 the band were presented a Grammy Lifetime Achievement Award.""".split( 31 | "\n" 32 | ) 33 | 34 | 35 | # TESTS ======================================================================== 36 | def test_preprocessing_multi(docs, stemmer, stopwords, tokenizer): 37 | out = [ 38 | preprocessing( 39 | doc, 40 | stemmer=stemmer, 41 | stopwords=stopwords, 42 | tokenizer=tokenizer, 43 | do_lowercasing=True, 44 | do_ampersand_normalization=True, 45 | do_special_chars_normalization=True, 46 | do_acronyms_normalization=True, 47 | do_punctuation_removal=True, 48 | ) 49 | for doc in docs 50 | ] 51 | 52 | pipeline = preprocessing_multi( 53 | tokenizer=tokenizer, 54 | stopwords=stopwords, 55 | stemmer=stemmer, 56 | do_lowercasing=True, 57 | do_ampersand_normalization=True, 58 | do_special_chars_normalization=True, 59 | do_acronyms_normalization=True, 60 | do_punctuation_removal=True, 61 | ) 62 | multi_out = pipeline(docs) 63 | 64 | assert len(out) == len(multi_out) 65 | assert out[0] == multi_out[0] 66 | assert out[1] == multi_out[1] 67 | assert out[2] == multi_out[2] 68 | assert out[3] == multi_out[3] 69 | assert out == multi_out 70 | -------------------------------------------------------------------------------- /tests/sparse_retriever/search_engine_test.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | 3 | import pytest 4 | 5 | from retriv import SearchEngine 6 | 7 | REL_TOL = 1e-6 8 | 9 | 10 | # FIXTURES ===================================================================== 11 | @pytest.fixture 12 | def collection(): 13 | return [ 14 | {"id": 1, "text": "Shane"}, 15 | {"id": 2, "text": "Shane C"}, 16 | {"id": 3, "text": "Shane P Connelly"}, 17 | {"id": 4, "text": "Shane Connelly"}, 18 | {"id": 5, "text": "Shane Shane Connelly Connelly"}, 19 | {"id": 6, "text": "Shane Shane Shane Connelly Connelly Connelly"}, 20 | ] 21 | 22 | 23 | def test_search_bm25(collection): 24 | se = SearchEngine(hyperparams=dict(b=0.5, k1=0)) 25 | se.index(collection) 26 | 27 | query = "shane" 28 | 29 | results = se.search(query=query, return_docs=False) 30 | 31 | print(se.inverted_index) 32 | 33 | print(results) 34 | assert isclose(results[1], 0.07410797, rel_tol=REL_TOL) 35 | assert isclose(results[2], 0.07410797, rel_tol=REL_TOL) 36 | assert isclose(results[3], 0.07410797, rel_tol=REL_TOL) 37 | assert isclose(results[4], 0.07410797, rel_tol=REL_TOL) 38 | assert isclose(results[5], 0.07410797, rel_tol=REL_TOL) 39 | assert isclose(results[6], 0.07410797, rel_tol=REL_TOL) 40 | 41 | se.hyperparams = dict(b=0, k1=10) 42 | results = se.search(query=query, return_docs=False) 43 | print(results) 44 | assert isclose(results[1], 0.07410797, rel_tol=REL_TOL) 45 | assert isclose(results[2], 0.07410797, rel_tol=REL_TOL) 46 | assert isclose(results[3], 0.07410797, rel_tol=REL_TOL) 47 | assert isclose(results[4], 0.07410797, rel_tol=REL_TOL) 48 | assert isclose(results[5], 0.13586462, rel_tol=REL_TOL) 49 | assert isclose(results[6], 0.18812023, rel_tol=REL_TOL) 50 | 51 | se.hyperparams = dict(b=1, k1=5) 52 | results = se.search(query=query, return_docs=False) 53 | print(results) 54 | assert isclose(results[1], 0.16674294, rel_tol=REL_TOL) 55 | assert isclose(results[2], 0.10261103, rel_tol=REL_TOL) 56 | assert isclose(results[3], 0.07410797, rel_tol=REL_TOL) 57 | assert isclose(results[4], 0.10261103, rel_tol=REL_TOL) 58 | assert isclose(results[5], 0.10261103, rel_tol=REL_TOL) 59 | assert isclose(results[6], 0.10261105, rel_tol=REL_TOL) 60 | 61 | 62 | def test_msearch_bm25(collection): 63 | se = SearchEngine(hyperparams=dict(b=0.5, k1=0)) 64 | se.index(collection) 65 | 66 | queries = [ 67 | {"id": "q_1", "text": "shane"}, 68 | {"id": "q_2", "text": "connelly"}, 69 | ] 70 | 71 | results = se.msearch(queries=queries) 72 | 73 | print(results) 74 | assert isclose(results["q_1"][1], 0.07410797, rel_tol=REL_TOL) 75 | assert isclose(results["q_1"][2], 0.07410797, rel_tol=REL_TOL) 76 | assert isclose(results["q_1"][3], 0.07410797, rel_tol=REL_TOL) 77 | assert isclose(results["q_1"][4], 0.07410797, rel_tol=REL_TOL) 78 | assert isclose(results["q_1"][5], 0.07410797, rel_tol=REL_TOL) 79 | assert isclose(results["q_1"][6], 0.07410797, rel_tol=REL_TOL) 80 | assert isclose(results["q_2"][3], 0.44183275, rel_tol=REL_TOL) 81 | assert isclose(results["q_2"][4], 0.44183275, rel_tol=REL_TOL) 82 | assert isclose(results["q_2"][5], 0.44183275, rel_tol=REL_TOL) 83 | assert isclose(results["q_2"][6], 0.44183275, rel_tol=REL_TOL) 84 | 85 | se.hyperparams = dict(b=0, k1=10) 86 | results = se.msearch(queries=queries) 87 | print(results) 88 | assert isclose(results["q_1"][1], 0.07410797, rel_tol=REL_TOL) 89 | assert isclose(results["q_1"][2], 0.07410797, rel_tol=REL_TOL) 90 | assert isclose(results["q_1"][3], 0.07410797, rel_tol=REL_TOL) 91 | assert isclose(results["q_1"][4], 0.07410797, rel_tol=REL_TOL) 92 | assert isclose(results["q_1"][5], 0.13586462, rel_tol=REL_TOL) 93 | assert isclose(results["q_1"][6], 0.18812023, rel_tol=REL_TOL) 94 | assert isclose(results["q_2"][3], 0.44183275, rel_tol=REL_TOL) 95 | assert isclose(results["q_2"][4], 0.44183275, rel_tol=REL_TOL) 96 | assert isclose(results["q_2"][5], 0.8100267, rel_tol=REL_TOL) 97 | assert isclose(results["q_2"][6], 1.1215755, rel_tol=REL_TOL) 98 | 99 | se.hyperparams = dict(b=1, k1=5) 100 | results = se.msearch(queries=queries) 101 | print(results) 102 | assert isclose(results["q_1"][1], 0.16674294, rel_tol=REL_TOL) 103 | assert isclose(results["q_1"][2], 0.10261103, rel_tol=REL_TOL) 104 | assert isclose(results["q_1"][3], 0.07410797, rel_tol=REL_TOL) 105 | assert isclose(results["q_1"][4], 0.10261103, rel_tol=REL_TOL) 106 | assert isclose(results["q_1"][5], 0.10261103, rel_tol=REL_TOL) 107 | assert isclose(results["q_1"][6], 0.10261105, rel_tol=REL_TOL) 108 | assert isclose(results["q_2"][3], 0.44183275, rel_tol=REL_TOL) 109 | assert isclose(results["q_2"][4], 0.6117684, rel_tol=REL_TOL) 110 | assert isclose(results["q_2"][5], 0.6117684, rel_tol=REL_TOL) 111 | assert isclose(results["q_2"][6], 0.6117684, rel_tol=REL_TOL) 112 | -------------------------------------------------------------------------------- /tests/sparse_retriever/stemmer_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from retriv.sparse_retriever.preprocessing.stemmer import get_stemmer 4 | 5 | 6 | # FIXTURES ===================================================================== 7 | @pytest.fixture 8 | def supported_stemmers(): 9 | # fmt: off 10 | return [ 11 | "krovetz", "porter", "lancaster", "arlstem", "arlstem2", "cistem", 12 | "isri", "arabic", 13 | "basque", "catalan", "danish", "dutch", "english", "finnish", "french", "german", "greek", "hindi", "hungarian", "indonesian", "irish", "italian", "lithuanian", "nepali", "norwegian", "portuguese", "romanian", "russian", 14 | "spanish", "swedish", "tamil", "turkish", 15 | ] 16 | # fmt: on 17 | 18 | 19 | # TESTS ======================================================================== 20 | def test_get_stemmer(supported_stemmers): 21 | for stemmer in supported_stemmers: 22 | assert callable(get_stemmer(stemmer)) 23 | 24 | 25 | def test_get_stemmer_fails(): 26 | with pytest.raises(Exception): 27 | get_stemmer("foobar") 28 | 29 | 30 | def test_get_stemmer_callable(): 31 | assert callable(get_stemmer(lambda x: x)) 32 | 33 | 34 | def test_get_stemmer_none(): 35 | assert callable(get_stemmer(None)) 36 | assert get_stemmer(None)("incredible") == "incredible" 37 | -------------------------------------------------------------------------------- /tests/sparse_retriever/stopwords_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from retriv.sparse_retriever.preprocessing.stopwords import get_stopwords 4 | 5 | 6 | # FIXTURES ===================================================================== 7 | @pytest.fixture 8 | def supported_languages(): 9 | # fmt: off 10 | return [ 11 | "arabic", "azerbaijani", "basque", "catalan", "bengali", "chinese", 12 | "danish", "dutch", "english", "finnish", "french", "german", "greek", 13 | "hebrew", "hinglish", "hungarian", "indonesian", "italian", "kazakh", 14 | "nepali", "norwegian", "portuguese", "romanian", "russian", "slovene", 15 | "spanish", "swedish", "tajik", "turkish", 16 | ] 17 | # fmt: on 18 | 19 | 20 | @pytest.fixture 21 | def sw_list(): 22 | return ["a", "the"] 23 | 24 | 25 | @pytest.fixture 26 | def sw_set(): 27 | return {"a", "the"} 28 | 29 | 30 | # TESTS ======================================================================== 31 | def test_get_stopwords_from_lang(supported_languages): 32 | for lang in supported_languages: 33 | assert type(get_stopwords(lang)) == list 34 | assert len(get_stopwords(lang)) > 0 35 | 36 | 37 | def test_get_stopwords_from_lang_fails(): 38 | with pytest.raises(Exception): 39 | get_stopwords("foobar") 40 | 41 | 42 | def test_get_stopwords_from_list(sw_list): 43 | assert type(get_stopwords(sw_list)) == list 44 | assert set(get_stopwords(sw_list)) == {"a", "the"} 45 | assert len(get_stopwords(sw_list)) > 0 46 | 47 | 48 | def test_get_stopwords_from_set(sw_set): 49 | assert type(get_stopwords(sw_set)) == list 50 | assert set(get_stopwords(sw_set)) == {"a", "the"} 51 | assert len(get_stopwords(sw_set)) > 0 52 | 53 | 54 | def test_get_stopwords_none(): 55 | assert type(get_stopwords(None)) == list 56 | assert len(get_stopwords(None)) == 0 57 | -------------------------------------------------------------------------------- /tests/sparse_retriever/text_normalization_test.py: -------------------------------------------------------------------------------- 1 | from retriv.sparse_retriever.preprocessing.normalization import ( 2 | lowercasing, 3 | normalize_acronyms, 4 | normalize_ampersand, 5 | normalize_special_chars, 6 | remove_punctuation, 7 | strip_whitespaces, 8 | ) 9 | 10 | 11 | # TESTS ======================================================================== 12 | def test_lowercasing(): 13 | assert lowercasing("hEllO") == "hello" 14 | 15 | 16 | def test_normalize_ampersand(): 17 | assert normalize_ampersand("black&sabbath") == "black and sabbath" 18 | 19 | 20 | def test_normalize_special_chars(): 21 | assert normalize_special_chars("‘’") == "''" 22 | 23 | 24 | def test_normalize_acronyms(): 25 | assert normalize_acronyms("a.b.c.") == "abc" 26 | assert normalize_acronyms("foo.bar") == "foo.bar" 27 | assert normalize_acronyms("a.b@hello.com") == "a.b@hello.com" 28 | 29 | 30 | def test_remove_punctuation(): 31 | assert remove_punctuation("foo.bar?") == "foo bar " 32 | # assert remove_punctuation("a.b@hello.com") == "a.b@hello.com" 33 | 34 | 35 | def test_strip_whitespaces(): 36 | assert strip_whitespaces(" hello world ") == "hello world" 37 | -------------------------------------------------------------------------------- /tests/sparse_retriever/tokenizer_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from retriv.sparse_retriever.preprocessing.tokenizer import get_tokenizer 4 | 5 | 6 | # FIXTURES ===================================================================== 7 | @pytest.fixture 8 | def supported_tokenizers(): 9 | return ["whitespace", "word", "wordpunct", "sent"] 10 | 11 | 12 | # TESTS ======================================================================== 13 | def test_get_tokenizer(supported_tokenizers): 14 | for tokenizer in supported_tokenizers: 15 | assert callable(get_tokenizer(tokenizer)) 16 | 17 | 18 | def test_get_tokenizer_fails(): 19 | with pytest.raises(Exception): 20 | get_tokenizer("foobar") 21 | 22 | 23 | def test_get_tokenizer_callable(): 24 | assert callable(get_tokenizer(lambda x: x)) 25 | 26 | 27 | def test_get_tokenizer_none(): 28 | assert callable(get_tokenizer(None)) 29 | assert get_tokenizer(None)("black sabbath") == "black sabbath" 30 | --------------------------------------------------------------------------------