├── .gitignore
├── CITATION.cff
├── LICENSE
├── Makefile
├── README.md
├── _typos.toml
├── changelog.md
├── conftest.py
├── dev_env.yml
├── docs
├── advanced_retriever.md
├── dense_retriever.md
├── faq.md
├── filters.md
├── hybrid_retriever.md
├── sparse_retriever.md
├── speed.md
└── text_preprocessing.md
├── pyproject.toml
├── requirements-dev.txt
├── requirements.txt
├── retriv
├── __init__.py
├── autotune
│ ├── __init__.py
│ ├── bm25_autotune.py
│ └── merger_autotune.py
├── base_retriever.py
├── dense_retriever
│ ├── __init__.py
│ ├── ann_searcher.py
│ ├── dense_retriever.py
│ └── encoder.py
├── experimental
│ ├── __init__.py
│ └── advanced_retriever.py
├── hybrid_retriever.py
├── merger
│ ├── __init__.py
│ ├── merger.py
│ └── normalization.py
├── paths.py
├── sparse_retriever
│ ├── __init__.py
│ ├── build_inverted_index.py
│ ├── preprocessing
│ │ ├── __init__.py
│ │ ├── normalization.py
│ │ ├── stemmer.py
│ │ ├── stopwords.py
│ │ ├── tokenizer.py
│ │ └── utils.py
│ ├── sparse_retrieval_models
│ │ ├── __init__.py
│ │ ├── bm25.py
│ │ └── tf_idf.py
│ └── sparse_retriever.py
└── utils
│ ├── __init__.py
│ └── numba_utils.py
├── setup.py
├── test_env.yml
└── tests
├── advanced_retriever
└── advanced_retriever_test.py
├── dense_retriever
└── encoder_test.py
├── merger
├── merger_test.py
└── score_normalization_test.py
├── numba_utils_test.py
└── sparse_retriever
├── preprocessing_test.py
├── search_engine_test.py
├── stemmer_test.py
├── stopwords_test.py
├── text_normalization_test.py
└── tokenizer_test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /*.ipynb
2 | /beir
3 | /dictionaries
4 | /collections
5 | /datasets
6 | /bkp
7 | *.npy
8 | *.tar.gz
9 | *.csv
10 | *.tsv
11 | *.jsonl
12 | /compute_runs.py
13 | .vscode
14 |
15 | raw_collection.jsonl
16 | .DS_Store
17 |
18 | # Byte-compiled / optimized / DLL files
19 | __pycache__/
20 | *.py[cod]
21 | *$py.class
22 |
23 | # C extensions
24 | *.so
25 |
26 | # Distribution / packaging
27 | .Python
28 | build/
29 | develop-eggs/
30 | dist/
31 | downloads/
32 | eggs/
33 | .eggs/
34 | lib/
35 | lib64/
36 | parts/
37 | sdist/
38 | var/
39 | wheels/
40 | pip-wheel-metadata/
41 | share/python-wheels/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 | MANIFEST
46 |
47 | # PyInstaller
48 | # Usually these files are written by a python script from a template
49 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
50 | *.manifest
51 | *.spec
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 |
57 | # Unit test / coverage reports
58 | htmlcov/
59 | .tox/
60 | .nox/
61 | .coverage
62 | .coverage.*
63 | .cache
64 | nosetests.xml
65 | coverage.xml
66 | *.cover
67 | *.py,cover
68 | .hypothesis/
69 | .pytest_cache/
70 |
71 | # Translations
72 | *.mo
73 | *.pot
74 |
75 | # Django stuff:
76 | *.log
77 | local_settings.py
78 | db.sqlite3
79 | db.sqlite3-journal
80 |
81 | # Flask stuff:
82 | instance/
83 | .webassets-cache
84 |
85 | # Scrapy stuff:
86 | .scrapy
87 |
88 | # Sphinx documentation
89 | docs/_build/
90 |
91 | # PyBuilder
92 | target/
93 |
94 | # Jupyter Notebook
95 | .ipynb_checkpoints
96 |
97 | # IPython
98 | profile_default/
99 | ipython_config.py
100 |
101 | # pyenv
102 | .python-version
103 |
104 | # pipenv
105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
108 | # install all needed dependencies.
109 | #Pipfile.lock
110 |
111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
112 | __pypackages__/
113 |
114 | # Celery stuff
115 | celerybeat-schedule
116 | celerybeat.pid
117 |
118 | # SageMath parsed files
119 | *.sage.py
120 |
121 | # Environments
122 | .env
123 | .venv
124 | env/
125 | venv/
126 | ENV/
127 | env.bak/
128 | venv.bak/
129 |
130 | # Spyder project settings
131 | .spyderproject
132 | .spyproject
133 |
134 | # Rope project settings
135 | .ropeproject
136 |
137 | # mkdocs documentation
138 | /site
139 |
140 | # mypy
141 | .mypy_cache/
142 | .dmypy.json
143 | dmypy.json
144 |
145 | # Pyre type checker
146 | .pyre/
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Bassani"
5 | given-names: "Elias"
6 | orcid: "https://orcid.org/0000-0001-7922-2578"
7 | title: "retriv: A Python Search Engine for the Common Man"
8 | version: 0.2.1
9 | doi: 10.5281/zenodo.7978820
10 | date-released: 2023-05-28
11 | url: "https://github.com/AmenRa/retriv"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Elias Bassani
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # from https://github.com/alexander-beedie/polars/blob/a083a26ca092b3821bf4044aed74d91ee127bad1/py-polars/Makefile
2 | .DEFAULT_GOAL := help
3 |
4 | SHELL := /bin/bash
5 | .SHELLFLAGS := -eu -o pipefail -c
6 | PYTHONPATH=
7 | VENV = .venv
8 | MAKE = make
9 |
10 | VENV_BIN=$(VENV)/bin
11 |
12 | .venv: ## Set up virtual environment and install requirements
13 | python3 -m venv $(VENV)
14 | $(MAKE) requirements
15 |
16 | .PHONY: requirements
17 | requirements: .venv ## Install/refresh all project requirements
18 | $(VENV_BIN)/python -m pip install --upgrade pip
19 | $(VENV_BIN)/pip install -r requirements-dev.txt
20 | $(VENV_BIN)/pip install -r requirements.txt
21 |
22 | .PHONY: build
23 | build: .venv ## Compile and install retriv
24 | . $(VENV_BIN)/activate
25 |
26 |
27 | .PHONY: fix-lint
28 | fix-lint: .venv ## Fix linting
29 | . $(VENV_BIN)/activate
30 | $(VENV_BIN)/black .
31 | $(VENV_BIN)/isort .
32 |
33 | .PHONY: lint
34 | lint: .venv ## Check linting
35 | . $(VENV_BIN)/activate
36 | $(VENV_BIN)/isort --check .
37 | $(VENV_BIN)/black --check .
38 | $(VENV_BIN)/blackdoc .
39 | $(VENV_BIN)/ruff .
40 | $(VENV_BIN)/typos .
41 | # $(VENV_BIN)/mypy retriv
42 |
43 | .PHONY: test
44 | test: .venv build ## Run unittest
45 | . $(VENV_BIN)/activate
46 | $(VENV_BIN)/pytest tests
47 |
48 | .PHONY: coverage
49 | coverage: .venv build ## Run tests and report coverage
50 | . $(VENV_BIN)/activate
51 | $(VENV_BIN)/pytest tests --cov -n auto --dist worksteal -m "not benchmark"
52 |
53 | .PHONY: release
54 | release: .venv build ## Release a new version
55 | . $(VENV_BIN)/activate
56 | $(VENV_BIN)/python setup.py sdist bdist_wheel
57 | $(VENV_BIN)/twine upload --repository-url https://upload.pypi.org/legacy/ dist/*
58 |
59 | .PHONY: clean
60 | clean: ## Clean up caches and build artifacts
61 | @rm -rf .venv/
62 | @rm -rf target/
63 | @rm -rf .pytest_cache/
64 | @rm -rf .coverage
65 | @rm -rf retriv.egg-info
66 | @rm -rf dist
67 | @rm -rf build
68 |
69 | .PHONY: help
70 | help: ## Display this help screen
71 | @echo -e "\033[1mAvailable commands:\033[0m\n"
72 | @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | ## 🔥 News
27 | - [August 23, 2023] `retriv` 0.2.2 is out!
28 | This release adds _experimental_ support for multi-field documents and filters.
29 | Please, refer to [Advanced Retriever](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md) documentation.
30 |
31 | - [February 18, 2023] `retriv` 0.2.0 is out!
32 | This release adds support for Dense and Hybrid Retrieval.
33 | Dense Retrieval leverages the semantic similarity of the queries' and documents' vector representations, which can be computed directly by `retriv` or imported from other sources.
34 | Hybrid Retrieval mix traditional retrieval, informally called Sparse Retrieval, and Dense Retrieval results to further improve retrieval effectiveness.
35 | As the library was almost completely redone, indices built with previous versions are no longer supported.
36 |
37 | ## ⚡️ Introduction
38 |
39 | [retriv](https://github.com/AmenRa/retriv) is a user-friendly and efficient [search engine](https://en.wikipedia.org/wiki/Search_engine) implemented in [Python](https://en.wikipedia.org/wiki/Python_(programming_language)) supporting Sparse (traditional search with [BM25](https://en.wikipedia.org/wiki/Okapi_BM25), [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf)), Dense ([semantic search](https://en.wikipedia.org/wiki/Semantic_search)) and Hybrid retrieval (a mix of Sparse and Dense Retrieval).
40 | It allows you to build a search engine in a __single line of code__.
41 |
42 | [retriv](https://github.com/AmenRa/retriv) is built upon [Numba](https://github.com/numba/numba) for high-speed [vector operations](https://en.wikipedia.org/wiki/Automatic_vectorization) and [automatic parallelization](https://en.wikipedia.org/wiki/Automatic_parallelization), [PyTorch](https://pytorch.org) and [Transformers](https://huggingface.co/docs/transformers/index) for easy access and usage of [Transformer-based Language Models](https://web.stanford.edu/~jurafsky/slp3/10.pdf), and [Faiss](https://github.com/facebookresearch/faiss) for approximate [nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search).
43 | In addition, it provides automatic tuning functionalities to allow you to tune its internal components with minimal intervention.
44 |
45 |
46 | ## ✨ Main Features
47 |
48 | ### Retrievers
49 | - [Sparse Retriever](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md): standard searcher based on lexical matching.
50 | [retriv](https://github.com/AmenRa/retriv) implements [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) as its main retrieval model.
51 | [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf) is also supported for educational purposes.
52 | The sparse retriever comes armed with multiple [stemmers](https://en.wikipedia.org/wiki/Stemming), [tokenizers](https://en.wikipedia.org/wiki/Lexical_analysis#Tokenization), and [stop-word](https://en.wikipedia.org/wiki/Stop_word) lists, for multiple languages.
53 | Click [here](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md) to learn more.
54 | - [Dense Retriever](https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md): a dense retriever is a retrieval model that performs [semantic search](https://en.wikipedia.org/wiki/Semantic_search).
55 | Click [here](https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md) to learn more.
56 | - [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md): an hybrid retriever is a retrieval model built on top of a sparse and a dense retriever.
57 | Click [here](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) to learn more.
58 | - [Advanced Retriever](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md): an advanced sparse retriever supporting filters. This is and experimental feature.
59 | Click [here](https://github.com/AmenRa/retriv/blob/main/docs/advanced_retriever.md) to learn more.
60 |
61 | ### Unified Search Interface
62 | All the supported retrievers share the same search interface:
63 | - [search](#search): standard search functionality, what you expect by a search engine.
64 | - [msearch](#multi-search): computes the results for multiple queries at once.
65 | It leverages [automatic parallelization](https://en.wikipedia.org/wiki/Automatic_parallelization) whenever possible.
66 | - [bsearch](#batch-search): similar to [msearch](#multi-search) but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. [bsearch](#batch-search) is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM. Pre-computed results can be leveraged for negative sampling during the training of [Neural Models](https://en.wikipedia.org/wiki/Artificial_neural_network) for [Information Retrieval](https://en.wikipedia.org/wiki/Information_retrieval).
67 |
68 | ### AutoTune
69 | [retriv](https://github.com/AmenRa/retriv) automatically tunes [Faiss](https://github.com/facebookresearch/faiss) configuration for approximate nearest neighbors search by leveraging [AutoFaiss](https://github.com/criteo/autofaiss) to guarantee 10ms response time based on your available hardware.
70 | Moreover, it offers an automatic tuning functionality for [BM25](https://en.wikipedia.org/wiki/Okapi_BM25)'s parameters, which require minimal user intervention.
71 | Under the hood, [retriv](https://github.com/AmenRa/retriv) leverages [Optuna](https://optuna.org), a [hyperparameter optimization](https://en.wikipedia.org/wiki/Hyperparameter_optimization) framework, and [ranx](https://github.com/AmenRa/ranx), an [Information Retrieval](https://en.wikipedia.org/wiki/Information_retrieval) evaluation library, to test several parameter configurations for [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) and choose the best one.
72 | Finally, it can automatically balance the importance of lexical and semantic relevance scores computed by the [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) to maximize retrieval effectiveness.
73 |
74 | ## 📚 Documentation
75 |
76 | - [Sparse Retriever](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md)
77 | - [Dense Retriever](https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md)
78 | - [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md)
79 | - [Text Pre-Processing](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md)
80 | - [FAQ](https://github.com/AmenRa/retriv/blob/main/docs/faq.md)
81 |
82 | ## 🔌 Requirements
83 | ```
84 | python>=3.8
85 | ```
86 |
87 | ## 💾 Installation
88 | ```bash
89 | pip install retriv
90 | ```
91 |
92 | ## 💡 Minimal Working Example
93 |
94 | ```python
95 | # Note: SearchEngine is an alias for the SparseRetriever
96 | from retriv import SearchEngine
97 |
98 | collection = [
99 | {"id": "doc_1", "text": "Generals gathered in their masses"},
100 | {"id": "doc_2", "text": "Just like witches at black masses"},
101 | {"id": "doc_3", "text": "Evil minds that plot destruction"},
102 | {"id": "doc_4", "text": "Sorcerer of death's construction"},
103 | ]
104 |
105 | se = SearchEngine("new-index").index(collection)
106 |
107 | se.search("witches masses")
108 | ```
109 | Output:
110 | ```json
111 | [
112 | {
113 | "id": "doc_2",
114 | "text": "Just like witches at black masses",
115 | "score": 1.7536403
116 | },
117 | {
118 | "id": "doc_1",
119 | "text": "Generals gathered in their masses",
120 | "score": 0.6931472
121 | }
122 | ]
123 | ```
124 |
125 |
126 |
127 |
128 |
129 |
130 | ## 🎁 Feature Requests
131 | Would you like to see other features implemented? Please, open a [feature request](https://github.com/AmenRa/retriv/issues/new?assignees=&labels=enhancement&template=feature_request.md&title=%5BFeature+Request%5D+title).
132 |
133 |
134 | ## 🤘 Want to contribute?
135 | Would you like to contribute? Please, drop me an [e-mail](mailto:elias.bssn@gmail.com?subject=[GitHub]%20retriv).
136 |
137 |
138 | ## 📄 License
139 | [retriv](https://github.com/AmenRa/retriv) is an open-sourced software licensed under the [MIT license](LICENSE).
140 |
--------------------------------------------------------------------------------
/_typos.toml:
--------------------------------------------------------------------------------
1 | [default]
2 | extend-ignore-identifiers-re = [
3 | ".*Hsi",
4 | ]
5 |
--------------------------------------------------------------------------------
/changelog.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | All notable changes to this project will be documented in this file.
3 |
4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
6 |
7 | ## [0.2.2] - 2023-08-23
8 | ### Added
9 | - Added `advanced_retriever.py`.
10 |
11 | ### Changed
12 | - Text preprocessing refactoring.
13 |
14 | ## [0.2.1] - 2023-05-16
15 | ### Added
16 | - Added doc strings to `sparse_retriver.py`.
17 | - Added doc strings to `dense_retriver.py`.
18 | - Added doc strings to `hybrid_retriver.py`.
19 |
20 | ## [0.2.0] - 2023-02-19
21 | ### Added
22 | - Added `sparse_retriver.py`.
23 | - Added `dense_retriver.py`.
24 | - Added `hybrid_retriver.py`.
25 | - Added `encoder.py`.
26 | - Added `ann_searcher.py`.
27 | - Added `merger.py`.
28 |
29 | ### Changed
30 | - Almost everything.
31 |
32 | ### Removed
33 | - Removed dependance from `cyhunspell`.
34 |
35 | ## [0.1.5] - 2023-01-26
36 | ### Changed
37 | - Search efficiency improvements.
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | # NEEDED BY PYTETST
2 |
--------------------------------------------------------------------------------
/dev_env.yml:
--------------------------------------------------------------------------------
1 | name: retriv
2 | dependencies:
3 | - python=3.8
4 | - black
5 | - conda-forge::icecream
6 | - isort
7 | - ipykernel
8 | - ipywidgets
9 | - conda-forge::mkdocs-material
10 | - conda-forge::mkdocs-autorefs
11 | - conda-forge::mkdocstrings
12 | - conda-forge::mkdocstrings-python
13 | - pygments>=2.12
14 | - notebook
15 | - pytest
16 | # Pip
17 | - pip
18 | - pip:
19 | - ranx
20 | - krovetzstemmer
21 | - twine
22 | - oneliner_utils
23 | - indxr
24 | - numpy
25 | - nltk
26 | - numba>=0.54.1
27 | - tqdm
28 | - optuna
29 | - pystemmer==2.0.1
30 | - unidecode
31 | - scikit-learn
32 | - torch
33 | - torchvision
34 | - torchaudio
35 | - transformers[torch]
36 | - faiss-cpu
37 | - autofaiss
38 | - multipipe
39 |
--------------------------------------------------------------------------------
/docs/advanced_retriever.md:
--------------------------------------------------------------------------------
1 | # Advanced Retriever
2 |
3 | ⚠️ This is an experimental feature.
4 |
5 | The Advanced Retriever is a searcher based on lexical matching and search filters.
6 | It supports [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) and [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf) as the [Sparse Retriever](https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md) and provides the same resources for multi-lingual [text pre-processing](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md). In addition, it supports search filters, i.e., a set of rules that can be used to filter out documents from the search results.
7 |
8 | In the following, we show how to build a search engine employing an advanced retriever, index a document collection, and search it.
9 |
10 | ## Schema
11 |
12 | The first step to building an Advanced Retriever is to define the `schema` of document collection.
13 | The `schema` is a dictionary describing the documents' `fields` and their `data types`.
14 | Based on the `data types`, search `filters` can be defined and applied to the search results.
15 |
16 | [retriv](https://github.com/AmenRa/retriv) supports the following data types:
17 | - __id:__ field used for the document IDs.
18 | - __text:__ text field used for lexical matching.
19 | - __number:__ numeric value.
20 | - __bool:__ boolean value (True or False).
21 | - __keyword:__ string or number representing a keyword or a category.
22 | - __keywords:__ list of keywords.
23 |
24 | An example of `schema` for a collection of books is shown below.
25 | NB: At the time of writing, [retriv](https://github.com/AmenRa/retriv) supports only one text field per schema.
26 | Therefore, the `content` field is used for both the title and the abstract of the books.
27 |
28 | ```json
29 | schema = {
30 | "isbn": "id",
31 | "content": "text",
32 | "year": "number",
33 | "is_english": "bool",
34 | "author": "keyword",
35 | "genres": "keywords",
36 | }
37 | ```
38 |
39 | ## Build
40 |
41 | The Advanced Retriever provides several options to tailor its functioning to you preferences, as shown below.
42 |
43 | ```python
44 | from retriv.experimental import AdvancedRetriever
45 |
46 | ar = AdvancedRetriever(
47 | schema=schema,
48 | index_name="new-index",
49 | model="bm25",
50 | min_df=1,
51 | tokenizer="whitespace",
52 | stemmer="english",
53 | stopwords="english",
54 | do_lowercasing=True,
55 | do_ampersand_normalization=True,
56 | do_special_chars_normalization=True,
57 | do_acronyms_normalization=True,
58 | do_punctuation_removal=True,
59 | )
60 | ```
61 |
62 | - `schema`: the documents' schema.
63 | - `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index.
64 | - `model`: defines the retrieval model to use for searching (`bm25` or `tf-idf`).
65 | - `min_df`: terms that appear in less than `min_df` documents will be ignored.
66 | If integer, the parameter indicates the absolute count.
67 | If float, it represents a proportion of documents.
68 | - `tokenizer`: [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization by setting the parameter to `None`.
69 | - `stemmer`: [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`.
70 | - `stopwords`: [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`.
71 | - `do_lowercasing`: whether to lowercase texts.
72 | - `do_ampersand_normalization`: whether to convert `&` in `and` during pre-processing.
73 | - `do_special_chars_normalization`: whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`.
74 | - `do_acronyms_normalization`: whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`.
75 | - `do_punctuation_removal`: whether to remove punctuation.
76 |
77 | __Note:__ text pre-processing is equally applied to documents during indexing and to queries at search time.
78 |
79 | ## Index
80 |
81 | ### Create
82 | You can index a document collection from JSONl, CSV, or TSV files.
83 | CSV and TSV files must have a header.
84 | [retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it.
85 | Use the `callback` parameter to pass a function for converting your documents in the format defined by your `schema` on the fly.
86 | Indexes are automatically persisted on disk at the end of the process.
87 |
88 | ```python
89 | ar = ar.index_file(
90 | path="path/to/collection", # File kind is automatically inferred
91 | show_progress=True, # Default value
92 | callback=lambda doc: { # Callback defaults to None.
93 | "id": doc["id"],
94 | "text": doc["title"] + ". " + doc["text"],
95 | ...
96 | )
97 | ```
98 |
99 | ### Load
100 | ```python
101 | ar = AdvancedRetriever.load("index-name")
102 | ```
103 |
104 | ### Delete
105 | ```python
106 | AdvancedRetriever.delete("index-name")
107 | ```
108 |
109 | ## Search
110 |
111 | ### Query & Filters
112 |
113 | Advanced Retriever search query can be either a string or a dictionary.
114 | In the former case, the string is used as the query text and no filters are applied.
115 | In the latter case, the dictionary defines the query text and the filters to apply to the search results. If the query text is omitted from the dictionary, documents matching the filters will be returned.
116 |
117 | [retriv](https://github.com/AmenRa/retriv) supports two way of filtering the search results (`where` and `where_not`) and several type-specific operators.
118 |
119 | - `where` means that only the documents matching the filter will be considered during search.
120 | - `where_not` means that the documents matching the filter will be ignored during search.
121 |
122 | Below we describe the effects of the supported operators for each data type and way of filtering.
123 |
124 | #### Where
125 |
126 | | Field Type | Operator | Value | Meaning |
127 | | ---------- | --------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
128 | | number | `eq` | number | Only the documents whose field value is **equal to** the provided value will be considered during search. |
129 | | number | `gt` | number | Only the documents whose field value is **greater than** the provided value will be considered during search. |
130 | | number | `gte` | number | Only the documents whose field value is **greater or equal to** the provided value will be considered during search. |
131 | | number | `lt` | number | Only the documents whose field value is **less than** the provided value will be considered during search. |
132 | | number | `lte` | number | Only the documents whose field value is **less or equal to** the provided value will be considered during search. |
133 | | number | `between` | number | Only the documents whose field value is **between** the provided values (included) will be considered during search. |
134 | | bool | | True / False | Only the documents whose field value is **equal to** the provided value will be considered during search. |
135 | | keyword | | any value / list of values | Only the documents whose field value is **equal to** the provided value or **among** the provided values will be considered during search. |
136 | | keywords | `or` | any value / list of values | Only the documents whose field value is **contains** the provided value or **contains one of** the provided values will be considered during search. |
137 | | keywords | `and` | any value / list of values | Only the documents whose field value **contains all** the provided values will be considered during search. |
138 |
139 | Query example:
140 | ```python
141 | query = {
142 | "text": "search terms",
143 | "where": {
144 | "numeric_field_name": ("gte", 1970),
145 | "boolean_field_name": True,
146 | "keyword_field_name": "kw_1",
147 | "keywords_field_name": ("or", ["kws_23", "kws_666"]),
148 | }
149 | }
150 | ```
151 |
152 | Alternatively, you can omit the `where` key and use the following syntax:
153 | ```python
154 | query = {
155 | "text": "search terms",
156 | "numeric_field_name": ("gte", 1970),
157 | "boolean_field_name": True,
158 | "keyword_field_name": "kw_1",
159 | "keywords_field_name": ("or", ["kws_23", "kws_666"]),
160 | }
161 | ```
162 |
163 |
164 | #### Where not
165 |
166 | | Field Type | Operator | Value | Meaning |
167 | | ---------- | --------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
168 | | number | `eq` | number | The documents whose field value is **equal to** the provided value will be ignored. |
169 | | number | `gt` | number | The documents whose field value is **greater than** the provided value will be ignored. |
170 | | number | `gte` | number | The documents whose field value is **greater or equal to** the provided value will be ignored. |
171 | | number | `lt` | number | The documents whose field value is **less than** the provided value will be ignored. |
172 | | number | `lte` | number | The documents whose field value is **less or equal to** the provided value will be ignored. |
173 | | number | `between` | number | The documents whose field value is **between** the provided values (included) will be ignored. |
174 | | bool | | True / False | The documents whose field value is **equal to** the provided value will be ignored. |
175 | | keyword | | any value / list of values | The documents whose field value is **equal to** the provided value or **among** the provided values will be ignored. |
176 | | keywords | `or` | any value / list of values | The documents whose field value is **contains** the provided value or **contains one of** the provided values will be ignored. |
177 | | keywords | `and` | any value / list of values | The documents whose field value **contains all** the provided values will be ignored. |
178 |
179 | Query example:
180 | ```python
181 | query = {
182 | "text": "search terms",
183 | "where": {
184 | "numeric_field_name": ("gte", 1970),
185 | "boolean_field_name": True,
186 | "keyword_field_name": "kw_1",
187 | "keywords_field_name": ("or", ["kws_23", "kws_666"]),
188 | }
189 | }
190 | ```
191 |
192 | ### Search
193 |
194 | ```python
195 | ar.search(
196 | query: ...
197 | return_docs=True # Default value.
198 | cutoff=100 # Default value.
199 | operator="OR" # Default value.
200 | subset_doc_ids=None # Default value.
201 | )
202 | ```
203 |
204 | - `query`: what to search for and which filters to apply. See the section [Query & Filters](#query--filters) for more details.
205 | - `return_docs`: whether to return documents or only their IDs.
206 | - `cutoff`: number of results to return.
207 | - `operator`: whether to perform conjunctive (`AND`) or disjunctive (`OR`) search. Conjunctive search retrieves documents that contain **all** the query terms. Disjunctive search retrieves documents that contain **at least one** of the query terms.
208 | - `subset_doc_ids`: restrict the search to the subset of documents having the provided IDs.
209 |
210 | Sample output:
211 | ```json
212 | [
213 | {
214 | "id": "doc_2",
215 | "text": "Just like witches at black masses",
216 | "score": 1.7536403
217 | },
218 | {
219 | "id": "doc_1",
220 | "text": "Generals gathered in their masses",
221 | "score": 0.6931472
222 | }
223 | ]
224 | ```
225 |
226 |
227 |
228 |
229 |
265 |
266 |
282 |
--------------------------------------------------------------------------------
/docs/dense_retriever.md:
--------------------------------------------------------------------------------
1 | # Dense Retriever
2 |
3 | The Dense Retriever performs [semantic search](https://en.wikipedia.org/wiki/Semantic_search), i.e., it compares vector representations of queries and documents to compute the relevance scores of the latter.
4 | The Dense Retriever comprises two components: the `Encoder` and the `ANN Searcher`, described below.
5 |
6 | In the following, we show how to build a search engine for [semantic search](https://en.wikipedia.org/wiki/Semantic_search) based on dense retrieval, index a document collection, and search it.
7 |
8 | ## Build
9 |
10 | Building a Dense Retriever is as simple as shown below.
11 | Default parameter values are shown.
12 |
13 | ```python
14 | from retriv import DenseRetriever
15 |
16 | dr = DenseRetriever(
17 | index_name="new-index",
18 | model="sentence-transformers/all-MiniLM-L6-v2",
19 | normalize=True,
20 | max_length=128,
21 | use_ann=True,
22 | )
23 | ```
24 |
25 | - `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index.
26 | - `model`: defines the encoder model to encode queries and documents into vectors. You can use an [HuggingFace's Transformers](https://huggingface.co/models) pre-trained model by providing its ID or load a local model by providing its path.
27 | In the case of local models, the path must point to the directory containing the data saved with the [`PreTrainedModel.save_pretrained`](https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/model#transformers.PreTrainedModel.save_pretrained) method.
28 | Note that the representations are computed with `mean pooling` over the `last_hidden_state`.
29 | - `normalize`: whether to L2 normalize the vector representations.
30 | - `max_length`: texts longer than `max_length` will be automatically truncated. Choose this parameter based on how the employed model was trained or is generally used.
31 | - `use_ann`: whether to use approximate nearest neighbors search. Set it to `False` to use nearest neighbors search without approximation. If you have less than 20k documents in your collection, you probably want to disable approximation.
32 |
33 | ## Index
34 |
35 | ### Create
36 | You can index a document collection from JSONl, CSV, or TSV files.
37 | CSV and TSV files must have a header.
38 | [retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it.
39 | Use the `callback` parameter to pass a function for converting your documents in the format supported by [retriv](https://github.com/AmenRa/retriv) on the fly.
40 | Documents must have a single `text` field and an `id`.
41 | Indexes are automatically persisted on disk at the end of the process.
42 | To speed up the indexing process, you can activate GPU-based encoding.
43 |
44 | ```python
45 | dr = dr.index_file(
46 | path="path/to/collection", # File kind is automatically inferred
47 | embeddings_path=None, # Default value
48 | use_gpu=False, # Default value
49 | batch_size=512, # Default value
50 | show_progress=True, # Default value
51 | callback=lambda doc: { # Callback defaults to None.
52 | "id": doc["id"],
53 | "text": doc["title"] + ". " + doc["text"],
54 | ),
55 | )
56 | ```
57 |
58 | - `embeddings_path`: in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file.
59 | - `use_gpu`: whether to use the GPU for document encoding.
60 | - `batch_size`: how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput.
61 |
62 |
63 | ### Load
64 | ```python
65 | dr = DenseRetriever.load("index-name")
66 | ```
67 |
68 | ### Delete
69 | ```python
70 | DenseRetriever.delete("index-name")
71 | ```
72 |
73 | ## Search
74 |
75 | ### Search
76 |
77 | Standard search functionality.
78 |
79 | ```python
80 | dr.search(
81 | query="witches masses", # What to search for
82 | return_docs=True, # Default value, return the text of the documents
83 | cutoff=100, # Default value, number of results to return
84 | )
85 | ```
86 | Output:
87 | ```json
88 | [
89 | {
90 | "id": "doc_2",
91 | "text": "Just like witches at black masses",
92 | "score": 0.9536403
93 | },
94 | {
95 | "id": "doc_1",
96 | "text": "Generals gathered in their masses",
97 | "score": 0.6931472
98 | }
99 | ]
100 | ```
101 |
102 | ### Multi-Search
103 |
104 | Compute results for multiple queries at once.
105 |
106 | ```python
107 | dr.msearch(
108 | queries=[{"id": "q_1", "text": "witches masses"}, ...],
109 | cutoff=100, # Default value, number of results
110 | batch_size=32, # Default value.
111 | )
112 | ```
113 | Output:
114 | ```json
115 | {
116 | "q_1": {
117 | "doc_2": 1.7536403,
118 | "doc_1": 0.6931472
119 | },
120 | ...
121 | }
122 | ```
123 |
124 | - `batch_size`: how many searches to perform at once. Regulate it if you ran into memory usage issues or want to maximize throughput.
125 |
126 | ### Batch-Search
127 |
128 | Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format.
129 | [bsearch](#batch-search) is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM.
130 |
131 | ```python
132 | dr.bsearch(
133 | queries=[{"id": "q_1", "text": "witches masses"}, ...],
134 | cutoff=100,
135 | batch_size=32,
136 | show_progress=True,
137 | qrels=None, # To add relevance information to the saved files
138 | path=None, # Where to save the results, if you want to save them
139 | )
140 | ```
141 |
--------------------------------------------------------------------------------
/docs/faq.md:
--------------------------------------------------------------------------------
1 | # FAQ
2 |
3 | ## How do I change `retriv` working directory?
4 | ```python
5 | import retriv
6 |
7 | retriv.set_base_path("new/working/path")
8 | ```
9 |
10 | or
11 |
12 | ```python
13 | import os
14 |
15 | os.environ["RETRIV_BASE_PATH"] = "new/working/path"
16 | ```
17 |
--------------------------------------------------------------------------------
/docs/filters.md:
--------------------------------------------------------------------------------
1 | ## Filtering Search Results
2 |
3 | [retriv](https://github.com/AmenRa/retriv) supports two way of filtering the search results (`where` and `where_not`) and several type-specific operators.
4 |
5 | - `where` means that only the documents matching the filter will be considered during search.
6 | - `where_not` means that the documents matching the filter will be ignored during search.
7 |
8 | Below we describe the effects of the supported operators for each data type and way of filtering.
9 |
10 | ### Where
11 |
12 | | Field Type | Operator | Value | Meaning |
13 | | ---------- | --------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
14 | | number | `eq` | number | Only the documents whose field value is **equal to** the provided value will be considered during search. |
15 | | number | `gt` | number | Only the documents whose field value is **greater than** the provided value will be considered during search. |
16 | | number | `gte` | number | Only the documents whose field value is **greater or equal to** the provided value will be considered during search. |
17 | | number | `lt` | number | Only the documents whose field value is **less than** the provided value will be considered during search. |
18 | | number | `lte` | number | Only the documents whose field value is **less or equal to** the provided value will be considered during search. |
19 | | number | `between` | number | Only the documents whose field value is **between** the provided values (included) will be considered during search. |
20 | | bool | | True / False | Only the documents whose field value is **equal to** the provided value will be considered during search. |
21 | | keyword | | any value / list of values | Only the documents whose field value is **equal to** the provided value or **among** the provided values will be considered during search. |
22 | | keywords | `or` | any value / list of values | Only the documents whose field value is **contains** the provided value or **contains one of** the provided values will be considered during search. |
23 | | keywords | `and` | any value / list of values | Only the documents whose field value **contains all** the provided values will be considered during search. |
24 |
25 | Query example:
26 | ```python
27 | query = {
28 | "text": "search terms",
29 | "where": {
30 | "numeric_field_name": ("gte", 1970),
31 | "boolean_field_name": True,
32 | "keyword_field_name": "kw_1",
33 | "keywords_field_name": ("or", ["kws_23", "kws_666"]),
34 | }
35 | }
36 | ```
37 |
38 | Alternatively, you can omit the `where` key and use the following syntax:
39 | ```python
40 | query = {
41 | "text": "search terms",
42 | "numeric_field_name": ("gte", 1970),
43 | "boolean_field_name": True,
44 | "keyword_field_name": "kw_1",
45 | "keywords_field_name": ("or", ["kws_23", "kws_666"]),
46 | }
47 | ```
48 |
49 |
50 | ### Where not
51 |
52 | | Field Type | Operator | Value | Meaning |
53 | | ---------- | --------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
54 | | number | `eq` | number | The documents whose field value is **equal to** the provided value will be ignored. |
55 | | number | `gt` | number | The documents whose field value is **greater than** the provided value will be ignored. |
56 | | number | `gte` | number | The documents whose field value is **greater or equal to** the provided value will be ignored. |
57 | | number | `lt` | number | The documents whose field value is **less than** the provided value will be ignored. |
58 | | number | `lte` | number | The documents whose field value is **less or equal to** the provided value will be ignored. |
59 | | number | `between` | number | The documents whose field value is **between** the provided values (included) will be ignored. |
60 | | bool | | True / False | The documents whose field value is **equal to** the provided value will be ignored. |
61 | | keyword | | any value / list of values | The documents whose field value is **equal to** the provided value or **among** the provided values will be ignored. |
62 | | keywords | `or` | any value / list of values | The documents whose field value is **contains** the provided value or **contains one of** the provided values will be ignored. |
63 | | keywords | `and` | any value / list of values | The documents whose field value **contains all** the provided values will be ignored. |
64 |
65 | Query example:
66 | ```python
67 | query = {
68 | "text": "search terms",
69 | "where": {
70 | "numeric_field_name": ("gte", 1970),
71 | "boolean_field_name": True,
72 | "keyword_field_name": "kw_1",
73 | "keywords_field_name": ("or", ["kws_23", "kws_666"]),
74 | }
75 | }
76 | ```
--------------------------------------------------------------------------------
/docs/hybrid_retriever.md:
--------------------------------------------------------------------------------
1 | # Sparse Retriever
2 |
3 | The [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) is searcher based on both lexical and semantic matching.
4 | It comprises three components: the [Sparse Retriever]((https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md)), the [Dense Retriever]((https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md)), and the Merger.
5 | The Merger fuses the results of the Sparse and Dense Retrievers to compute the _hybrid_ results.
6 |
7 | In the following, we show how to build an hybrid search engine, index a document collection, and search it.
8 |
9 | ## Build
10 |
11 | You can instantiate and Hybrid Retriever as follows.
12 | Default parameter values are shown.
13 |
14 | ```python
15 | from retriv import HybridRetriever
16 |
17 | hr = HybridRetriever(
18 | # Shared params ------------------------------------------------------------
19 | index_name="new-index",
20 | # Sparse retriever params --------------------------------------------------
21 | sr_model="bm25",
22 | min_df=1,
23 | tokenizer="whitespace",
24 | stemmer="english",
25 | stopwords="english",
26 | do_lowercasing=True,
27 | do_ampersand_normalization=True,
28 | do_special_chars_normalization=True,
29 | do_acronyms_normalization=True,
30 | do_punctuation_removal=True,
31 | # Dense retriever params ---------------------------------------------------
32 | dr_model="sentence-transformers/all-MiniLM-L6-v2",
33 | normalize=True,
34 | max_length=128,
35 | use_ann=True,
36 | )
37 | ```
38 |
39 | - Shared params:
40 | - `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index.
41 | - Sparse Retriever params:
42 | - `sr_model`: defines the model to use for sparse retrieval (`bm25` or `tf-idf`).
43 | - `min_df`: terms that appear in less than `min_df` documents will be ignored.
44 | If integer, the parameter indicates the absolute count.
45 | If float, it represents a proportion of documents.
46 | - `tokenizer`: [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization setting the parameter to `None`.
47 | - `stemmer`: [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`.
48 | - `stopwords`: [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`.
49 | - `do_lowercasing`: whether to lowercase texts.
50 | - `do_ampersand_normalization`: whether to convert `&` in `and` during pre-processing.
51 | - `do_special_chars_normalization`: whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`.
52 | - `do_acronyms_normalization`: whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`.
53 | - `do_punctuation_removal`: whether to remove punctuation.
54 | - Dense Retriever params:
55 | - `dr_model`: defines the model to use for encoding queries and documents into vectors. You can use an [HuggingFace's Transformers](https://huggingface.co/models) pre-trained model by providing its ID or load a local model by providing its path.
56 | In the case of local models, the path must point to the directory containing the data saved with the [`PreTrainedModel.save_pretrained`](https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/model#transformers.PreTrainedModel.save_pretrained) method.
57 | Note that the representations are computed with `mean pooling` over the `last_hidden_state`.
58 | - `normalize`: whether to L2 normalize the vector representations.
59 | - `max_length`: texts longer than `max_length` will be automatically truncated. Choose this parameter based on how the employed model was trained or is generally used.
60 | - `use_ann`: whether to use approximate nearest neighbors search. Set it to `False` to use nearest neighbors search without approximation. If you have less than 20k documents in your collection, you probably want to disable approximation.
61 |
62 | __Note:__ text pre-processing is equally applied to documents during indexing and to queries at search time.
63 |
64 | ## Index
65 |
66 | ### Create
67 | You can index a document collection from JSONl, CSV, or TSV files.
68 | CSV and TSV files must have a header.
69 | [retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it.
70 | Use the `callback` parameter to pass a function for converting your documents in the format supported by [retriv](https://github.com/AmenRa/retriv) on the fly.
71 | Documents must have a single `text` field and an `id`.
72 | The Hybrid Retriever sequentially build the indices for the Sparse and Dense Retrievers.
73 | Indexes are automatically persisted on disk at the end of the process.
74 | To speed up the indexing process of the Dense Retriever, you can activate GPU-based encoding.
75 |
76 | ```python
77 | hr = hr.index_file(
78 | path="path/to/collection", # File kind is automatically inferred
79 | embeddings_path=None, # Default value
80 | use_gpu=False, # Default value
81 | batch_size=512, # Default value
82 | show_progress=True, # Default value
83 | callback=lambda doc: { # Callback defaults to None.
84 | "id": doc["id"],
85 | "text": doc["title"] + ". " + doc["text"],
86 | ),
87 | )
88 | ```
89 |
90 | ### Load
91 | ```python
92 | hr = HybridRetriever.load("index-name")
93 | ```
94 |
95 | ### Delete
96 | ```python
97 | HybridRetriever.delete("index-name")
98 | ```
99 |
100 | ## Search
101 |
102 | During search, the Hybrid Retriever fuses the top 1000 results of the Sparse and Dense Retrievers.
103 |
104 | ### Search
105 |
106 | Standard search functionality.
107 |
108 | ```python
109 | hr.search(
110 | query="witches masses", # What to search for
111 | return_docs=True, # Default value, return the text of the documents
112 | cutoff=100, # Default value, number of results to return
113 | )
114 | ```
115 | Output:
116 | ```json
117 | [
118 | {
119 | "id": "doc_2",
120 | "text": "Just like witches at black masses",
121 | "score": 0.9536403
122 | },
123 | {
124 | "id": "doc_1",
125 | "text": "Generals gathered in their masses",
126 | "score": 0.6931472
127 | }
128 | ]
129 | ```
130 |
131 | ### Multi-Search
132 |
133 | Compute results for multiple queries at once.
134 |
135 | ```python
136 | hr.msearch(
137 | queries=[{"id": "q_1", "text": "witches masses"}, ...],
138 | cutoff=100, # Default value, number of results
139 | batch_size=32, # Default value.
140 | )
141 | ```
142 | Output:
143 | ```json
144 | {
145 | "q_1": {
146 | "doc_2": 1.7536403,
147 | "doc_1": 0.6931472
148 | },
149 | ...
150 | }
151 | ```
152 |
153 | - `batch_size`: how many searches to perform at once. Regulate it if you ran into memory usage issues or want to maximize throughput.
154 |
155 | ### Batch-Search
156 |
157 | Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format.
158 | [bsearch](#batch-search) is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM.
159 |
160 | ```python
161 | hr.bsearch(
162 | queries=[{"id": "q_1", "text": "witches masses"}, ...],
163 | cutoff=100,
164 | batch_size=32,
165 | show_progress=True,
166 | qrels=None, # To add relevance information to the saved files
167 | path=None, # Where to save the results, if you want to save them
168 | )
169 | ```
170 |
171 | ## AutoTune
172 |
173 | Use the AutoTune function to tune the Sparse Retriever's model [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) parameters and the importance given to the lexical and semantic relevance scores computed by the Sparse and Dense Retrievers, respectively.
174 | All metrics supported by [ranx](https://github.com/AmenRa/ranx) are supported by the `autotune` function.
175 | At the of the process, the best parameter configurations are automatically applied and saved to disk.
176 | You can inspect the best configurations found by printing `hr.sparse_retriever.hyperparams`, `hr.merger.norm` and `hr.merger.params`.
177 |
178 | ```python
179 | sr.autotune(
180 | queries=[{ "q_id": "q_1", "text": "...", ... }], # Train queries
181 | qrels={ "q_1": { "doc_1": 1, ... }, ... }, # Train qrels
182 | metric="ndcg", # Default value, metric to maximize
183 | n_trials=100, # Default value, number of trials
184 | cutoff=100, # Default value, number of results
185 | batch_size=32, # Default value
186 | )
187 | ```
188 |
--------------------------------------------------------------------------------
/docs/sparse_retriever.md:
--------------------------------------------------------------------------------
1 | # Sparse Retriever
2 |
3 | The Sparse Retriever is a traditional searcher based on lexical matching.
4 | It supports [BM25](https://en.wikipedia.org/wiki/Okapi_BM25), the retrieval model used by major search engines libraries, such as [Lucene](https://en.wikipedia.org/wiki/Apache_Lucene) and [Elasticsearch](https://en.wikipedia.org/wiki/Elasticsearch).
5 | [retriv](https://github.com/AmenRa/retriv) also implements the classic relevance model [TF-IDF](https://en.wikipedia.org/wiki/Tf–idf) for educational purposes.
6 |
7 | The Sparse Retriever also provides several resources for multi-lingual [text pre-processing](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md), aiming to maximize its retrieval effectiveness.
8 |
9 | In the following, we show how to build a search engine employing a sparse retriever, index a document collection, and search it.
10 |
11 | ## Build
12 |
13 | The Sparse Retriever provides several options to tailor its functioning to you preferences, as shown below.
14 | Default parameter values are shown.
15 |
16 | ```python
17 | # Note: the SparseRetriever has an alias called SearchEngine, if you prefer
18 | from retriv import SparseRetriever
19 |
20 | sr = SparseRetriever(
21 | index_name="new-index",
22 | model="bm25",
23 | min_df=1,
24 | tokenizer="whitespace",
25 | stemmer="english",
26 | stopwords="english",
27 | do_lowercasing=True,
28 | do_ampersand_normalization=True,
29 | do_special_chars_normalization=True,
30 | do_acronyms_normalization=True,
31 | do_punctuation_removal=True,
32 | )
33 | ```
34 |
35 | - `index_name`: [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index.
36 | - `model`: defines the retrieval model to use for searching (`bm25` or `tf-idf`).
37 | - `min_df`: terms that appear in less than `min_df` documents will be ignored.
38 | If integer, the parameter indicates the absolute count.
39 | If float, it represents a proportion of documents.
40 | - `tokenizer`: [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization by setting the parameter to `None`.
41 | - `stemmer`: [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`.
42 | - `stopwords`: [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`.
43 | - `do_lowercasing`: whether to lowercase texts.
44 | - `do_ampersand_normalization`: whether to convert `&` in `and` during pre-processing.
45 | - `do_special_chars_normalization`: whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`.
46 | - `do_acronyms_normalization`: whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`.
47 | - `do_punctuation_removal`: whether to remove punctuation.
48 |
49 | __Note:__ text pre-processing is equally applied to documents during indexing and to queries at search time.
50 |
51 | ## Index
52 |
53 | ### Create
54 | You can index a document collection from JSONl, CSV, or TSV files.
55 | CSV and TSV files must have a header.
56 | [retriv](https://github.com/AmenRa/retriv) automatically infers the file kind, so there's no need to specify it.
57 | Use the `callback` parameter to pass a function for converting your documents in the format supported by [retriv](https://github.com/AmenRa/retriv) on the fly.
58 | Documents must have a single `text` field and an `id`.
59 | Indexes are automatically persisted on disk at the end of the process.
60 | Indexing functionalities are built to have minimal memory footprint while leveraging multi-processing for efficiency.
61 | Indexing 10M documents takes from 5 to 10 minutes on a [AMD Ryzen™ 9 5950X](https://www.amd.com/en/products/cpu/amd-ryzen-9-5950x), depending on the length of the documents.
62 |
63 | ```python
64 | sr = sr.index_file(
65 | path="path/to/collection", # File kind is automatically inferred
66 | show_progress=True, # Default value
67 | callback=lambda doc: { # Callback defaults to None.
68 | "id": doc["id"],
69 | "text": doc["title"] + ". " + doc["text"],
70 | )
71 | ```
72 |
73 | ### Load
74 | ```python
75 | sr = SparseRetriever.load("index-name")
76 | ```
77 |
78 | ### Delete
79 | ```python
80 | SparseRetriever.delete("index-name")
81 | ```
82 |
83 | ## Search
84 |
85 | ### Search
86 |
87 | Standard search functionality.
88 |
89 | ```python
90 | sr.search(
91 | query="witches masses", # What to search for
92 | return_docs=True, # Default value, return the text of the documents
93 | cutoff=100, # Default value, number of results to return
94 | )
95 | ```
96 | Output:
97 | ```json
98 | [
99 | {
100 | "id": "doc_2",
101 | "text": "Just like witches at black masses",
102 | "score": 1.7536403
103 | },
104 | {
105 | "id": "doc_1",
106 | "text": "Generals gathered in their masses",
107 | "score": 0.6931472
108 | }
109 | ]
110 | ```
111 |
112 | ### Multi-Search
113 |
114 | Compute results for multiple queries at once.
115 |
116 | ```python
117 | sr.msearch(
118 | queries=[{"id": "q_1", "text": "witches masses"}, ...],
119 | cutoff=100, # Default value, number of results
120 | )
121 | ```
122 | Output:
123 | ```json
124 | {
125 | "q_1": {
126 | "doc_2": 1.7536403,
127 | "doc_1": 0.6931472
128 | },
129 | ...
130 | }
131 | ```
132 |
133 | ### Batch-Search
134 |
135 | Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format.
136 | [bsearch](#batch-search) is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM.
137 |
138 | ```python
139 | sr.bsearch(
140 | queries=[{"id": "q_1", "text": "witches masses"}, ...],
141 | cutoff=100,
142 | batch_size=1_000,
143 | show_progress=True,
144 | qrels=None, # To add relevance information to the saved files
145 | path=None, # Where to save the results, if you want to save them
146 | )
147 | ```
148 |
149 | ## AutoTune
150 |
151 | Use the AutoTune function to tune [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) parameters w.r.t. your document collection and queries.
152 | All metrics supported by [ranx](https://github.com/AmenRa/ranx) are supported by the `autotune` function.
153 | At the of the process, the best parameter configuration is automatically applied to the `SparseRetriever` instance and saved to disk.
154 | You can inspect the current configuration by printing `sr.hyperparams`.
155 |
156 | ```python
157 | sr.autotune(
158 | queries=[{ "q_id": "q_1", "text": "...", ... }], # Train queries
159 | qrels={ "q_1": { "doc_1": 1, ... }, ... }, # Train qrels
160 | metric="ndcg", # Default value, metric to maximize
161 | n_trials=100, # Default value, number of trials
162 | cutoff=100, # Default value, number of results
163 | )
164 | ```
165 |
--------------------------------------------------------------------------------
/docs/speed.md:
--------------------------------------------------------------------------------
1 | ## Speed Comparison
2 |
3 | TO BE UPDATED
4 |
5 | We performed a speed test, comparing [retriv](https://github.com/AmenRa/retriv) to [rank_bm25](https://github.com/dorianbrown/rank_bm25), a popular [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) implementation in [Python](https://en.wikipedia.org/wiki/Python_(programming_language)), and [pyserini](https://github.com/castorini/pyserini), a [Python](https://en.wikipedia.org/wiki/Python_(programming_language)) binding to the [Lucene](https://en.wikipedia.org/wiki/Apache_Lucene) search engine.
6 |
7 | We relied on the [MSMARCO Passage](https://microsoft.github.io/msmarco) dataset to collect documents and queries.
8 | Specifically, we used the original document collection and three sub-samples of it, accounting for 1k, 100k, and 1M documents, respectively, and sampled 1k queries from the original ones.
9 | We computed the top-100 results with each library (if possible).
10 | Results are reported below. Best results are highlighted in boldface.
11 |
12 | | Library | Collection Size | Elapsed Time | Avg. Query Time | Throughput (q/s) |
13 | | --------------------------------------------------------- | --------------: | -----------: | --------------: | ---------------: |
14 | | [rank_bm25](https://github.com/dorianbrown/rank_bm25) | 1,000 | 646ms | 6.5ms | 1548/s |
15 | | [pyserini](https://github.com/castorini/pyserini) | 1,000 | 1,438ms | 1.4ms | 695/s |
16 | | [retriv](https://github.com/AmenRa/retriv) | 1,000 | 140ms | 0.1ms | 7143/s |
17 | | [retriv](https://github.com/AmenRa/retriv) (multi-search) | 1,000 | __134ms__ | __0.1ms__ | __7463/s__ |
18 | | [rank_bm25](https://github.com/dorianbrown/rank_bm25) | 100,000 | 106,000ms | 1060ms | 1/s |
19 | | [pyserini](https://github.com/castorini/pyserini) | 100,000 | 2,532ms | 2.5ms | 395/s |
20 | | [retriv](https://github.com/AmenRa/retriv) | 100,000 | 314ms | 0.3ms | 3185/s |
21 | | [retriv](https://github.com/AmenRa/retriv) (multi-search) | 100,000 | __256ms__ | __0.3ms__ | __3906__/s |
22 | | [rank_bm25](https://github.com/dorianbrown/rank_bm25) | 1,000,000 | N/A | N/A | N/A |
23 | | [pyserini](https://github.com/castorini/pyserini) | 1,000,000 | 4,060ms | 4.1ms | 246/s |
24 | | [retriv](https://github.com/AmenRa/retriv) | 1,000,000 | 1,018ms | 1.0ms | 982/s |
25 | | [retriv](https://github.com/AmenRa/retriv) (multi-search) | 1,000,000 | __503ms__ | __0.5ms__ | __1988/s__ |
26 | | [rank_bm25](https://github.com/dorianbrown/rank_bm25) | 8,841,823 | N/A | N/A | N/A |
27 | | [pyserini](https://github.com/castorini/pyserini) | 8,841,823 | 12,245ms | 12.2ms | 82/s |
28 | | [retriv](https://github.com/AmenRa/retriv) | 8,841,823 | 10,763ms | 10.8ms | 93/s |
29 | | [retriv](https://github.com/AmenRa/retriv) (multi-search) | 8,841,823 | __4,476ms__ | __4.4ms__ | __227/s__ |
--------------------------------------------------------------------------------
/docs/text_preprocessing.md:
--------------------------------------------------------------------------------
1 | # Text Pre-Processing
2 |
3 | [retriv](https://github.com/AmenRa/retriv) provides several resources for multi-lingual text pre-processing, aiming to maximize its retrieval effectiveness.
4 |
5 | ## Stemmers
6 | [Stemmers](https://en.wikipedia.org/wiki/Stemming) reduce words to their word stem, base or root form.
7 | [retriv](https://github.com/AmenRa/retriv) supports the following stemmers:
8 | - [snowball](https://snowballstem.org) (default)
9 | The following languages are supported by Snowball Stemmer:
10 | Arabic, Basque, Catalan, Danish, Dutch, English, Finnish, French, German, Greek, Hindi, Hungarian, Indonesian, Irish, Italian, Lithuanian, Nepali, Norwegian, Portuguese, Romanian, Russian, Spanish, Swedish, Tamil, Turkish.
11 | To select your preferred language simply use `` .
12 | - [arlstem](https://www.nltk.org/api/nltk.stem.arlstem.html) (Arabic)
13 | - [arlstem2](https://www.nltk.org/api/nltk.stem.arlstem2.html) (Arabic)
14 | - [cistem](https://www.nltk.org/api/nltk.stem.cistem.html) (German)
15 | - [isri](https://www.nltk.org/api/nltk.stem.isri.html) (Arabic)
16 | - [krovetz](https://dl.acm.org/doi/10.1145/160688.160718) (English)
17 | - [lancaster](https://www.nltk.org/api/nltk.stem.lancaster.html) (English)
18 | - [porter](https://www.nltk.org/api/nltk.stem.porter.html) (English)
19 |
20 |
21 | ## Tokenizers
22 | [Tokenizers](https://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) divide a string into smaller units, such as words.
23 | [retriv](https://github.com/AmenRa/retriv) supports the following tokenizers:
24 | - [whitespace](https://www.nltk.org/api/nltk.tokenize.html)
25 | - [word](https://www.nltk.org/api/nltk.tokenize.html)
26 | - [wordpunct](https://www.nltk.org/api/nltk.tokenize.html)
27 | - [sent](https://www.nltk.org/api/nltk.tokenize.html)
28 |
29 |
30 | ## Stop-word Lists
31 | [retriv](https://github.com/AmenRa/retriv) supports [stop-word](https://en.wikipedia.org/wiki/Stop_word) lists for the following languages: Arabic, Azerbaijani, Basque, Bengali, Catalan, Chinese, Danish, Dutch, English, Finnish, French, German, Greek, Hebrew, Hinglish, Hungarian, Indonesian, Italian, Kazakh, Nepali, Norwegian, Portuguese, Romanian, Russian, Slovene, Spanish, Swedish, Tajik, and Turkish.
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | # E501 ignore long comment lines (black should take care of that)
3 | # E712 ignore == False, needed for numba checks.
4 | ignore = ["E501", "E712"]
5 |
6 | [tool.isort]
7 | profile = "black"
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | black
2 | black[jupyter]
3 | blackdoc
4 | isort
5 | pytest
6 | pytest-cov
7 | pytest-xdist
8 | ruff
9 | twine
10 | typos
11 | wheel
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | autofaiss
2 | faiss-cpu
3 | indxr
4 | krovetzstemmer
5 | multipipe
6 | nltk
7 | numba>=0.54.1
8 | numpy
9 | oneliner_utils
10 | optuna
11 | pystemmer==2.0.1
12 | ranx
13 | scikit-learn
14 | torch
15 | torchaudio
16 | torchvision
17 | tqdm
18 | transformers[torch]
19 | unidecode
--------------------------------------------------------------------------------
/retriv/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "ANN_Searcher",
3 | "DenseRetriever",
4 | "Encoder",
5 | "SearchEngine",
6 | "SparseRetriever",
7 | "HybridRetriever",
8 | "Merger",
9 | ]
10 |
11 | import os
12 | from pathlib import Path
13 |
14 | from .dense_retriever.ann_searcher import ANN_Searcher
15 | from .dense_retriever.dense_retriever import DenseRetriever
16 | from .dense_retriever.encoder import Encoder
17 | from .hybrid_retriever import HybridRetriever
18 | from .merger.merger import Merger
19 | from .sparse_retriever.sparse_retriever import SparseRetriever
20 | from .sparse_retriever.sparse_retriever import SparseRetriever as SearchEngine
21 |
22 | # Set environment variables ----------------------------------------------------
23 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
24 | if "RETRIV_BASE_PATH" not in os.environ: # allow user to set a different path in .bash_profile
25 | os.environ["RETRIV_BASE_PATH"] = str(Path.home() / ".retriv")
26 |
27 | def set_base_path(path: str):
28 | os.environ["RETRIV_BASE_PATH"] = path
29 |
--------------------------------------------------------------------------------
/retriv/autotune/__init__.py:
--------------------------------------------------------------------------------
1 | from .bm25_autotune import tune_bm25
2 | from .merger_autotune import tune_merger
3 |
4 | __all__ = [
5 | "tune_bm25",
6 | "tune_merger",
7 | ]
8 |
--------------------------------------------------------------------------------
/retriv/autotune/bm25_autotune.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import optuna
4 | from optuna.exceptions import ExperimentalWarning
5 | from ranx import Qrels, Run, evaluate
6 |
7 | warnings.filterwarnings("ignore", category=ExperimentalWarning)
8 |
9 |
10 | def bm25_objective(trial, queries, qrels, se, metric, cutoff):
11 | b = trial.suggest_float("b", 0.0, 1.0, step=0.01)
12 | k1 = trial.suggest_float("k1", 0.0, 10.0, step=0.1)
13 |
14 | se.hyperparams = dict(b=b, k1=k1)
15 | run = Run(se.bsearch(queries=queries, cutoff=cutoff, show_progress=False))
16 |
17 | return evaluate(qrels, run, metric)
18 |
19 |
20 | def tune_bm25(queries, qrels, se, metric, n_trials, cutoff):
21 | qrels = Qrels(qrels)
22 |
23 | optuna.logging.set_verbosity(optuna.logging.WARNING)
24 |
25 | sampler = optuna.samplers.TPESampler(seed=42)
26 | study = optuna.create_study(direction="maximize", sampler=sampler)
27 | study.optimize(
28 | lambda trial: bm25_objective(trial, queries, qrels, se, metric, cutoff),
29 | n_trials=n_trials,
30 | show_progress_bar=True,
31 | )
32 |
33 | optuna.logging.set_verbosity(optuna.logging.INFO)
34 |
35 | # Set best params
36 | se.hyperparams = study.best_params
37 |
38 | return study.best_params
39 |
--------------------------------------------------------------------------------
/retriv/autotune/merger_autotune.py:
--------------------------------------------------------------------------------
1 | from ranx import Qrels, Run, evaluate, fuse, optimize_fusion
2 |
3 |
4 | def tune_merger(qrels, runs, metric):
5 | ranx_qrels = Qrels(qrels)
6 | ranx_runs = [Run(run) for run in runs]
7 |
8 | best_score = 0.0
9 | best_config = None
10 |
11 | for norm in ["min-max", "max", "sum"]:
12 | best_params = optimize_fusion(
13 | qrels=ranx_qrels,
14 | runs=ranx_runs,
15 | norm=norm,
16 | method="wsum",
17 | metric=metric,
18 | show_progress=False,
19 | )
20 |
21 | combined_run = fuse(
22 | runs=ranx_runs, norm=norm, method="wsum", params=best_params
23 | )
24 |
25 | score = evaluate(ranx_qrels, combined_run, metric)
26 | if score > best_score:
27 | best_score = score
28 | best_config = {
29 | "norm": norm,
30 | "params": best_params,
31 | }
32 |
33 | return best_config
34 |
--------------------------------------------------------------------------------
/retriv/base_retriever.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from typing import Iterable, List
4 |
5 | import numpy as np
6 | import orjson
7 | from indxr import Indxr
8 | from oneliner_utils import read_csv, read_jsonl
9 |
10 | from .paths import docs_path, index_path
11 |
12 |
13 | class BaseRetriever:
14 | def __init__(self, index_name: str = "new-index"):
15 | self.index_name = index_name
16 | self.id_mapping = None
17 | self.doc_count = None
18 | self.doc_index = None
19 |
20 | @staticmethod
21 | def delete(index_name="new-index"):
22 | try:
23 | shutil.rmtree(index_path(index_name))
24 | print(f"{index_name} successfully removed.")
25 | except FileNotFoundError:
26 | print(f"{index_name} not found.")
27 |
28 | def collection_generator(self, path: str, callback: callable = None):
29 | kind = os.path.splitext(path)[1][1:]
30 | assert kind in {
31 | "jsonl",
32 | "csv",
33 | "tsv",
34 | }, "Only JSONl, CSV, and TSV are currently supported."
35 |
36 | if kind == "jsonl":
37 | collection = read_jsonl(path, generator=True, callback=callback)
38 | elif kind == "csv":
39 | collection = read_csv(path, generator=True, callback=callback)
40 | elif kind == "tsv":
41 | collection = read_csv(
42 | path, delimiter="\t", generator=True, callback=callback
43 | )
44 |
45 | return collection
46 |
47 | def save_collection(self, collection: Iterable, callback: callable = None):
48 | with open(docs_path(self.index_name), "wb") as f:
49 | for doc in collection:
50 | x = callback(doc) if callback is not None else doc
51 | f.write(orjson.dumps(x) + "\n".encode())
52 |
53 | def initialize_doc_index(self):
54 | self.doc_index = Indxr(docs_path(self.index_name))
55 |
56 | def initialize_id_mapping(self):
57 | ids = read_jsonl(
58 | docs_path(self.index_name),
59 | generator=True,
60 | callback=lambda x: x["id"],
61 | )
62 | self.id_mapping = dict(enumerate(ids))
63 |
64 | def get_doc(self, doc_id: str) -> dict:
65 | return self.doc_index.get(doc_id)
66 |
67 | def get_docs(self, doc_ids: List[str]) -> List[dict]:
68 | return self.doc_index.mget(doc_ids)
69 |
70 | def prepare_results(self, doc_ids: List[str], scores: np.ndarray) -> List[dict]:
71 | docs = self.get_docs(doc_ids)
72 | results = []
73 | for doc, score in zip(docs, scores):
74 | doc["score"] = score
75 | results.append(doc)
76 |
77 | return results
78 |
79 | def map_internal_ids_to_original_ids(self, doc_ids: Iterable) -> List[str]:
80 | return [self.id_mapping[doc_id] for doc_id in doc_ids]
81 |
82 | def save(self):
83 | raise NotImplementedError()
84 |
85 | @staticmethod
86 | def load():
87 | raise NotImplementedError()
88 |
89 | def index(self):
90 | raise NotImplementedError()
91 |
92 | def index_file(self):
93 | raise NotImplementedError()
94 |
95 | def search(self):
96 | raise NotImplementedError()
97 |
98 | def msearch(self):
99 | raise NotImplementedError()
100 |
101 | def autotune(self):
102 | raise NotImplementedError()
103 |
--------------------------------------------------------------------------------
/retriv/dense_retriever/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/dense_retriever/__init__.py
--------------------------------------------------------------------------------
/retriv/dense_retriever/ann_searcher.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import faiss
4 | import numpy as np
5 | import psutil
6 | from autofaiss import build_index
7 | from oneliner_utils import read_json
8 |
9 | from ..paths import embeddings_folder_path, faiss_index_infos_path, faiss_index_path
10 |
11 |
12 | def get_ram():
13 | size_bytes = psutil.virtual_memory().total
14 | i = int(math.floor(math.log(size_bytes, 1024)))
15 | p = math.pow(1024, i)
16 | s = round(size_bytes / p, 2)
17 | return f"{s}GB"
18 |
19 |
20 | class ANN_Searcher:
21 | def __init__(self, index_name: str = "new-index"):
22 | self.index_name = index_name
23 | self.faiss_index = None
24 | self.faiss_index_infos = None
25 |
26 | def build(self, use_gpu=False):
27 | index, index_infos = build_index(
28 | embeddings=str(embeddings_folder_path(self.index_name)),
29 | index_path=str(faiss_index_path(self.index_name)),
30 | index_infos_path=str(faiss_index_infos_path(self.index_name)),
31 | save_on_disk=True,
32 | metric_type="ip",
33 | # max_index_memory_usage="32GB",
34 | current_memory_available=get_ram(),
35 | max_index_query_time_ms=10,
36 | min_nearest_neighbors_to_retrieve=20,
37 | index_key=None,
38 | index_param=None,
39 | use_gpu=use_gpu,
40 | nb_cores=None,
41 | make_direct_map=False,
42 | should_be_memory_mappable=False,
43 | distributed=None,
44 | verbose=40,
45 | )
46 |
47 | self.faiss_index = index
48 | self.faiss_index_infos = index_infos
49 |
50 | @staticmethod
51 | def load(index_name: str = "new-index"):
52 | ann_searcher = ANN_Searcher(index_name)
53 | ann_searcher.faiss_index = faiss.read_index(str(faiss_index_path(index_name)))
54 | ann_searcher.faiss_index_infos = read_json(faiss_index_infos_path(index_name))
55 | return ann_searcher
56 |
57 | def search(self, query: np.ndarray, cutoff: int = 100):
58 | query = query.reshape(1, len(query))
59 | ids, scores = self.msearch(query, cutoff)
60 | return ids[0], scores[0]
61 |
62 | def msearch(self, queries: np.ndarray, cutoff: int = 100):
63 | scores, ids = self.faiss_index.search(queries, cutoff)
64 | return ids, scores
65 |
--------------------------------------------------------------------------------
/retriv/dense_retriever/dense_retriever.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from typing import Dict, Iterable, List
4 |
5 | import numpy as np
6 | import orjson
7 | from numba import njit, prange
8 | from numba.typed import List as TypedList
9 | from oneliner_utils import create_path
10 | from tqdm import tqdm
11 |
12 | from ..base_retriever import BaseRetriever
13 | from ..paths import docs_path, dr_state_path, embeddings_folder_path
14 | from .ann_searcher import ANN_Searcher
15 | from .encoder import Encoder
16 |
17 |
18 | class DenseRetriever(BaseRetriever):
19 | def __init__(
20 | self,
21 | index_name: str = "new-index",
22 | model: str = "sentence-transformers/all-MiniLM-L6-v2",
23 | normalize: bool = True,
24 | max_length: int = 128,
25 | use_ann: bool = True,
26 | ):
27 | """The Dense Retriever performs [semantic search](https://en.wikipedia.org/wiki/Semantic_search), i.e., it compares vector representations of queries and documents to compute the relevance scores of the latter.
28 |
29 | Args:
30 | index_name (str, optional): [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. Defaults to "new-index".
31 |
32 | model (str, optional): defines the encoder model to encode queries and documents into vectors. You can use an [HuggingFace's Transformers](https://huggingface.co/models) pre-trained model by providing its ID or load a local model by providing its path. In the case of local models, the path must point to the directory containing the data saved with the [`PreTrainedModel.save_pretrained`](https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/model#transformers.PreTrainedModel.save_pretrained) method. Note that the representations are computed with `mean pooling` over the `last_hidden_state`. Defaults to "sentence-transformers/all-MiniLM-L6-v2".
33 |
34 | normalize (bool, optional): whether to L2 normalize the vector representations. Defaults to True.
35 |
36 | max_length (int, optional): texts longer than `max_length` will be automatically truncated. Choose this parameter based on how the employed model was trained or is generally used. Defaults to 128.
37 |
38 | use_ann (bool, optional): whether to use approximate nearest neighbors search. Set it to `False` to use nearest neighbors search without approximation. If you have less than 20k documents in your collection, you probably want to disable approximation. Defaults to True.
39 | """
40 |
41 | self.index_name = index_name
42 | self.model = model
43 | self.normalize = normalize
44 | self.max_length = max_length
45 | self.use_ann = use_ann
46 |
47 | self.encoder = Encoder(
48 | index_name=index_name,
49 | model=model,
50 | normalize=normalize,
51 | max_length=max_length,
52 | )
53 |
54 | self.ann_searcher = ANN_Searcher(index_name=index_name)
55 |
56 | self.id_mapping = None
57 | self.doc_count = None
58 | self.doc_index = None
59 |
60 | self.embeddings = None
61 |
62 | def save(self):
63 | """Save the state of the retriever to be able to restore it later."""
64 |
65 | state = dict(
66 | init_args=dict(
67 | index_name=self.index_name,
68 | model=self.model,
69 | normalize=self.normalize,
70 | max_length=self.max_length,
71 | use_ann=self.use_ann,
72 | ),
73 | id_mapping=self.id_mapping,
74 | doc_count=self.doc_count,
75 | embeddings=True if self.embeddings is not None else None,
76 | )
77 | np.savez_compressed(dr_state_path(self.index_name), state=state)
78 |
79 | @staticmethod
80 | def load(index_name: str = "new-index"):
81 | """Load a retriever and its index.
82 |
83 | Args:
84 | index_name (str, optional): Name of the index. Defaults to "new-index".
85 |
86 | Returns:
87 | DenseRetriever: Dense Retriever.
88 | """
89 |
90 | state = np.load(dr_state_path(index_name), allow_pickle=True)["state"][()]
91 | dr = DenseRetriever(**state["init_args"])
92 | dr.initialize_doc_index()
93 | dr.id_mapping = state["id_mapping"]
94 | dr.doc_count = state["doc_count"]
95 | if state["embeddings"]:
96 | dr.load_embeddings()
97 | if dr.use_ann:
98 | dr.ann_searcher = ANN_Searcher.load(index_name)
99 | return dr
100 |
101 | def load_embeddings(self):
102 | """Internal usage."""
103 | path = embeddings_folder_path(self.index_name)
104 | npy_file_paths = sorted(os.listdir(path))
105 | self.embeddings = np.concatenate(
106 | [np.load(path / npy_file_path) for npy_file_path in npy_file_paths]
107 | )
108 |
109 | def import_embeddings(self, path: str):
110 | """Internal usage."""
111 | shutil.copyfile(path, embeddings_folder_path(self.index_name) / "chunk_0.npy")
112 |
113 | def index_aux(
114 | self,
115 | embeddings_path: str = None,
116 | use_gpu: bool = False,
117 | batch_size: int = 512,
118 | callback: callable = None,
119 | show_progress: bool = True,
120 | ):
121 | """Internal usage."""
122 | if embeddings_path is not None:
123 | self.import_embeddings(embeddings_path)
124 | else:
125 | self.encoder.change_device("cuda" if use_gpu else "cpu")
126 | self.encoder.encode_collection(
127 | path=docs_path(self.index_name),
128 | batch_size=batch_size,
129 | callback=callback,
130 | show_progress=show_progress,
131 | )
132 | self.encoder.change_device("cpu")
133 |
134 | if self.use_ann:
135 | if show_progress:
136 | print("Building ANN Searcher")
137 | self.ann_searcher.build()
138 | else:
139 | if show_progress:
140 | print("Loading embeddings...")
141 | self.load_embeddings()
142 |
143 | def index(
144 | self,
145 | collection: Iterable,
146 | embeddings_path: str = None,
147 | use_gpu: bool = False,
148 | batch_size: int = 512,
149 | callback: callable = None,
150 | show_progress: bool = True,
151 | ):
152 | """Index a given collection of documents.
153 |
154 | Args:
155 | collection (Iterable): collection of documents to index.
156 |
157 | embeddings_path (str, optional): in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. Defaults to None.
158 |
159 | use_gpu (bool, optional): whether to use the GPU for document encoding. Defaults to False.
160 |
161 | batch_size (int, optional): how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 512.
162 |
163 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None.
164 |
165 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True.
166 |
167 | Returns:
168 | DenseRetriever: Dense Retriever
169 | """
170 |
171 | self.save_collection(collection, callback)
172 | self.initialize_doc_index()
173 | self.initialize_id_mapping()
174 | self.doc_count = len(self.id_mapping)
175 | self.index_aux(
176 | embeddings_path=embeddings_path,
177 | use_gpu=use_gpu,
178 | batch_size=batch_size,
179 | callback=callback,
180 | show_progress=show_progress,
181 | )
182 | self.save()
183 | return self
184 |
185 | def index_file(
186 | self,
187 | path: str,
188 | embeddings_path: str = None,
189 | use_gpu: bool = False,
190 | batch_size: int = 512,
191 | callback: callable = None,
192 | show_progress: bool = True,
193 | ):
194 | """Index the collection contained in a given file.
195 |
196 | Args:
197 | path (str): path of file containing the collection to index.
198 |
199 | embeddings_path (str, optional): in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. Defaults to None.
200 |
201 | use_gpu (bool, optional): whether to use the GPU for document encoding. Defaults to False.
202 |
203 | batch_size (int, optional): how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 512.
204 |
205 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None.
206 |
207 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True.
208 |
209 | Returns:
210 | DenseRetriever: Dense Retriever.
211 | """
212 |
213 | collection = self.collection_generator(path, callback)
214 | return self.index(
215 | collection,
216 | embeddings_path,
217 | use_gpu,
218 | batch_size,
219 | None,
220 | show_progress,
221 | )
222 |
223 | def search(
224 | self,
225 | query: str,
226 | return_docs: bool = True,
227 | cutoff: int = 100,
228 | ) -> List:
229 | """Standard search functionality.
230 |
231 | Args:
232 | query (str): what to search for.
233 |
234 | return_docs (bool, optional): whether to return the texts of the documents. Defaults to True.
235 |
236 | cutoff (int, optional): number of results to return. Defaults to 100.
237 |
238 | Returns:
239 | List: results.
240 | """
241 |
242 | encoded_query = self.encoder(query)
243 |
244 | if self.use_ann:
245 | doc_ids, scores = self.ann_searcher.search(encoded_query, cutoff)
246 | else:
247 | if self.embeddings is None:
248 | self.load_embeddings()
249 | doc_ids, scores = compute_scores(encoded_query, self.embeddings, cutoff)
250 |
251 | doc_ids = self.map_internal_ids_to_original_ids(doc_ids)
252 |
253 | return (
254 | self.prepare_results(doc_ids, scores)
255 | if return_docs
256 | else dict(zip(doc_ids, scores))
257 | )
258 |
259 | def msearch(
260 | self,
261 | queries: List[Dict[str, str]],
262 | cutoff: int = 100,
263 | batch_size: int = 32,
264 | ) -> Dict:
265 | """Compute results for multiple queries at once.
266 |
267 | Args:
268 | queries (List[Dict[str, str]]): what to search for.
269 |
270 | cutoff (int, optional): number of results to return. Defaults to 100.
271 |
272 | batch_size (int, optional): how many queries to search at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 32.
273 |
274 | Returns:
275 | Dict: results.
276 | """
277 |
278 | q_ids = [x["id"] for x in queries]
279 | q_texts = [x["text"] for x in queries]
280 | encoded_queries = self.encoder(q_texts, batch_size, show_progress=False)
281 |
282 | if self.use_ann:
283 | doc_ids, scores = self.ann_searcher.msearch(encoded_queries, cutoff)
284 | else:
285 | if self.embeddings is None:
286 | self.load_embeddings()
287 | doc_ids, scores = compute_scores_multi(
288 | encoded_queries, self.embeddings, cutoff
289 | )
290 |
291 | doc_ids = [
292 | self.map_internal_ids_to_original_ids(_doc_ids) for _doc_ids in doc_ids
293 | ]
294 |
295 | results = {q: dict(zip(doc_ids[i], scores[i])) for i, q in enumerate(q_ids)}
296 |
297 | return {q_id: results[q_id] for q_id in q_ids}
298 |
299 | def bsearch(
300 | self,
301 | queries: List[Dict[str, str]],
302 | cutoff: int = 100,
303 | batch_size: int = 32,
304 | show_progress: bool = True,
305 | qrels: Dict[str, Dict[str, float]] = None,
306 | path: str = None,
307 | ):
308 | """Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. bsearch is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM.
309 |
310 | Args:
311 | queries (List[Dict[str, str]]): what to search for.
312 |
313 | cutoff (int, optional): number of results to return. Defaults to 100.
314 |
315 | batch_size (int, optional): how many queries to search at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 32.
316 |
317 | show_progress (bool, optional): whether to show a progress bar for the search process. Defaults to True.
318 |
319 | qrels (Dict[str, Dict[str, float]], optional): query relevance judgements for the queries. Defaults to None.
320 |
321 | path (str, optional): where to save the results. Defaults to None.
322 |
323 | Returns:
324 | Dict: results.
325 | """
326 |
327 | batches = [
328 | queries[i : i + batch_size] for i in range(0, len(queries), batch_size)
329 | ]
330 |
331 | results = {}
332 |
333 | pbar = tqdm(
334 | total=len(queries),
335 | disable=not show_progress,
336 | desc="Batch search",
337 | dynamic_ncols=True,
338 | mininterval=0.5,
339 | )
340 |
341 | if path is None:
342 | for batch in batches:
343 | new_results = self.msearch(
344 | queries=batch, cutoff=cutoff, batch_size=len(batch)
345 | )
346 | results = {**results, **new_results}
347 | pbar.update(min(batch_size, len(batch)))
348 | else:
349 | path = create_path(path)
350 | path.parent.mkdir(parents=True, exist_ok=True)
351 |
352 | with open(path, "wb") as f:
353 | for batch in batches:
354 | new_results = self.msearch(queries=batch, cutoff=cutoff)
355 |
356 | for i, (k, v) in enumerate(new_results.items()):
357 | x = {
358 | "id": k,
359 | "text": batch[i]["text"],
360 | "dense_doc_ids": list(v.keys()),
361 | "dense_scores": [float(s) for s in list(v.values())],
362 | }
363 | if qrels is not None:
364 | x["rel_doc_ids"] = list(qrels[k].keys())
365 | x["rel_scores"] = list(qrels[k].values())
366 | f.write(orjson.dumps(x) + "\n".encode())
367 |
368 | pbar.update(min(batch_size, len(batch)))
369 |
370 | return results
371 |
372 |
373 | @njit(cache=True)
374 | def compute_scores(query: np.ndarray, docs: np.ndarray, cutoff: int):
375 | """Internal usage."""
376 |
377 | scores = docs @ query
378 | indices = np.argsort(-scores)[:cutoff]
379 |
380 | return indices, scores[indices]
381 |
382 |
383 | @njit(cache=True, parallel=True)
384 | def compute_scores_multi(queries: np.ndarray, docs: np.ndarray, cutoff: int):
385 | """Internal usage."""
386 |
387 | n = len(queries)
388 | ids = TypedList([np.empty(1, dtype=np.int64) for _ in range(n)])
389 | scores = TypedList([np.empty(1, dtype=np.float32) for _ in range(n)])
390 |
391 | for i in prange(len(queries)):
392 | ids[i], scores[i] = compute_scores(queries[i], docs, cutoff)
393 |
394 | return ids, scores
395 |
--------------------------------------------------------------------------------
/retriv/dense_retriever/encoder.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | from typing import Generator, List, Union
3 |
4 | import numpy as np
5 | import torch
6 | from oneliner_utils import read_jsonl
7 | from torch import Tensor, einsum
8 | from torch.nn.functional import normalize
9 | from tqdm import tqdm
10 | from transformers import AutoConfig, AutoModel, AutoTokenizer
11 |
12 | from ..paths import embeddings_folder_path, encoder_state_path
13 |
14 | pbar_kwargs = dict(position=0, dynamic_ncols=True, mininterval=1.0)
15 |
16 |
17 | def count_lines(path: str):
18 | """Counts the number of lines in a file."""
19 | return sum(1 for _ in open(path))
20 |
21 |
22 | def generate_batch(docs: Union[List, Generator], batch_size: int) -> Generator:
23 | texts = []
24 |
25 | for doc in docs:
26 | texts.append(doc["text"])
27 |
28 | if len(texts) == batch_size:
29 | yield texts
30 | texts = []
31 |
32 | if texts:
33 | yield texts
34 |
35 |
36 | class Encoder:
37 | def __init__(
38 | self,
39 | index_name: str = "new-index",
40 | model: str = "sentence-transformers/all-MiniLM-L6-v2",
41 | normalize: bool = True,
42 | return_numpy: bool = True,
43 | max_length: int = 128,
44 | device: str = "cpu",
45 | ):
46 | self.index_name = index_name
47 | self.model = model
48 | self.tokenizer = AutoTokenizer.from_pretrained(model)
49 | self.encoder = AutoModel.from_pretrained(model).to(device).eval()
50 | self.embedding_dim = AutoConfig.from_pretrained(model).hidden_size
51 | self.max_length = max_length
52 | self.normalize = normalize
53 | self.return_numpy = return_numpy
54 | self.device = device
55 | self.tokenizer_kwargs = {
56 | "padding": True,
57 | "truncation": True,
58 | "max_length": self.max_length,
59 | "return_tensors": "pt",
60 | }
61 |
62 | def save(self):
63 | state = dict(
64 | index_name=self.index_name,
65 | model=self.model,
66 | normalize=self.normalize,
67 | return_numpy=self.return_numpy,
68 | max_length=self.max_length,
69 | device=self.device,
70 | )
71 | np.save(encoder_state_path(self.index_name), state)
72 |
73 | @staticmethod
74 | def load(index_name: str, device: str = None):
75 | state = np.load(encoder_state_path(index_name), allow_pickle=True)[()]
76 | if device is not None:
77 | state["device"] = device
78 | return Encoder(**state)
79 |
80 | def change_device(self, device: str = "cpu"):
81 | self.device = device
82 | self.encoder.to(device)
83 |
84 | def tokenize(self, texts: List[str]):
85 | tokens = self.tokenizer(texts, **self.tokenizer_kwargs)
86 | return {k: v.to(self.device) for k, v in tokens.items()}
87 |
88 | def mean_pooling(self, embeddings: Tensor, mask: Tensor) -> Tensor:
89 | numerators = einsum("xyz,xy->xyz", embeddings, mask).sum(dim=1)
90 | denominators = torch.clamp(mask.sum(dim=1), min=1e-9)
91 | return einsum("xz,x->xz", numerators, 1 / denominators)
92 |
93 | def __call__(
94 | self,
95 | x: Union[str, List[str]],
96 | batch_size: int = 32,
97 | show_progress: bool = True,
98 | ):
99 | if isinstance(x, str):
100 | return self.encode(x)
101 | else:
102 | return self.bencode(x, batch_size=batch_size, show_progress=show_progress)
103 |
104 | def encode(self, text: str):
105 | return self.bencode([text], batch_size=1, show_progress=False)[0]
106 |
107 | def bencode(
108 | self,
109 | texts: List[str],
110 | batch_size: int = 32,
111 | show_progress: bool = True,
112 | ):
113 | embeddings = []
114 |
115 | pbar = tqdm(
116 | total=len(texts),
117 | desc="Generating embeddings",
118 | disable=not show_progress,
119 | **pbar_kwargs,
120 | )
121 |
122 | for i in range(ceil(len(texts) / batch_size)):
123 | start, stop = i * batch_size, (i + 1) * batch_size
124 | tokens = self.tokenize(texts[start:stop])
125 |
126 | with torch.no_grad():
127 | emb = self.encoder(**tokens).last_hidden_state
128 | emb = self.mean_pooling(emb, tokens["attention_mask"])
129 | if self.normalize:
130 | emb = normalize(emb, dim=-1)
131 |
132 | embeddings.append(emb)
133 |
134 | pbar.update(stop - start)
135 | pbar.close()
136 |
137 | embeddings = torch.cat(embeddings)
138 |
139 | if self.return_numpy:
140 | embeddings = embeddings.detach().cpu().numpy()
141 |
142 | return embeddings
143 |
144 | def encode_collection(
145 | self,
146 | path: str,
147 | batch_size: int = 512,
148 | callback: callable = None,
149 | show_progress: bool = True,
150 | ):
151 | n_docs = count_lines(path)
152 | collection = read_jsonl(path, callback=callback, generator=True)
153 |
154 | reservoir = np.empty((1_000_000, self.embedding_dim), dtype=np.float32)
155 | reservoir_n = 0
156 | offset = 0
157 |
158 | pbar = tqdm(
159 | total=n_docs,
160 | desc="Embedding documents",
161 | disable=not show_progress,
162 | **pbar_kwargs,
163 | )
164 |
165 | for texts in generate_batch(collection, batch_size):
166 | # Compute embeddings -----------------------------------------------
167 | embeddings = self.bencode(texts, batch_size=len(texts), show_progress=False)
168 |
169 | # Compute new offset -----------------------------------------------
170 | new_offset = offset + len(embeddings)
171 |
172 | if new_offset >= len(reservoir):
173 | np.save(
174 | embeddings_folder_path(self.index_name)
175 | / f"chunk_{reservoir_n}.npy",
176 | reservoir[:offset],
177 | )
178 | reservoir = np.empty((1_000_000, self.embedding_dim), dtype=np.float32)
179 | reservoir_n += 1
180 | offset = 0
181 | new_offset = len(embeddings)
182 |
183 | # Save embeddings in the reservoir ---------------------------------
184 | reservoir[offset:new_offset] = embeddings
185 |
186 | # Update offeset ---------------------------------------------------
187 | offset = new_offset
188 |
189 | pbar.update(len(embeddings))
190 |
191 | if offset < len(reservoir):
192 | np.save(
193 | embeddings_folder_path(self.index_name) / f"chunk_{reservoir_n}.npy",
194 | reservoir[:offset],
195 | )
196 | reservoir = []
197 |
198 | assert len(reservoir) == 0, "Reservoir is not empty."
199 |
--------------------------------------------------------------------------------
/retriv/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ["AdvancedRetriever"]
2 |
3 |
4 | from .advanced_retriever import AdvancedRetriever
5 |
--------------------------------------------------------------------------------
/retriv/hybrid_retriever.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Iterable, List, Set, Union
2 |
3 | import numpy as np
4 | import orjson
5 | from oneliner_utils import create_path
6 | from tqdm import tqdm
7 |
8 | from .base_retriever import BaseRetriever
9 | from .dense_retriever.dense_retriever import DenseRetriever
10 | from .merger.merger import Merger
11 | from .paths import hr_state_path
12 | from .sparse_retriever.sparse_retriever import SparseRetriever
13 |
14 |
15 | class HybridRetriever(BaseRetriever):
16 | def __init__(
17 | self,
18 | # Global params
19 | index_name: str = "new-index",
20 | # Sparse retriever params
21 | sr_model: str = "bm25",
22 | min_df: int = 1,
23 | tokenizer: Union[str, callable] = "whitespace",
24 | stemmer: Union[str, callable] = "english",
25 | stopwords: Union[str, List[str], Set[str]] = "english",
26 | do_lowercasing: bool = True,
27 | do_ampersand_normalization: bool = True,
28 | do_special_chars_normalization: bool = True,
29 | do_acronyms_normalization: bool = True,
30 | do_punctuation_removal: bool = True,
31 | # Dense retriever params
32 | dr_model: str = "sentence-transformers/all-MiniLM-L6-v2",
33 | normalize: bool = True,
34 | max_length: int = 128,
35 | use_ann: bool = True,
36 | # For already instantiated modules
37 | sparse_retriever: SparseRetriever = None,
38 | dense_retriever: DenseRetriever = None,
39 | merger: Merger = None,
40 | ):
41 | """The [Hybrid Retriever](https://github.com/AmenRa/retriv/blob/main/docs/hybrid_retriever.md) is searcher based on both lexical and semantic matching. It comprises three components: the [Sparse Retriever]((https://github.com/AmenRa/retriv/blob/main/docs/sparse_retriever.md)), the [Dense Retriever]((https://github.com/AmenRa/retriv/blob/main/docs/dense_retriever.md)), and the Merger. The Merger fuses the results of the Sparse and Dense Retrievers to compute the _hybrid_ results.
42 |
43 | Args:
44 | index_name (str, optional): [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. Defaults to "new-index".
45 |
46 | sr_model (str, optional): defines the model to use for sparse retrieval (`bm25` or `tf-idf`). Defaults to "bm25".
47 |
48 | min_df (int, optional): terms that appear in less than `min_df` documents will be ignored. If integer, the parameter indicates the absolute count. If float, it represents a proportion of documents. Defaults to 1.
49 |
50 | tokenizer (Union[str, callable], optional): [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization setting the parameter to `None`. Defaults to "whitespace".
51 |
52 | stemmer (Union[str, callable], optional): [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`. Defaults to "english".
53 |
54 | stopwords (Union[str, List[str], Set[str]], optional): [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`. Defaults to "english".
55 |
56 | do_lowercasing (bool, optional): whether to lowercase texts. Defaults to True.
57 |
58 | do_ampersand_normalization (bool, optional): whether to convert `&` in `and` during pre-processing. Defaults to True.
59 |
60 | do_special_chars_normalization (bool, optional): whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`. Defaults to True.
61 |
62 | do_acronyms_normalization (bool, optional): whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`. Defaults to True.
63 |
64 | do_punctuation_removal (bool, optional): whether to remove punctuation. Defaults to True.
65 |
66 | dr_model (str, optional): defines the model to use for encoding queries and documents into vectors. You can use an [HuggingFace's Transformers](https://huggingface.co/models) pre-trained model by providing its ID or load a local model by providing its path. In the case of local models, the path must point to the directory containing the data saved with the [`PreTrainedModel.save_pretrained`](https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/model#transformers.PreTrainedModel.save_pretrained) method. Note that the representations are computed with `mean pooling` over the `last_hidden_state`. Defaults to "sentence-transformers/all-MiniLM-L6-v2".
67 |
68 | normalize (bool, optional): whether to L2 normalize the vector representations. Defaults to True.
69 |
70 | max_length (int, optional): texts longer than `max_length` will be automatically truncated. Choose this parameter based on how the employed model was trained or is generally used. Defaults to 128.
71 |
72 | use_ann (bool, optional): whether to use approximate nearest neighbors search. Set it to `False` to use nearest neighbors search without approximation. If you have less than 20k documents in your collection, you probably want to disable approximation. Defaults to True.
73 | """
74 |
75 | self.index_name = index_name
76 |
77 | self.sparse_retriever = (
78 | sparse_retriever
79 | if sparse_retriever is not None
80 | else SparseRetriever(
81 | index_name=index_name,
82 | model=sr_model,
83 | min_df=min_df,
84 | tokenizer=tokenizer,
85 | stemmer=stemmer,
86 | stopwords=stopwords,
87 | do_lowercasing=do_lowercasing,
88 | do_ampersand_normalization=do_ampersand_normalization,
89 | do_special_chars_normalization=do_special_chars_normalization,
90 | do_acronyms_normalization=do_acronyms_normalization,
91 | do_punctuation_removal=do_punctuation_removal,
92 | )
93 | )
94 |
95 | self.dense_retriever = (
96 | dense_retriever
97 | if dense_retriever is not None
98 | else DenseRetriever(
99 | index_name=index_name,
100 | model=dr_model,
101 | normalize=normalize,
102 | max_length=max_length,
103 | use_ann=use_ann,
104 | )
105 | )
106 |
107 | self.merger = merger if merger is not None else Merger(index_name=index_name)
108 |
109 | def index(
110 | self,
111 | collection: Iterable,
112 | embeddings_path: str = None,
113 | use_gpu: bool = False,
114 | batch_size: int = 512,
115 | callback: callable = None,
116 | show_progress: bool = True,
117 | ):
118 | """Index a given collection of documents.
119 |
120 | Args:
121 | collection (Iterable): collection of documents to index.
122 |
123 | embeddings_path (str, optional): in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. Defaults to None.
124 |
125 | use_gpu (bool, optional): whether to use the GPU for document encoding. Defaults to False.
126 |
127 | batch_size (int, optional): how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 512.
128 |
129 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None.
130 |
131 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True.
132 |
133 | Returns:
134 | HybridRetriever: Hybrid Retriever
135 | """
136 |
137 | self.save_collection(collection, callback)
138 |
139 | self.initialize_doc_index()
140 | self.initialize_id_mapping()
141 | self.doc_count = len(self.id_mapping)
142 |
143 | # Sparse ---------------------------------------------------------------
144 | self.sparse_retriever.doc_index = self.doc_index
145 | self.sparse_retriever.id_mapping = self.id_mapping
146 | self.sparse_retriever.doc_count = self.doc_count
147 | self.sparse_retriever.index_aux(show_progress)
148 |
149 | # Dense ----------------------------------------------------------------
150 | self.dense_retriever.doc_index = self.doc_index
151 | self.dense_retriever.id_mapping = self.id_mapping
152 | self.dense_retriever.doc_count = self.doc_count
153 | self.dense_retriever.index_aux(
154 | embeddings_path, use_gpu, batch_size, callback, show_progress
155 | )
156 |
157 | self.save()
158 |
159 | return self
160 |
161 | def index_file(
162 | self,
163 | path: str,
164 | embeddings_path: str = None,
165 | use_gpu: bool = False,
166 | batch_size: int = 512,
167 | callback: callable = None,
168 | show_progress: bool = True,
169 | ):
170 | """Index the collection contained in a given file.
171 |
172 | Args:
173 | path (str): path of file containing the collection to index.
174 |
175 | embeddings_path (str, optional): in case you want to load pre-computed embeddings, you can provide the path to a `.npy` file. Embeddings must be in the same order as the documents in the collection file. Defaults to None.
176 |
177 | use_gpu (bool, optional): whether to use the GPU for document encoding. Defaults to False.
178 |
179 | batch_size (int, optional): how many documents to encode at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 512.
180 |
181 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None.
182 |
183 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True.
184 |
185 | Returns:
186 | HybridRetriever: Hybrid Retriever.
187 | """
188 |
189 | collection = self.collection_generator(path, callback)
190 | return self.index(
191 | collection,
192 | embeddings_path,
193 | use_gpu,
194 | batch_size,
195 | None,
196 | show_progress,
197 | )
198 |
199 | def save(self):
200 | """Save the state of the retriever to be able to restore it later."""
201 |
202 | state = dict(
203 | id_mapping=self.id_mapping,
204 | doc_count=self.doc_count,
205 | )
206 | np.savez_compressed(hr_state_path(self.index_name), state=state)
207 |
208 | self.sparse_retriever.save()
209 | self.dense_retriever.save()
210 | self.merger.save()
211 |
212 | @staticmethod
213 | def load(index_name: str = "new-index"):
214 | """Load a retriever and its index.
215 |
216 | Args:
217 | index_name (str, optional): Name of the index. Defaults to "new-index".
218 |
219 | Returns:
220 | HybridRetriever: Hybrid Retriever.
221 | """
222 |
223 | state = np.load(hr_state_path(index_name), allow_pickle=True)["state"][()]
224 |
225 | hr = HybridRetriever(index_name)
226 | hr.initialize_doc_index()
227 | hr.id_mapping = state["id_mapping"]
228 | hr.doc_count = state["doc_count"]
229 |
230 | hr.sparse_retriever = SparseRetriever.load(index_name)
231 | hr.dense_retriever = DenseRetriever.load(index_name)
232 | hr.merger = Merger.load(index_name)
233 | return hr
234 |
235 | def search(
236 | self,
237 | query: str,
238 | return_docs: bool = True,
239 | cutoff: int = 100,
240 | ) -> List:
241 | """Standard search functionality.
242 |
243 | Args:
244 | query (str): what to search for.
245 |
246 | return_docs (bool, optional): whether to return the texts of the documents. Defaults to True.
247 |
248 | cutoff (int, optional): number of results to return. Defaults to 100.
249 |
250 | Returns:
251 | List: results.
252 | """
253 |
254 | sparse_results = self.sparse_retriever.search(query, False, 1_000)
255 | dense_results = self.dense_retriever.search(query, False, 1_000)
256 | hybrid_results = self.merger.fuse([sparse_results, dense_results])
257 | return (
258 | self.prepare_results(
259 | list(hybrid_results.keys())[:cutoff],
260 | list(hybrid_results.values())[:cutoff],
261 | )
262 | if return_docs
263 | else hybrid_results
264 | )
265 |
266 | def msearch(
267 | self,
268 | queries: List[Dict[str, str]],
269 | cutoff: int = 100,
270 | batch_size: int = 32,
271 | ) -> Dict:
272 | """Compute results for multiple queries at once.
273 |
274 | Args:
275 | queries (List[Dict[str, str]]): what to search for.
276 |
277 | cutoff (int, optional): number of results to return. Defaults to 100.
278 |
279 | batch_size (int, optional): how many queries to search at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 32.
280 |
281 | Returns:
282 | Dict: results.
283 | """
284 |
285 | sparse_results = self.sparse_retriever.msearch(queries, 1_000)
286 | dense_results = self.dense_retriever.msearch(queries, 1_000, batch_size)
287 | return self.merger.mfuse([sparse_results, dense_results], cutoff)
288 |
289 | def bsearch(
290 | self,
291 | queries: List[Dict[str, str]],
292 | cutoff: int = 100,
293 | batch_size: int = 32,
294 | show_progress: bool = True,
295 | qrels: Dict[str, Dict[str, float]] = None,
296 | path: str = None,
297 | ):
298 | """Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. bsearch is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM.
299 |
300 | Args:
301 | queries (List[Dict[str, str]]): what to search for.
302 |
303 | cutoff (int, optional): number of results to return. Defaults to 100.
304 |
305 | batch_size (int, optional): how many queries to search at once. Regulate it if you ran into memory usage issues or want to maximize throughput. Defaults to 32.
306 |
307 | show_progress (bool, optional): whether to show a progress bar for the search process. Defaults to True.
308 |
309 | qrels (Dict[str, Dict[str, float]], optional): query relevance judgements for the queries. Defaults to None.
310 |
311 | path (str, optional): where to save the results. Defaults to None.
312 |
313 | Returns:
314 | Dict: results.
315 | """
316 |
317 | batches = [
318 | queries[i : i + batch_size] for i in range(0, len(queries), batch_size)
319 | ]
320 |
321 | results = {}
322 |
323 | pbar = tqdm(
324 | total=len(queries),
325 | disable=not show_progress,
326 | desc="Batch search",
327 | dynamic_ncols=True,
328 | mininterval=0.5,
329 | )
330 |
331 | if path is None:
332 | for batch in batches:
333 | new_results = self.msearch(
334 | queries=batch, cutoff=cutoff, batch_size=len(batch)
335 | )
336 | results = {**results, **new_results}
337 | pbar.update(min(batch_size, len(batch)))
338 | else:
339 | path = create_path(path)
340 | path.parent.mkdir(parents=True, exist_ok=True)
341 |
342 | with open(path, "wb") as f:
343 | for batch in batches:
344 | new_results = self.msearch(queries=batch, cutoff=cutoff)
345 |
346 | for i, (k, v) in enumerate(new_results.items()):
347 | x = {
348 | "id": k,
349 | "text": batch[i]["text"],
350 | "hybrid_doc_ids": list(v.keys()),
351 | "hybrid_scores": [float(s) for s in list(v.values())],
352 | }
353 | if qrels is not None:
354 | x["rel_doc_ids"] = list(qrels[k].keys())
355 | x["rel_scores"] = list(qrels[k].values())
356 | f.write(orjson.dumps(x) + "\n".encode())
357 |
358 | pbar.update(min(batch_size, len(batch)))
359 |
360 | return results
361 |
362 | def autotune(
363 | self,
364 | queries: List[Dict[str, str]],
365 | qrels: Dict[str, Dict[str, float]],
366 | metric: str = "ndcg",
367 | n_trials: int = 100,
368 | cutoff: int = 100,
369 | batch_size: int = 32,
370 | ):
371 | """Use the AutoTune function to tune the Sparse Retriever's model [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) parameters and the importance given to the lexical and semantic relevance scores computed by the Sparse and Dense Retrievers, respectively. All metrics supported by [ranx](https://github.com/AmenRa/ranx) are supported by the `autotune` function. At the of the process, the best parameter configurations are automatically applied and saved to disk. You can inspect the best configurations found by printing `hr.sparse_retriever.hyperparams`, `hr.merger.norm` and `hr.merger.params`.
372 |
373 | Args:
374 | queries (List[Dict[str, str]]): queries to use for the optimization process.
375 |
376 | qrels (Dict[str, Dict[str, float]]): query relevance judgements for the queries.
377 |
378 | metric (str, optional): metric to optimize for. Defaults to "ndcg".
379 |
380 | n_trials (int, optional): number of configuration to evaluate. Defaults to 100.
381 |
382 | cutoff (int, optional): number of results to consider for the optimization process. Defaults to 100.
383 | """
384 |
385 | # Tune sparse ----------------------------------------------------------
386 | self.sparse_retriever.autotune(
387 | queries=queries,
388 | qrels=qrels,
389 | metric=metric,
390 | n_trials=n_trials,
391 | cutoff=cutoff,
392 | )
393 |
394 | # Tune merger ----------------------------------------------------------
395 | sparse_results = self.sparse_retriever.msearch(queries, 1_000)
396 | dense_results = self.dense_retriever.msearch(queries, 1_000, batch_size)
397 | self.merger.autotune(qrels, [sparse_results, dense_results], metric)
398 |
--------------------------------------------------------------------------------
/retriv/merger/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/merger/__init__.py
--------------------------------------------------------------------------------
/retriv/merger/merger.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Dict, List
3 |
4 | import numpy as np
5 |
6 | from ..autotune import tune_merger
7 | from ..paths import merger_state_path
8 | from .normalization import max_norm_multi, min_max_norm_multi, sum_norm_multi
9 |
10 |
11 | class Merger:
12 | def __init__(self, index_name: str = "new-index"):
13 | self.index_name = index_name
14 | self.norm = "min-max"
15 | self.params = None
16 |
17 | def fuse(
18 | self, results: List[Dict[str, float]], cutoff: int = 100
19 | ) -> Dict[str, float]:
20 | return self.mfuse([{"q_0": res} for res in results], cutoff)["q_0"]
21 |
22 | def mfuse(
23 | self, runs: List[Dict[str, Dict[str, float]]], cutoff: int = 100
24 | ) -> Dict[str, Dict[str, float]]:
25 | if self.norm == "min-max":
26 | normalized_runs = min_max_norm_multi(runs)
27 | elif self.norm == "max":
28 | normalized_runs = max_norm_multi(runs)
29 | elif self.norm == "sum":
30 | normalized_runs = sum_norm_multi(runs)
31 | else:
32 | raise NotImplementedError
33 |
34 | weights = [1.0 for _ in runs] if self.params is None else self.params["weights"]
35 |
36 | fused_run = defaultdict(lambda: defaultdict(float))
37 | for i, run in enumerate(normalized_runs):
38 | for q_id in run:
39 | for doc_id in run[q_id]:
40 | fused_run[q_id][doc_id] += weights[i] * run[q_id][doc_id]
41 |
42 | # Sort results by descending value and ascending key
43 | for q_id, results in list(fused_run.items()):
44 | fused_run[q_id] = dict(sorted(results.items(), key=lambda x: (-x[1], x[0])))
45 |
46 | # Apply cutoff
47 | for q_id, results in list(fused_run.items()):
48 | fused_run[q_id] = dict(list(results.items())[:cutoff])
49 |
50 | return dict(fused_run)
51 |
52 | def save(self):
53 | state = dict(
54 | init_args=dict(index_name=self.index_name),
55 | norm=self.norm,
56 | params=self.params,
57 | )
58 | np.savez_compressed(merger_state_path(self.index_name), state=state)
59 |
60 | @staticmethod
61 | def load(index_name: str = "new-index"):
62 | state = np.load(merger_state_path(index_name), allow_pickle=True)["state"][()]
63 | merger = Merger(**state["init_args"])
64 | merger.norm = state["norm"]
65 | merger.params = state["params"]
66 | return merger
67 |
68 | def autotune(
69 | self,
70 | qrels: Dict[str, Dict[str, float]],
71 | runs: List[Dict[str, Dict[str, float]]],
72 | metric: str = "ndcg",
73 | ):
74 | config = tune_merger(qrels, runs, metric)
75 | self.norm = config["norm"]
76 | self.params = config["params"]
77 | self.save()
78 |
--------------------------------------------------------------------------------
/retriv/merger/normalization.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 |
4 | def extract_scores(results):
5 | """Extract the scores from a given results dictionary."""
6 | scores = [None] * len(results)
7 | for i, v in enumerate(results.values()):
8 | scores[i] = v
9 | return scores
10 |
11 |
12 | def safe_max(x):
13 | return max(x) if len(x) != 0 else 0
14 |
15 |
16 | def safe_min(x):
17 | return min(x) if len(x) != 0 else 0
18 |
19 |
20 | def min_max_norm(run: Dict[str, float]):
21 | """Apply `min-max normalization` to a given run."""
22 | normalized_run = {}
23 |
24 | for q_id, results in run.items():
25 | scores = extract_scores(results)
26 | min_score = safe_min(scores)
27 | max_score = safe_max(scores)
28 | denominator = max(max_score - min_score, 1e-9)
29 |
30 | normalized_results = {
31 | doc_id: (results[doc_id] - min_score) / (denominator) for doc_id in results
32 | }
33 |
34 | normalized_run[q_id] = normalized_results
35 |
36 | return normalized_run
37 |
38 |
39 | def max_norm(run: Dict[str, float]):
40 | """Apply `max normalization` to a given run."""
41 | normalized_run = {}
42 |
43 | for q_id, results in run.items():
44 | scores = extract_scores(results)
45 | max_score = safe_max(scores)
46 | denominator = max(max_score, 1e-9)
47 |
48 | normalized_results = {
49 | doc_id: results[doc_id] / denominator for doc_id in results
50 | }
51 |
52 | normalized_run[q_id] = normalized_results
53 |
54 | return normalized_run
55 |
56 |
57 | def sum_norm(run: Dict[str, float]):
58 | """Apply `sum normalization` to a given run."""
59 | normalized_run = {}
60 |
61 | for q_id, results in run.items():
62 | scores = extract_scores(results)
63 | min_score = safe_min(scores)
64 | sum_score = sum(scores)
65 | denominator = sum_score - min_score * len(results)
66 | denominator = max(denominator, 1e-9)
67 |
68 | normalized_results = {
69 | doc_id: (results[doc_id] - min_score) / (denominator) for doc_id in results
70 | }
71 |
72 | normalized_run[q_id] = normalized_results
73 |
74 | return normalized_run
75 |
76 |
77 | def min_max_norm_multi(runs: List[Dict[str, float]]):
78 | """Apply `min-max normalization` to a list of given runs."""
79 | return [min_max_norm(run) for run in runs]
80 |
81 |
82 | def max_norm_multi(runs: List[Dict[str, float]]):
83 | """Apply `max normalization` to a list of given runs."""
84 | return [max_norm(run) for run in runs]
85 |
86 |
87 | def sum_norm_multi(runs: List[Dict[str, float]]):
88 | """Apply `sum normalization` to a list of given runs."""
89 | return [sum_norm(run) for run in runs]
90 |
--------------------------------------------------------------------------------
/retriv/paths.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 |
5 | def base_path():
6 | p = Path(os.environ.get("RETRIV_BASE_PATH"))
7 | p.mkdir(parents=True, exist_ok=True)
8 | return p
9 |
10 |
11 | def collections_path():
12 | p = base_path() / "collections"
13 | p.mkdir(parents=True, exist_ok=True)
14 | return p
15 |
16 |
17 | def index_path(index_name: str):
18 | path = collections_path() / index_name
19 | path.mkdir(parents=True, exist_ok=True)
20 | return path
21 |
22 |
23 | def docs_path(index_name: str):
24 | return index_path(index_name) / "docs.jsonl"
25 |
26 |
27 | def sr_state_path(index_name: str):
28 | return index_path(index_name) / "sr_state.npz"
29 |
30 |
31 | def fr_state_path(index_name: str):
32 | return index_path(index_name) / "fr_state.npz"
33 |
34 |
35 | def embeddings_path(index_name: str):
36 | return index_path(index_name) / "embeddings.h5"
37 |
38 |
39 | def embeddings_folder_path(index_name: str):
40 | path = index_path(index_name) / "embeddings"
41 | path.mkdir(parents=True, exist_ok=True)
42 | return path
43 |
44 |
45 | def faiss_index_path(index_name: str):
46 | return index_path(index_name) / "faiss.index"
47 |
48 |
49 | def faiss_index_infos_path(index_name: str):
50 | return index_path(index_name) / "faiss_index_infos.json"
51 |
52 |
53 | def dr_state_path(index_name: str):
54 | return index_path(index_name) / "dr_state.npz"
55 |
56 |
57 | def hr_state_path(index_name: str):
58 | return index_path(index_name) / "hr_state.npz"
59 |
60 |
61 | def encoder_state_path(index_name: str):
62 | return index_path(index_name) / "encoder_state.json"
63 |
64 |
65 | def merger_state_path(index_name: str):
66 | return index_path(index_name) / "merger_state.npz"
67 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/sparse_retriever/__init__.py
--------------------------------------------------------------------------------
/retriv/sparse_retriever/build_inverted_index.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Dict, Iterable
3 |
4 | import numpy as np
5 | from sklearn.feature_extraction.text import CountVectorizer
6 | from tqdm import tqdm
7 |
8 |
9 | def convert_df_matrix_into_inverted_index(
10 | df_matrix, vocabulary, show_progress: bool = True
11 | ):
12 | inverted_index = defaultdict(dict)
13 |
14 | for i, term in enumerate(
15 | tqdm(
16 | vocabulary,
17 | disable=not show_progress,
18 | desc="Building inverted index",
19 | dynamic_ncols=True,
20 | mininterval=0.5,
21 | )
22 | ):
23 | inverted_index[term]["doc_ids"] = df_matrix[i].indices
24 | inverted_index[term]["tfs"] = df_matrix[i].data
25 |
26 | return inverted_index
27 |
28 |
29 | def build_inverted_index(
30 | collection: Iterable,
31 | n_docs: int,
32 | min_df: int = 1,
33 | show_progress: bool = True,
34 | ) -> Dict:
35 | vectorizer = CountVectorizer(
36 | tokenizer=lambda x: x,
37 | preprocessor=lambda x: x,
38 | min_df=min_df,
39 | dtype=np.int16,
40 | token_pattern=None,
41 | )
42 |
43 | # [n_docs x n_terms]
44 | df_matrix = vectorizer.fit_transform(
45 | tqdm(
46 | collection,
47 | total=n_docs,
48 | disable=not show_progress,
49 | desc="Building TDF matrix",
50 | dynamic_ncols=True,
51 | mininterval=0.5,
52 | )
53 | )
54 | # [n_terms x n_docs]
55 | df_matrix = df_matrix.transpose().tocsr()
56 | vocabulary = vectorizer.get_feature_names_out()
57 | inverted_index = convert_df_matrix_into_inverted_index(
58 | df_matrix=df_matrix,
59 | vocabulary=vocabulary,
60 | show_progress=show_progress,
61 | )
62 |
63 | doc_lens = np.squeeze(np.asarray(df_matrix.sum(axis=0), dtype=np.float32))
64 | relative_doc_lens = doc_lens / np.mean(doc_lens, dtype=np.float32)
65 |
66 | return dict(inverted_index), doc_lens, relative_doc_lens
67 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "lowercasing",
3 | "normalize_acronyms",
4 | "normalize_ampersand",
5 | "normalize_special_chars",
6 | "remove_punctuation",
7 | "strip_whitespaces",
8 | "get_stemmer",
9 | "get_tokenizer",
10 | "get_stopwords",
11 | ]
12 |
13 |
14 | from typing import Callable, List, Set
15 |
16 | from multipipe import Multipipe
17 |
18 | from .normalization import (
19 | lowercasing,
20 | normalize_acronyms,
21 | normalize_ampersand,
22 | normalize_special_chars,
23 | remove_punctuation,
24 | strip_whitespaces,
25 | )
26 | from .stemmer import get_stemmer
27 | from .stopwords import get_stopwords
28 | from .tokenizer import get_tokenizer
29 |
30 |
31 | def preprocessing(
32 | x: str,
33 | tokenizer: Callable,
34 | stopwords: Set[str],
35 | stemmer: Callable,
36 | do_lowercasing: bool,
37 | do_ampersand_normalization: bool,
38 | do_special_chars_normalization: bool,
39 | do_acronyms_normalization: bool,
40 | do_punctuation_removal: bool,
41 | ) -> List[str]:
42 | if do_lowercasing:
43 | x = lowercasing(x)
44 | if do_ampersand_normalization:
45 | x = normalize_ampersand(x)
46 | if do_special_chars_normalization:
47 | x = normalize_special_chars(x)
48 | if do_acronyms_normalization:
49 | x = normalize_acronyms(x)
50 |
51 | if tokenizer == str.split and do_punctuation_removal:
52 | x = remove_punctuation(x)
53 | x = strip_whitespaces(x)
54 |
55 | x = tokenizer(x)
56 |
57 | if tokenizer != str.split and do_punctuation_removal:
58 | x = [remove_punctuation(t) for t in x]
59 | x = [t for t in x if t]
60 |
61 | x = [t for t in x if t not in stopwords]
62 |
63 | return [stemmer(t) for t in x]
64 |
65 |
66 | def preprocessing_multi(
67 | tokenizer: callable,
68 | stopwords: List[str],
69 | stemmer: callable,
70 | do_lowercasing: bool,
71 | do_ampersand_normalization: bool,
72 | do_special_chars_normalization: bool,
73 | do_acronyms_normalization: bool,
74 | do_punctuation_removal: bool,
75 | ):
76 | callables = []
77 |
78 | if do_lowercasing:
79 | callables.append(lowercasing)
80 | if do_ampersand_normalization:
81 | callables.append(normalize_ampersand)
82 | if do_special_chars_normalization:
83 | callables.append(normalize_special_chars)
84 | if do_acronyms_normalization:
85 | callables.append(normalize_acronyms)
86 | if tokenizer == str.split and do_punctuation_removal:
87 | callables.append(remove_punctuation)
88 | callables.append(strip_whitespaces)
89 |
90 | callables.append(tokenizer)
91 |
92 | if tokenizer != str.split and do_punctuation_removal:
93 |
94 | def rp(x):
95 | x = [remove_punctuation(t) for t in x]
96 | return [t for t in x if t]
97 |
98 | callables.append(rp)
99 |
100 | def sw(x):
101 | return [t for t in x if t not in stopwords]
102 |
103 | callables.append(sw)
104 |
105 | def stem(x):
106 | return [stemmer(t) for t in x]
107 |
108 | callables.append(stem)
109 |
110 | return Multipipe(callables)
111 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/preprocessing/normalization.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | from unidecode import unidecode
5 |
6 |
7 | def lowercasing(x: str) -> str:
8 | return x.lower()
9 |
10 |
11 | def normalize_ampersand(x: str) -> str:
12 | return x.replace("&", " and ")
13 |
14 |
15 | def normalize_diacritics(x: str) -> str:
16 | return unidecode(x)
17 |
18 |
19 | def normalize_special_chars(x: str) -> str:
20 | special_chars_trans = dict(
21 | [(ord(x), ord(y)) for x, y in zip("‘’´“”–-", "'''\"\"--")]
22 | )
23 | return x.translate(special_chars_trans)
24 |
25 |
26 | def normalize_acronyms(x: str) -> str:
27 | return re.sub(r"\.(?!(\S[^. ])|\d)", "", x)
28 |
29 |
30 | def remove_punctuation(x: str) -> str:
31 | translator = str.maketrans(string.punctuation, " " * len(string.punctuation))
32 | return x.translate(translator)
33 |
34 |
35 | def strip_whitespaces(x: str) -> str:
36 | x = x.strip()
37 |
38 | while " " in x:
39 | x = x.replace(" ", " ")
40 |
41 | return x
42 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/preprocessing/stemmer.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Union
3 |
4 | import nltk
5 | from krovetzstemmer import Stemmer as KrovetzStemmer
6 | from Stemmer import Stemmer as SnowballStemmer
7 |
8 | from .utils import identity_function
9 |
10 | stemmers_dict = {
11 | "krovetz": partial(KrovetzStemmer()),
12 | "porter": partial(nltk.stem.PorterStemmer().stem),
13 | "lancaster": partial(nltk.stem.LancasterStemmer().stem),
14 | "arlstem": partial(nltk.stem.ARLSTem().stem), # Arabic
15 | "arlstem2": partial(nltk.stem.ARLSTem2().stem), # Arabic
16 | "cistem": partial(nltk.stem.Cistem().stem), # German
17 | "isri": partial(nltk.stem.ISRIStemmer().stem), # Arabic
18 | "arabic": partial(SnowballStemmer("arabic").stemWord),
19 | "basque": partial(SnowballStemmer("basque").stemWord),
20 | "catalan": partial(SnowballStemmer("catalan").stemWord),
21 | "danish": partial(SnowballStemmer("danish").stemWord),
22 | "dutch": partial(SnowballStemmer("dutch").stemWord),
23 | "english": partial(nltk.stem.SnowballStemmer("english").stem),
24 | "finnish": partial(SnowballStemmer("finnish").stemWord),
25 | "french": partial(SnowballStemmer("french").stemWord),
26 | "german": partial(SnowballStemmer("german").stemWord),
27 | "greek": partial(SnowballStemmer("greek").stemWord),
28 | "hindi": partial(SnowballStemmer("hindi").stemWord),
29 | "hungarian": partial(SnowballStemmer("hungarian").stemWord),
30 | "indonesian": partial(SnowballStemmer("indonesian").stemWord),
31 | "irish": partial(SnowballStemmer("irish").stemWord),
32 | "italian": partial(SnowballStemmer("italian").stemWord),
33 | "lithuanian": partial(SnowballStemmer("lithuanian").stemWord),
34 | "nepali": partial(SnowballStemmer("nepali").stemWord),
35 | "norwegian": partial(SnowballStemmer("norwegian").stemWord),
36 | "portuguese": partial(SnowballStemmer("portuguese").stemWord),
37 | "romanian": partial(SnowballStemmer("romanian").stemWord),
38 | "russian": partial(SnowballStemmer("russian").stemWord),
39 | "spanish": partial(SnowballStemmer("spanish").stemWord),
40 | "swedish": partial(SnowballStemmer("swedish").stemWord),
41 | "tamil": partial(SnowballStemmer("tamil").stemWord),
42 | "turkish": partial(SnowballStemmer("turkish").stemWord),
43 | }
44 |
45 |
46 | def krovetz_f(x: str) -> str:
47 | return stemmers_dict["krovetz"](x)
48 |
49 |
50 | def porter_f(x: str) -> str:
51 | return stemmers_dict["porter"](x)
52 |
53 |
54 | def lancaster_f(x: str) -> str:
55 | return stemmers_dict["lancaster"](x)
56 |
57 |
58 | def arlstem_f(x: str) -> str:
59 | return stemmers_dict["arlstem"](x)
60 |
61 |
62 | def arlstem2_f(x: str) -> str:
63 | return stemmers_dict["arlstem2"](x)
64 |
65 |
66 | def cistem_f(x: str) -> str:
67 | return stemmers_dict["cistem"](x)
68 |
69 |
70 | def isri_f(x: str) -> str:
71 | return stemmers_dict["isri"](x)
72 |
73 |
74 | def arabic_f(x: str) -> str:
75 | return stemmers_dict["arabic"](x)
76 |
77 |
78 | def basque_f(x: str) -> str:
79 | return stemmers_dict["basque"](x)
80 |
81 |
82 | def catalan_f(x: str) -> str:
83 | return stemmers_dict["catalan"](x)
84 |
85 |
86 | def danish_f(x: str) -> str:
87 | return stemmers_dict["danish"](x)
88 |
89 |
90 | def dutch_f(x: str) -> str:
91 | return stemmers_dict["dutch"](x)
92 |
93 |
94 | def english_f(x: str) -> str:
95 | return stemmers_dict["english"](x)
96 |
97 |
98 | def finnish_f(x: str) -> str:
99 | return stemmers_dict["finnish"](x)
100 |
101 |
102 | def french_f(x: str) -> str:
103 | return stemmers_dict["french"](x)
104 |
105 |
106 | def german_f(x: str) -> str:
107 | return stemmers_dict["german"](x)
108 |
109 |
110 | def greek_f(x: str) -> str:
111 | return stemmers_dict["greek"](x)
112 |
113 |
114 | def hindi_f(x: str) -> str:
115 | return stemmers_dict["hindi"](x)
116 |
117 |
118 | def hungarian_f(x: str) -> str:
119 | return stemmers_dict["hungarian"](x)
120 |
121 |
122 | def indonesian_f(x: str) -> str:
123 | return stemmers_dict["indonesian"](x)
124 |
125 |
126 | def irish_f(x: str) -> str:
127 | return stemmers_dict["irish"](x)
128 |
129 |
130 | def italian_f(x: str) -> str:
131 | return stemmers_dict["italian"](x)
132 |
133 |
134 | def lithuanian_f(x: str) -> str:
135 | return stemmers_dict["lithuanian"](x)
136 |
137 |
138 | def nepali_f(x: str) -> str:
139 | return stemmers_dict["nepali"](x)
140 |
141 |
142 | def norwegian_f(x: str) -> str:
143 | return stemmers_dict["norwegian"](x)
144 |
145 |
146 | def portuguese_f(x: str) -> str:
147 | return stemmers_dict["portuguese"](x)
148 |
149 |
150 | def romanian_f(x: str) -> str:
151 | return stemmers_dict["romanian"](x)
152 |
153 |
154 | def russian_f(x: str) -> str:
155 | return stemmers_dict["russian"](x)
156 |
157 |
158 | def spanish_f(x: str) -> str:
159 | return stemmers_dict["spanish"](x)
160 |
161 |
162 | def swedish_f(x: str) -> str:
163 | return stemmers_dict["swedish"](x)
164 |
165 |
166 | def tamil_f(x: str) -> str:
167 | return stemmers_dict["tamil"](x)
168 |
169 |
170 | def turkish_f(x: str) -> str:
171 | return stemmers_dict["turkish"](x)
172 |
173 |
174 | stemmers_f_dict = {
175 | "krovetz": krovetz_f,
176 | "porter": porter_f,
177 | "lancaster": lancaster_f,
178 | "arlstem": arlstem_f,
179 | "arlstem2": arlstem2_f,
180 | "cistem": cistem_f,
181 | "isri": isri_f,
182 | "arabic": arabic_f,
183 | "basque": basque_f,
184 | "catalan": catalan_f,
185 | "danish": danish_f,
186 | "dutch": dutch_f,
187 | "english": english_f,
188 | "finnish": finnish_f,
189 | "french": french_f,
190 | "german": german_f,
191 | "greek": greek_f,
192 | "hindi": hindi_f,
193 | "hungarian": hungarian_f,
194 | "indonesian": indonesian_f,
195 | "irish": irish_f,
196 | "italian": italian_f,
197 | "lithuanian": lithuanian_f,
198 | "nepali": nepali_f,
199 | "norwegian": norwegian_f,
200 | "portuguese": portuguese_f,
201 | "romanian": romanian_f,
202 | "russian": russian_f,
203 | "spanish": spanish_f,
204 | "swedish": swedish_f,
205 | "tamil": tamil_f,
206 | "turkish": turkish_f,
207 | }
208 |
209 |
210 | def _get_stemmer(stemmer: str) -> callable:
211 | assert stemmer.lower() in stemmers_f_dict, f"Stemmer {stemmer} not supported."
212 | return stemmers_f_dict[stemmer.lower()]
213 |
214 |
215 | def get_stemmer(stemmer: Union[str, callable, bool]) -> callable:
216 | if isinstance(stemmer, str):
217 | return _get_stemmer(stemmer)
218 | elif callable(stemmer):
219 | return stemmer
220 | elif stemmer is None:
221 | return identity_function
222 | else:
223 | raise (NotImplementedError)
224 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/preprocessing/stopwords.py:
--------------------------------------------------------------------------------
1 | from typing import List, Set, Union
2 |
3 | import nltk
4 |
5 | supported_languages = {
6 | "arabic",
7 | "azerbaijani",
8 | "basque",
9 | "bengali",
10 | "catalan",
11 | "chinese",
12 | "danish",
13 | "dutch",
14 | "english",
15 | "finnish",
16 | "french",
17 | "german",
18 | "greek",
19 | "hebrew",
20 | "hinglish",
21 | "hungarian",
22 | "indonesian",
23 | "italian",
24 | "kazakh",
25 | "nepali",
26 | "norwegian",
27 | "portuguese",
28 | "romanian",
29 | "russian",
30 | "slovene",
31 | "spanish",
32 | "swedish",
33 | "tajik",
34 | "turkish",
35 | }
36 |
37 |
38 | def _get_stopwords(lang: str) -> List[str]:
39 | nltk.download("stopwords", quiet=True)
40 | assert (
41 | lang.lower() in supported_languages
42 | ), f"Stop-words for {lang.capitalize()} are not available."
43 | return nltk.corpus.stopwords.words(lang)
44 |
45 |
46 | def get_stopwords(sw_list: Union[str, List[str], Set[str], bool]) -> List[str]:
47 | if isinstance(sw_list, str):
48 | return _get_stopwords(sw_list)
49 | elif type(sw_list) is list and all(isinstance(x, str) for x in sw_list):
50 | return sw_list
51 | elif type(sw_list) is set:
52 | return list(sw_list)
53 | elif sw_list is None:
54 | return []
55 | else:
56 | raise (NotImplementedError)
57 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/preprocessing/tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import nltk
4 |
5 | from .utils import identity_function
6 |
7 | tokenizers_dict = {
8 | "whitespace": str.split,
9 | "word": nltk.tokenize.word_tokenize,
10 | "wordpunct": nltk.tokenize.wordpunct_tokenize,
11 | "sent": nltk.tokenize.sent_tokenize,
12 | }
13 |
14 |
15 | def _get_tokenizer(tokenizer: str) -> callable:
16 | assert tokenizer.lower() in tokenizers_dict, f"Tokenizer {tokenizer} not supported."
17 | if tokenizer == "punkt":
18 | nltk.download("punkt", quiet=True)
19 | return tokenizers_dict[tokenizer.lower()]
20 |
21 |
22 | def get_tokenizer(tokenizer: Union[str, callable, bool]) -> callable:
23 | if isinstance(tokenizer, str):
24 | return _get_tokenizer(tokenizer)
25 | elif callable(tokenizer):
26 | return tokenizer
27 | elif tokenizer is None:
28 | return identity_function
29 | else:
30 | raise (NotImplementedError)
31 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/preprocessing/utils.py:
--------------------------------------------------------------------------------
1 | def identity_function(x):
2 | return x
3 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/sparse_retrieval_models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/sparse_retriever/sparse_retrieval_models/__init__.py
--------------------------------------------------------------------------------
/retriv/sparse_retriever/sparse_retrieval_models/bm25.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numba as nb
4 | import numpy as np
5 | from numba import njit, prange
6 | from numba.typed import List as TypedList
7 |
8 | from ...utils.numba_utils import (
9 | intersect_sorted,
10 | intersect_sorted_multi,
11 | union_sorted_multi,
12 | unsorted_top_k,
13 | )
14 |
15 |
16 | @njit(cache=True)
17 | def bm25(
18 | b: float,
19 | k1: float,
20 | term_doc_freqs: nb.typed.List[np.ndarray],
21 | doc_ids: nb.typed.List[np.ndarray],
22 | relative_doc_lens: nb.typed.List[np.ndarray],
23 | doc_count: int,
24 | cutoff: int,
25 | operator: str = "OR",
26 | subset_doc_ids: np.ndarray = None,
27 | ) -> Tuple[np.ndarray]:
28 | if operator == "AND":
29 | unique_doc_ids = intersect_sorted_multi(doc_ids)
30 | elif operator == "OR":
31 | unique_doc_ids = union_sorted_multi(doc_ids)
32 |
33 | if subset_doc_ids is not None:
34 | unique_doc_ids = intersect_sorted(unique_doc_ids, subset_doc_ids)
35 |
36 | scores = np.empty(doc_count, dtype=np.float32)
37 | scores[unique_doc_ids] = 0.0 # Initialize scores
38 |
39 | for i in range(len(term_doc_freqs)):
40 | indices = doc_ids[i]
41 | freqs = term_doc_freqs[i]
42 |
43 | df = np.float32(len(indices))
44 | idf = np.float32(np.log(1.0 + (((doc_count - df) + 0.5) / (df + 0.5))))
45 |
46 | scores[indices] += idf * (
47 | (freqs * (k1 + 1.0))
48 | / (freqs + k1 * (1.0 - b + (b * relative_doc_lens[indices])))
49 | )
50 |
51 | scores = scores[unique_doc_ids]
52 |
53 | if cutoff < len(scores):
54 | scores, indices = unsorted_top_k(scores, cutoff)
55 | unique_doc_ids = unique_doc_ids[indices]
56 |
57 | indices = np.argsort(-scores)
58 |
59 | return unique_doc_ids[indices], scores[indices]
60 |
61 |
62 | @njit(cache=True, parallel=True)
63 | def bm25_multi(
64 | b: float,
65 | k1: float,
66 | term_doc_freqs: nb.typed.List[nb.typed.List[np.ndarray]],
67 | doc_ids: nb.typed.List[nb.typed.List[np.ndarray]],
68 | relative_doc_lens: nb.typed.List[np.ndarray],
69 | doc_count: int,
70 | cutoff: int,
71 | ) -> Tuple[nb.typed.List[np.ndarray]]:
72 | unique_doc_ids = TypedList([np.empty(1, dtype=np.int32) for _ in doc_ids])
73 | scores = TypedList([np.empty(1, dtype=np.float32) for _ in doc_ids])
74 |
75 | for i in prange(len(term_doc_freqs)):
76 | _term_doc_freqs = term_doc_freqs[i]
77 | _doc_ids = doc_ids[i]
78 |
79 | _unique_doc_ids = union_sorted_multi(_doc_ids)
80 |
81 | _scores = np.empty(doc_count, dtype=np.float32)
82 | _scores[_unique_doc_ids] = 0.0 # Initialize _scores
83 |
84 | for j in range(len(_term_doc_freqs)):
85 | indices = _doc_ids[j]
86 | freqs = _term_doc_freqs[j]
87 |
88 | df = np.float32(len(indices))
89 | idf = np.float32(np.log(1.0 + (((doc_count - df) + 0.5) / (df + 0.5))))
90 |
91 | _scores[indices] += idf * (
92 | (freqs * (k1 + 1.0))
93 | / (freqs + k1 * (1.0 - b + (b * relative_doc_lens[indices])))
94 | )
95 |
96 | _scores = _scores[_unique_doc_ids]
97 |
98 | if cutoff < len(_scores):
99 | _scores, indices = unsorted_top_k(_scores, cutoff)
100 | _unique_doc_ids = _unique_doc_ids[indices]
101 |
102 | indices = np.argsort(_scores)[::-1]
103 |
104 | unique_doc_ids[i] = _unique_doc_ids[indices]
105 | scores[i] = _scores[indices]
106 |
107 | return unique_doc_ids, scores
108 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/sparse_retrieval_models/tf_idf.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numba as nb
4 | import numpy as np
5 | from numba import njit, prange
6 | from numba.typed import List as TypedList
7 |
8 | from ...utils.numba_utils import (
9 | intersect_sorted,
10 | intersect_sorted_multi,
11 | union_sorted_multi,
12 | unsorted_top_k,
13 | )
14 |
15 |
16 | @njit(cache=True)
17 | def tf_idf(
18 | term_doc_freqs: nb.typed.List[np.ndarray],
19 | doc_ids: nb.typed.List[np.ndarray],
20 | doc_lens: nb.typed.List[np.ndarray],
21 | cutoff: int,
22 | operator: str = "OR",
23 | subset_doc_ids: np.ndarray = None,
24 | ) -> Tuple[np.ndarray]:
25 | if operator == "AND":
26 | unique_doc_ids = intersect_sorted_multi(doc_ids)
27 | elif operator == "OR":
28 | unique_doc_ids = union_sorted_multi(doc_ids)
29 |
30 | if subset_doc_ids is not None:
31 | unique_doc_ids = intersect_sorted(unique_doc_ids, subset_doc_ids)
32 |
33 | doc_count = len(doc_lens)
34 | scores = np.empty(doc_count, dtype=np.float32)
35 | scores[unique_doc_ids] = 0.0 # Initialize scores
36 |
37 | for i in range(len(term_doc_freqs)):
38 | indices = doc_ids[i]
39 | freqs = term_doc_freqs[i]
40 |
41 | tf = freqs / doc_lens[indices]
42 |
43 | df = np.float32(len(indices))
44 | idf = np.float32(np.log((1.0 + doc_count) / (1.0 + df)) + 1.0)
45 |
46 | scores[indices] += tf * idf
47 |
48 | scores = scores[unique_doc_ids]
49 |
50 | if cutoff < len(scores):
51 | scores, indices = unsorted_top_k(scores, cutoff)
52 | unique_doc_ids = unique_doc_ids[indices]
53 |
54 | indices = np.argsort(-scores)
55 |
56 | return unique_doc_ids[indices], scores[indices]
57 |
58 |
59 | @njit(cache=True, parallel=True)
60 | def tf_idf_multi(
61 | term_doc_freqs: nb.typed.List[nb.typed.List[np.ndarray]],
62 | doc_ids: nb.typed.List[nb.typed.List[np.ndarray]],
63 | doc_lens: nb.typed.List[nb.typed.List[np.ndarray]],
64 | cutoff: int,
65 | ) -> Tuple[nb.typed.List[np.ndarray]]:
66 | unique_doc_ids = TypedList([np.empty(1, dtype=np.int32) for _ in doc_ids])
67 | scores = TypedList([np.empty(1, dtype=np.float32) for _ in doc_ids])
68 |
69 | for i in prange(len(term_doc_freqs)):
70 | _term_doc_freqs = term_doc_freqs[i]
71 | _doc_ids = doc_ids[i]
72 |
73 | _unique_doc_ids = union_sorted_multi(_doc_ids)
74 |
75 | doc_count = len(doc_lens)
76 | _scores = np.empty(doc_count, dtype=np.float32)
77 | _scores[_unique_doc_ids] = 0.0 # Initialize _scores
78 |
79 | for j in range(len(_term_doc_freqs)):
80 | indices = _doc_ids[j]
81 | freqs = _term_doc_freqs[j]
82 |
83 | tf = freqs / doc_lens[indices]
84 |
85 | df = np.float32(len(indices))
86 | idf = np.float32(np.log((1.0 + doc_count) / (1.0 + df)) + 1.0)
87 |
88 | _scores[indices] += tf * idf
89 |
90 | _scores = _scores[_unique_doc_ids]
91 |
92 | if cutoff < len(_scores):
93 | _scores, indices = unsorted_top_k(_scores, cutoff)
94 | _unique_doc_ids = _unique_doc_ids[indices]
95 |
96 | indices = np.argsort(_scores)[::-1]
97 |
98 | unique_doc_ids[i] = _unique_doc_ids[indices]
99 | scores[i] = _scores[indices]
100 |
101 | return unique_doc_ids, scores
102 |
--------------------------------------------------------------------------------
/retriv/sparse_retriever/sparse_retriever.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Iterable, List, Set, Union
3 |
4 | import numba as nb
5 | import numpy as np
6 | import orjson
7 | from numba.typed import List as TypedList
8 | from oneliner_utils import create_path, read_jsonl
9 | from tqdm import tqdm
10 |
11 | from ..autotune import tune_bm25
12 | from ..base_retriever import BaseRetriever
13 | from ..paths import docs_path, sr_state_path
14 | from .build_inverted_index import build_inverted_index
15 | from .preprocessing import (
16 | get_stemmer,
17 | get_stopwords,
18 | get_tokenizer,
19 | preprocessing,
20 | preprocessing_multi,
21 | )
22 | from .sparse_retrieval_models.bm25 import bm25, bm25_multi
23 | from .sparse_retrieval_models.tf_idf import tf_idf, tf_idf_multi
24 |
25 |
26 | class SparseRetriever(BaseRetriever):
27 | def __init__(
28 | self,
29 | index_name: str = "new-index",
30 | model: str = "bm25",
31 | min_df: int = 1,
32 | tokenizer: Union[str, callable] = "whitespace",
33 | stemmer: Union[str, callable] = "english",
34 | stopwords: Union[str, List[str], Set[str]] = "english",
35 | do_lowercasing: bool = True,
36 | do_ampersand_normalization: bool = True,
37 | do_special_chars_normalization: bool = True,
38 | do_acronyms_normalization: bool = True,
39 | do_punctuation_removal: bool = True,
40 | hyperparams: dict = None,
41 | ):
42 | """The Sparse Retriever is a traditional searcher based on lexical matching. It supports BM25, the retrieval model used by major search engines libraries, such as Lucene and Elasticsearch. retriv also implements the classic relevance model TF-IDF for educational purposes.
43 |
44 | Args:
45 | index_name (str, optional): [retriv](https://github.com/AmenRa/retriv) will use `index_name` as the identifier of your index. Defaults to "new-index".
46 |
47 | model (str, optional): defines the retrieval model to use for searching (`bm25` or `tf-idf`). Defaults to "bm25".
48 |
49 | min_df (int, optional): terms that appear in less than `min_df` documents will be ignored. If integer, the parameter indicates the absolute count. If float, it represents a proportion of documents. Defaults to 1.
50 |
51 | tokenizer (Union[str, callable], optional): [tokenizer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable tokenizer or disable tokenization by setting the parameter to `None`. Defaults to "whitespace".
52 |
53 | stemmer (Union[str, callable], optional): [stemmer](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to use during preprocessing. You can pass a custom callable stemmer or disable stemming setting the parameter to `None`. Defaults to "english".
54 |
55 | stopwords (Union[str, List[str], Set[str]], optional): [stopwords](https://github.com/AmenRa/retriv/blob/main/docs/text_preprocessing.md) to remove during preprocessing. You can pass a custom stop-word list or disable stop-words removal by setting the parameter to `None`. Defaults to "english".
56 |
57 | do_lowercasing (bool, optional): whether or not to lowercase texts. Defaults to True.
58 |
59 | do_ampersand_normalization (bool, optional): whether to convert `&` in `and` during pre-processing. Defaults to True.
60 |
61 | do_special_chars_normalization (bool, optional): whether to remove special characters for letters, e.g., `übermensch` → `ubermensch`. Defaults to True.
62 |
63 | do_acronyms_normalization (bool, optional): whether to remove full stop symbols from acronyms without splitting them in multiple words, e.g., `P.C.I.` → `PCI`. Defaults to True.
64 |
65 | do_punctuation_removal (bool, optional): whether to remove punctuation. Defaults to True.
66 |
67 | hyperparams (dict, optional): Retrieval model hyperparams. If `None`, it is automatically set to `{b: 0.75, k1: 1.2}`. Defaults to None.
68 | """
69 |
70 | assert model.lower() in {"bm25", "tf-idf"}
71 | assert min_df > 0, "`min_df` must be greater than zero."
72 | self.init_args = {
73 | "model": model.lower(),
74 | "min_df": min_df,
75 | "index_name": index_name,
76 | "do_lowercasing": do_lowercasing,
77 | "do_ampersand_normalization": do_ampersand_normalization,
78 | "do_special_chars_normalization": do_special_chars_normalization,
79 | "do_acronyms_normalization": do_acronyms_normalization,
80 | "do_punctuation_removal": do_punctuation_removal,
81 | "tokenizer": tokenizer,
82 | "stemmer": stemmer,
83 | "stopwords": stopwords,
84 | }
85 |
86 | self.model = model.lower()
87 | self.min_df = min_df
88 | self.index_name = index_name
89 |
90 | self.do_lowercasing = do_lowercasing
91 | self.do_ampersand_normalization = do_ampersand_normalization
92 | self.do_special_chars_normalization = do_special_chars_normalization
93 | self.do_acronyms_normalization = do_acronyms_normalization
94 | self.do_punctuation_removal = do_punctuation_removal
95 |
96 | self.tokenizer = get_tokenizer(tokenizer)
97 | self.stemmer = get_stemmer(stemmer)
98 | self.stopwords = [self.stemmer(sw) for sw in get_stopwords(stopwords)]
99 |
100 | self.id_mapping = None
101 | self.inverted_index = None
102 | self.vocabulary = None
103 | self.doc_count = None
104 | self.doc_lens = None
105 | self.avg_doc_len = None
106 | self.relative_doc_lens = None
107 | self.doc_index = None
108 |
109 | self.preprocessing_kwargs = {
110 | "tokenizer": self.tokenizer,
111 | "stemmer": self.stemmer,
112 | "stopwords": self.stopwords,
113 | "do_lowercasing": self.do_lowercasing,
114 | "do_ampersand_normalization": self.do_ampersand_normalization,
115 | "do_special_chars_normalization": self.do_special_chars_normalization,
116 | "do_acronyms_normalization": self.do_acronyms_normalization,
117 | "do_punctuation_removal": self.do_punctuation_removal,
118 | }
119 |
120 | self.preprocessing_pipe = preprocessing_multi(**self.preprocessing_kwargs)
121 |
122 | self.hyperparams = dict(b=0.75, k1=1.2) if hyperparams is None else hyperparams
123 |
124 | def save(self) -> None:
125 | """Save the state of the retriever to be able to restore it later."""
126 |
127 | state = {
128 | "init_args": self.init_args,
129 | "id_mapping": self.id_mapping,
130 | "doc_count": self.doc_count,
131 | "inverted_index": self.inverted_index,
132 | "vocabulary": self.vocabulary,
133 | "doc_lens": self.doc_lens,
134 | "relative_doc_lens": self.relative_doc_lens,
135 | "hyperparams": self.hyperparams,
136 | }
137 |
138 | np.savez_compressed(sr_state_path(self.index_name), state=state)
139 |
140 | @staticmethod
141 | def load(index_name: str = "new-index"):
142 | """Load a retriever and its index.
143 |
144 | Args:
145 | index_name (str, optional): Name of the index. Defaults to "new-index".
146 |
147 | Returns:
148 | SparseRetriever: Sparse Retriever.
149 | """
150 |
151 | state = np.load(sr_state_path(index_name), allow_pickle=True)["state"][()]
152 |
153 | se = SparseRetriever(**state["init_args"])
154 | se.initialize_doc_index()
155 | se.id_mapping = state["id_mapping"]
156 | se.doc_count = state["doc_count"]
157 | se.inverted_index = state["inverted_index"]
158 | se.vocabulary = set(se.inverted_index)
159 | se.doc_lens = state["doc_lens"]
160 | se.relative_doc_lens = state["relative_doc_lens"]
161 | se.hyperparams = state["hyperparams"]
162 |
163 | state = {
164 | "init_args": se.init_args,
165 | "id_mapping": se.id_mapping,
166 | "doc_count": se.doc_count,
167 | "inverted_index": se.inverted_index,
168 | "vocabulary": se.vocabulary,
169 | "doc_lens": se.doc_lens,
170 | "relative_doc_lens": se.relative_doc_lens,
171 | "hyperparams": se.hyperparams,
172 | }
173 |
174 | return se
175 |
176 | def index_aux(self, show_progress: bool = True):
177 | """Internal usage."""
178 | collection = read_jsonl(
179 | docs_path(self.index_name),
180 | generator=True,
181 | callback=lambda x: x["text"],
182 | )
183 |
184 | # Preprocessing --------------------------------------------------------
185 | collection = self.preprocessing_pipe(collection, generator=True)
186 |
187 | # Inverted index -------------------------------------------------------
188 | (
189 | self.inverted_index,
190 | self.doc_lens,
191 | self.relative_doc_lens,
192 | ) = build_inverted_index(
193 | collection=collection,
194 | n_docs=self.doc_count,
195 | min_df=self.min_df,
196 | show_progress=show_progress,
197 | )
198 | self.avg_doc_len = np.mean(self.doc_lens, dtype=np.float32)
199 | self.vocabulary = set(self.inverted_index)
200 |
201 | def index(
202 | self,
203 | collection: Iterable,
204 | callback: callable = None,
205 | show_progress: bool = True,
206 | ):
207 | """Index a given collection of documents.
208 |
209 | Args:
210 | collection (Iterable): collection of documents to index.
211 |
212 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None.
213 |
214 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True.
215 |
216 | Returns:
217 | SparseRetriever: Sparse Retriever.
218 | """
219 |
220 | self.save_collection(collection, callback)
221 | self.initialize_doc_index()
222 | self.initialize_id_mapping()
223 | self.doc_count = len(self.id_mapping)
224 | self.index_aux(show_progress)
225 | self.save()
226 | return self
227 |
228 | def index_file(
229 | self, path: str, callback: callable = None, show_progress: bool = True
230 | ):
231 | """Index the collection contained in a given file.
232 |
233 | Args:
234 | path (str): path of file containing the collection to index.
235 |
236 | callback (callable, optional): callback to apply before indexing the documents to modify them on the fly if needed. Defaults to None.
237 |
238 | show_progress (bool, optional): whether to show a progress bar for the indexing process. Defaults to True.
239 |
240 | Returns:
241 | SparseRetriever: Sparse Retriever
242 | """
243 |
244 | collection = self.collection_generator(path=path, callback=callback)
245 | return self.index(collection=collection, show_progress=show_progress)
246 |
247 | # SEARCH ===================================================================
248 | def query_preprocessing(self, query: str) -> List[str]:
249 | """Internal usage."""
250 | return preprocessing(query, **self.preprocessing_kwargs)
251 |
252 | def get_term_doc_freqs(self, query_terms: List[str]) -> nb.types.List:
253 | """Internal usage."""
254 | return TypedList([self.inverted_index[t]["tfs"] for t in query_terms])
255 |
256 | def get_doc_ids(self, query_terms: List[str]) -> nb.types.List:
257 | """Internal usage."""
258 | return TypedList([self.inverted_index[t]["doc_ids"] for t in query_terms])
259 |
260 | def search(self, query: str, return_docs: bool = True, cutoff: int = 100) -> List:
261 | """Standard search functionality.
262 |
263 | Args:
264 | query (str): what to search for.
265 |
266 | return_docs (bool, optional): wether to return the texts of the documents. Defaults to True.
267 |
268 | cutoff (int, optional): number of results to return. Defaults to 100.
269 |
270 | Returns:
271 | List: results.
272 | """
273 |
274 | query_terms = self.query_preprocessing(query)
275 | if not query_terms:
276 | return {}
277 | query_terms = [t for t in query_terms if t in self.vocabulary]
278 | if not query_terms:
279 | return {}
280 |
281 | doc_ids = self.get_doc_ids(query_terms)
282 | term_doc_freqs = self.get_term_doc_freqs(query_terms)
283 |
284 | if self.model == "bm25":
285 | unique_doc_ids, scores = bm25(
286 | term_doc_freqs=term_doc_freqs,
287 | doc_ids=doc_ids,
288 | relative_doc_lens=self.relative_doc_lens,
289 | doc_count=self.doc_count,
290 | cutoff=cutoff,
291 | **self.hyperparams,
292 | )
293 | elif self.model == "tf-idf":
294 | unique_doc_ids, scores = tf_idf(
295 | term_doc_freqs=term_doc_freqs,
296 | doc_ids=doc_ids,
297 | doc_lens=self.doc_lens,
298 | cutoff=cutoff,
299 | )
300 | else:
301 | raise NotImplementedError()
302 |
303 | unique_doc_ids = self.map_internal_ids_to_original_ids(unique_doc_ids)
304 |
305 | if not return_docs:
306 | return dict(zip(unique_doc_ids, scores))
307 |
308 | return self.prepare_results(unique_doc_ids, scores)
309 |
310 | def msearch(self, queries: List[Dict[str, str]], cutoff: int = 100) -> Dict:
311 | """Compute results for multiple queries at once.
312 |
313 | Args:
314 | queries (List[Dict[str, str]]): what to search for.
315 |
316 | cutoff (int, optional): number of results to return. Defaults to 100.
317 |
318 | Returns:
319 | Dict: results.
320 | """
321 |
322 | term_doc_freqs = TypedList()
323 | doc_ids = TypedList()
324 | q_ids = []
325 | no_results_q_ids = []
326 |
327 | for q in queries:
328 | q_id, query = q["id"], q["text"]
329 | query_terms = self.query_preprocessing(query)
330 | query_terms = [t for t in query_terms if t in self.vocabulary]
331 | if not query_terms:
332 | no_results_q_ids.append(q_id)
333 | continue
334 |
335 | if all(t not in self.inverted_index for t in query_terms):
336 | no_results_q_ids.append(q_id)
337 | continue
338 |
339 | q_ids.append(q_id)
340 | term_doc_freqs.append(self.get_term_doc_freqs(query_terms))
341 | doc_ids.append(self.get_doc_ids(query_terms))
342 |
343 | if not q_ids:
344 | return {q_id: {} for q_id in [q["id"] for q in queries]}
345 |
346 | if self.model == "bm25":
347 | unique_doc_ids, scores = bm25_multi(
348 | term_doc_freqs=term_doc_freqs,
349 | doc_ids=doc_ids,
350 | relative_doc_lens=self.relative_doc_lens,
351 | doc_count=self.doc_count,
352 | cutoff=cutoff,
353 | **self.hyperparams,
354 | )
355 | elif self.model == "tf-idf":
356 | unique_doc_ids, scores = tf_idf_multi(
357 | term_doc_freqs=term_doc_freqs,
358 | doc_ids=doc_ids,
359 | doc_lens=self.doc_lens,
360 | cutoff=cutoff,
361 | )
362 | else:
363 | raise NotImplementedError()
364 |
365 | unique_doc_ids = [
366 | self.map_internal_ids_to_original_ids(_unique_doc_ids)
367 | for _unique_doc_ids in unique_doc_ids
368 | ]
369 |
370 | results = {
371 | q: dict(zip(unique_doc_ids[i], scores[i])) for i, q in enumerate(q_ids)
372 | }
373 |
374 | for q_id in no_results_q_ids:
375 | results[q_id] = {}
376 |
377 | # Order as queries
378 | return {q_id: results[q_id] for q_id in [q["id"] for q in queries]}
379 |
380 | def bsearch(
381 | self,
382 | queries: List[Dict[str, str]],
383 | cutoff: int = 100,
384 | batch_size: int = 1_000,
385 | show_progress: bool = True,
386 | qrels: Dict[str, Dict[str, float]] = None,
387 | path: str = None,
388 | ) -> Dict:
389 | """Batch-Search is similar to Multi-Search but automatically generates batches of queries to evaluate and allows dynamic writing of the search results to disk in [JSONl](https://jsonlines.org) format. bsearch is handy for computing results for hundreds of thousands or even millions of queries without hogging your RAM.
390 |
391 | Args:
392 | queries (List[Dict[str, str]]): what to search for.
393 |
394 | cutoff (int, optional): number of results to return. Defaults to 100.
395 |
396 | batch_size (int, optional): number of query to perform simultaneously. Defaults to 1_000.
397 |
398 | show_progress (bool, optional): whether to show a progress bar for the search process. Defaults to True.
399 |
400 | qrels (Dict[str, Dict[str, float]], optional): query relevance judgements for the queries. Defaults to None.
401 |
402 | path (str, optional): where to save the results. Defaults to None.
403 |
404 | Returns:
405 | Dict: results.
406 | """
407 |
408 | batches = [
409 | queries[i : i + batch_size] for i in range(0, len(queries), batch_size)
410 | ]
411 |
412 | results = {}
413 |
414 | pbar = tqdm(
415 | total=len(queries),
416 | disable=not show_progress,
417 | desc="Batch search",
418 | dynamic_ncols=True,
419 | mininterval=0.5,
420 | )
421 |
422 | if path is None:
423 | for batch in batches:
424 | new_results = self.msearch(queries=batch, cutoff=cutoff)
425 | results = {**results, **new_results}
426 | pbar.update(min(batch_size, len(batch)))
427 | else:
428 | path = create_path(path)
429 | path.parent.mkdir(parents=True, exist_ok=True)
430 |
431 | with open(path, "wb") as f:
432 | for batch in batches:
433 | new_results = self.msearch(queries=batch, cutoff=cutoff)
434 |
435 | for i, (k, v) in enumerate(new_results.items()):
436 | x = {
437 | "id": k,
438 | "text": batch[i]["text"],
439 | f"{self.model}_doc_ids": list(v.keys()),
440 | f"{self.model}_scores": [
441 | float(s) for s in list(v.values())
442 | ],
443 | }
444 | if qrels is not None:
445 | x["rel_doc_ids"] = list(qrels[k].keys())
446 | x["rel_scores"] = list(qrels[k].values())
447 | f.write(orjson.dumps(x) + "\n".encode())
448 |
449 | pbar.update(min(batch_size, len(batch)))
450 |
451 | return results
452 |
453 | def autotune(
454 | self,
455 | queries: List[Dict[str, str]],
456 | qrels: Dict[str, Dict[str, float]],
457 | metric: str = "ndcg",
458 | n_trials: int = 100,
459 | cutoff: int = 100,
460 | ):
461 | """Use the AutoTune function to tune [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) parameters w.r.t. your document collection and queries.
462 | All metrics supported by [ranx](https://github.com/AmenRa/ranx) are supported by the `autotune` function. At the of the process, the best parameter configuration is automatically applied to the `SparseRetriever` instance and saved to disk. You can inspect the current configuration by printing `sr.hyperparams`.
463 |
464 | Args:
465 | queries (List[Dict[str, str]]): queries to use for the optimization process.
466 |
467 | qrels (Dict[str, Dict[str, float]]): query relevance judgements for the queries.
468 |
469 | metric (str, optional): metric to optimize for. Defaults to "ndcg".
470 |
471 | n_trials (int, optional): number of configuration to evaluate. Defaults to 100.
472 |
473 | cutoff (int, optional): number of results to consider for the optimization process. Defaults to 100.
474 | """
475 |
476 | hyperparams = tune_bm25(
477 | queries=queries,
478 | qrels=qrels,
479 | se=self,
480 | metric=metric,
481 | n_trials=n_trials,
482 | cutoff=cutoff,
483 | )
484 | self.hyperparams = hyperparams
485 | self.save()
486 |
--------------------------------------------------------------------------------
/retriv/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmenRa/retriv/0c418a87f06a66e89d388ea3bf52575faf287d91/retriv/utils/__init__.py
--------------------------------------------------------------------------------
/retriv/utils/numba_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numba import njit
3 |
4 |
5 | # UNION ------------------------------------------------------------------------
6 | @njit(cache=True)
7 | def union_sorted(a1: np.array, a2: np.array):
8 | result = np.empty(len(a1) + len(a2), dtype=np.int32)
9 | i = 0
10 | j = 0
11 | k = 0
12 |
13 | while i < len(a1) and j < len(a2):
14 | if a1[i] < a2[j]:
15 | result[k] = a1[i]
16 | i += 1
17 | elif a1[i] > a2[j]:
18 | result[k] = a2[j]
19 | j += 1
20 | else: # a1[i] == a2[j]
21 | result[k] = a1[i]
22 | i += 1
23 | j += 1
24 | k += 1
25 |
26 | result = result[:k]
27 |
28 | if i < len(a1):
29 | result = np.concatenate((result, a1[i:]))
30 | elif j < len(a2):
31 | result = np.concatenate((result, a2[j:]))
32 |
33 | return result
34 |
35 |
36 | @njit(cache=True)
37 | def union_sorted_multi(arrays):
38 | if len(arrays) == 1:
39 | return arrays[0]
40 | elif len(arrays) == 2:
41 | return union_sorted(arrays[0], arrays[1])
42 | else:
43 | return union_sorted(
44 | union_sorted_multi(arrays[:2]), union_sorted_multi(arrays[2:])
45 | )
46 |
47 |
48 | # INTERSECTION -----------------------------------------------------------------
49 | @njit(cache=True)
50 | def intersect_sorted(a1: np.array, a2: np.array):
51 | result = np.empty(min(len(a1), len(a2)), dtype=np.int32)
52 | i = 0
53 | j = 0
54 | k = 0
55 |
56 | while i < len(a1) and j < len(a2):
57 | if a1[i] < a2[j]:
58 | i += 1
59 | elif a1[i] > a2[j]:
60 | j += 1
61 | else: # a1[i] == a2[j]
62 | result[k] = a1[i]
63 | i += 1
64 | j += 1
65 | k += 1
66 |
67 | return result[:k]
68 |
69 |
70 | @njit(cache=True)
71 | def intersect_sorted_multi(arrays):
72 | a = arrays[0]
73 |
74 | for i in range(1, len(arrays)):
75 | a = intersect_sorted(a, arrays[i])
76 |
77 | return a
78 |
79 |
80 | # DIFFERENCE -------------------------------------------------------------------
81 | @njit(cache=True)
82 | def diff_sorted(a1: np.array, a2: np.array):
83 | result = np.empty(len(a1), dtype=np.int32)
84 | i = 0
85 | j = 0
86 | k = 0
87 |
88 | while i < len(a1) and j < len(a2):
89 | if a1[i] < a2[j]:
90 | result[k] = a1[i]
91 | i += 1
92 | k += 1
93 | elif a1[i] > a2[j]:
94 | j += 1
95 | else: # a1[i] == a2[j]
96 | i += 1
97 | j += 1
98 |
99 | result = result[:k]
100 |
101 | if i < len(a1):
102 | result = np.concatenate((result, a1[i:]))
103 |
104 | return result
105 |
106 |
107 | # -----------------------------------------------------------------------------
108 | @njit(cache=True)
109 | def concat1d(X):
110 | out = np.empty(sum([len(x) for x in X]), dtype=X[0].dtype)
111 |
112 | i = 0
113 | for x in X:
114 | for j in range(len(x)):
115 | out[i] = x[j]
116 | i = i + 1
117 |
118 | return out
119 |
120 |
121 | @njit(cache=True)
122 | def get_indices(array, scores):
123 | n_scores = len(scores)
124 | min_score = min(scores)
125 | max_score = max(scores)
126 | indices = np.full(n_scores, -1, dtype=np.int64)
127 | counter = 0
128 |
129 | for i in range(len(array)):
130 | if array[i] >= min_score and array[i] <= max_score:
131 | for j in range(len(scores)):
132 | if indices[j] == -1:
133 | if scores[j] == array[i]:
134 | indices[j] = i
135 | counter += 1
136 | if len(indices) == counter:
137 | return indices
138 | break
139 |
140 | return indices
141 |
142 |
143 | @njit(cache=True)
144 | def unsorted_top_k(array: np.ndarray, k: int):
145 | top_k_values = np.zeros(k, dtype=np.float32)
146 | top_k_indices = np.zeros(k, dtype=np.int32)
147 |
148 | min_value = 0.0
149 | min_value_idx = 0
150 |
151 | for i, value in enumerate(array):
152 | if value > min_value:
153 | top_k_values[min_value_idx] = value
154 | top_k_indices[min_value_idx] = i
155 | min_value_idx = top_k_values.argmin()
156 | min_value = top_k_values[min_value_idx]
157 |
158 | return top_k_values, top_k_indices
159 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="retriv",
8 | version="0.2.3",
9 | author="Elias Bassani",
10 | author_email="elias.bssn@gmail.com",
11 | description="retriv: A Python Search Engine for Humans.",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/AmenRa/retriv",
15 | packages=setuptools.find_packages(),
16 | install_requires=[
17 | "numpy",
18 | "nltk",
19 | "numba>=0.54.1",
20 | "tqdm",
21 | "optuna",
22 | "krovetzstemmer",
23 | "pystemmer==2.0.1",
24 | "unidecode",
25 | "scikit-learn",
26 | "ranx",
27 | "indxr",
28 | "oneliner_utils",
29 | "torch",
30 | "torchvision",
31 | "torchaudio",
32 | "transformers[torch]",
33 | "faiss-cpu",
34 | "autofaiss",
35 | "multipipe",
36 | ],
37 | classifiers=[
38 | "Programming Language :: Python :: 3",
39 | "License :: OSI Approved :: MIT License",
40 | "Intended Audience :: Science/Research",
41 | "Operating System :: OS Independent",
42 | "Topic :: Text Processing :: General",
43 | ],
44 | keywords=[
45 | "information retrieval",
46 | "search engine",
47 | "bm25",
48 | "numba",
49 | "sparse retrieval",
50 | "dense retrieval",
51 | "hybrid retrieval",
52 | "neural information retrieval",
53 | ],
54 | python_requires=">=3.8",
55 | )
56 |
--------------------------------------------------------------------------------
/test_env.yml:
--------------------------------------------------------------------------------
1 | name: asd
2 | dependencies:
3 | - python=3.8
4 | - black
5 | - conda-forge::icecream
6 | - isort
7 | - ipykernel
8 | - ipywidgets
9 | - conda-forge::mkdocs-material
10 | - conda-forge::mkdocs-autorefs
11 | - conda-forge::mkdocstrings
12 | - conda-forge::mkdocstrings-python
13 | - pygments>=2.12
14 | - notebook
15 | - pytest
16 | # Pip
17 | - pip
18 | - pip:
19 | - retriv
20 |
21 |
--------------------------------------------------------------------------------
/tests/dense_retriever/encoder_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from retriv.dense_retriever.encoder import Encoder
4 |
5 |
6 | # FIXTURES =====================================================================
7 | @pytest.fixture
8 | def encoder():
9 | return Encoder(model="sentence-transformers/all-MiniLM-L6-v2")
10 |
11 |
12 | @pytest.fixture
13 | def texts():
14 | return [
15 | "Generals gathered in their masses",
16 | "Just like witches at black masses",
17 | "Evil minds that plot destruction",
18 | "Sorcerer of death's construction",
19 | ]
20 |
21 |
22 | # TESTS ========================================================================
23 | def test_call(encoder, texts):
24 | embeddings = encoder(texts)
25 | assert embeddings.shape[0] == 4
26 | assert embeddings.shape[1] == encoder.embedding_dim
27 |
--------------------------------------------------------------------------------
/tests/merger/merger_test.py:
--------------------------------------------------------------------------------
1 | from math import isclose
2 |
3 | import pytest
4 |
5 | from retriv.merger.merger import Merger
6 | from retriv.merger.normalization import min_max_norm
7 |
8 |
9 | # FIXTURES =====================================================================
10 | @pytest.fixture
11 | def merger():
12 | return Merger()
13 |
14 |
15 | @pytest.fixture
16 | def run_a():
17 | return {
18 | "q1": {
19 | "d1": 2.0,
20 | "d2": 0.7,
21 | "d3": 0.5,
22 | },
23 | "q2": {
24 | "d1": 1.0,
25 | "d2": 0.7,
26 | "d3": 0.5,
27 | },
28 | }
29 |
30 |
31 | @pytest.fixture
32 | def run_b():
33 | return {
34 | "q1": {
35 | "d3": 2.0,
36 | "d1": 0.7,
37 | },
38 | "q2": {
39 | "d1": 1.0,
40 | "d2": 0.7,
41 | "d3": 0.5,
42 | },
43 | }
44 |
45 |
46 | # TESTS ========================================================================
47 | def test_fuse(merger, run_a, run_b):
48 | fused_results = merger.fuse([run_a["q1"], run_b["q1"]])
49 |
50 | norm_run_a = min_max_norm(run_a)
51 | norm_run_b = min_max_norm(run_b)
52 |
53 | assert isclose(fused_results["d1"], norm_run_a["q1"]["d1"] + norm_run_b["q1"]["d1"])
54 | assert isclose(fused_results["d2"], norm_run_a["q1"]["d2"])
55 | assert isclose(fused_results["d3"], norm_run_a["q1"]["d3"] + norm_run_b["q1"]["d3"])
56 |
57 |
58 | def test_mfuse(merger, run_a, run_b):
59 | fused_run = merger.mfuse([run_a, run_b])
60 |
61 | norm_run_a = min_max_norm(run_a)
62 | norm_run_b = min_max_norm(run_b)
63 |
64 | assert isclose(
65 | fused_run["q1"]["d1"], norm_run_a["q1"]["d1"] + norm_run_b["q1"]["d1"]
66 | )
67 | assert isclose(fused_run["q1"]["d2"], norm_run_a["q1"]["d2"])
68 | assert isclose(
69 | fused_run["q1"]["d3"], norm_run_a["q1"]["d3"] + norm_run_b["q1"]["d3"]
70 | )
71 |
72 | assert isclose(
73 | fused_run["q2"]["d1"], norm_run_a["q2"]["d1"] + norm_run_b["q2"]["d1"]
74 | )
75 | assert isclose(
76 | fused_run["q2"]["d2"], norm_run_a["q2"]["d2"] + norm_run_b["q2"]["d2"]
77 | )
78 | assert isclose(
79 | fused_run["q2"]["d3"], norm_run_a["q2"]["d3"] + norm_run_b["q2"]["d3"]
80 | )
81 |
--------------------------------------------------------------------------------
/tests/merger/score_normalization_test.py:
--------------------------------------------------------------------------------
1 | from math import isclose
2 |
3 | import pytest
4 |
5 | from retriv.merger.normalization import (
6 | extract_scores,
7 | max_norm,
8 | max_norm_multi,
9 | min_max_norm,
10 | min_max_norm_multi,
11 | safe_max,
12 | safe_min,
13 | sum_norm,
14 | sum_norm_multi,
15 | )
16 |
17 |
18 | # FIXTURES =====================================================================
19 | @pytest.fixture
20 | def run_a():
21 | return {
22 | "q1": {
23 | "d1": 2.0,
24 | "d2": 0.7,
25 | "d3": 0.5,
26 | },
27 | "q2": {
28 | "d1": 1.0,
29 | "d2": 0.7,
30 | "d3": 0.5,
31 | },
32 | }
33 |
34 |
35 | @pytest.fixture
36 | def run_b():
37 | return {
38 | "q1": {
39 | "d3": 2.0,
40 | "d1": 0.7,
41 | },
42 | "q2": {
43 | "d1": 1.0,
44 | "d2": 0.7,
45 | "d3": 0.5,
46 | },
47 | }
48 |
49 |
50 | # TESTS ========================================================================
51 | def test_extract_scores(run_a, run_b):
52 | assert extract_scores(run_a["q1"]) == [2.0, 0.7, 0.5]
53 | assert extract_scores(run_a["q2"]) == [1.0, 0.7, 0.5]
54 | assert extract_scores(run_b["q1"]) == [2.0, 0.7]
55 | assert extract_scores(run_b["q2"]) == [1.0, 0.7, 0.5]
56 |
57 |
58 | def test_safe_max():
59 | assert safe_max([1.0, 0.7, 0.5]) == 1.0
60 | assert safe_max([]) == 0.0
61 |
62 |
63 | def test_safe_min():
64 | assert safe_min([1.0, 0.7, 0.5]) == 0.5
65 | assert safe_min([]) == 0.0
66 |
67 |
68 | def test_max_norm(run_a, run_b):
69 | normalized_run = max_norm(run_a)
70 |
71 | assert isclose(normalized_run["q1"]["d1"], 1.0)
72 | assert isclose(normalized_run["q1"]["d2"], 0.7 / 2.0)
73 | assert isclose(normalized_run["q1"]["d3"], 0.5 / 2.0)
74 |
75 | assert isclose(normalized_run["q2"]["d1"], 1.0)
76 | assert isclose(normalized_run["q2"]["d2"], 0.7)
77 | assert isclose(normalized_run["q2"]["d3"], 0.5)
78 |
79 | normalized_run = max_norm(run_b)
80 |
81 | assert isclose(normalized_run["q1"]["d3"], 1.0)
82 | assert isclose(normalized_run["q1"]["d1"], 0.7 / 2.0)
83 |
84 | assert isclose(normalized_run["q2"]["d1"], 1.0)
85 | assert isclose(normalized_run["q2"]["d2"], 0.7)
86 | assert isclose(normalized_run["q2"]["d3"], 0.5)
87 |
88 |
89 | def test_min_max_norm(run_a, run_b):
90 | normalized_run = min_max_norm(run_a)
91 |
92 | assert isclose(normalized_run["q1"]["d1"], (2.0 - 0.5) / (2.0 - 0.5))
93 | assert isclose(normalized_run["q1"]["d2"], (0.7 - 0.5) / (2.0 - 0.5))
94 | assert isclose(normalized_run["q1"]["d3"], (0.5 - 0.5) / (2.0 - 0.5))
95 |
96 | assert isclose(normalized_run["q2"]["d1"], (1.0 - 0.5) / (1.0 - 0.5))
97 | assert isclose(normalized_run["q2"]["d2"], (0.7 - 0.5) / (1.0 - 0.5))
98 | assert isclose(normalized_run["q2"]["d3"], (0.5 - 0.5) / (1.0 - 0.5))
99 |
100 | normalized_run = min_max_norm(run_b)
101 |
102 | assert isclose(normalized_run["q1"]["d3"], (2.0 - 0.7) / (2.0 - 0.7))
103 | assert isclose(normalized_run["q1"]["d1"], (0.7 - 0.7) / (2.0 - 0.7))
104 |
105 | assert isclose(normalized_run["q2"]["d1"], (1.0 - 0.5) / (1.0 - 0.5))
106 | assert isclose(normalized_run["q2"]["d2"], (0.7 - 0.5) / (1.0 - 0.5))
107 | assert isclose(normalized_run["q2"]["d3"], (0.5 - 0.5) / (1.0 - 0.5))
108 |
109 |
110 | def test_sum_norm(run_a, run_b):
111 | normalized_run = sum_norm(run_a)
112 |
113 | denominator = (2.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5)
114 | assert isclose(normalized_run["q1"]["d1"], (2.0 - 0.5) / denominator)
115 | assert isclose(normalized_run["q1"]["d2"], (0.7 - 0.5) / denominator)
116 | assert isclose(normalized_run["q1"]["d3"], (0.5 - 0.5) / denominator)
117 |
118 | denominator = (1.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5)
119 | assert isclose(normalized_run["q2"]["d1"], (1.0 - 0.5) / denominator)
120 | assert isclose(normalized_run["q2"]["d2"], (0.7 - 0.5) / denominator)
121 | assert isclose(normalized_run["q2"]["d3"], (0.5 - 0.5) / denominator)
122 |
123 | normalized_run = sum_norm(run_b)
124 |
125 | denominator = (1.0 - 0.7) + (0.7 - 0.7)
126 | assert isclose(normalized_run["q1"]["d3"], (1.0 - 0.7) / denominator)
127 | assert isclose(normalized_run["q1"]["d1"], (0.7 - 0.7) / denominator)
128 |
129 | denominator = (1.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5)
130 | assert isclose(normalized_run["q2"]["d1"], (1.0 - 0.5) / denominator)
131 | assert isclose(normalized_run["q2"]["d2"], (0.7 - 0.5) / denominator)
132 | assert isclose(normalized_run["q2"]["d3"], (0.5 - 0.5) / denominator)
133 |
134 |
135 | def test_max_norm_multi(run_a, run_b):
136 | runs = [run_a, run_b]
137 | normalized_runs = max_norm_multi(runs)
138 |
139 | assert isclose(normalized_runs[0]["q1"]["d1"], 1.0)
140 | assert isclose(normalized_runs[0]["q1"]["d2"], 0.7 / 2.0)
141 | assert isclose(normalized_runs[0]["q1"]["d3"], 0.5 / 2.0)
142 | assert isclose(normalized_runs[0]["q2"]["d1"], 1.0)
143 | assert isclose(normalized_runs[0]["q2"]["d2"], 0.7)
144 | assert isclose(normalized_runs[0]["q2"]["d3"], 0.5)
145 |
146 | assert isclose(normalized_runs[1]["q1"]["d3"], 1.0)
147 | assert isclose(normalized_runs[1]["q1"]["d1"], 0.7 / 2.0)
148 | assert isclose(normalized_runs[1]["q2"]["d1"], 1.0)
149 | assert isclose(normalized_runs[1]["q2"]["d2"], 0.7)
150 | assert isclose(normalized_runs[1]["q2"]["d3"], 0.5)
151 |
152 |
153 | def test_min_max_norm_multi(run_a, run_b):
154 | runs = [run_a, run_b]
155 | normalized_runs = min_max_norm_multi(runs)
156 |
157 | assert isclose(normalized_runs[0]["q1"]["d1"], (2.0 - 0.5) / (2.0 - 0.5))
158 | assert isclose(normalized_runs[0]["q1"]["d2"], (0.7 - 0.5) / (2.0 - 0.5))
159 | assert isclose(normalized_runs[0]["q1"]["d3"], (0.5 - 0.5) / (2.0 - 0.5))
160 | assert isclose(normalized_runs[0]["q2"]["d1"], (1.0 - 0.5) / (1.0 - 0.5))
161 | assert isclose(normalized_runs[0]["q2"]["d2"], (0.7 - 0.5) / (1.0 - 0.5))
162 | assert isclose(normalized_runs[0]["q2"]["d3"], (0.5 - 0.5) / (1.0 - 0.5))
163 |
164 | assert isclose(normalized_runs[1]["q1"]["d3"], (2.0 - 0.7) / (2.0 - 0.7))
165 | assert isclose(normalized_runs[1]["q1"]["d1"], (0.7 - 0.7) / (2.0 - 0.7))
166 | assert isclose(normalized_runs[1]["q2"]["d1"], (1.0 - 0.5) / (1.0 - 0.5))
167 | assert isclose(normalized_runs[1]["q2"]["d2"], (0.7 - 0.5) / (1.0 - 0.5))
168 | assert isclose(normalized_runs[1]["q2"]["d3"], (0.5 - 0.5) / (1.0 - 0.5))
169 |
170 |
171 | def test_sum_norm_multi(run_a, run_b):
172 | runs = [run_a, run_b]
173 | normalized_runs = sum_norm_multi(runs)
174 |
175 | denominator = (2.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5)
176 | assert isclose(normalized_runs[0]["q1"]["d1"], (2.0 - 0.5) / denominator)
177 | assert isclose(normalized_runs[0]["q1"]["d2"], (0.7 - 0.5) / denominator)
178 | assert isclose(normalized_runs[0]["q1"]["d3"], (0.5 - 0.5) / denominator)
179 |
180 | denominator = (1.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5)
181 | assert isclose(normalized_runs[0]["q2"]["d1"], (1.0 - 0.5) / denominator)
182 | assert isclose(normalized_runs[0]["q2"]["d2"], (0.7 - 0.5) / denominator)
183 | assert isclose(normalized_runs[0]["q2"]["d3"], (0.5 - 0.5) / denominator)
184 |
185 | denominator = (1.0 - 0.7) + (0.7 - 0.7)
186 | assert isclose(normalized_runs[1]["q1"]["d3"], (1.0 - 0.7) / denominator)
187 | assert isclose(normalized_runs[1]["q1"]["d1"], (0.7 - 0.7) / denominator)
188 | denominator = (1.0 - 0.5) + (0.7 - 0.5) + (0.5 - 0.5)
189 | assert isclose(normalized_runs[1]["q2"]["d1"], (1.0 - 0.5) / denominator)
190 | assert isclose(normalized_runs[1]["q2"]["d2"], (0.7 - 0.5) / denominator)
191 | assert isclose(normalized_runs[1]["q2"]["d3"], (0.5 - 0.5) / denominator)
192 |
--------------------------------------------------------------------------------
/tests/numba_utils_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from numba.typed import List as TypedList
4 |
5 | from retriv.utils.numba_utils import (
6 | concat1d,
7 | diff_sorted,
8 | get_indices,
9 | intersect_sorted,
10 | intersect_sorted_multi,
11 | union_sorted,
12 | union_sorted_multi,
13 | unsorted_top_k,
14 | )
15 |
16 |
17 | # TESTS ========================================================================
18 | def test_union_sorted():
19 | a1 = np.array([1, 3, 4, 7], dtype=np.int32)
20 | a2 = np.array([1, 4, 7, 9], dtype=np.int32)
21 | result = union_sorted(a1, a2)
22 | expected = np.array([1, 3, 4, 7, 9], dtype=np.int32)
23 |
24 | assert np.array_equal(result, expected)
25 |
26 |
27 | def test_union_sorted_multi():
28 | a1 = np.array([1, 3, 4, 7], dtype=np.int32)
29 | a2 = np.array([1, 4, 7, 9], dtype=np.int32)
30 | a3 = np.array([10, 11], dtype=np.int32)
31 | a4 = np.array([11, 12, 13], dtype=np.int32)
32 |
33 | arrays = TypedList([a1, a2, a3, a4])
34 |
35 | result = union_sorted_multi(arrays)
36 | expected = np.array([1, 3, 4, 7, 9, 10, 11, 12, 13], dtype=np.int32)
37 |
38 | assert np.array_equal(result, expected)
39 |
40 |
41 | def test_intersect_sorted():
42 | a1 = np.array([1, 3, 4, 7], dtype=np.int32)
43 | a2 = np.array([1, 4, 7, 9], dtype=np.int32)
44 | result = intersect_sorted(a1, a2)
45 | expected = np.array([1, 4, 7], dtype=np.int32)
46 |
47 | assert np.array_equal(result, expected)
48 |
49 |
50 | def test_intersect_sorted_multi():
51 | a1 = np.array([1, 3, 4, 7], dtype=np.int32)
52 | a2 = np.array([1, 4, 7, 9], dtype=np.int32)
53 | a3 = np.array([4, 7], dtype=np.int32)
54 | a4 = np.array([3, 7, 9], dtype=np.int32)
55 |
56 | arrays = TypedList([a1, a2, a3, a4])
57 |
58 | result = intersect_sorted_multi(arrays)
59 | expected = np.array([7], dtype=np.int32)
60 |
61 | print(result)
62 |
63 | assert np.array_equal(result, expected)
64 |
65 |
66 | def test_diff_sorted():
67 | a1 = np.array([1, 3, 4, 7], dtype=np.int32)
68 | a2 = np.array([1, 4, 7, 9], dtype=np.int32)
69 | result = diff_sorted(a1, a2)
70 | expected = np.array([3], dtype=np.int32)
71 |
72 | assert np.array_equal(result, expected)
73 |
74 | a1 = np.array([1, 3, 4, 7, 11], dtype=np.int32)
75 | a2 = np.array([1, 4, 7, 9], dtype=np.int32)
76 | result = diff_sorted(a1, a2)
77 | expected = np.array([3, 11], dtype=np.int32)
78 |
79 | assert np.array_equal(result, expected)
80 |
81 |
82 | def test_concat1d():
83 | a1 = np.array([1, 3, 4, 7], dtype=np.int32)
84 | a2 = np.array([1, 4, 7, 9], dtype=np.int32)
85 | a3 = np.array([10, 11], dtype=np.int32)
86 | a4 = np.array([11, 12, 13], dtype=np.int32)
87 |
88 | arrays = TypedList([a1, a2, a3, a4])
89 |
90 | result = concat1d(arrays)
91 | expected = np.array([1, 3, 4, 7, 1, 4, 7, 9, 10, 11, 11, 12, 13], dtype=np.int32)
92 |
93 | assert np.array_equal(result, expected)
94 |
95 |
96 | def test_get_indices():
97 | array = np.array([0.1, 0.3, 0.2, 0.4], dtype=np.float32)
98 | scores = np.array([0.4, 0.3, 0.2, 0.1], dtype=np.float32)
99 |
100 | result = get_indices(array, scores)
101 | expected = np.array([3, 1, 2, 0], dtype=np.int64)
102 |
103 | assert np.array_equal(result, expected)
104 |
105 |
106 | def test_unsorted_top_k():
107 | array = np.array([0.1, 0.3, 0.2, 0.4], dtype=np.float32)
108 | k = 2
109 |
110 | top_k_values, top_k_indices = unsorted_top_k(array, k)
111 |
112 | assert len(top_k_values) == 2
113 | assert 0.3 in top_k_values
114 | assert 0.4 in top_k_values
115 | assert len(top_k_indices) == 2
116 | assert 1 in top_k_indices
117 | assert 3 in top_k_indices
118 |
--------------------------------------------------------------------------------
/tests/sparse_retriever/preprocessing_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from retriv.sparse_retriever.preprocessing import preprocessing, preprocessing_multi
4 | from retriv.sparse_retriever.preprocessing.stemmer import get_stemmer
5 | from retriv.sparse_retriever.preprocessing.stopwords import get_stopwords
6 | from retriv.sparse_retriever.preprocessing.tokenizer import get_tokenizer
7 |
8 |
9 | # FIXTURES =====================================================================
10 | @pytest.fixture
11 | def stemmer():
12 | return get_stemmer("english")
13 |
14 |
15 | @pytest.fixture
16 | def stopwords():
17 | return get_stopwords("english")
18 |
19 |
20 | @pytest.fixture
21 | def tokenizer():
22 | return get_tokenizer("whitespace")
23 |
24 |
25 | @pytest.fixture
26 | def docs():
27 | return """Black Sabbath were an English rock band formed in Birmingham in 1968 by guitarist Tony Iommi, drummer Bill Ward, bassist Geezer Butler and vocalist Ozzy Osbourne. They are often cited as pioneers of heavy metal music. The band helped define the genre with releases such as Black Sabbath (1970), Paranoid (1970) and Master of Reality (1971). The band had multiple line-up changes following Osbourne's departure in 1979 and Iommi is the only constant member throughout their history.
28 | After previous iterations of the group - the Polka Tulk Blues Band and Earth - the band settled on the name Black Sabbath in 1969. They distinguished themselves through occult themes with horror-inspired lyrics and down-tuned guitars. Signing to Philips Records in November 1969, they released their first single, "Evil Woman", in January 1970, and their debut album, Black Sabbath, was released the following month. Though it received a negative critical response, the album was a commercial success, leading to a follow-up record, Paranoid, later that year. The band's popularity grew, and by 1973's Sabbath Bloody Sabbath, critics were starting to respond favourably.
29 | Osbourne's excessive substance abuse led to his firing in 1979. He was replaced by former Rainbow vocalist Ronnie James Dio. Following two albums with Dio, Black Sabbath endured many personnel changes in the 1980s and 1990s that included vocalists Ian Gillan, Glenn Hughes, Ray Gillen and Tony Martin, as well as several drummers and bassists. Martin, who replaced Gillen in 1987, was the second-longest serving vocalist and recorded three albums with Black Sabbath before his dismissal in 1991. That same year, Iommi and Butler were rejoined by Dio and drummer Vinny Appice to record Dehumanizer (1992). After two more studio albums with Martin, who replaced Dio in 1993, the band's original line-up reunited in 1997 and released a live album, Reunion, the following year; they continued to tour occasionally until 2005. Other than various back catalogue reissues and compilation albums, as well as the Mob Rules-era line-up reunited as Heaven & Hell, there was no further activity under the Black Sabbath name for six years. They reunited in 2011 and released their final studio album and 19th overall, 13, in 2013, which features all of the original members except Ward. During their farewell tour, the band played their final concert in their home city of Birmingham on 4 February 2017. Occasional partial reunions have happened since, most recently when Osbourne and Iommi performed together at the closing ceremony of the 2022 Commonwealth Games in Birmingham.
30 | Black Sabbath have sold over 70 million records worldwide as of 2013, making them one of the most commercially successful heavy metal bands. Black Sabbath, together with Deep Purple and Led Zeppelin, have been referred to as the "unholy trinity of British hard rock and heavy metal in the early to mid-seventies". They were ranked by MTV as the "Greatest Metal Band of All Time" and placed second on VH1's "100 Greatest Artists of Hard Rock" list. Rolling Stone magazine ranked them number 85 on their "100 Greatest Artists of All Time". Black Sabbath were inducted into the UK Music Hall of Fame in 2005 and the Rock and Roll Hall of Fame in 2006. They have also won two Grammy Awards for Best Metal Performance, and in 2019 the band were presented a Grammy Lifetime Achievement Award.""".split(
31 | "\n"
32 | )
33 |
34 |
35 | # TESTS ========================================================================
36 | def test_preprocessing_multi(docs, stemmer, stopwords, tokenizer):
37 | out = [
38 | preprocessing(
39 | doc,
40 | stemmer=stemmer,
41 | stopwords=stopwords,
42 | tokenizer=tokenizer,
43 | do_lowercasing=True,
44 | do_ampersand_normalization=True,
45 | do_special_chars_normalization=True,
46 | do_acronyms_normalization=True,
47 | do_punctuation_removal=True,
48 | )
49 | for doc in docs
50 | ]
51 |
52 | pipeline = preprocessing_multi(
53 | tokenizer=tokenizer,
54 | stopwords=stopwords,
55 | stemmer=stemmer,
56 | do_lowercasing=True,
57 | do_ampersand_normalization=True,
58 | do_special_chars_normalization=True,
59 | do_acronyms_normalization=True,
60 | do_punctuation_removal=True,
61 | )
62 | multi_out = pipeline(docs)
63 |
64 | assert len(out) == len(multi_out)
65 | assert out[0] == multi_out[0]
66 | assert out[1] == multi_out[1]
67 | assert out[2] == multi_out[2]
68 | assert out[3] == multi_out[3]
69 | assert out == multi_out
70 |
--------------------------------------------------------------------------------
/tests/sparse_retriever/search_engine_test.py:
--------------------------------------------------------------------------------
1 | from math import isclose
2 |
3 | import pytest
4 |
5 | from retriv import SearchEngine
6 |
7 | REL_TOL = 1e-6
8 |
9 |
10 | # FIXTURES =====================================================================
11 | @pytest.fixture
12 | def collection():
13 | return [
14 | {"id": 1, "text": "Shane"},
15 | {"id": 2, "text": "Shane C"},
16 | {"id": 3, "text": "Shane P Connelly"},
17 | {"id": 4, "text": "Shane Connelly"},
18 | {"id": 5, "text": "Shane Shane Connelly Connelly"},
19 | {"id": 6, "text": "Shane Shane Shane Connelly Connelly Connelly"},
20 | ]
21 |
22 |
23 | def test_search_bm25(collection):
24 | se = SearchEngine(hyperparams=dict(b=0.5, k1=0))
25 | se.index(collection)
26 |
27 | query = "shane"
28 |
29 | results = se.search(query=query, return_docs=False)
30 |
31 | print(se.inverted_index)
32 |
33 | print(results)
34 | assert isclose(results[1], 0.07410797, rel_tol=REL_TOL)
35 | assert isclose(results[2], 0.07410797, rel_tol=REL_TOL)
36 | assert isclose(results[3], 0.07410797, rel_tol=REL_TOL)
37 | assert isclose(results[4], 0.07410797, rel_tol=REL_TOL)
38 | assert isclose(results[5], 0.07410797, rel_tol=REL_TOL)
39 | assert isclose(results[6], 0.07410797, rel_tol=REL_TOL)
40 |
41 | se.hyperparams = dict(b=0, k1=10)
42 | results = se.search(query=query, return_docs=False)
43 | print(results)
44 | assert isclose(results[1], 0.07410797, rel_tol=REL_TOL)
45 | assert isclose(results[2], 0.07410797, rel_tol=REL_TOL)
46 | assert isclose(results[3], 0.07410797, rel_tol=REL_TOL)
47 | assert isclose(results[4], 0.07410797, rel_tol=REL_TOL)
48 | assert isclose(results[5], 0.13586462, rel_tol=REL_TOL)
49 | assert isclose(results[6], 0.18812023, rel_tol=REL_TOL)
50 |
51 | se.hyperparams = dict(b=1, k1=5)
52 | results = se.search(query=query, return_docs=False)
53 | print(results)
54 | assert isclose(results[1], 0.16674294, rel_tol=REL_TOL)
55 | assert isclose(results[2], 0.10261103, rel_tol=REL_TOL)
56 | assert isclose(results[3], 0.07410797, rel_tol=REL_TOL)
57 | assert isclose(results[4], 0.10261103, rel_tol=REL_TOL)
58 | assert isclose(results[5], 0.10261103, rel_tol=REL_TOL)
59 | assert isclose(results[6], 0.10261105, rel_tol=REL_TOL)
60 |
61 |
62 | def test_msearch_bm25(collection):
63 | se = SearchEngine(hyperparams=dict(b=0.5, k1=0))
64 | se.index(collection)
65 |
66 | queries = [
67 | {"id": "q_1", "text": "shane"},
68 | {"id": "q_2", "text": "connelly"},
69 | ]
70 |
71 | results = se.msearch(queries=queries)
72 |
73 | print(results)
74 | assert isclose(results["q_1"][1], 0.07410797, rel_tol=REL_TOL)
75 | assert isclose(results["q_1"][2], 0.07410797, rel_tol=REL_TOL)
76 | assert isclose(results["q_1"][3], 0.07410797, rel_tol=REL_TOL)
77 | assert isclose(results["q_1"][4], 0.07410797, rel_tol=REL_TOL)
78 | assert isclose(results["q_1"][5], 0.07410797, rel_tol=REL_TOL)
79 | assert isclose(results["q_1"][6], 0.07410797, rel_tol=REL_TOL)
80 | assert isclose(results["q_2"][3], 0.44183275, rel_tol=REL_TOL)
81 | assert isclose(results["q_2"][4], 0.44183275, rel_tol=REL_TOL)
82 | assert isclose(results["q_2"][5], 0.44183275, rel_tol=REL_TOL)
83 | assert isclose(results["q_2"][6], 0.44183275, rel_tol=REL_TOL)
84 |
85 | se.hyperparams = dict(b=0, k1=10)
86 | results = se.msearch(queries=queries)
87 | print(results)
88 | assert isclose(results["q_1"][1], 0.07410797, rel_tol=REL_TOL)
89 | assert isclose(results["q_1"][2], 0.07410797, rel_tol=REL_TOL)
90 | assert isclose(results["q_1"][3], 0.07410797, rel_tol=REL_TOL)
91 | assert isclose(results["q_1"][4], 0.07410797, rel_tol=REL_TOL)
92 | assert isclose(results["q_1"][5], 0.13586462, rel_tol=REL_TOL)
93 | assert isclose(results["q_1"][6], 0.18812023, rel_tol=REL_TOL)
94 | assert isclose(results["q_2"][3], 0.44183275, rel_tol=REL_TOL)
95 | assert isclose(results["q_2"][4], 0.44183275, rel_tol=REL_TOL)
96 | assert isclose(results["q_2"][5], 0.8100267, rel_tol=REL_TOL)
97 | assert isclose(results["q_2"][6], 1.1215755, rel_tol=REL_TOL)
98 |
99 | se.hyperparams = dict(b=1, k1=5)
100 | results = se.msearch(queries=queries)
101 | print(results)
102 | assert isclose(results["q_1"][1], 0.16674294, rel_tol=REL_TOL)
103 | assert isclose(results["q_1"][2], 0.10261103, rel_tol=REL_TOL)
104 | assert isclose(results["q_1"][3], 0.07410797, rel_tol=REL_TOL)
105 | assert isclose(results["q_1"][4], 0.10261103, rel_tol=REL_TOL)
106 | assert isclose(results["q_1"][5], 0.10261103, rel_tol=REL_TOL)
107 | assert isclose(results["q_1"][6], 0.10261105, rel_tol=REL_TOL)
108 | assert isclose(results["q_2"][3], 0.44183275, rel_tol=REL_TOL)
109 | assert isclose(results["q_2"][4], 0.6117684, rel_tol=REL_TOL)
110 | assert isclose(results["q_2"][5], 0.6117684, rel_tol=REL_TOL)
111 | assert isclose(results["q_2"][6], 0.6117684, rel_tol=REL_TOL)
112 |
--------------------------------------------------------------------------------
/tests/sparse_retriever/stemmer_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from retriv.sparse_retriever.preprocessing.stemmer import get_stemmer
4 |
5 |
6 | # FIXTURES =====================================================================
7 | @pytest.fixture
8 | def supported_stemmers():
9 | # fmt: off
10 | return [
11 | "krovetz", "porter", "lancaster", "arlstem", "arlstem2", "cistem",
12 | "isri", "arabic",
13 | "basque", "catalan", "danish", "dutch", "english", "finnish", "french", "german", "greek", "hindi", "hungarian", "indonesian", "irish", "italian", "lithuanian", "nepali", "norwegian", "portuguese", "romanian", "russian",
14 | "spanish", "swedish", "tamil", "turkish",
15 | ]
16 | # fmt: on
17 |
18 |
19 | # TESTS ========================================================================
20 | def test_get_stemmer(supported_stemmers):
21 | for stemmer in supported_stemmers:
22 | assert callable(get_stemmer(stemmer))
23 |
24 |
25 | def test_get_stemmer_fails():
26 | with pytest.raises(Exception):
27 | get_stemmer("foobar")
28 |
29 |
30 | def test_get_stemmer_callable():
31 | assert callable(get_stemmer(lambda x: x))
32 |
33 |
34 | def test_get_stemmer_none():
35 | assert callable(get_stemmer(None))
36 | assert get_stemmer(None)("incredible") == "incredible"
37 |
--------------------------------------------------------------------------------
/tests/sparse_retriever/stopwords_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from retriv.sparse_retriever.preprocessing.stopwords import get_stopwords
4 |
5 |
6 | # FIXTURES =====================================================================
7 | @pytest.fixture
8 | def supported_languages():
9 | # fmt: off
10 | return [
11 | "arabic", "azerbaijani", "basque", "catalan", "bengali", "chinese",
12 | "danish", "dutch", "english", "finnish", "french", "german", "greek",
13 | "hebrew", "hinglish", "hungarian", "indonesian", "italian", "kazakh",
14 | "nepali", "norwegian", "portuguese", "romanian", "russian", "slovene",
15 | "spanish", "swedish", "tajik", "turkish",
16 | ]
17 | # fmt: on
18 |
19 |
20 | @pytest.fixture
21 | def sw_list():
22 | return ["a", "the"]
23 |
24 |
25 | @pytest.fixture
26 | def sw_set():
27 | return {"a", "the"}
28 |
29 |
30 | # TESTS ========================================================================
31 | def test_get_stopwords_from_lang(supported_languages):
32 | for lang in supported_languages:
33 | assert type(get_stopwords(lang)) == list
34 | assert len(get_stopwords(lang)) > 0
35 |
36 |
37 | def test_get_stopwords_from_lang_fails():
38 | with pytest.raises(Exception):
39 | get_stopwords("foobar")
40 |
41 |
42 | def test_get_stopwords_from_list(sw_list):
43 | assert type(get_stopwords(sw_list)) == list
44 | assert set(get_stopwords(sw_list)) == {"a", "the"}
45 | assert len(get_stopwords(sw_list)) > 0
46 |
47 |
48 | def test_get_stopwords_from_set(sw_set):
49 | assert type(get_stopwords(sw_set)) == list
50 | assert set(get_stopwords(sw_set)) == {"a", "the"}
51 | assert len(get_stopwords(sw_set)) > 0
52 |
53 |
54 | def test_get_stopwords_none():
55 | assert type(get_stopwords(None)) == list
56 | assert len(get_stopwords(None)) == 0
57 |
--------------------------------------------------------------------------------
/tests/sparse_retriever/text_normalization_test.py:
--------------------------------------------------------------------------------
1 | from retriv.sparse_retriever.preprocessing.normalization import (
2 | lowercasing,
3 | normalize_acronyms,
4 | normalize_ampersand,
5 | normalize_special_chars,
6 | remove_punctuation,
7 | strip_whitespaces,
8 | )
9 |
10 |
11 | # TESTS ========================================================================
12 | def test_lowercasing():
13 | assert lowercasing("hEllO") == "hello"
14 |
15 |
16 | def test_normalize_ampersand():
17 | assert normalize_ampersand("black&sabbath") == "black and sabbath"
18 |
19 |
20 | def test_normalize_special_chars():
21 | assert normalize_special_chars("‘’") == "''"
22 |
23 |
24 | def test_normalize_acronyms():
25 | assert normalize_acronyms("a.b.c.") == "abc"
26 | assert normalize_acronyms("foo.bar") == "foo.bar"
27 | assert normalize_acronyms("a.b@hello.com") == "a.b@hello.com"
28 |
29 |
30 | def test_remove_punctuation():
31 | assert remove_punctuation("foo.bar?") == "foo bar "
32 | # assert remove_punctuation("a.b@hello.com") == "a.b@hello.com"
33 |
34 |
35 | def test_strip_whitespaces():
36 | assert strip_whitespaces(" hello world ") == "hello world"
37 |
--------------------------------------------------------------------------------
/tests/sparse_retriever/tokenizer_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from retriv.sparse_retriever.preprocessing.tokenizer import get_tokenizer
4 |
5 |
6 | # FIXTURES =====================================================================
7 | @pytest.fixture
8 | def supported_tokenizers():
9 | return ["whitespace", "word", "wordpunct", "sent"]
10 |
11 |
12 | # TESTS ========================================================================
13 | def test_get_tokenizer(supported_tokenizers):
14 | for tokenizer in supported_tokenizers:
15 | assert callable(get_tokenizer(tokenizer))
16 |
17 |
18 | def test_get_tokenizer_fails():
19 | with pytest.raises(Exception):
20 | get_tokenizer("foobar")
21 |
22 |
23 | def test_get_tokenizer_callable():
24 | assert callable(get_tokenizer(lambda x: x))
25 |
26 |
27 | def test_get_tokenizer_none():
28 | assert callable(get_tokenizer(None))
29 | assert get_tokenizer(None)("black sabbath") == "black sabbath"
30 |
--------------------------------------------------------------------------------