├── .github ├── scripts │ └── python │ │ └── update_version.py └── workflows │ ├── publish-python.yaml │ └── run-tests.yaml ├── .gitignore ├── LICENSE ├── README.md ├── assets └── comparison.png ├── bm25s ├── __init__.py ├── hf.py ├── numba │ ├── __init__.py │ ├── retrieve_utils.py │ └── selection.py ├── scoring.py ├── selection.py ├── stopwords.py ├── tokenization.py ├── utils │ ├── __init__.py │ ├── beir.py │ ├── benchmark.py │ ├── corpus.py │ └── json_functions.py └── version.py ├── examples ├── evaluate_on_beir.py ├── index_and_retrieve_with_numba.py ├── index_nq.py ├── index_to_hf.py ├── index_with_metadata.py ├── nltk_stemmer.py ├── retrieve_from_hf.py ├── retrieve_nq.py ├── retrieve_nq_with_batching.py ├── retrieve_with_numba_advanced.py ├── retrieve_with_numba_hf.py ├── save_and_reload_end_to_end.py ├── tokenize_multiprocess.py └── tokenizer_class.py ├── setup.py └── tests ├── README.md ├── __init__.py ├── comparison ├── test_bm25_pt.py ├── test_bm25s_indexing.py ├── test_jsonl_corpus.py ├── test_rank_bm25.py ├── test_rank_bm25l.py ├── test_rank_bm25plus.py └── test_utils_corpus.py ├── comparison_full ├── test_bm25_pt.py └── test_rank_bm25.py ├── core ├── test_allow_empty.py ├── test_retrieve.py ├── test_save_load.py ├── test_tokenizer.py ├── test_tokenizer_misc.py ├── test_topk.py ├── test_utils_corpus.py └── test_vocab_dict.py ├── data └── nfcorpus.txt ├── numba ├── test_numba_backend_retrieve.py └── test_topk_numba.py ├── requirements-comparison.txt ├── requirements-core.txt └── stopwords └── test_stopwords.py /.github/scripts/python/update_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | This CLI script is used to update the version of the package. It is used by the 3 | CI/CD pipeline to update the version of the package when a new release is made. 4 | 5 | It uses argparse to parse the command line arguments, which are the new version 6 | and the path to the package's __init__.py file. 7 | """ 8 | 9 | import argparse 10 | from pathlib import Path 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser( 14 | description="Update the version of the package." 15 | ) 16 | parser.add_argument( 17 | "--version", 18 | type=str, 19 | help="The new version of the package.", 20 | required=True, 21 | ) 22 | parser.add_argument( 23 | "--path", 24 | type=Path, 25 | help="The path to the package's version file.", 26 | ) 27 | args = parser.parse_args() 28 | 29 | with open(args.path, "w") as f: 30 | f.write(f"__version__ = \"{args.version}\"") 31 | 32 | 33 | if __name__ == "__main__": 34 | main() -------------------------------------------------------------------------------- /.github/workflows/publish-python.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Publish Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | jobs: 11 | bump-version-and-publish: 12 | name: Bump version and upload release to PyPI 13 | 14 | runs-on: ubuntu-latest 15 | permissions: 16 | # IMPORTANT: this permission is mandatory for trusted publishing 17 | id-token: write 18 | 19 | environment: 20 | name: pypi 21 | url: https://pypi.org/p/bm25s 22 | 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up Python 26 | uses: actions/setup-python@v2 27 | with: 28 | python-version: '3.10' 29 | 30 | - name: Update version.py with release tag 31 | env: 32 | RELEASE_TAG: ${{ github.event.release.tag_name }} 33 | run: | 34 | python .github/scripts/python/update_version.py --version $RELEASE_TAG --path "bm25s/version.py" 35 | 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install setuptools wheel twine 40 | 41 | - name: Build package 42 | run: | 43 | python setup.py sdist bdist_wheel 44 | 45 | - name: Publish package distributions to PyPI 46 | uses: pypa/gh-action-pypi-publish@release/v1 47 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test-core: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Check out repository 15 | uses: actions/checkout@v2 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: '3.10' 21 | 22 | - name: Cache Python dependencies 23 | uses: actions/cache@v4 24 | with: 25 | path: ~/.cache/pip 26 | key: ${{ runner.os }}-pip-${{ hashFiles('tests/requirements-core.txt') }} 27 | restore-keys: | 28 | ${{ runner.os }}-pip- 29 | 30 | - name: Install core dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r tests/requirements-core.txt 34 | 35 | - name: Run core tests with dependencies 36 | run: | 37 | python -m unittest tests/core/test_*.py 38 | 39 | - name: Install Numba 40 | run: | 41 | pip install "numba>=0.60.0" 42 | 43 | - name: Run numba tests 44 | run: | 45 | python -m unittest tests/numba/test_*.py 46 | 47 | test-comparison: 48 | runs-on: ubuntu-latest 49 | 50 | steps: 51 | - name: Check out repository 52 | uses: actions/checkout@v2 53 | 54 | - name: Set up Python 55 | uses: actions/setup-python@v3 56 | with: 57 | python-version: '3.10' 58 | 59 | 60 | - name: Cache Python dependencies 61 | uses: actions/cache@v4 62 | with: 63 | path: ~/.cache/pip 64 | key: ${{ runner.os }}-pip-${{ hashFiles('tests/requirements-comparison.txt') }} 65 | restore-keys: | 66 | ${{ runner.os }}-pip- 67 | 68 | - name: Install comparison dependencies 69 | run: | 70 | python -m pip install --upgrade pip 71 | pip install -r tests/requirements-comparison.txt 72 | 73 | - name: Run comparison tests with dependencies 74 | run: | 75 | python -m unittest tests/comparison/test_*.py 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | venv-*/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | /.rss 164 | venv/ 165 | *.c 166 | *.so 167 | datasets/ 168 | __pycache__/ 169 | bm25s.prof 170 | tests/artifacts/ 171 | results 172 | *.whl 173 | conda-env-* 174 | 175 | # Ignore elasticsearch 176 | elasticsearch-*.tar.gz 177 | elasticsearch-*/ 178 | 179 | # inference benchmarks 180 | benchmark/inference/models/ 181 | comparison_results 182 | bm25s_indices/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Xing Han Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhluca/bm25s/f4dda5f2c9fece7329822c78eb21ef6dadcdff9b/assets/comparison.png -------------------------------------------------------------------------------- /bm25s/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | import tempfile 6 | from typing import Iterable, Union 7 | from . import BM25, __version__ 8 | from .tokenization import Tokenizer 9 | 10 | try: 11 | from huggingface_hub import HfApi 12 | except ImportError: 13 | raise ImportError( 14 | "Please install the huggingface_hub package to use the HuggingFace integrations for bm25s. You can install it via `pip install huggingface_hub`." 15 | ) 16 | 17 | README_TEMPLATE = """--- 18 | language: en 19 | library_name: bm25s 20 | tags: 21 | - bm25 22 | - bm25s 23 | - retrieval 24 | - search 25 | - lexical 26 | --- 27 | 28 | # BM25S Index 29 | 30 | This is a BM25S index created with the [`bm25s` library](https://github.com/xhluca/bm25s) (version `{version}`), an ultra-fast implementation of BM25. It can be used for lexical retrieval tasks. 31 | 32 | BM25S Related Links: 33 | 34 | * 🏠[Homepage](https://bm25s.github.io) 35 | * 💻[GitHub Repository](https://github.com/xhluca/bm25s) 36 | * 🤗[Blog Post](https://huggingface.co/blog/xhluca/bm25s) 37 | * 📝[Technical Report](https://arxiv.org/abs/2407.03618) 38 | 39 | 40 | ## Installation 41 | 42 | You can install the `bm25s` library with `pip`: 43 | 44 | ```bash 45 | pip install "bm25s=={version}" 46 | 47 | # Include extra dependencies like stemmer 48 | pip install "bm25s[full]=={version}" 49 | 50 | # For huggingface hub usage 51 | pip install huggingface_hub 52 | ``` 53 | 54 | ## Loading a `bm25s` index 55 | 56 | You can use this index for information retrieval tasks. Here is an example: 57 | 58 | ```python 59 | import bm25s 60 | from bm25s.hf import BM25HF 61 | 62 | # Load the index 63 | retriever = BM25HF.load_from_hub("{username}/{repo_name}") 64 | 65 | # You can retrieve now 66 | query = "a cat is a feline" 67 | results = retriever.retrieve(bm25s.tokenize(query), k=3) 68 | ``` 69 | 70 | ## Saving a `bm25s` index 71 | 72 | You can save a `bm25s` index to the Hugging Face Hub. Here is an example: 73 | 74 | ```python 75 | import bm25s 76 | from bm25s.hf import BM25HF 77 | 78 | corpus = [ 79 | "a cat is a feline and likes to purr", 80 | "a dog is the human's best friend and loves to play", 81 | "a bird is a beautiful animal that can fly", 82 | "a fish is a creature that lives in water and swims", 83 | ] 84 | 85 | retriever = BM25HF(corpus=corpus) 86 | retriever.index(bm25s.tokenize(corpus)) 87 | 88 | token = None # You can get a token from the Hugging Face website 89 | retriever.save_to_hub("{username}/{repo_name}", token=token) 90 | ``` 91 | 92 | ## Advanced usage 93 | 94 | You can leverage more advanced features of the BM25S library during `load_from_hub`: 95 | 96 | ```python 97 | # Load corpus and index in memory-map (mmap=True) to reduce memory 98 | retriever = BM25HF.load_from_hub("{username}/{repo_name}", load_corpus=True, mmap=True) 99 | 100 | # Load a different branch/revision 101 | retriever = BM25HF.load_from_hub("{username}/{repo_name}", revision="main") 102 | 103 | # Change directory where the local files should be downloaded 104 | retriever = BM25HF.load_from_hub("{username}/{repo_name}", local_dir="/path/to/dir") 105 | 106 | # Load private repositories with a token: 107 | retriever = BM25HF.load_from_hub("{username}/{repo_name}", token=token) 108 | ``` 109 | 110 | ## Tokenizer 111 | 112 | If you have saved a `Tokenizer` object with the index using the following approach: 113 | 114 | ```python 115 | from bm25s.hf import TokenizerHF 116 | 117 | token = "your_hugging_face_token" 118 | tokenizer = TokenizerHF(corpus=corpus, stopwords="english") 119 | tokenizer.save_to_hub("{username}/{repo_name}", token=token) 120 | 121 | # and stopwords too 122 | tokenizer.save_stopwords_to_hub("{username}/{repo_name}", token=token) 123 | ``` 124 | 125 | Then, you can load the tokenizer using the following code: 126 | 127 | ```python 128 | from bm25s.hf import TokenizerHF 129 | 130 | tokenizer = TokenizerHF(corpus=corpus, stopwords=[]) 131 | tokenizer.load_vocab_from_hub("{username}/{repo_name}", token=token) 132 | tokenizer.load_stopwords_from_hub("{username}/{repo_name}", token=token) 133 | ``` 134 | 135 | 136 | ## Stats 137 | 138 | This dataset was created using the following data: 139 | 140 | | Statistic | Value | 141 | | --- | --- | 142 | | Number of documents | {num_docs} | 143 | | Number of tokens | {num_tokens} | 144 | | Average tokens per document | {avg_tokens_per_doc} | 145 | 146 | ## Parameters 147 | 148 | The index was created with the following parameters: 149 | 150 | | Parameter | Value | 151 | | --- | --- | 152 | | k1 | `{k1}` | 153 | | b | `{b}` | 154 | | delta | `{delta}` | 155 | | method | `{method}` | 156 | | idf method | `{idf_method}` | 157 | 158 | ## Citation 159 | 160 | To cite `bm25s`, please use the following bibtex: 161 | 162 | ``` 163 | @misc{{lu_2024_bm25s, 164 | title={{BM25S: Orders of magnitude faster lexical search via eager sparse scoring}}, 165 | author={{Xing Han Lù}}, 166 | year={{2024}}, 167 | eprint={{2407.03618}}, 168 | archivePrefix={{arXiv}}, 169 | primaryClass={{cs.IR}}, 170 | url={{https://arxiv.org/abs/2407.03618}}, 171 | }} 172 | ``` 173 | 174 | """ 175 | 176 | 177 | def batch_tokenize(tokenizer, texts, add_special_tokens=False): 178 | from tqdm.auto import tqdm 179 | 180 | tokenizer_kwargs = dict( 181 | return_attention_mask=False, 182 | return_token_type_ids=False, 183 | add_special_tokens=add_special_tokens, 184 | max_length=None, 185 | ) 186 | tokenized = tokenizer(texts, **tokenizer_kwargs) 187 | output = [] 188 | 189 | for i in tqdm( 190 | range(len(texts)), desc="Processing tokens (huggingface tokenizer)", leave=False 191 | ): 192 | output.append(tokenized[i].tokens) 193 | 194 | return output 195 | 196 | 197 | def is_dir_empty(local_save_dir): 198 | """ 199 | Check if a directory is empty or not. 200 | 201 | Parameters 202 | ---------- 203 | local_save_dir: str 204 | The directory to check. 205 | 206 | Returns 207 | ------- 208 | bool 209 | True if the directory is empty, False otherwise. 210 | """ 211 | if not os.path.exists(local_save_dir): 212 | return True 213 | return len(os.listdir(local_save_dir)) == 0 214 | 215 | 216 | def can_save_locally(local_save_dir, overwrite_local: bool) -> bool: 217 | """ 218 | Check if it is possible to save the model to a local directory. 219 | 220 | Parameters 221 | ---------- 222 | local_save_dir: str 223 | The directory to save the model to. 224 | 225 | overwrite_local: bool 226 | Whether to overwrite the existing local directory if it exists. 227 | 228 | Returns 229 | ------- 230 | bool 231 | True if it is possible to save the model to the local directory, False otherwise. 232 | """ 233 | # if local_save_dir is None, we cannot save locally 234 | if local_save_dir is None: 235 | return False 236 | 237 | # if the directory is empty, we can save locally 238 | if is_dir_empty(local_save_dir): 239 | return True 240 | 241 | # if we are allowed to overwrite the directory, we can save locally 242 | if overwrite_local: 243 | return True 244 | 245 | 246 | class TokenizerHF(Tokenizer): 247 | def save_vocab_to_hub( 248 | self, 249 | repo_id: str, 250 | token: str = None, 251 | local_dir: str = None, 252 | commit_message: str = "Update tokenizer", 253 | overwrite_local: bool = False, 254 | private=True, 255 | **kwargs, 256 | ): 257 | """ 258 | This function saves the tokenizer's vocab to the Hugging Face Hub. 259 | 260 | Parameters 261 | ---------- 262 | repo_id: str 263 | The unique identifier of the repository to save the model to. 264 | The `repo_id` should be in the form of "username/repo_name". 265 | 266 | token: str 267 | The Hugging Face API token to use. 268 | 269 | local_dir: str 270 | The directory to save the model to before pushing to the Hub. 271 | If it is not empty and `overwrite_local` is False, it will fall 272 | back to saving to a temporary directory. 273 | 274 | commit_message: str 275 | The commit message to use when saving the model. 276 | 277 | overwrite_local: bool 278 | Whether to overwrite the existing local directory if it exists. 279 | 280 | kwargs: dict 281 | Additional keyword arguments to pass to `HfApi.upload_folder` call. 282 | """ 283 | api = HfApi(token=token) 284 | repo_url = api.create_repo( 285 | repo_id=repo_id, 286 | token=api.token, 287 | private=private, 288 | repo_type="model", 289 | exist_ok=True, 290 | ) 291 | repo_id = repo_url.repo_id 292 | 293 | saving_locally = can_save_locally(local_dir, overwrite_local) 294 | if saving_locally: 295 | os.makedirs(local_dir, exist_ok=True) 296 | save_dir = local_dir 297 | else: 298 | # save to a temporary directory otherwise 299 | save_dir = tempfile.mkdtemp() 300 | 301 | self.save_vocab(save_dir) 302 | # push content of the temporary directory to the repo 303 | api.upload_folder( 304 | repo_id=repo_id, 305 | commit_message=commit_message, 306 | token=api.token, 307 | folder_path=save_dir, 308 | repo_type=repo_url.repo_type, 309 | **kwargs, 310 | ) 311 | # delete the temporary directory if it was created 312 | if not saving_locally: 313 | shutil.rmtree(save_dir) 314 | 315 | return repo_url 316 | 317 | def load_vocab_from_hub( 318 | cls, 319 | repo_id: str, 320 | revision=None, 321 | token=None, 322 | local_dir=None, 323 | ): 324 | """ 325 | This function loads the tokenizer's vocab from the Hugging Face Hub. 326 | 327 | Parameters 328 | ---------- 329 | repo_id: str 330 | The unique identifier of the repository to load the model from. 331 | The `repo_id` should be in the form of "username/repo_name". 332 | 333 | revision: str 334 | The revision of the model to load. 335 | 336 | token: str 337 | The Hugging Face API token to use. 338 | 339 | local_dir: str 340 | The local dir where the model will be stored after downloading. 341 | 342 | allow_pickle: bool 343 | Whether to allow pickling the model. Default is False. 344 | """ 345 | api = HfApi(token=token) 346 | # check if the model exists 347 | repo_url = api.repo_info(repo_id) 348 | if repo_url is None: 349 | raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") 350 | 351 | snapshot = api.snapshot_download( 352 | repo_id=repo_id, revision=revision, token=token, local_dir=local_dir 353 | ) 354 | if snapshot is None: 355 | raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") 356 | 357 | return cls.load_vocab(save_dir=snapshot) 358 | 359 | def save_stopwords_to_hub( 360 | self, 361 | repo_id: str, 362 | token: str = None, 363 | local_dir: str = None, 364 | commit_message: str = "Update stopwords", 365 | overwrite_local: bool = False, 366 | private=True, 367 | **kwargs, 368 | ): 369 | """ 370 | This function saves the tokenizer's stopwords to the Hugging Face Hub. 371 | 372 | Parameters 373 | ---------- 374 | repo_id: str 375 | The unique identifier of the repository to save the model to. 376 | The `repo_id` should be in the form of "username/repo_name". 377 | 378 | token: str 379 | The Hugging Face API token to use. 380 | 381 | local_dir: str 382 | The directory to save the model to before pushing to the Hub. 383 | If it is not empty and `overwrite_local` is False, it will fall 384 | back to saving to a temporary directory. 385 | 386 | commit_message: str 387 | The commit message to use when saving the model. 388 | 389 | overwrite_local: bool 390 | Whether to overwrite the existing local directory if it exists. 391 | 392 | kwargs: dict 393 | Additional keyword arguments to pass to `HfApi.upload_folder` call. 394 | """ 395 | api = HfApi(token=token) 396 | repo_url = api.create_repo( 397 | repo_id=repo_id, 398 | token=api.token, 399 | private=private, 400 | repo_type="model", 401 | exist_ok=True, 402 | ) 403 | repo_id = repo_url.repo_id 404 | 405 | saving_locally = can_save_locally(local_dir, overwrite_local) 406 | if saving_locally: 407 | os.makedirs(local_dir, exist_ok=True) 408 | save_dir = local_dir 409 | else: 410 | # save to a temporary directory otherwise 411 | save_dir = tempfile.mkdtemp() 412 | 413 | self.save_stopwords(save_dir) 414 | # push content of the temporary directory to the repo 415 | api.upload_folder( 416 | repo_id=repo_id, 417 | commit_message=commit_message, 418 | token=api.token, 419 | folder_path=save_dir, 420 | repo_type=repo_url.repo_type, 421 | **kwargs, 422 | ) 423 | # delete the temporary directory if it was created 424 | if not saving_locally: 425 | shutil.rmtree(save_dir) 426 | 427 | return repo_url 428 | 429 | def load_stopwords_from_hub( 430 | self, 431 | repo_id: str, 432 | revision=None, 433 | token=None, 434 | local_dir=None, 435 | ): 436 | """ 437 | This function loads the tokenizer's stopwords from the Hugging Face Hub. 438 | 439 | Parameters 440 | ---------- 441 | repo_id: str 442 | The unique identifier of the repository to load the model from. 443 | The `repo_id` should be in the form of "username/repo_name". 444 | 445 | revision: str 446 | The revision of the model to load. 447 | 448 | token: str 449 | The Hugging Face API token to use. 450 | 451 | local_dir: str 452 | The local dir where the model will be stored after downloading. 453 | """ 454 | api = HfApi(token=token) 455 | # check if the model exists 456 | repo_url = api.repo_info(repo_id) 457 | if repo_url is None: 458 | raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") 459 | 460 | snapshot = api.snapshot_download( 461 | repo_id=repo_id, revision=revision, token=token, local_dir=local_dir 462 | ) 463 | if snapshot is None: 464 | raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") 465 | 466 | return self.load_stopwords(save_dir=snapshot) 467 | 468 | class BM25HF(BM25): 469 | def save_to_hub( 470 | self, 471 | repo_id: str, 472 | token: str = None, 473 | local_dir: str = None, 474 | corpus: Iterable[Union[str, dict, list, tuple]] = None, 475 | private=True, 476 | commit_message: str = "Update BM25S model", 477 | overwrite_local: bool = False, 478 | include_readme: bool = True, 479 | allow_pickle: bool = False, 480 | **kwargs, 481 | ): 482 | """ 483 | This function saves the BM25 model to the Hugging Face Hub. 484 | 485 | Parameters 486 | ---------- 487 | 488 | repo_id: str 489 | The name of the repository to save the model to. 490 | the `repo_id` should be in the form of "username/repo_name". 491 | 492 | token: str 493 | The Hugging Face API token to use. 494 | 495 | local_dir: str 496 | The directory to save the model to before pushing to the Hub. 497 | If it is not empty and `overwrite_local` is False, it will fall 498 | back to saving to a temporary directory. 499 | 500 | corpus: Iterable[str, dict, list, tuple] 501 | A corpus of documents to save with the model. If it is not None, 502 | the corpus will be saved to the repository, as a jsonl file. If it is 503 | a list of string, the dictionary will have a single key "text" with the 504 | value being the string. If it is a list of dictionaries, apply json.dumps 505 | to each dictionary before saving. 506 | 507 | private: bool 508 | Whether the repository should be private or not. Default is True. 509 | 510 | commit_message: str 511 | The commit message to use when saving the model. 512 | 513 | overwrite_local: bool 514 | Whether to overwrite the existing local directory if it exists. 515 | 516 | include_readme: bool 517 | Whether to include a default README file with the model. 518 | 519 | allow_pickle: bool 520 | Whether to allow pickling the model. Default is False. 521 | 522 | kwargs: dict 523 | Additional keyword arguments to pass to `HfApi.upload_folder` call. 524 | """ 525 | api = HfApi(token=token) 526 | repo_url = api.create_repo( 527 | repo_id=repo_id, 528 | token=api.token, 529 | private=private, 530 | repo_type="model", 531 | exist_ok=True, 532 | ) 533 | repo_id = repo_url.repo_id 534 | 535 | username, repo_name = repo_id.split("/", 1) 536 | 537 | saving_locally = can_save_locally(local_dir, overwrite_local) 538 | if saving_locally: 539 | os.makedirs(local_dir, exist_ok=True) 540 | save_dir = local_dir 541 | else: 542 | # save to a temporary directory otherwise 543 | save_dir = tempfile.mkdtemp() 544 | 545 | self.save(save_dir, corpus=corpus, allow_pickle=allow_pickle) 546 | # if we include the README, write it to the directory 547 | if include_readme: 548 | num_docs = self.scores["num_docs"] 549 | num_tokens = self.scores["data"].shape[0] 550 | avg_tokens_per_doc = round(num_tokens / num_docs, 2) 551 | 552 | results = README_TEMPLATE.format( 553 | username=username, 554 | version=__version__, 555 | repo_name=repo_name, 556 | num_docs=num_docs, 557 | num_tokens=num_tokens, 558 | avg_tokens_per_doc=avg_tokens_per_doc, 559 | k1=self.k1, 560 | b=self.b, 561 | delta=self.delta, 562 | method=self.method, 563 | idf_method=self.idf_method, 564 | ) 565 | 566 | with open(os.path.join(save_dir, "README.md"), "w") as f: 567 | f.write(results) 568 | 569 | # push content of the temporary directory to the repo 570 | api.upload_folder( 571 | repo_id=repo_id, 572 | commit_message=commit_message, 573 | token=api.token, 574 | folder_path=save_dir, 575 | repo_type=repo_url.repo_type, 576 | **kwargs, 577 | ) 578 | # delete the temporary directory if it was created 579 | if not saving_locally: 580 | shutil.rmtree(save_dir) 581 | 582 | return repo_url 583 | 584 | @classmethod 585 | def load_from_hub( 586 | cls, 587 | repo_name: str, 588 | revision=None, 589 | token=None, 590 | local_dir=None, 591 | load_corpus=False, 592 | mmap=False, 593 | allow_pickle=False, 594 | ): 595 | """ 596 | This function loads the BM25 model from the Hugging Face Hub. 597 | 598 | Parameters 599 | ---------- 600 | 601 | repo_name: str 602 | The name of the repository to load the model from. 603 | 604 | revision: str 605 | The revision of the model to load. 606 | 607 | token: str 608 | The Hugging Face API token to use. 609 | 610 | local_dir: str 611 | The local dir where the model will be stored after downloading. 612 | 613 | load_corpus: bool 614 | Whether to load the corpus of documents saved with the model, if present. 615 | 616 | mmap: bool 617 | Whether to memory-map the model. Default is False, which loads the index 618 | (and potentially corpus) into memory. 619 | 620 | allow_pickle: bool 621 | Whether to allow pickling the model. Default is False. 622 | """ 623 | api = HfApi(token=token) 624 | # check if the model exists 625 | repo_url = api.repo_info(repo_name) 626 | if repo_url is None: 627 | raise ValueError(f"Model {repo_name} not found on the Hugging Face Hub.") 628 | 629 | snapshot = api.snapshot_download( 630 | repo_name, revision=revision, token=token, local_dir=local_dir 631 | ) 632 | if snapshot is None: 633 | raise ValueError(f"Model {repo_name} not found on the Hugging Face Hub.") 634 | 635 | return cls.load( 636 | save_dir=snapshot, 637 | load_corpus=load_corpus, 638 | mmap=mmap, 639 | allow_pickle=allow_pickle, 640 | ) 641 | -------------------------------------------------------------------------------- /bm25s/numba/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhluca/bm25s/f4dda5f2c9fece7329822c78eb21ef6dadcdff9b/bm25s/numba/__init__.py -------------------------------------------------------------------------------- /bm25s/numba/retrieve_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from numba import njit, prange 3 | import numpy as np 4 | from typing import List, Tuple, Any 5 | import logging 6 | 7 | from .. import utils 8 | from ..scoring import _compute_relevance_from_scores_jit_ready 9 | from .selection import _numba_sorted_top_k 10 | 11 | _compute_relevance_from_scores_jit_ready = njit()(_compute_relevance_from_scores_jit_ready) 12 | 13 | @njit(parallel=True) 14 | def _retrieve_internal_jitted_parallel( 15 | query_tokens_ids_flat: np.ndarray, 16 | query_pointers: np.ndarray, 17 | k: int, 18 | sorted: bool, 19 | dtype: np.dtype, 20 | int_dtype: np.dtype, 21 | data: np.ndarray, 22 | indptr: np.ndarray, 23 | indices: np.ndarray, 24 | num_docs: int, 25 | nonoccurrence_array: np.ndarray = None, 26 | weight_mask: np.ndarray = None, 27 | ): 28 | N = len(query_pointers) - 1 29 | 30 | topk_scores = np.zeros((N, k), dtype=dtype) 31 | topk_indices = np.zeros((N, k), dtype=int_dtype) 32 | 33 | for i in prange(N): 34 | query_tokens_single = query_tokens_ids_flat[query_pointers[i] : query_pointers[i + 1]] 35 | 36 | # query_tokens_single = np.asarray(query_tokens_single, dtype=int_dtype) 37 | scores_single = _compute_relevance_from_scores_jit_ready( 38 | query_tokens_ids=query_tokens_single, 39 | data=data, 40 | indptr=indptr, 41 | indices=indices, 42 | num_docs=num_docs, 43 | dtype=dtype, 44 | ) 45 | 46 | # if there's a non-occurrence array, we need to add the non-occurrence score 47 | # back to the scores 48 | if nonoccurrence_array is not None: 49 | nonoccurrence_scores = nonoccurrence_array[query_tokens_single].sum() 50 | scores_single += nonoccurrence_scores 51 | 52 | if weight_mask is not None: 53 | scores_single = scores_single * weight_mask 54 | 55 | topk_scores_sing, topk_indices_sing = _numba_sorted_top_k( 56 | scores_single, k=k, sorted=sorted 57 | ) 58 | topk_scores[i] = topk_scores_sing 59 | topk_indices[i] = topk_indices_sing 60 | 61 | return topk_scores, topk_indices 62 | 63 | 64 | def _retrieve_numba_functional( 65 | query_tokens_ids, 66 | scores, 67 | corpus: List[Any] = None, 68 | k: int = 10, 69 | sorted: bool = True, 70 | return_as: str = "tuple", 71 | show_progress: bool = True, 72 | leave_progress: bool = False, 73 | n_threads: int = 0, 74 | chunksize: int = None, 75 | nonoccurrence_array=None, 76 | backend_selection="numba", 77 | dtype="float32", 78 | int_dtype="int32", 79 | weight_mask=None, 80 | ): 81 | from numba import get_num_threads, set_num_threads, njit 82 | 83 | 84 | if backend_selection != "numba": 85 | error_msg = "The `numba` backend must be selected when retrieving using the numba backend. Please choose a different backend or change the backend_selection parameter to numba." 86 | raise ValueError(error_msg) 87 | 88 | if chunksize != None: 89 | # warn the user that the chunksize parameter is ignored 90 | logging.warning( 91 | "The `chunksize` parameter is ignored in the `retrieve` function when using the `numba` backend." 92 | "The function will automatically determine the best chunksize." 93 | ) 94 | 95 | allowed_return_as = ["tuple", "documents"] 96 | 97 | if return_as not in allowed_return_as: 98 | raise ValueError("`return_as` must be either 'tuple' or 'documents'") 99 | else: 100 | pass 101 | 102 | if n_threads == -1: 103 | n_threads = os.cpu_count() 104 | elif n_threads == 0: 105 | n_threads = 1 106 | 107 | # get og thread count 108 | og_n_threads = get_num_threads() 109 | set_num_threads(n_threads) 110 | 111 | 112 | # convert query_tokens_ids from list of list to a flat 1-d np.ndarray with 113 | # pointers to the start of each query to be used to find the boundaries of each query 114 | query_pointers = np.cumsum([0] + [len(q) for q in query_tokens_ids], dtype=int_dtype) 115 | query_tokens_ids_flat = np.concatenate(query_tokens_ids).astype(int_dtype) 116 | 117 | retrieved_scores, retrieved_indices = _retrieve_internal_jitted_parallel( 118 | query_pointers=query_pointers, 119 | query_tokens_ids_flat=query_tokens_ids_flat, 120 | k=k, 121 | sorted=sorted, 122 | dtype=np.dtype(dtype), 123 | int_dtype=np.dtype(int_dtype), 124 | data=scores["data"], 125 | indptr=scores["indptr"], 126 | indices=scores["indices"], 127 | num_docs=scores["num_docs"], 128 | nonoccurrence_array=nonoccurrence_array, 129 | weight_mask=weight_mask, 130 | ) 131 | 132 | # reset the number of threads 133 | set_num_threads(og_n_threads) 134 | 135 | if corpus is None: 136 | retrieved_docs = retrieved_indices 137 | else: 138 | # if it is a JsonlCorpus object, we do not need to convert it to a list 139 | if isinstance(corpus, utils.corpus.JsonlCorpus): 140 | retrieved_docs = corpus[retrieved_indices] 141 | elif isinstance(corpus, np.ndarray) and corpus.ndim == 1: 142 | retrieved_docs = corpus[retrieved_indices] 143 | else: 144 | index_flat = retrieved_indices.flatten().tolist() 145 | results = [corpus[i] for i in index_flat] 146 | retrieved_docs = np.array(results).reshape(retrieved_indices.shape) 147 | 148 | if return_as == "tuple": 149 | return retrieved_docs, retrieved_scores 150 | elif return_as == "documents": 151 | return retrieved_docs 152 | else: 153 | raise ValueError("`return_as` must be either 'tuple' or 'documents'") 154 | -------------------------------------------------------------------------------- /bm25s/numba/selection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Acknowledgement: 3 | numba_unsorted_top_k is taken from retriv. The original code can be found at: 4 | https://github.com/AmenRa/retriv/blob/v0.2.1/retriv/utils/numba_utils.py 5 | 6 | numba_sorted_top_k was created based on numba_unsorted_top_k, but modified to use a heap to keep track of the top-k values. 7 | """ 8 | 9 | import numpy as np 10 | from numba import njit 11 | 12 | 13 | @njit() 14 | def _numba_unsorted_top_k_legacy(array: np.ndarray, k: int): 15 | top_k_values = np.zeros(k, dtype=np.float32) 16 | top_k_indices = np.zeros(k, dtype=np.int32) 17 | 18 | min_value = 0.0 19 | min_value_idx = 0 20 | 21 | for i, value in enumerate(array): 22 | if value > min_value: 23 | top_k_values[min_value_idx] = value 24 | top_k_indices[min_value_idx] = i 25 | min_value_idx = top_k_values.argmin() 26 | min_value = top_k_values[min_value_idx] 27 | 28 | return top_k_values, top_k_indices 29 | 30 | 31 | @njit() 32 | def sift_down(values, indices, startpos, pos): 33 | new_value = values[pos] 34 | new_index = indices[pos] 35 | while pos > startpos: 36 | parentpos = (pos - 1) >> 1 37 | parent_value = values[parentpos] 38 | if new_value < parent_value: 39 | values[pos] = parent_value 40 | indices[pos] = indices[parentpos] 41 | pos = parentpos 42 | continue 43 | break 44 | values[pos] = new_value 45 | indices[pos] = new_index 46 | 47 | 48 | @njit() 49 | def sift_up(values, indices, pos, length): 50 | startpos = pos 51 | new_value = values[pos] 52 | new_index = indices[pos] 53 | childpos = 2 * pos + 1 54 | while childpos < length: 55 | rightpos = childpos + 1 56 | if rightpos < length and values[rightpos] < values[childpos]: 57 | childpos = rightpos 58 | values[pos] = values[childpos] 59 | indices[pos] = indices[childpos] 60 | pos = childpos 61 | childpos = 2 * pos + 1 62 | values[pos] = new_value 63 | indices[pos] = new_index 64 | sift_down(values, indices, startpos, pos) 65 | 66 | 67 | @njit() 68 | def heap_push(values, indices, value, index, length): 69 | values[length] = value 70 | indices[length] = index 71 | sift_down(values, indices, 0, length) 72 | 73 | 74 | @njit() 75 | def heap_pop(values, indices, length): 76 | return_value = values[0] 77 | return_index = indices[0] 78 | last_value = values[length - 1] 79 | last_index = indices[length - 1] 80 | values[0] = last_value 81 | indices[0] = last_index 82 | sift_up(values, indices, 0, length - 1) 83 | return return_value, return_index 84 | 85 | 86 | @njit() 87 | def _numba_sorted_top_k(array: np.ndarray, k: int, sorted=True): 88 | n = len(array) 89 | if k > n: 90 | k = n 91 | 92 | values = np.zeros(k, dtype=array.dtype) # aka scores 93 | indices = np.zeros(k, dtype=np.int32) 94 | length = 0 95 | 96 | for i, value in enumerate(array): 97 | if length < k: 98 | heap_push(values, indices, value, i, length) 99 | length += 1 100 | else: 101 | if value > values[0]: 102 | values[0] = value 103 | indices[0] = i 104 | sift_up(values, indices, 0, length) 105 | 106 | if sorted: 107 | # # This is the original code for sorting, we can skip it and return the values and indices 108 | # # to let numpy handle the sorting 109 | # top_k_values = np.zeros(k, dtype=array.dtype) 110 | # top_k_indices = np.zeros(k, dtype=np.int32) 111 | 112 | # for i in range(k - 1, -1, -1): 113 | # top_k_values[i], top_k_indices[i] = heap_pop(values, indices, length) 114 | # length -= 1 115 | # values = top_k_values 116 | # indices = top_k_indices 117 | 118 | # This is the new code that uses numpy to sort the values and indices instead of 119 | # using the heap to sort them. 120 | sorted_indices = np.flip(np.argsort(values)) 121 | indices = indices[sorted_indices] 122 | values = values[sorted_indices] 123 | 124 | return values, indices 125 | 126 | 127 | def topk(query_scores, k, backend="numba", sorted=True): 128 | """ 129 | This function is used to retrieve the top-k results for a single query. It will only work 130 | on a 1-dimensional array of scores. 131 | """ 132 | if backend not in ["numba"]: 133 | raise ValueError( 134 | "Invalid backend. Only 'numba' is supported." 135 | ) 136 | elif backend == "numba": 137 | uns_scores, uns_indices = _numba_sorted_top_k(query_scores, k) 138 | if sorted: 139 | sorted_inds = np.flip(np.argsort(uns_scores)) 140 | query_inds = uns_indices[sorted_inds] 141 | query_scores = uns_scores[sorted_inds] 142 | else: 143 | query_inds = uns_indices 144 | query_scores = uns_scores 145 | 146 | return query_scores, query_inds 147 | 148 | else: 149 | raise ValueError("Invalid backend. Only 'numba' is supported.") 150 | -------------------------------------------------------------------------------- /bm25s/scoring.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import math 3 | 4 | import numpy as np 5 | 6 | try: 7 | from tqdm.auto import tqdm 8 | except ImportError: 9 | 10 | def tqdm(iterable, *args, **kwargs): 11 | return iterable 12 | 13 | 14 | def _calculate_doc_freqs( 15 | corpus_tokens, unique_tokens, show_progress=True, leave_progress=False 16 | ) -> dict: 17 | """ 18 | Document Frequency, aka DF, is the number of documents that contain a specific token. 19 | This function return a dictionary with the document frequency of each token, which is 20 | why it is called `doc_frequencies`. 21 | """ 22 | unique_tokens = set(unique_tokens) 23 | 24 | # Now that we have all the unique tokens, we can count the number of 25 | # documents that contain each token 26 | doc_frequencies = {token: 0 for token in unique_tokens} 27 | 28 | for doc_tokens in tqdm( 29 | corpus_tokens, 30 | leave=leave_progress, 31 | disable=not show_progress, 32 | desc="BM25S Count Tokens", 33 | ): 34 | 35 | # get intersection of unique tokens and the tokens in the document 36 | shared_tokens = unique_tokens.intersection(doc_tokens) 37 | 38 | # for each token in the document, we increment the count of documents 39 | # This is a simple way to count the number of documents that contain each token 40 | for token in shared_tokens: 41 | doc_frequencies[token] += 1 42 | 43 | return doc_frequencies 44 | 45 | 46 | def _build_idf_array( 47 | doc_frequencies: dict, 48 | n_docs: int, 49 | compute_idf_fn: callable = None, 50 | dtype="float32", 51 | ) -> np.ndarray: 52 | n_vocab = len(doc_frequencies) 53 | idf_array = np.zeros(n_vocab, dtype=dtype) 54 | 55 | for token_id, df in doc_frequencies.items(): 56 | if df != 0: 57 | idf_array[token_id] = compute_idf_fn(df, N=n_docs) 58 | 59 | return idf_array 60 | 61 | 62 | def _build_nonoccurrence_array( 63 | doc_frequencies: dict, 64 | n_docs: int, 65 | compute_idf_fn: callable, 66 | calculate_tfc_fn: callable, 67 | l_d, 68 | l_avg, 69 | k1, 70 | b, 71 | delta, 72 | dtype="float32", 73 | ) -> np.ndarray: 74 | """ 75 | The non-occurrence array is used to store the idf score for tokens that do not occur in the 76 | document. This is useful for BM25L and BM25+ variants, where we need to calculate the idf 77 | score for tokens that do not occur in the document, which will be used to calculate the 78 | final score. 79 | 80 | The nonoccurence array has length |V|, where V is the set of unique tokens in the corpus. 81 | 82 | The `compute_idf_fn` is the function to calculate the idf score for a token that does not occur 83 | in the document. The `calculate_tfc_fn` is the function to calculate the term frequency component 84 | of the BM25 score, which is used to calculate the final score for tokens that do not occur in the 85 | document. 86 | """ 87 | n_vocab = len(doc_frequencies) 88 | nonoccurrence_array = np.zeros(n_vocab, dtype=dtype) 89 | 90 | for token_id, df in doc_frequencies.items(): 91 | if df != 0: 92 | idf = compute_idf_fn(df, N=n_docs) 93 | tfc = calculate_tfc_fn( 94 | tf_array=0, l_d=l_d, l_avg=l_avg, k1=k1, b=b, delta=delta 95 | ) 96 | nonoccurrence_array[token_id] = idf * tfc 97 | 98 | return nonoccurrence_array 99 | 100 | 101 | def _score_tfc_robertson(tf_array, l_d, l_avg, k1, b, delta=None): 102 | """ 103 | Computes the term frequency component of the BM25 score using Robertson+ (original) variant 104 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 105 | """ 106 | # idf component is given by the idf_array 107 | # we calculate the term-frequency component (tfc) 108 | return tf_array / (k1 * ((1 - b) + b * l_d / l_avg) + tf_array) 109 | 110 | 111 | def _score_tfc_lucene(tf_array, l_d, l_avg, k1, b, delta=None): 112 | """ 113 | Computes the term frequency component of the BM25 score using Lucene variant (accurate) 114 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 115 | """ 116 | return _score_tfc_robertson(tf_array, l_d, l_avg, k1, b) 117 | 118 | 119 | def _score_tfc_atire(tf_array, l_d, l_avg, k1, b, delta=None): 120 | """ 121 | Computes the term frequency component of the BM25 score using ATIRE variant 122 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 123 | """ 124 | # idf component is given by the idf_array 125 | # we calculate the term-frequency component (tfc) 126 | return (tf_array * (k1 + 1)) / (tf_array + k1 * (1 - b + b * l_d / l_avg)) 127 | 128 | 129 | def _score_tfc_bm25l(tf_array, l_d, l_avg, k1, b, delta): 130 | """ 131 | Computes the term frequency component of the BM25 score using BM25L variant 132 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 133 | """ 134 | c_array = tf_array / (1 - b + b * l_d / l_avg) 135 | return ((k1 + 1) * (c_array + delta)) / (k1 + c_array + delta) 136 | 137 | 138 | def _score_tfc_bm25plus(tf_array, l_d, l_avg, k1, b, delta): 139 | """ 140 | Computes the term frequency component of the BM25 score using BM25+ variant 141 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 142 | """ 143 | num = (k1 + 1) * tf_array 144 | den = k1 * (1 - b + b * l_d / l_avg) + tf_array 145 | return (num / den) + delta 146 | 147 | 148 | def _select_tfc_scorer(method) -> callable: 149 | if method == "robertson": 150 | return _score_tfc_robertson 151 | elif method == "lucene": 152 | return _score_tfc_lucene 153 | elif method == "atire": 154 | return _score_tfc_atire 155 | elif method == "bm25l": 156 | return _score_tfc_bm25l 157 | elif method == "bm25+": 158 | return _score_tfc_bm25plus 159 | else: 160 | error_msg = f"Invalid score_tfc value: {method}. Choose from 'robertson', 'lucene', 'atire'." 161 | raise ValueError(error_msg) 162 | 163 | 164 | def _score_idf_robertson(df, N, allow_negative=False): 165 | """ 166 | Computes the inverse document frequency component of the BM25 score using Robertson+ (original) variant 167 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 168 | """ 169 | inner = (N - df + 0.5) / (df + 0.5) 170 | if not allow_negative and inner < 1: 171 | inner = 1 172 | 173 | return math.log(inner) 174 | 175 | 176 | def _score_idf_lucene(df, N): 177 | """ 178 | Computes the inverse document frequency component of the BM25 score using Lucene variant (accurate) 179 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 180 | """ 181 | return math.log(1 + (N - df + 0.5) / (df + 0.5)) 182 | 183 | 184 | def _score_idf_atire(df, N): 185 | """ 186 | Computes the inverse document frequency component of the BM25 score using ATIRE variant 187 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 188 | """ 189 | return math.log(N / df) 190 | 191 | 192 | def _score_idf_bm25l(df, N): 193 | """ 194 | Computes the inverse document frequency component of the BM25 score using BM25L variant 195 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 196 | """ 197 | return math.log((N + 1) / (df + 0.5)) 198 | 199 | 200 | def _score_idf_bm25plus(df, N): 201 | """ 202 | Computes the inverse document frequency component of the BM25 score using BM25+ variant 203 | Implementation: https://cs.uwaterloo.ca/~jimmylin/publications/Kamphuis_etal_ECIR2020_preprint.pdf 204 | """ 205 | return math.log((N + 1) / df) 206 | 207 | 208 | def _select_idf_scorer(method) -> callable: 209 | if method == "robertson": 210 | return _score_idf_robertson 211 | elif method == "lucene": 212 | return _score_idf_lucene 213 | elif method == "atire": 214 | return _score_idf_atire 215 | elif method == "bm25l": 216 | return _score_idf_bm25l 217 | elif method == "bm25+": 218 | return _score_idf_bm25plus 219 | else: 220 | error_msg = f"Invalid score_idf_inner value: {method}. Choose from 'robertson', 'lucene', 'atire', 'bm25l', 'bm25+'." 221 | raise ValueError(error_msg) 222 | 223 | 224 | def _get_counts_from_token_ids(token_ids, dtype, int_dtype): 225 | token_counter = Counter(token_ids) 226 | voc_ind = np.array(list(token_counter.keys()), dtype=int_dtype) 227 | tf_array = np.array(list(token_counter.values()), dtype=dtype) 228 | 229 | return voc_ind, tf_array 230 | 231 | 232 | def _build_scores_and_indices_for_matrix( 233 | corpus_token_ids, 234 | idf_array, 235 | avg_doc_len, 236 | doc_frequencies, 237 | k1, 238 | b, 239 | delta, 240 | nonoccurrence_array, 241 | method="robertson", 242 | dtype="float32", 243 | int_dtype="int32", 244 | show_progress=True, 245 | leave_progress=False, 246 | ): 247 | array_size = sum(doc_frequencies.values()) 248 | 249 | # We create 3 arrays to store the scores, document indices, and vocabulary indices 250 | # The length is at most n_tokens, remaining elements will be truncated at the end 251 | scores = np.empty(array_size, dtype=dtype) 252 | doc_indices = np.empty(array_size, dtype=int_dtype) 253 | voc_indices = np.empty(array_size, dtype=int_dtype) 254 | 255 | calculate_tfc = _select_tfc_scorer(method) 256 | 257 | i = 0 258 | for doc_idx, token_ids in enumerate( 259 | tqdm( 260 | corpus_token_ids, 261 | desc="BM25S Compute Scores", 262 | disable=not show_progress, 263 | leave=leave_progress, 264 | ) 265 | ): 266 | doc_len = len(token_ids) 267 | 268 | # Get the term frequency array for the document 269 | # Note: tokens might contain duplicates, we use Counter to get the term freq 270 | voc_ind_doc, tf_array = _get_counts_from_token_ids( 271 | token_ids, dtype=dtype, int_dtype=int_dtype 272 | ) 273 | 274 | # Calculate the BM25 score for each token in the document 275 | tfc = calculate_tfc( 276 | tf_array=tf_array, l_d=doc_len, l_avg=avg_doc_len, k1=k1, b=b, delta=delta 277 | ) 278 | idf = idf_array[voc_ind_doc] 279 | scores_doc = idf * tfc 280 | 281 | # If the method is uses a non-occurrence score array, then we need to subtract 282 | # the non-occurrence score from the scores 283 | if method in ("bm25l", "bm25+"): 284 | scores_doc -= nonoccurrence_array[voc_ind_doc] 285 | 286 | # Update the arrays with the new scores, document indices, and vocabulary indices 287 | doc_len = len(scores_doc) 288 | start, end = i, i + doc_len 289 | i = end 290 | 291 | doc_indices[start:end] = doc_idx 292 | voc_indices[start:end] = voc_ind_doc 293 | scores[start:end] = scores_doc 294 | 295 | return scores, doc_indices, voc_indices 296 | 297 | 298 | def _compute_relevance_from_scores_legacy( 299 | data, indptr, indices, num_docs, query_tokens_ids, dtype 300 | ): 301 | """ 302 | The legacy implementation of the `_compute_relevance_from_scores` function. This may 303 | be faster than the new implementation for some cases, but it cannot benefit from 304 | numba acceleration, as it uses python lists. This function is kept for reference 305 | and comparison purposes. 306 | """ 307 | # First, we use the query_token_ids to select the relevant columns from the score_matrix 308 | query_tokens_ids = np.array(query_tokens_ids, dtype=int) 309 | indptr_starts = indptr[query_tokens_ids] 310 | indptr_ends = indptr[query_tokens_ids + 1] 311 | 312 | scores_lists = [] 313 | indices_lists = [] 314 | 315 | for i, (start, end) in enumerate(zip(indptr_starts, indptr_ends)): 316 | scores_lists.append(data[start:end]) 317 | indices_lists.append(indices[start:end]) 318 | 319 | # combine the lists into a single array 320 | 321 | scores = np.zeros(num_docs, dtype=dtype) 322 | if len(scores_lists) == 0: 323 | return scores 324 | 325 | scores_flat = np.concatenate(scores_lists) 326 | indices_flat = np.concatenate(indices_lists) 327 | np.add.at(scores, indices_flat, scores_flat) 328 | 329 | return scores 330 | 331 | def _compute_relevance_from_scores_jit_ready( 332 | data: np.ndarray, 333 | indptr: np.ndarray, 334 | indices: np.ndarray, 335 | num_docs: int, 336 | query_tokens_ids: np.ndarray, 337 | dtype: np.dtype, 338 | ) -> np.ndarray: 339 | """ 340 | This internal static function calculates the relevance scores for a given query, 341 | by using the BM25 scores that have been precomputed in the BM25 eager index. 342 | This version is ready for JIT compilation with numba, but is slow if not compiled. 343 | """ 344 | indptr_starts = indptr[query_tokens_ids] 345 | indptr_ends = indptr[query_tokens_ids + 1] 346 | 347 | scores = np.zeros(num_docs, dtype=dtype) 348 | for i in range(len(query_tokens_ids)): 349 | start, end = indptr_starts[i], indptr_ends[i] 350 | # The following code is slower with numpy, but faster after JIT compilation 351 | for j in range(start, end): 352 | scores[indices[j]] += data[j] 353 | 354 | return scores -------------------------------------------------------------------------------- /bm25s/selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | try: 4 | import jax.lax 5 | except ImportError: 6 | JAX_IS_AVAILABLE = False 7 | else: 8 | JAX_IS_AVAILABLE = True 9 | # if JAX is available, we need to initialize it with a dummy scores and capture 10 | # any output to avoid it from saying that gpu is not available 11 | _ = jax.lax.top_k(np.array([0] * 5), 1) 12 | 13 | 14 | def _topk_numpy(query_scores, k, sorted): 15 | # https://stackoverflow.com/questions/65038206/how-to-get-indices-of-top-k-values-from-a-numpy-array 16 | # np.argpartition is faster than np.argsort, but do not return the values in order 17 | partitioned_ind = np.argpartition(query_scores, -k) 18 | # Since lit's a single query, we can take the last k elements 19 | partitioned_ind = partitioned_ind.take(indices=range(-k, 0)) 20 | # We use the newly selected indices to find the score of the top-k values 21 | partitioned_scores = np.take(query_scores, partitioned_ind) 22 | 23 | if sorted: 24 | # Since our top-k indices are not correctly ordered, we can sort them with argsort 25 | # only if sorted=True (otherwise we keep it in an arbitrary order) 26 | sorted_trunc_ind = np.flip(np.argsort(partitioned_scores)) 27 | 28 | # We again use np.take_along_axis as we have an array of indices that we use to 29 | # decide which values to select 30 | ind = partitioned_ind[sorted_trunc_ind] 31 | query_scores = partitioned_scores[sorted_trunc_ind] 32 | 33 | else: 34 | ind = partitioned_ind 35 | query_scores = partitioned_scores 36 | 37 | return query_scores, ind 38 | 39 | 40 | def _topk_jax(query_scores, k): 41 | topk_scores, topk_indices = jax.lax.top_k(query_scores, k) 42 | topk_scores = np.asarray(topk_scores) 43 | topk_indices = np.asarray(topk_indices) 44 | 45 | return topk_scores, topk_indices 46 | 47 | 48 | def topk(query_scores, k, backend="auto", sorted=True): 49 | """ 50 | This function is used to retrieve the top-k results for a single query. It will only work 51 | on a 1-dimensional array of scores. 52 | """ 53 | if backend == "auto": 54 | # if jax.lax is available, use it to speed up selection, otherwise use numpy 55 | backend = "jax" if JAX_IS_AVAILABLE else "numpy" 56 | 57 | if backend not in ["numpy", "jax"]: 58 | raise ValueError("Invalid backend. Please choose from 'numpy' or 'jax'.") 59 | elif backend == "jax": 60 | if not JAX_IS_AVAILABLE: 61 | raise ImportError("JAX is not available. Please install JAX with `pip install jax[cpu]` to use this backend.") 62 | return _topk_jax(query_scores, k) 63 | else: 64 | return _topk_numpy(query_scores, k, sorted) 65 | -------------------------------------------------------------------------------- /bm25s/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import benchmark, beir, corpus, json_functions -------------------------------------------------------------------------------- /bm25s/utils/beir.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from pathlib import Path 4 | from typing import Dict, List, Tuple 5 | 6 | try: 7 | from tqdm.auto import tqdm 8 | except ImportError: 9 | 10 | def tqdm(iterable, *args, **kwargs): 11 | return iterable 12 | 13 | 14 | from . import json_functions 15 | 16 | BASE_URL = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip" 17 | GH_URL = "https://github.com/xhluca/bm25s/releases/download/data/{}.zip" 18 | 19 | 20 | def clean_results_keys(beir_results): 21 | return {k.split("@")[-1]: v for k, v in beir_results.items()} 22 | 23 | 24 | def postprocess_results_for_eval(results, scores, query_ids): 25 | """ 26 | Given the queried results and scores output by BM25S, postprocess them 27 | to be compatible with BEIR evaluation functions. 28 | query_ids is a list of query ids in the same order as the results. 29 | """ 30 | 31 | results_record = [ 32 | {"id": qid, "hits": results[i], "scores": list(scores[i])} 33 | for i, qid in enumerate(query_ids) 34 | ] 35 | 36 | result_dict_for_eval = { 37 | res["id"]: { 38 | docid: float(score) for docid, score in zip(res["hits"], res["scores"]) 39 | } 40 | for res in results_record 41 | } 42 | 43 | return result_dict_for_eval 44 | 45 | 46 | def merge_cqa_dupstack(data_path): 47 | data_path = Path(data_path) 48 | dataset = data_path.name 49 | assert dataset == "cqadupstack", "Dataset must be CQADupStack" 50 | 51 | # check if corpus.jsonl exists 52 | corpus_path = data_path / "corpus.jsonl" 53 | if not corpus_path.exists(): 54 | # combine all the corpus files into one 55 | # corpus files are located under cqadupstack//corpus.jsonl 56 | corpus_files = list(data_path.glob("*/corpus.jsonl")) 57 | with open(corpus_path, "w") as f: 58 | for file in tqdm(corpus_files, desc="Merging Corpus", leave=False): 59 | # get the name of the corpus 60 | corpus_name = file.parent.name 61 | 62 | with open(file, "r") as f2: 63 | for line in tqdm( 64 | f2, desc=f"Merging {corpus_name} Corpus", leave=False 65 | ): 66 | line = json_functions.loads(line) 67 | # add the corpus name to _id 68 | line["_id"] = f"{corpus_name}_{line['_id']}" 69 | # write back to file 70 | f.write(json.dumps(line)) # json_functions.dumps generates json that can't be read by beir 71 | f.write("\n") 72 | 73 | # now, do the same for queries.jsonl 74 | queries_path = data_path / "queries.jsonl" 75 | if not queries_path.exists(): 76 | queries_files = list(data_path.glob("*/queries.jsonl")) 77 | with open(queries_path, "w") as f: 78 | for file in tqdm(queries_files, desc="Merging Queries", leave=False): 79 | # get the name of the corpus 80 | corpus_name = file.parent.name 81 | 82 | with open(file, "r") as f2: 83 | for line in tqdm( 84 | f2, desc=f"Merging {corpus_name} Queries", leave=False 85 | ): 86 | line = json_functions.loads(line) 87 | # add the corpus name to _id 88 | line["_id"] = f"{corpus_name}_{line['_id']}" 89 | # write back to file 90 | f.write(json_functions.dumps(line)) 91 | f.write("\n") 92 | 93 | # now, do the same for qrels/test.tsv 94 | qrels_path = data_path / "qrels" / "test.tsv" 95 | qrels_path.parent.mkdir(parents=True, exist_ok=True) 96 | 97 | if not qrels_path.exists(): 98 | qrels_files = list(data_path.glob("*/qrels/test.tsv")) 99 | with open(qrels_path, "w") as f: 100 | # First, write the columns: query-id corpus-id score 101 | f.write("query-id\tcorpus-id\tscore\n") 102 | for file in tqdm(qrels_files, desc="Merging Qrels", leave=False): 103 | # get the name of the corpus 104 | corpus_name = file.parent.parent.name 105 | with open(file, "r") as f2: 106 | # skip first line 107 | next(f2) 108 | 109 | for line in tqdm( 110 | f2, desc=f"Merging {corpus_name} Qrels", leave=False 111 | ): 112 | # since it's a tsv, split by tab 113 | qid, cid, score = line.strip().split("\t") 114 | # add the corpus name to _id 115 | qid = f"{corpus_name}_{qid}" 116 | cid = f"{corpus_name}_{cid}" 117 | # write back to file 118 | f.write(f"{qid}\t{cid}\t{score}\n") 119 | 120 | 121 | def download_dataset( 122 | dataset, 123 | base_url=GH_URL, 124 | save_dir="./datasets", 125 | unzip=True, 126 | redownload=False, 127 | show_progress=True, 128 | ): 129 | import urllib.request 130 | import zipfile 131 | from pathlib import Path 132 | from tqdm.auto import tqdm 133 | 134 | save_dir = Path(save_dir) 135 | save_dir.mkdir(parents=True, exist_ok=True) 136 | 137 | url = base_url.format(dataset) 138 | # check if zip file already exist 139 | save_zip_path = save_dir / "archive" / f"{dataset}.zip" 140 | save_zip_path.parent.mkdir(parents=True, exist_ok=True) 141 | 142 | if not save_zip_path.exists() or redownload: 143 | # download the zip file and save it with tqdm progress bar 144 | pbar = tqdm( 145 | unit="B", 146 | unit_scale=True, 147 | desc=f"Downloading {dataset}", 148 | leave=False, 149 | disable=not show_progress, 150 | ) 151 | with open(save_zip_path, "wb") as f: 152 | response = urllib.request.urlopen(url) 153 | total_size = int(response.headers.get("content-length", 0)) 154 | block_size = 8192 * 2 155 | # set the tqdm total to the total size 156 | pbar.total = total_size 157 | while True: 158 | buffer = response.read(block_size) 159 | if not buffer: 160 | break 161 | f.write(buffer) 162 | pbar.update(len(buffer)) 163 | 164 | pbar.close() 165 | 166 | # now that we have the zip file, extract it 167 | if unzip: 168 | with zipfile.ZipFile(save_zip_path, "r") as zip_ref: 169 | zip_ref.extractall(save_dir) 170 | 171 | # if it's CQADupStack, merge the corpus, queries, and qrels 172 | if dataset == "cqadupstack": 173 | merge_cqa_dupstack(save_dir / dataset) 174 | 175 | return save_dir / dataset 176 | else: 177 | return save_zip_path 178 | 179 | 180 | def load_jsonl( 181 | dataset, 182 | fname, 183 | save_dir="./datasets", 184 | show_progress=True, 185 | return_dict=True, 186 | force_title=False, 187 | remove=None, 188 | ): 189 | dataset_path = Path(save_dir) / dataset 190 | corpus_path = dataset_path / fname 191 | 192 | if not corpus_path.exists(): 193 | raise FileNotFoundError(f"Corpus file not found at {corpus_path}") 194 | 195 | corpus = [] 196 | with open(corpus_path, "r") as f: 197 | # get the number of bytes in the file 198 | num_lines = sum(1 for i in open(corpus_path, "rb")) 199 | pbar = tqdm( 200 | f, 201 | desc="[{}] loading {}".format(dataset, fname), 202 | leave=False, 203 | disable=not show_progress, 204 | total=num_lines, 205 | ) 206 | for line in pbar: 207 | line = json_functions.loads(line) 208 | if force_title: 209 | line["title"] = line.get("title") 210 | if remove is not None: 211 | for key in remove: 212 | del line[key] 213 | corpus.append(line) 214 | # update the progress bar wrt the number of bytes read 215 | 216 | if return_dict: 217 | corpus = {doc.pop("_id"): doc for doc in corpus} 218 | 219 | return corpus 220 | 221 | 222 | def load_corpus(dataset, save_dir="./datasets", show_progress=True, return_dict=True): 223 | return load_jsonl( 224 | dataset=dataset, 225 | save_dir=save_dir, 226 | show_progress=show_progress, 227 | return_dict=return_dict, 228 | fname="corpus.jsonl", 229 | force_title=True, 230 | remove=["metadata"], 231 | ) 232 | 233 | 234 | def load_queries(dataset, save_dir="./datasets", show_progress=True, return_dict=True): 235 | return load_jsonl( 236 | dataset=dataset, 237 | save_dir=save_dir, 238 | show_progress=show_progress, 239 | return_dict=return_dict, 240 | fname="queries.jsonl", 241 | force_title=False, 242 | remove=["metadata"], 243 | ) 244 | 245 | 246 | def load_qrels( 247 | dataset, split="test", save_dir="./datasets", show_progress=True, return_dict=True 248 | ): 249 | """ 250 | This is tsv files 251 | """ 252 | if split not in ["train", "dev", "test"]: 253 | raise ValueError("split must be one of ['train', 'dev', 'test']") 254 | 255 | dataset_path = Path(save_dir) / dataset 256 | qrels_path = dataset_path / "qrels" / f"{split}.tsv" 257 | 258 | if not qrels_path.exists(): 259 | raise FileNotFoundError(f"Qrels file not found at {qrels_path}") 260 | 261 | qrels = [] 262 | with open(qrels_path, "r") as f: 263 | # skip first line 264 | next(f) 265 | for line in tqdm( 266 | f, 267 | desc="Loading Qrels {}".format(dataset), 268 | leave=False, 269 | disable=not show_progress, 270 | ): 271 | qid, cid, score = line.strip().split("\t") 272 | qrels.append((qid, cid, int(score))) 273 | 274 | if return_dict: 275 | qrels = {qid: {cid: score} for qid, cid, score in qrels} 276 | 277 | return qrels 278 | 279 | 280 | def evaluate( 281 | qrels: Dict[str, Dict[str, int]], 282 | results: Dict[str, Dict[str, float]], 283 | k_values: List[int], 284 | ignore_identical_ids: bool = True, 285 | ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]: 286 | """ 287 | Acknowledgement: This function is adapted from BEIR's EvaluateRetrieval class. 288 | License for this function: Apache-2.0 289 | """ 290 | try: 291 | import pytrec_eval 292 | 293 | except ImportError: 294 | raise ImportError( 295 | "Please install pytrec_eval to use this function. You can install it via `pip install pytrec_eval`." 296 | ) 297 | 298 | if ignore_identical_ids: 299 | logging.info( 300 | "For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this." 301 | ) 302 | popped = [] 303 | for qid, rels in results.items(): 304 | for pid in list(rels): 305 | if qid == pid: 306 | results[qid].pop(pid) 307 | popped.append(pid) 308 | 309 | ndcg = {} 310 | _map = {} 311 | recall = {} 312 | precision = {} 313 | 314 | for k in k_values: 315 | ndcg[f"NDCG@{k}"] = 0.0 316 | _map[f"MAP@{k}"] = 0.0 317 | recall[f"Recall@{k}"] = 0.0 318 | precision[f"P@{k}"] = 0.0 319 | 320 | map_string = "map_cut." + ",".join([str(k) for k in k_values]) 321 | ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) 322 | recall_string = "recall." + ",".join([str(k) for k in k_values]) 323 | precision_string = "P." + ",".join([str(k) for k in k_values]) 324 | evaluator = pytrec_eval.RelevanceEvaluator( 325 | qrels, {map_string, ndcg_string, recall_string, precision_string} 326 | ) 327 | scores = evaluator.evaluate(results) 328 | 329 | for query_id in scores.keys(): 330 | for k in k_values: 331 | ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)] 332 | _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)] 333 | recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)] 334 | precision[f"P@{k}"] += scores[query_id]["P_" + str(k)] 335 | 336 | for k in k_values: 337 | ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5) 338 | _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5) 339 | recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5) 340 | precision[f"P@{k}"] = round(precision[f"P@{k}"] / len(scores), 5) 341 | 342 | for eval in [ndcg, _map, recall, precision]: 343 | logging.info("\n") 344 | for k in eval.keys(): 345 | logging.info("{}: {:.4f}".format(k, eval[k])) 346 | 347 | return ndcg, _map, recall, precision 348 | -------------------------------------------------------------------------------- /bm25s/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import time 3 | import sys 4 | 5 | try: 6 | import resource 7 | except ImportError: 8 | print("resource module not available on Windows") 9 | resource = None 10 | 11 | 12 | def get_max_memory_usage(format="GB"): 13 | if resource is None: 14 | return None 15 | if format not in ["GB", "MB", "KB"]: 16 | raise ValueError("format should be one of 'GB', 'MB', 'KB'") 17 | 18 | usage_kb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss 19 | # for mac, ru_maxrss is in bytes 20 | 21 | if sys.platform == "darwin": 22 | usage_kb /= 1024 23 | 24 | if format == "GB": 25 | return usage_kb / (1024**2) 26 | elif format == "MB": 27 | return usage_kb / 1024 28 | else: 29 | return usage_kb 30 | 31 | 32 | class Timer: 33 | def __init__(self, prefix="", precision=4): 34 | self.results = {} 35 | self.prefix = prefix 36 | self.precision = precision 37 | 38 | def start(self, name): 39 | if name in self.results: 40 | raise ValueError(f"Timer with name {name} already started.") 41 | start_time = time.monotonic() 42 | self.results[name] = {"start": start_time, "elapsed": 0, "last": start_time} 43 | return name 44 | 45 | def stop(self, name, show=False, n_total=None): 46 | if name not in self.results: 47 | raise ValueError(f"Timer with name {name} not started.") 48 | 49 | stop_time = time.monotonic() 50 | r = self.results[name] 51 | r["stopped"] = stop_time 52 | r["elapsed"] += stop_time - r.pop("last") 53 | 54 | if show: 55 | self.show(name, n_total=n_total) 56 | 57 | return self.results[name]["elapsed"] 58 | 59 | def pause(self, name): 60 | # if self.has_stopped(name): 61 | # raise ValueError(f"Timer with name {name} already stopped.") 62 | 63 | # if not self.has_started(name): 64 | # raise ValueError(f"Timer with name {name} not started.") 65 | 66 | paused_time = time.monotonic() 67 | r = self.results[name] 68 | 69 | r["elapsed"] += paused_time - r["last"] 70 | 71 | def resume(self, name): 72 | # if not self.has_started(name): 73 | # raise ValueError(f"Timer with name {name} not started.") 74 | 75 | # if not self.is_paused(name): 76 | # raise ValueError(f"Timer with name {name} not paused.") 77 | 78 | # if self.has_stopped(name): 79 | # raise ValueError(f"Timer with name {name} already stopped.") 80 | 81 | self.results[name]["last"] = time.monotonic() 82 | 83 | def is_paused(self, name): 84 | return name in self.results and "paused" in self.results[name] 85 | 86 | def is_resumed(self, name): 87 | return name in self.results and "resumed" in self.results[name] 88 | 89 | def has_started(self, name): 90 | return name in self.results 91 | 92 | def has_stopped(self, name): 93 | return self.has_started(name) and "stopped" in self.results[name] 94 | 95 | def elapsed(self, name, precision=None): 96 | if precision is None: 97 | precision = self.precision 98 | 99 | if not self.has_started(name): 100 | raise ValueError(f"Timer with name {name} not started.") 101 | if not self.has_stopped(name): 102 | raise ValueError(f"Timer with name {name} not stopped.") 103 | 104 | return round(self.results[name]["elapsed"], precision) 105 | 106 | def show(self, name, offset=0, n_total=None): 107 | t = self.elapsed(name) + offset 108 | s = f"{self.prefix} {name}: {t:.4f}s" 109 | if n_total is not None: 110 | # calculate throughput 111 | throughput = n_total / t 112 | s += f" ({throughput:.2f}/s)" 113 | print(s) 114 | 115 | def show_all(self): 116 | for name in self.results: 117 | if self.has_stopped(name): 118 | self.show(name) 119 | 120 | def to_dict(self, underscore=False, lowercase=False): 121 | results_to_save = deepcopy(self.results) 122 | if underscore: 123 | results_to_save = { 124 | k.replace(" ", "_"): v for k, v in results_to_save.items() 125 | } 126 | 127 | if lowercase: 128 | results_to_save = {k.lower(): v for k, v in results_to_save.items()} 129 | 130 | return results_to_save 131 | -------------------------------------------------------------------------------- /bm25s/utils/corpus.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import mmap 3 | import os 4 | 5 | import numpy as np 6 | 7 | try: 8 | import orjson as json 9 | except ImportError: 10 | import json 11 | 12 | try: 13 | from tqdm.auto import tqdm 14 | TQDM_AVAILABLE = True 15 | except ImportError: 16 | TQDM_AVAILABLE = False 17 | def tqdm(iterable=None, *args, **kwargs): 18 | return iterable 19 | 20 | from . import json_functions 21 | 22 | def change_extension(path, new_extension): 23 | path = str(path) 24 | return path.rpartition(".")[0] + new_extension 25 | 26 | 27 | def find_newline_positions(path, show_progress=True, leave_progress=True, encoding="utf-8"): 28 | path = str(path) 29 | indexes = [] 30 | with open(path, "r", encoding=encoding) as f: 31 | indexes.append(f.tell()) 32 | pbar = tqdm( 33 | total=os.path.getsize(path), 34 | desc="Finding newlines for mmindex", 35 | unit="B", 36 | unit_scale=True, 37 | disable=not show_progress, 38 | leave=leave_progress, 39 | ) 40 | 41 | while f.readline(): 42 | t = f.tell() 43 | indexes.append(t) 44 | 45 | if pbar is not None: 46 | pbar.update(t - indexes[-2]) 47 | 48 | if pbar is not None: 49 | pbar.close() 50 | 51 | return indexes[:-1] 52 | 53 | 54 | def save_mmindex(indexes, path, encoding="utf-8"): 55 | path = str(path) 56 | index_file = change_extension(path, ".mmindex.json") 57 | with open(index_file, "w", encoding=encoding) as f: 58 | f.write(json_functions.dumps(indexes)) 59 | 60 | 61 | def load_mmindex(path, encoding="utf-8"): 62 | path = str(path) 63 | index_file = change_extension(path, ".mmindex.json") 64 | with open(index_file, "r", encoding=encoding) as f: 65 | return json_functions.loads(f.read()) 66 | 67 | 68 | # now we can jump to any line in the file thanks to the index and mmap 69 | def get_line( 70 | path, 71 | index, 72 | mmindex, 73 | encoding="utf-8", 74 | file_obj=None, 75 | mmap_obj=None, 76 | ) -> str: 77 | path = str(path) 78 | if file_obj is None: 79 | file_obj = open(path, "r", encoding=encoding) 80 | CLOSE_FILE = True 81 | else: 82 | CLOSE_FILE = False 83 | 84 | if mmap_obj is None: 85 | mmap_obj = mmap.mmap(file_obj.fileno(), 0, access=mmap.ACCESS_READ) 86 | CLOSE_MMAP = True 87 | else: 88 | CLOSE_MMAP = False 89 | 90 | mmap_obj.seek(mmindex[index]) 91 | result = mmap_obj.readline().decode(encoding) 92 | 93 | if CLOSE_MMAP: 94 | mmap_obj.close() 95 | 96 | if CLOSE_FILE: 97 | file_obj.close() 98 | 99 | return result 100 | 101 | 102 | class JsonlCorpus: 103 | """ 104 | A class to read a jsonl file line by line using mmap, allowing extremely fast 105 | access to any line in the file. For example, you could access the N-th line 106 | of a 10GB file in a fraction of a second, returning a dictionary. 107 | 108 | Example 109 | -------- 110 | 111 | Traditioanally, you would read a jsonl file line by line like this: 112 | 113 | ```python 114 | import json 115 | data = [json.loads(line) for line in open("file.jsonl")] 116 | print(corpus[1000]) 117 | ``` 118 | 119 | This is memory inefficient and has a large overhead. Instead, you can use this class: 120 | 121 | ```python 122 | corpus = JsonlCorpus("file.jsonl") 123 | print(corpus[1000]) 124 | ``` 125 | 126 | Which only loads the line you need into memory, and is much faster. 127 | """ 128 | 129 | def __init__(self, path, show_progress=True, leave_progress=True, save_index=True, verbosity=1, encoding='utf-8'): 130 | self.path = path 131 | self.verbosity = verbosity 132 | self.encoding = encoding 133 | 134 | # if the index file does not exist, create it 135 | if os.path.exists(change_extension(path, ".mmindex.json")): 136 | self.mmindex = load_mmindex(path, encoding=self.encoding) 137 | else: 138 | logging.info("Creating index file for jsonl corpus") 139 | mmindex = find_newline_positions( 140 | path, show_progress=show_progress, leave_progress=leave_progress, encoding=self.encoding 141 | ) 142 | if save_index: 143 | save_mmindex(mmindex, path, encoding=self.encoding) 144 | 145 | self.mmindex = mmindex 146 | 147 | # Finally, open the file and mmap objects 148 | self.load() 149 | 150 | def __len__(self): 151 | return len(self.mmindex) 152 | 153 | def __getitem__(self, index): 154 | # handle multiple indices 155 | if isinstance(index, int): 156 | return json_functions.loads( 157 | get_line( 158 | self.path, 159 | index, 160 | self.mmindex, 161 | encoding=self.encoding, 162 | file_obj=self.file_obj, 163 | mmap_obj=self.mmap_obj, 164 | ) 165 | ) 166 | 167 | if isinstance(index, slice): 168 | return [self.__getitem__(i) for i in range(*index.indices(len(self)))] 169 | if isinstance(index, (list, tuple)): 170 | return [self.__getitem__(i) for i in index] 171 | if isinstance(index, np.ndarray): 172 | # if it's an ndarray, this means each element is an index, and the array can 173 | # be of any shape. thus, we should flatten it first, get the results as if it 174 | # was a list, and then reshape it back to the original shape 175 | index_flat = index.flatten().tolist() 176 | results = [self.__getitem__(i) for i in index_flat] 177 | reshaped = np.array(results).reshape(index.shape) 178 | return reshaped 179 | 180 | raise TypeError("Invalid index type") 181 | 182 | def close(self): 183 | """ 184 | Close the file and mmap objects. This is useful if you want to free up memory. To reopen them, use the `load` method. 185 | If you don't call this method, the objects will be closed automatically when the object is deleted. 186 | """ 187 | if hasattr(self, "file_obj") and self.file_obj is not None: 188 | self.file_obj.close() 189 | # delete the object 190 | del self.file_obj 191 | self.file_obj = None 192 | if hasattr(self, "mmap_obj") and self.mmap_obj is not None: 193 | self.mmap_obj.close() 194 | # delete the object 195 | del self.mmap_obj 196 | self.mmap_obj = None 197 | if self.verbosity >= 1: 198 | logging.info("Closed file and mmap objects") 199 | 200 | def load(self): 201 | """ 202 | Load the file and mmap objects. This is useful if you closed them and want to reopen them. 203 | 204 | Note 205 | ---- 206 | This is called automatically when the object is created. You don't need to call it manually. 207 | Also, if there is an existing file and mmap object, this will close them before reopening. 208 | """ 209 | self.close() # close any existing file and mmap objects 210 | 211 | self.file_obj = open(self.path, "r", encoding=self.encoding) 212 | self.mmap_obj = mmap.mmap(self.file_obj.fileno(), 0, access=mmap.ACCESS_READ) 213 | if self.verbosity >= 1: 214 | logging.info("Opened file and mmap objects") 215 | 216 | def __del__(self): 217 | self.close() 218 | -------------------------------------------------------------------------------- /bm25s/utils/json_functions.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | try: 4 | import orjson 5 | ORJSON_AVAILABLE = True 6 | except ImportError: 7 | ORJSON_AVAILABLE = False 8 | 9 | 10 | def dumps_with_builtin(d: dict, **kwargs) -> str: 11 | return json.dumps(d, **kwargs) 12 | 13 | def dumps_with_orjson(d: dict, **kwargs) -> str: 14 | if kwargs.get("ensure_ascii", True): 15 | # Simulate `ensure_ascii=True` by escaping non-ASCII characters 16 | return orjson.dumps(d).decode("utf-8").encode("ascii", "backslashreplace").decode("utf-8") 17 | # Ignore other kwargs not supported by orjson 18 | return orjson.dumps(d).decode("utf-8") 19 | 20 | if ORJSON_AVAILABLE: 21 | def dumps(d: dict, **kwargs) -> str: 22 | return dumps_with_orjson(d, **kwargs) 23 | loads = orjson.loads 24 | else: 25 | def dumps(d: dict, **kwargs) -> str: 26 | return dumps_with_builtin(d, **kwargs) 27 | loads = json.loads 28 | 29 | 30 | -------------------------------------------------------------------------------- /bm25s/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1dev0" 2 | -------------------------------------------------------------------------------- /examples/evaluate_on_beir.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sometimes you might be interested in benchmarking BM25 on BEIR. bm25s makes this straightforward in Python. 3 | 4 | To install: 5 | 6 | ``` 7 | pip install bm25s[core] beir 8 | ``` 9 | 10 | Now, run this script, you can modify the `run_benchmark()` part to use the datase you want to test on. 11 | """ 12 | import json 13 | import os 14 | from pathlib import Path 15 | import time 16 | 17 | import beir.util 18 | from beir.datasets.data_loader import GenericDataLoader 19 | from beir.retrieval.evaluation import EvaluateRetrieval 20 | import numpy as np 21 | from tqdm.auto import tqdm 22 | import Stemmer 23 | 24 | import bm25s 25 | from bm25s.utils.benchmark import get_max_memory_usage, Timer 26 | from bm25s.utils.beir import ( 27 | BASE_URL, 28 | clean_results_keys, 29 | ) 30 | 31 | def postprocess_results_for_eval(results, scores, query_ids): 32 | """ 33 | Given the queried results and scores output by BM25S, postprocess them 34 | to be compatible with BEIR evaluation functions. 35 | query_ids is a list of query ids in the same order as the results. 36 | """ 37 | 38 | results_record = [ 39 | {"id": qid, "hits": results[i], "scores": list(scores[i])} 40 | for i, qid in enumerate(query_ids) 41 | ] 42 | 43 | result_dict_for_eval = { 44 | res["id"]: { 45 | docid: float(score) for docid, score in zip(res["hits"], res["scores"]) 46 | } 47 | for res in results_record 48 | } 49 | 50 | return result_dict_for_eval 51 | 52 | def run_benchmark(dataset, save_dir="datasets"): 53 | #### Download dataset and unzip the dataset 54 | data_path = beir.util.download_and_unzip(BASE_URL.format(dataset), save_dir) 55 | split = "test" if dataset != "msmarco" else "dev" 56 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split) 57 | 58 | corpus_ids, corpus_lst = [], [] 59 | for key, val in corpus.items(): 60 | corpus_ids.append(key) 61 | corpus_lst.append(val["title"] + " " + val["text"]) 62 | del corpus 63 | 64 | qids, queries_lst = [], [] 65 | for key, val in queries.items(): 66 | qids.append(key) 67 | queries_lst.append(val) 68 | 69 | stemmer = Stemmer.Stemmer("english") 70 | 71 | corpus_tokens = bm25s.tokenize( 72 | corpus_lst, stemmer=stemmer, leave=False 73 | ) 74 | 75 | del corpus_lst 76 | 77 | query_tokens = bm25s.tokenize( 78 | queries_lst, stemmer=stemmer, leave=False 79 | ) 80 | 81 | model = bm25s.BM25(method="lucene", k1=1.2, b=0.75) 82 | model.index(corpus_tokens, leave_progress=False) 83 | 84 | ############## BENCHMARKING BEIR HERE ############## 85 | queried_results, queried_scores = model.retrieve( 86 | query_tokens, corpus=corpus_ids, k=1000, n_threads=4 87 | ) 88 | 89 | results_dict = postprocess_results_for_eval(queried_results, queried_scores, qids) 90 | ndcg, _map, recall, precision = EvaluateRetrieval.evaluate( 91 | qrels, results_dict, [1, 10, 100, 1000] 92 | ) 93 | 94 | print(ndcg) 95 | print(recall) 96 | 97 | return ndcg, _map, recall, precision 98 | 99 | if __name__ == "__main__": 100 | ndcg, _map, recall, precision = run_benchmark("scidocs") # Change to dataset you want 101 | -------------------------------------------------------------------------------- /examples/index_and_retrieve_with_numba.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Use Numba to speed up the retrieval process 3 | 4 | ```bash 5 | pip install "bm25s[full]" numba 6 | ``` 7 | 8 | To build an index, please refer to the `examples/index_and_upload_to_hf.py` script. 9 | 10 | Now, to run this script, execute: 11 | ```bash 12 | python examples/retrieve_with_numba.py 13 | ``` 14 | """ 15 | import os 16 | import Stemmer 17 | 18 | import bm25s.hf 19 | import bm25s 20 | 21 | def main(dataset='scifact', dataset_dir='./datasets'): 22 | queries = [ 23 | "Is chemotherapy effective for treating cancer?", 24 | "Is Cardiac injury is common in critical cases of COVID-19?", 25 | ] 26 | 27 | bm25s.utils.beir.download_dataset(dataset=dataset, save_dir=dataset_dir) 28 | corpus: dict = bm25s.utils.beir.load_corpus(dataset=dataset, save_dir=dataset_dir) 29 | corpus_records = [ 30 | {'id': k, 'title': v["title"], 'text': v["text"]} for k, v in corpus.items() 31 | ] 32 | corpus_lst = [r["title"] + " " + r["text"] for r in corpus_records] 33 | 34 | retriever = bm25s.BM25(corpus=corpus_records, backend='numba') 35 | retriever.index(corpus_lst) 36 | # corpus=corpus_records is optional, only used when you are calling retrieve and want to return the documents 37 | 38 | # Tokenize the queries 39 | stemmer = Stemmer.Stemmer("english") 40 | tokenizer = bm25s.tokenization.Tokenizer(stemmer=stemmer) 41 | queries_tokenized = tokenizer.tokenize(queries) 42 | # Retrieve the top-k results 43 | results = retriever.retrieve(queries_tokenized, k=3) 44 | # show first results 45 | result = results.documents[0] 46 | print(f"First score (# 1 result): {results.scores[0, 0]:.4f}") 47 | print(f"First result id (# 1 result): {result[0]['id']}") 48 | print(f"First result title (# 1 result): {result[0]['title']}") 49 | 50 | if __name__ == "__main__": 51 | main() -------------------------------------------------------------------------------- /examples/index_nq.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Indexing Natural Questions 3 | 4 | This shows how to build an index of the natural questions dataset using BM25S. 5 | 6 | To run this example, you need to install the following dependencies: 7 | 8 | ```bash 9 | pip install bm25s[core] 10 | ``` 11 | 12 | Then, run with: 13 | 14 | ```bash 15 | python examples/index_nq.py 16 | ``` 17 | """ 18 | 19 | from pathlib import Path 20 | import bm25s 21 | import Stemmer 22 | 23 | 24 | def main(save_dir="datasets", index_dir="bm25s_indices/", dataset="nq"): 25 | index_dir = Path(index_dir) / dataset 26 | index_dir.mkdir(parents=True, exist_ok=True) 27 | 28 | print("Downloading the dataset...") 29 | bm25s.utils.beir.download_dataset(dataset, save_dir=save_dir) 30 | print("Loading the corpus...") 31 | corpus = bm25s.utils.beir.load_corpus(dataset, save_dir=save_dir) 32 | corpus_records = [ 33 | {"id": k, "title": v["title"], "text": v["text"]} for k, v in corpus.items() 34 | ] 35 | corpus_lst = [r["title"] + " " + r["text"] for r in corpus_records] 36 | 37 | stemmer = Stemmer.Stemmer("english") 38 | tokenizer = bm25s.tokenization.Tokenizer(stemmer=stemmer) 39 | corpus_tokens = tokenizer.tokenize(corpus_lst, return_as="tuple") 40 | 41 | retriever = bm25s.BM25(corpus=corpus_records, backend="numba") 42 | retriever.index(corpus_tokens) 43 | 44 | retriever.save(index_dir) 45 | tokenizer.save_vocab(index_dir) 46 | tokenizer.save_stopwords(index_dir) 47 | print(f"Saved the index to {index_dir}.") 48 | 49 | # get memory usage 50 | mem_use = bm25s.utils.benchmark.get_max_memory_usage() 51 | print(f"Peak memory usage: {mem_use:.2f} GB") 52 | 53 | 54 | if __name__ == "__main__": 55 | main(dataset='msmarco') 56 | -------------------------------------------------------------------------------- /examples/index_to_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Indexing BEIR dataset and upload to Hugging Face Hub 3 | 4 | This will show how to index a dataset from BEIR and upload it to the Hugging Face Hub. 5 | 6 | To run this example, you need to install the following dependencies: 7 | 8 | ```bash 9 | pip install beir bm25s[full] 10 | ``` 11 | 12 | Make sure to replace `write-your-username-here` with your Hugging Face username, 13 | or set the `HF_USERNAME` environment variable. 14 | 15 | Then, run with: 16 | 17 | ``` 18 | export HF_USERNAME="write-your-username-here" 19 | export HF_TOKEN="your-hf-token" 20 | python examples/index_and_upload_to_hf.py 21 | ``` 22 | """ 23 | import os 24 | import beir.util 25 | from beir.datasets.data_loader import GenericDataLoader 26 | import Stemmer 27 | 28 | import bm25s.hf 29 | from bm25s.utils.beir import BASE_URL 30 | 31 | 32 | def main(user, save_dir="datasets", repo_name="bm25s-scifact-testing", dataset="scifact"): 33 | # First, use the beir library to download the dataset, and process it 34 | data_path = beir.util.download_and_unzip(BASE_URL.format(dataset), save_dir) 35 | corpus, _, __ = GenericDataLoader(data_folder=data_path).load(split="test") 36 | corpus_records = [ 37 | {'id': k, 'title': v["title"], 'text': v["text"]} for k, v in corpus.items() 38 | ] 39 | corpus_lst = [r["title"] + " " + r["text"] for r in corpus_records] 40 | 41 | # We will use the snowball stemmer from the PyStemmer library and tokenize the corpus 42 | stemmer = Stemmer.Stemmer("english") 43 | corpus_tokenized = bm25s.tokenize(corpus_lst, stemmer=stemmer) 44 | 45 | # We create a BM25 retriever, index the corpus, and save to Hugging Face Hub 46 | retriever = bm25s.hf.BM25HF() 47 | retriever.index(corpus_tokenized) 48 | 49 | hf_token = os.getenv("HF_TOKEN") 50 | retriever.save_to_hub(repo_id=f"{user}/{repo_name}", token=hf_token, corpus=corpus_records) 51 | 52 | # you can do the same with a tokenizer class 53 | tokenizer = bm25s.hf.TokenizerHF(stemmer=stemmer) 54 | tokenizer.tokenize(corpus_lst, update_vocab=True) 55 | tokenizer.save_vocab_to_hub(repo_id=f"{user}/{repo_name}", token=hf_token) 56 | 57 | # you can also load the retriever and tokenizer from the hub 58 | tokenizer_new = bm25s.hf.TokenizerHF(stemmer=stemmer, stopwords=[]) 59 | tokenizer_new.load_vocab_from_hub(repo_id=f"{user}/{repo_name}", token=hf_token) 60 | 61 | # You can do the same for stopwords 62 | stopwords = tokenizer.stopwords 63 | tokenizer.save_stopwords_to_hub(repo_id=f"{user}/{repo_name}", token=hf_token) 64 | 65 | # you can also load the stopwords from the hub 66 | tokenizer_new.load_stopwords_from_hub(repo_id=f"{user}/{repo_name}", token=hf_token) 67 | 68 | print("Original stopwords:", stopwords) 69 | print("Reloaded stopwords:", tokenizer_new.stopwords) 70 | 71 | 72 | if __name__ == "__main__": 73 | user = os.getenv("HF_USERNAME", "write-your-username-here") 74 | cont = input(f"Are you sure you want to upload as user '{user}'? (yes/no): ") 75 | if cont.lower() == "yes": 76 | main(user=user) -------------------------------------------------------------------------------- /examples/index_with_metadata.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sometimes, you might want to have a corpus consisting of dict rather than pure text. 3 | 4 | dicts, and any json-serializable object, is supported by bm25s. This example shows you how to pass a list of dict. 5 | 6 | Note: If the elements in your corpus is not json serializable, it will not be properly saved. In those cases, you 7 | should avoid passing 8 | """ 9 | import bm25s 10 | 11 | # Create your corpus here 12 | 13 | corpus_json = [ 14 | {"text": "a cat is a feline and likes to purr", "metadata": {"source": "internet"}}, 15 | {"text": "a dog is the human's best friend and loves to play", "metadata": {"source": "encyclopedia"}}, 16 | {"text": "a bird is a beautiful animal that can fly", "metadata": {"source": "cnn"}}, 17 | {"text": "a fish is a creature that lives in water and swims", "metadata": {"source": "i made it up"}}, 18 | ] 19 | corpus_text = [doc["text"] for doc in corpus_json] 20 | 21 | 22 | # Tokenize the corpus and only keep the ids (faster and saves memory) 23 | corpus_tokens = bm25s.tokenize(corpus_text, stopwords="en") 24 | 25 | # Create the BM25 retriever and attach your corpus_json to it 26 | retriever = bm25s.BM25(corpus=corpus_json) 27 | # Now, index the corpus_tokens (the corpus_json is not used yet) 28 | retriever.index(corpus_tokens) 29 | 30 | # Query the corpus 31 | query = "does the fish purr like a cat?" 32 | query_tokens = bm25s.tokenize(query) 33 | 34 | # Get top-k results as a tuple of (doc, scores). Note that results 35 | # will correspond to the corpus item at the corresponding index 36 | # (you are responsible to make sure each element in corpus_json 37 | # corresponds to each element in your tokenized corpus) 38 | results, scores = retriever.retrieve(query_tokens, k=2) 39 | 40 | for i in range(results.shape[1]): 41 | doc, score = results[0, i], scores[0, i] 42 | print(f"Rank {i+1} (score: {score:.2f}): {doc}") 43 | 44 | # You can save the arrays to a directory... 45 | # Note that this will fail if your corpus passed to `BM25(corpus...)` is not serializable 46 | retriever.save("animal_index_bm25") 47 | 48 | # ...and load them when you need them 49 | import bm25s 50 | reloaded_retriever = bm25s.BM25.load("animal_index_bm25", load_corpus=True) 51 | # set load_corpus=False if you don't need the corpus 52 | -------------------------------------------------------------------------------- /examples/nltk_stemmer.py: -------------------------------------------------------------------------------- 1 | # This class provides a way to use NLTK stemming functions with bm25s library 2 | 3 | from nltk.stem.porter import PorterStemmer 4 | from nltk.stem.snowball import SnowballStemmer 5 | from nltk.stem.lancaster import LancasterStemmer 6 | 7 | class NLTKMultiStemmer: 8 | """ 9 | A class that provides a unified interface for using different stemming algorithms. 10 | 11 | Attributes: 12 | stemmer_name (str): The name of the stemmer algorithm to use. 13 | available_stemmers (dict): A dictionary that maps stemmer names to their corresponding stemmer objects. 14 | stemmer (object): The current stemmer object being used. 15 | 16 | Methods: 17 | stem(tokens): Applies the current stemmer to a list of tokens and returns the stemmed tokens. 18 | set_stemmer(stemmer_name): Sets the current stemmer to the specified stemmer name. 19 | """ 20 | def __init__(self, stemmer_name='porter', language='english'): 21 | self.stemmer_name = stemmer_name 22 | self.language = language 23 | 24 | self.available_stemmers = { 25 | 'porter': PorterStemmer(), 26 | 'snowball': SnowballStemmer('english'), 27 | 'lancaster': LancasterStemmer() 28 | } 29 | self.stemmer = self.available_stemmers[self.stemmer_name] 30 | 31 | def stem(self, tokens)->list: 32 | """ 33 | Applies the current stemmer to a list of tokens and returns the stemmed tokens. 34 | This is done because bm25s passes a list of strings to the stemmer, 35 | and the nltk function expects a single string per call. 36 | 37 | Args: 38 | tokens (list): A list of tokens to be stemmed. 39 | 40 | Returns: 41 | list: A list of stemmed tokens. 42 | """ 43 | return [self.stemmer.stem(token) for token in tokens] 44 | 45 | def set_stemmer(self, stemmer_name, language='english'): 46 | """ 47 | Sets the current stemmer to the specified stemmer name 48 | 49 | Args: 50 | stemmer_name (str): The name of the stemmer to use. 51 | Raises: 52 | ValueError: If the specified stemmer name is not available. 53 | """ 54 | if stemmer_name in self.available_stemmers: 55 | self.stemmer_name = stemmer_name 56 | if stemmer_name == 'snowball': 57 | self.language = language 58 | self.stemmer = SnowballStemmer(self.language) 59 | else: 60 | self.stemmer = self.available_stemmers[stemmer_name] 61 | else: 62 | raise ValueError(f"Invalid stemmer name: {stemmer_name}. Available stemmers: {list(self.available_stemmers.keys())}") 63 | # Usage 64 | 65 | # Create a NLTKMultiStemmer instance with the default Snowball stemmer 66 | nltk_stemmer = NLTKMultiStemmer() 67 | 68 | # Tokenize and stem the corpus using the Snowball stemmer 69 | corpus_tokens = bm25s.tokenize(corpus, stopwords=None, stemmer=nltk_stemmer.stem) 70 | 71 | # Change the stemmer to Porter 72 | nltk_stemmer.set_stemmer('porter') 73 | 74 | # Change the stemmer to Lancaster 75 | nltk_stemmer.set_stemmer('lancaster') 76 | -------------------------------------------------------------------------------- /examples/retrieve_from_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Load index from Hugging Face Hub and retrieve from SciFact dataset 3 | 4 | This shows how to load an index from the Hugging Face Hub created with BM25HF.index and 5 | saved with BM25HF.save_to_hub. We will retrieve the top-k results for custom queries. 6 | 7 | To run this example, you need to install the following dependencies: 8 | 9 | ```bash 10 | pip install bm25s[full] 11 | ``` 12 | 13 | To build an index, please refer to the `examples/index_and_upload_to_hf.py` script. You 14 | can run this script with: 15 | 16 | ```bash 17 | python examples/index_and_upload_to_hf.py 18 | ``` 19 | 20 | Then, run this script with: 21 | 22 | ```bash 23 | python examples/retrieve_from_hf.py 24 | ``` 25 | """ 26 | import os 27 | import Stemmer 28 | 29 | import bm25s.hf 30 | 31 | def main(user, repo_name="bm25s-scifact-index"): 32 | queries = [ 33 | "Is chemotherapy effective for treating cancer?", 34 | "Is Cardiac injury is common in critical cases of COVID-19?", 35 | ] 36 | 37 | # Load the BM25 index from Hugging Face Hub 38 | # mmap=True helps to reduce memory usage by memory-mapping the index 39 | # load_corpus=True loads the corpus along with the index, so you can access the documents 40 | retriever = bm25s.hf.BM25HF.load_from_hub( 41 | f"{user}/{repo_name}", load_corpus=True, mmap=True 42 | ) 43 | 44 | # Tokenize the queries 45 | stemmer = Stemmer.Stemmer("english") 46 | queries_tokenized = bm25s.tokenize(queries, stemmer=stemmer) 47 | 48 | # Retrieve the top-k results 49 | results = retriever.retrieve(queries_tokenized, k=3) 50 | # show first results 51 | result = results.documents[0] 52 | print(f"First score (# 1 result):{results.scores[0, 0]}") 53 | print(f"First result (# 1 result):\n{result[0]}") 54 | 55 | if __name__ == "__main__": 56 | user = os.getenv("HF_USERNAME", "write-your-username-here") 57 | main(user=user) -------------------------------------------------------------------------------- /examples/retrieve_nq.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Retrieve from pre-built index of Natural Questions 3 | 4 | This shows how to load an index built with BM25.index and saved with BM25.save, and retrieve 5 | the top-k results for a set of queries from the Natural Questions dataset, via BEIR library. 6 | 7 | To run this example, you need to install the following dependencies: 8 | 9 | ```bash 10 | pip install bm25s[core] 11 | ``` 12 | 13 | To build an index, please refer to the `examples/index_nq.py` script. You 14 | can run this script with: 15 | 16 | ```bash 17 | python examples/index_nq.py 18 | ``` 19 | 20 | Then, run this script with: 21 | 22 | ```bash 23 | python examples/retrieve_nq.py 24 | ``` 25 | """ 26 | 27 | from pathlib import Path 28 | import numpy as np 29 | import bm25s 30 | import Stemmer 31 | from tqdm import tqdm 32 | 33 | 34 | def main(index_dir="bm25s_indices", data_dir="datasets", dataset="nq", split="test", mmap=True): 35 | index_dir = Path(index_dir) / dataset 36 | 37 | if mmap: 38 | print("Using memory-mapped index (mmap) to reduce memory usage.") 39 | 40 | timer = bm25s.utils.benchmark.Timer("[BM25S]") 41 | 42 | queries = bm25s.utils.beir.load_queries(dataset, save_dir=data_dir) 43 | qrels = bm25s.utils.beir.load_qrels(dataset, split=split, save_dir=data_dir) 44 | queries_lst = [v["text"] for k, v in queries.items() if k in qrels] 45 | print(f"Loaded {len(queries_lst)} queries.") 46 | 47 | stemmer = Stemmer.Stemmer("english") 48 | 49 | # Tokenize the queries 50 | queries_tokenized = bm25s.tokenize(queries_lst, stemmer=stemmer, return_ids=False) 51 | 52 | # # Alternatively, you can use the following code to tokenize the queries 53 | # # using the saved tokenizer from the index directory 54 | # tokenizer = bm25s.tokenization.Tokenizer(stemmer=stemmer) 55 | # tokenizer.load_stopwords(index_dir) 56 | # tokenizer.load_vocab(index_dir) 57 | # queries_tokenized = tokenizer.tokenize(queries_lst, update_vocab=False) 58 | 59 | mem_use = bm25s.utils.benchmark.get_max_memory_usage() 60 | print(f"Initial memory usage: {mem_use:.2f} GB") 61 | 62 | # Load the BM25 index and retrieve the top-k results 63 | print(f"Loading the BM25 index for: {dataset}") 64 | t = timer.start("Loading index") 65 | retriever = bm25s.BM25.load(index_dir, mmap=mmap, load_corpus=True) 66 | retriever.backend = "numba" 67 | num_docs = retriever.scores['num_docs'] 68 | timer.stop(t, show=True, n_total=num_docs) 69 | 70 | mem_use = bm25s.utils.benchmark.get_max_memory_usage() 71 | print(f"Memory usage after loading the index: {mem_use:.2f} GB") 72 | 73 | print("Retrieving the top-k results...") 74 | t = timer.start("Retrieving") 75 | results = retriever.retrieve(queries_tokenized, k=10) 76 | timer.stop(t, show=True, n_total=len(queries_lst)) 77 | 78 | # get memory usage 79 | mem_use = bm25s.utils.benchmark.get_max_memory_usage() 80 | print(f"Final (peak) memory usage: {mem_use:.2f} GB") 81 | 82 | print("-" * 50) 83 | first_result = results.documents[0] 84 | print(f"First score (# 1 result): {results.scores[0, 0]:.4f}") 85 | print(f"First result (# 1 result):\n{first_result[0]}") 86 | 87 | 88 | if __name__ == "__main__": 89 | main(mmap=True) 90 | -------------------------------------------------------------------------------- /examples/retrieve_nq_with_batching.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Retrieve from pre-built index of Natural Questions 3 | 4 | This is a modified version of the `examples/retrieve_nq.py` script that uses batching to 5 | even reduce memory usage further. This script loads the queries in batches and retrieves 6 | the top-k results for each batch, clearing the memory after each batch. 7 | 8 | To run this example, you need to install the following dependencies: 9 | 10 | ```bash 11 | pip install bm25s[core] 12 | ``` 13 | 14 | To build an index, please refer to the `examples/index_nq.py` script. You 15 | can run this script with: 16 | 17 | ```bash 18 | python examples/index_nq.py 19 | ``` 20 | 21 | Then, run this script with: 22 | 23 | ```bash 24 | python examples/retrieve_nq.py 25 | ``` 26 | """ 27 | 28 | from pathlib import Path 29 | import bm25s 30 | import Stemmer 31 | from tqdm import tqdm 32 | 33 | 34 | def main(index_dir="bm25s_indices/", data_dir="datasets", dataset="nq", split="test", bsize=20): 35 | index_dir = Path(index_dir) / dataset 36 | mmap = True 37 | print("Using memory-mapped index (mmap) to reduce memory usage.") 38 | 39 | timer = bm25s.utils.benchmark.Timer("[BM25S]") 40 | 41 | queries = bm25s.utils.beir.load_queries(dataset, save_dir=data_dir) 42 | qrels = bm25s.utils.beir.load_qrels(dataset, split=split, save_dir=data_dir) 43 | queries_lst = [v["text"] for k, v in queries.items() if k in qrels] 44 | print(f"Loaded {len(queries_lst)} queries.") 45 | 46 | # Tokenize the queries 47 | stemmer = Stemmer.Stemmer("english") 48 | queries_tokenized = bm25s.tokenize(queries_lst, stemmer=stemmer, return_ids=False) 49 | 50 | mem_use = bm25s.utils.benchmark.get_max_memory_usage() 51 | print(f"Initial memory usage: {mem_use:.2f} GB") 52 | 53 | # Load the BM25 index and retrieve the top-k results 54 | print("Loading the BM25 index...") 55 | t = timer.start("Loading index") 56 | retriever = bm25s.BM25.load(index_dir, mmap=mmap, load_corpus=True) 57 | retriever.backend = "numba" 58 | num_docs = retriever.scores["num_docs"] 59 | timer.stop(t, show=True, n_total=num_docs) 60 | 61 | mem_use = bm25s.utils.benchmark.get_max_memory_usage() 62 | print(f"Memory usage after loading the index: {mem_use:.2f} GB") 63 | 64 | print("Retrieving the top-k results...") 65 | t = timer.start("Retrieving") 66 | 67 | batches = [] 68 | 69 | for i in tqdm(range(0, len(queries_lst), bsize)): 70 | batches.append(retriever.retrieve(queries_tokenized[i : i + bsize], k=10)) 71 | 72 | # reload the corpus and scores to free up memory 73 | retriever.load_scores(save_dir=index_dir, mmap=mmap, num_docs=num_docs) 74 | if isinstance(retriever.corpus, bm25s.utils.corpus.JsonlCorpus): 75 | retriever.corpus.load() 76 | 77 | results = bm25s.Results.merge(batches) 78 | 79 | timer.stop(t, show=True, n_total=len(queries_lst)) 80 | 81 | # get memory usage 82 | mem_use = bm25s.utils.benchmark.get_max_memory_usage() 83 | print(f"Final (peak) memory usage: {mem_use:.2f} GB") 84 | 85 | print("-" * 50) 86 | first_result = results.documents[0] 87 | print(f"First score (# 1 result): {results.scores[0, 0]:.4f}") 88 | print(f"First result (# 1 result):\n{first_result[0]}") 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /examples/retrieve_with_numba_advanced.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Use Numba to speed up the retrieval process 3 | 4 | ```bash 5 | pip install "bm25s[full]" numba 6 | ``` 7 | 8 | To build an index, please refer to the `examples/index_and_upload_to_hf.py` script. 9 | 10 | Now, to run this script, execute: 11 | ```bash 12 | python examples/retrieve_with_numba.py 13 | ``` 14 | """ 15 | import os 16 | import Stemmer 17 | 18 | import bm25s.hf 19 | 20 | def main(repo_name="xhluca/bm25s-fiqa-index"): 21 | queries = [ 22 | "Is chemotherapy effective for treating cancer?", 23 | "Is Cardiac injury is common in critical cases of COVID-19?", 24 | ] 25 | 26 | retriever = bm25s.hf.BM25HF.load_from_hub( 27 | repo_name, load_corpus=False, mmap=False 28 | ) 29 | 30 | # Tokenize the queries 31 | stemmer = Stemmer.Stemmer("english") 32 | queries_tokenized = bm25s.tokenize(queries, stemmer=stemmer) 33 | 34 | # Retrieve the top-k results 35 | retriever.activate_numba_scorer() 36 | results = retriever.retrieve(queries_tokenized, k=3, backend_selection="numba") 37 | # show first results 38 | result = results.documents[0] 39 | print(f"First score (# 1 result):{results.scores[0, 0]}") 40 | print(f"First result (# 1 result):\n{result[0]}") 41 | 42 | if __name__ == "__main__": 43 | main() -------------------------------------------------------------------------------- /examples/retrieve_with_numba_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Use Numba to speed up the retrieval process 3 | 4 | ```bash 5 | pip install "bm25s[full]" numba 6 | ``` 7 | 8 | To build an index, please refer to the `examples/index_and_upload_to_hf.py` script. 9 | 10 | Now, to run this script, execute: 11 | ```bash 12 | python examples/retrieve_with_numba.py 13 | ``` 14 | """ 15 | import os 16 | import Stemmer 17 | 18 | import bm25s.hf 19 | 20 | def main(repo_name="xhluca/bm25s-fiqa-index"): 21 | queries = [ 22 | "Is chemotherapy effective for treating cancer?", 23 | "Is Cardiac injury is common in critical cases of COVID-19?", 24 | ] 25 | 26 | retriever = bm25s.hf.BM25HF.load_from_hub( 27 | repo_name, load_corpus=False, mmap=False 28 | ) 29 | 30 | retriever.backend = "numba" # this can also be set during initialization of the retriever 31 | 32 | # Tokenize the queries 33 | stemmer = Stemmer.Stemmer("english") 34 | tokenizer = bm25s.tokenization.Tokenizer(stemmer=stemmer) 35 | queries_tokenized = tokenizer.tokenize(queries) 36 | 37 | # Retrieve the top-k results 38 | results = retriever.retrieve(queries_tokenized, k=3) 39 | # show first results 40 | result = results.documents[0] 41 | print(f"First score (# 1 result): {results.scores[0, 0]:.4f}") 42 | print(f"First result (# 1 result): {result[0]}") 43 | 44 | if __name__ == "__main__": 45 | main() -------------------------------------------------------------------------------- /examples/save_and_reload_end_to_end.py: -------------------------------------------------------------------------------- 1 | import bm25s 2 | from bm25s.tokenization import Tokenizer 3 | 4 | corpus = [ 5 | "Welcome to bm25s, a library that implements BM25 in Python, allowing you to rank documents based on a query.", 6 | "BM25 is a widely used ranking function used for text retrieval tasks, and is a core component of search services like Elasticsearch.", 7 | "It is designed to be:", 8 | "Fast: bm25s is implemented in pure Python and leverage Scipy sparse matrices to store eagerly computed scores for all document tokens.", 9 | "This allows extremely fast scoring at query time, improving performance over popular libraries by orders of magnitude (see benchmarks below).", 10 | "Simple: bm25s is designed to be easy to use and understand.", 11 | "You can install it with pip and start using it in minutes.", 12 | "There is no dependencies on Java or Pytorch - all you need is Scipy and Numpy, and optional lightweight dependencies for stemming.", 13 | "Below, we compare bm25s with Elasticsearch in terms of speedup over rank-bm25, the most popular Python implementation of BM25.", 14 | "We measure the throughput in queries per second (QPS) on a few popular datasets from BEIR in a single-threaded setting.", 15 | "bm25s aims to offer a faster alternative for Python users who need efficient text retrieval.", 16 | "It leverages modern Python libraries and data structures for performance optimization.", 17 | "You can find more details in the documentation and example notebooks provided.", 18 | "Installation and usage guidelines are simple and accessible for developers of all skill levels.", 19 | "Try bm25s for a scalable and fast text ranking solution in your Python projects." 20 | ] 21 | 22 | print(f"We have {len(corpus)} documents in the corpus.") 23 | 24 | tokenizer = Tokenizer(splitter=lambda x: x.split()) 25 | corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple") 26 | 27 | retriever = bm25s.BM25(corpus=corpus) 28 | retriever.index(corpus_tokens) 29 | 30 | retriever.save("bm25s_index_readme") 31 | tokenizer.save_vocab(save_dir="bm25s_index_readme") 32 | 33 | # Let's reload the retriever and tokenizer and use them to retrieve documents based on a query 34 | 35 | reloaded_retriever = bm25s.BM25.load("bm25s_index_readme", load_corpus=True) 36 | 37 | reloaded_tokenizer = Tokenizer(splitter=lambda x: x.split()) 38 | reloaded_tokenizer.load_vocab("bm25s_index_readme") 39 | 40 | queries = ["widely used text ranking function"] 41 | 42 | query_tokens = reloaded_tokenizer.tokenize(queries, update_vocab=False) 43 | results, scores = reloaded_retriever.retrieve(query_tokens, k=2) 44 | 45 | for i in range(results.shape[1]): 46 | doc, score = results[0, i], scores[0, i] 47 | print(f"Rank {i+1} (score: {score:.2f}): {doc}") -------------------------------------------------------------------------------- /examples/tokenize_multiprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Tokenize with multiprocessing 3 | 4 | In this example, we see how to tokenize the NQ dataset using multiprocessing.Pool 5 | to parallelize the tokenization process. Note that this does not show how to use 6 | the snowball stemmer as the c object is not picklable (this should be fixable in a 7 | future PR) and returns strings instead of IDs/vocab, since each process cannot communicate 8 | with the other processes to use the same dictionary. 9 | 10 | Note that we can observe a speedup, but the per-core efficiency will go down as you use more 11 | cores. For examples, on NQ, we observe the following: 12 | 13 | Single Process: 110.0863s (24357.87/s) 14 | Multiprocess (4x): 61.4338s (43648.09/s) 15 | 16 | As you can see, the time taken went down by 50s but uses 4x more threads. 17 | """ 18 | import multiprocessing as mp 19 | 20 | import beir.util 21 | from beir.datasets.data_loader import GenericDataLoader 22 | import bm25s 23 | from bm25s.utils.benchmark import Timer 24 | from bm25s.utils.beir import BASE_URL 25 | 26 | def tokenize_fn(texts): 27 | return bm25s.tokenize(texts=texts, return_ids=False, show_progress=False) 28 | 29 | def chunk(lst, n): 30 | return [lst[i : i + n] for i in range(0, len(lst), n)] 31 | 32 | def unchunk(lsts): 33 | # merge all lsts into one list 34 | return [item for lst in lsts for item in lst] 35 | 36 | if __name__ == "__main__": 37 | dataset = "nq" 38 | save_dir = "datasets" 39 | split = "test" 40 | num_processes = 4 41 | 42 | data_path = beir.util.download_and_unzip(BASE_URL.format(dataset), save_dir) 43 | corpus, _, __ = GenericDataLoader(data_folder=data_path).load(split=split) 44 | 45 | corpus_ids, corpus_lst = [], [] 46 | for key, val in corpus.items(): 47 | corpus_ids.append(key) 48 | corpus_lst.append(val["title"] + " " + val["text"]) 49 | 50 | del corpus 51 | 52 | timer = Timer("[Tokenization]") 53 | 54 | # let's try single process 55 | t = timer.start("single-threaded") 56 | tokens = bm25s.tokenize(texts=corpus_lst, return_ids=False) 57 | timer.stop(t, show=True, n_total=len(corpus_lst)) 58 | 59 | # we will use the tokenizer class here 60 | corpus_chunks = chunk(corpus_lst, 1000) 61 | t = timer.start(f"num_processes={num_processes}") 62 | with mp.Pool(processes=num_processes) as pool: 63 | tokens_lst_chunks = pool.map(tokenize_fn, corpus_chunks) 64 | timer.stop(t, show=True, n_total=len(corpus_lst)) 65 | 66 | tokens_lst_final = unchunk(tokens_lst_chunks) 67 | assert tokens == tokens_lst_final 68 | -------------------------------------------------------------------------------- /examples/tokenizer_class.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Example: Retrieve from pre-built index of SciFact 3 | 4 | This script shows how to load an index built with BM25.index and saved with BM25.save, and retrieve 5 | the top-k results for a set of queries from the SciFact dataset, via the BEIR library. 6 | """ 7 | import shutil 8 | import tempfile 9 | import beir.util 10 | from beir.datasets.data_loader import GenericDataLoader 11 | import Stemmer 12 | 13 | import bm25s 14 | from bm25s.utils.beir import BASE_URL 15 | from bm25s.tokenization import Tokenizer, Tokenized 16 | 17 | 18 | def main(data_dir="datasets", dataset="scifact"): 19 | # Load the queries from BEIR 20 | data_path = beir.util.download_and_unzip(BASE_URL.format(dataset), data_dir) 21 | loader = GenericDataLoader(data_folder=data_path) 22 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split='test') 23 | corpus_lst = [doc["title"] + " " + doc["text"] for doc in corpus.values()] 24 | queries_lst = list(queries.values()) 25 | 26 | # Initialize the stemmer 27 | stemmer = Stemmer.Stemmer("english") 28 | 29 | # Initialize the Tokenizer with the stemmer 30 | tokenizer = Tokenizer( 31 | stemmer=stemmer, 32 | lower=True, # lowercase the tokens 33 | stopwords="english", # or pass a list of stopwords 34 | splitter=r"\w+", # by default r"(?u)\b\w\w+\b", can also be a function 35 | ) 36 | 37 | # Tokenize the corpus 38 | corpus_tokenized = tokenizer.tokenize( 39 | corpus_lst, 40 | update_vocab=True, # update the vocab as we tokenize 41 | return_as="ids" 42 | ) 43 | 44 | # stream tokenizing the queries, without updating the vocabulary 45 | # note: this cannot return as string due to the streaming nature 46 | tokenizer_stream = tokenizer.streaming_tokenize( 47 | queries_lst, 48 | update_vocab=False 49 | ) 50 | query_ids = [] 51 | 52 | for q in tokenizer_stream: 53 | # you can do something with the ids here, e.g. retrieve from the index 54 | if 1 in q: 55 | query_ids.append(q) 56 | 57 | # you can convert the ids to a Tokenized namedtuple ids and tokens... 58 | res = tokenizer.to_tokenized_tuple(query_ids) 59 | # ... which is equivalent to: 60 | # tokenizer.tokenize(your_query_lst, return_as="tuple", update_vocab=False) 61 | 62 | # You can verify the results 63 | assert res.ids == query_ids 64 | assert res.vocab == tokenizer.get_vocab_dict() 65 | assert isinstance(res, Tokenized) 66 | 67 | 68 | # You can also get strings 69 | query_strs = tokenizer.decode(query_ids) 70 | # ... which is equivalent to: 71 | # tokenizer.tokenize(your_query_lst, return_as="string", update_vocab=False) 72 | 73 | # let's verify the results 74 | assert isinstance(query_strs, list) 75 | assert isinstance(query_strs[0], list) 76 | assert isinstance(query_strs[0][0], str) 77 | 78 | # Let's see how it's all used 79 | retriever = bm25s.BM25() 80 | retriever.index(corpus_tokenized, leave_progress=False) 81 | 82 | # all of the above can be passed to index a bm25s model 83 | 84 | # e.g. using the ids directly 85 | results, scores = retriever.retrieve(query_ids, k=3) 86 | 87 | # or passing the strings 88 | results, scores = retriever.retrieve(query_strs, k=3) 89 | 90 | # or passing the Tokenized namedtuple 91 | results, scores = retriever.retrieve(res, k=3) 92 | 93 | # or passing a tuple of ids and vocab dict 94 | vocab_dict = tokenizer.get_vocab_dict() 95 | results, scores = retriever.retrieve((query_ids, vocab_dict), k=3) 96 | 97 | # If you want, you can save the vocab and stopwords, it can be the same dir as your index 98 | your_index_dir = tempfile.mkdtemp() 99 | tokenizer.save_vocab(save_dir=your_index_dir) 100 | 101 | # Unhappy with your vocab? you can reset your tokenizer 102 | tokenizer.reset_vocab() 103 | 104 | 105 | # loading: 106 | new_tokenizer = Tokenizer( 107 | stemmer=stemmer, 108 | lower=True, 109 | stopwords=[], 110 | splitter=r"\w+", 111 | ) 112 | print("Vocabulary size before reloading:", len(new_tokenizer.get_vocab_dict())) 113 | new_tokenizer.load_vocab(your_index_dir) 114 | print("Vocabulary size after reloading:", len(new_tokenizer.get_vocab_dict())) 115 | 116 | # the same can be done for stopwords 117 | print("stopwords before reloading:", new_tokenizer.stopwords) 118 | tokenizer.save_stopwords(save_dir=your_index_dir) 119 | new_tokenizer.load_stopwords(your_index_dir) 120 | print("stopwords after reloaded:", new_tokenizer.stopwords) 121 | 122 | # cleanup 123 | shutil.rmtree(your_index_dir) 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | 129 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | package_name = "bm25s" 4 | version = {} 5 | with open(f"{package_name}/version.py", encoding="utf8") as fp: 6 | exec(fp.read(), version) 7 | 8 | with open("README.md", encoding="utf8") as fp: 9 | long_description = fp.read() 10 | 11 | extras_require = { 12 | "core": ["orjson", "tqdm", "PyStemmer", "numba"], 13 | "stem": ["PyStemmer"], 14 | "hf": ["huggingface_hub"], 15 | "dev": ["black"], 16 | "selection": ["jax[cpu]"], 17 | "evaluation": ["pytrec_eval"], 18 | } 19 | # Dynamically create the 'full' extra by combining all other extras 20 | extras_require["full"] = sum(extras_require.values(), []) 21 | 22 | setup( 23 | name=package_name, 24 | version=version["__version__"], 25 | author="Xing Han Lù", 26 | author_email=f"{package_name}@googlegroups.com", 27 | url=f"https://github.com/xhluca/{package_name}", 28 | description=f"An ultra-fast implementation of BM25 based on sparse matrices.", 29 | long_description=long_description, 30 | packages=find_packages(include=[f"{package_name}*"]), 31 | package_data={}, 32 | install_requires=['scipy', 'numpy'], 33 | extras_require=extras_require, 34 | classifiers=[ 35 | "Programming Language :: Python :: 3", 36 | "License :: OSI Approved :: MIT License", 37 | "Operating System :: OS Independent", 38 | ], 39 | python_requires=">=3.8", 40 | # Cast long description to markdown 41 | long_description_content_type="text/markdown", 42 | ) -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # BM25S tests 2 | 3 | Welcome to the test suite for BM25S! This test suite is designed to test the BM25S implementation in the `bm25s` package. 4 | 5 | ## Core tests 6 | 7 | To run the core tests (of library), simply run the following command: 8 | 9 | ```bash 10 | python -m unittest tests/core/*.py 11 | python -m unittest tests/stopwords/*.py 12 | ``` 13 | 14 | For numba, you have to run: 15 | 16 | ```bash 17 | python -m unittest tests/numba/*.py 18 | ``` 19 | 20 | 21 | ## Basic Comparisons 22 | 23 | To run the basic comparison tests (with other BM25 implementations), simply run the following command: 24 | 25 | ```bash 26 | python -m unittest tests/comparison/*.py 27 | ``` 28 | 29 | ## Multiple tests 30 | 31 | To run the core tests (of library), simply run the following command: 32 | 33 | ```bash 34 | python -m unittest tests/core/*.py 35 | python -m unittest tests/stopwords/*.py 36 | python -m unittest tests/numba/*.py 37 | python -m unittest tests/comparison/*.py 38 | ``` 39 | 40 | ## Full comparison tests 41 | 42 | To run the full comparison tests, simply run the following command: 43 | 44 | ```bash 45 | python -m unittest tests/comparison_full/*.py 46 | ``` 47 | 48 | ## Artifacts 49 | 50 | By default, the artifacts are stored in the `./artifacts` directory. This directory is created if it does not exist. To specify the directory, you can set the `BM25_ARTIFACTS_DIR` environment variable: 51 | 52 | ```bash 53 | export BM25_ARTIFACTS_DIR=/path/to/artifacts 54 | ``` 55 | 56 | 57 | ## Adding new tests 58 | 59 | First, create a new file in tests/core, tests/comparison, tests/numba, tests/stopwords, or tests/comparison_full. Then, add the following code to the file: 60 | 61 | ```python 62 | import os 63 | import shutil 64 | from pathlib import Path 65 | import unittest 66 | import tempfile 67 | import Stemmer # optional: for stemming 68 | import unittest.mock 69 | import json 70 | 71 | import bm25s 72 | 73 | class TestYourName(unittest.TestCase): 74 | def test_your_name(self): 75 | # Your test code here 76 | pass 77 | ``` 78 | 79 | Modify the `test_your_name` function to test your code. You can use the `bm25s` package to test your code. You can also use the `unittest.mock` package to mock objects. -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import unittest 5 | from pathlib import Path 6 | import warnings 7 | 8 | import numpy as np 9 | 10 | import bm25s 11 | 12 | 13 | # Make sure to import or define the functions/classes you're going to use, 14 | # such as bm25s.skl_tokenize and the bm25s.BM25 class, among others. 15 | def save_scores(scores, artifact_dir="tests/artifacts"): 16 | if os.getenv("ARTIFACTS_DIR"): 17 | artifacts_dir = Path(os.getenv("BM25_ARTIFACTS_DIR")) 18 | elif artifact_dir is not None: 19 | artifacts_dir = Path(artifact_dir) 20 | else: 21 | artifacts_dir = Path(__file__).parent / "artifacts" 22 | 23 | if "dataset" not in scores: 24 | raise ValueError("scores must contain a 'dataset' key.") 25 | if "model" not in scores: 26 | raise ValueError("scores must contain a 'model' key.") 27 | 28 | artifacts_dir = artifacts_dir / scores["model"] 29 | artifacts_dir.mkdir(exist_ok=True, parents=True) 30 | 31 | filename = f"{scores['dataset']}-{os.urandom(8).hex()}.json" 32 | with open(artifacts_dir / filename, "w") as f: 33 | json.dump(scores, f, indent=2) 34 | 35 | 36 | class BM25TestCase(unittest.TestCase): 37 | def compare_with_rank_bm25( 38 | self, 39 | dataset, 40 | artifact_dir="tests/artifacts", 41 | rel_save_dir="datasets", 42 | corpus_subsample=None, 43 | queries_subsample=None, 44 | method="rank", 45 | ): 46 | from beir.datasets.data_loader import GenericDataLoader 47 | from beir.util import download_and_unzip 48 | import rank_bm25 49 | import Stemmer 50 | 51 | warnings.filterwarnings("ignore", category=ResourceWarning) 52 | 53 | if method not in ["rank", "bm25+", "bm25l"]: 54 | raise ValueError("method must be either 'rank' or 'bm25+'.") 55 | 56 | # Download and prepare dataset 57 | base_url = ( 58 | "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip" 59 | ) 60 | url = base_url.format(dataset) 61 | out_dir = Path(__file__).parent / rel_save_dir 62 | data_path = download_and_unzip(url, str(out_dir)) 63 | 64 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load( 65 | split="test" 66 | ) 67 | 68 | # Convert corpus and queries to lists 69 | corpus_lst = [val["title"] + " " + val["text"] for val in corpus.values()] 70 | queries_lst = list(queries.values()) 71 | 72 | if corpus_subsample is not None: 73 | corpus_lst = corpus_lst[:corpus_subsample] 74 | 75 | if queries_subsample is not None: 76 | queries_lst = queries_lst[:queries_subsample] 77 | 78 | # Tokenize using sklearn-style tokenizer + PyStemmer 79 | stemmer = Stemmer.Stemmer("english") 80 | 81 | corpus_token_strs = bm25s.tokenize( 82 | corpus_lst, stopwords="en", stemmer=stemmer, return_ids=False 83 | ) 84 | queries_token_strs = bm25s.tokenize( 85 | queries_lst, stopwords="en", stemmer=stemmer, return_ids=False 86 | ) 87 | print() 88 | print(f"Dataset: {dataset}\n") 89 | # print corpus and queries size 90 | print(f"Corpus size: {len(corpus_lst)}") 91 | print(f"Queries size: {len(queries_lst)}") 92 | print() 93 | 94 | # Initialize and index bm25s with atire + robertson idf (to match rank-bm25) 95 | if method == "rank": 96 | bm25_sparse = bm25s.BM25(k1=1.5, b=0.75, method="atire", idf_method="robertson") 97 | elif method in ["bm25+", "bm25l"]: 98 | bm25_sparse = bm25s.BM25(k1=1.5, b=0.75, delta=0.5, method=method) 99 | else: 100 | raise ValueError("invalid method") 101 | 102 | start_time = time.monotonic() 103 | bm25_sparse.index(corpus_token_strs) 104 | bm25_sparse_index_time = time.monotonic() - start_time 105 | print(f"bm25s index time: {bm25_sparse_index_time:.4f}s") 106 | 107 | # Scoring with bm25-sparse 108 | start_time = time.monotonic() 109 | bm25_sparse_scores = [bm25_sparse.get_scores(q) for q in queries_token_strs] 110 | bm25_sparse_score_time = time.monotonic() - start_time 111 | print(f"bm25s score time: {bm25_sparse_score_time:.4f}s") 112 | 113 | # Initialize and index rank-bm25 114 | start_time = time.monotonic() 115 | if method == "rank": 116 | bm25_rank = rank_bm25.BM25Okapi(corpus_token_strs, k1=1.5, b=0.75, epsilon=0.0) 117 | elif method == "bm25+": 118 | bm25_rank = rank_bm25.BM25Plus(corpus_token_strs, k1=1.5, b=0.75, delta=0.5) 119 | elif method == "bm25l": 120 | bm25_rank = rank_bm25.BM25L(corpus_token_strs, k1=1.5, b=0.75, delta=0.5) 121 | else: 122 | raise ValueError("invalid method") 123 | 124 | bm25_rank_index_time = time.monotonic() - start_time 125 | print(f"rank-bm25 index time: {bm25_rank_index_time:.4f}s") 126 | 127 | # Scoring with rank-bm25 128 | start_time = time.monotonic() 129 | bm25_rank_scores = [bm25_rank.get_scores(q) for q in queries_token_strs] 130 | bm25_rank_score_time = time.monotonic() - start_time 131 | print(f"rank-bm25 score time: {bm25_rank_score_time:.4f}s") 132 | 133 | # print difference in time 134 | print( 135 | f"Index Time: BM25S is {bm25_rank_index_time / bm25_sparse_index_time:.2f}x faster than rank-bm25." 136 | ) 137 | print( 138 | f"Score Time: BM25S is {bm25_rank_score_time / bm25_sparse_score_time:.2f}x faster than rank-bm25." 139 | ) 140 | 141 | # Check if scores are exactly the same 142 | sparse_scores = np.array(bm25_sparse_scores) 143 | rank_scores = np.array(bm25_rank_scores) 144 | 145 | error_msg = f"\nScores between bm25-sparse and rank-bm25 are not exactly the same on dataset {dataset}." 146 | almost_equal = np.allclose(sparse_scores, rank_scores) 147 | self.assertTrue(almost_equal, error_msg) 148 | 149 | general_info = { 150 | "date": time.strftime("%Y-%m-%d %H:%M:%S"), 151 | "num_jobs": 1, 152 | "dataset": dataset, 153 | "corpus_size": len(corpus_lst), 154 | "queries_size": len(queries_lst), 155 | "corpus_subsampled": corpus_subsample is not None, 156 | "queries_subsampled": queries_subsample is not None, 157 | } 158 | # Save metrics 159 | res = { 160 | "model": "bm25s", 161 | "index_time": bm25_sparse_index_time, 162 | "score_time": bm25_sparse_score_time, 163 | } 164 | res.update(general_info) 165 | save_scores(res, artifact_dir=artifact_dir) 166 | 167 | res = { 168 | "model": "rank-bm25", 169 | "score_time": bm25_rank_score_time, 170 | "index_time": bm25_rank_index_time, 171 | } 172 | res.update(general_info) 173 | save_scores(res, artifact_dir=artifact_dir) 174 | 175 | def compare_with_bm25_pt( 176 | self, 177 | dataset, 178 | artifact_dir="tests/artifacts", 179 | rel_save_dir="datasets", 180 | corpus_subsample=None, 181 | queries_subsample=None, 182 | ): 183 | from beir.datasets.data_loader import GenericDataLoader 184 | from beir.util import download_and_unzip 185 | import bm25_pt 186 | import bm25s.hf 187 | 188 | from transformers import AutoTokenizer 189 | 190 | warnings.filterwarnings("ignore", category=ResourceWarning) 191 | warnings.filterwarnings("ignore", category=UserWarning) 192 | 193 | # Download and prepare dataset 194 | base_url = ( 195 | "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip" 196 | ) 197 | url = base_url.format(dataset) 198 | out_dir = Path(__file__).parent / rel_save_dir 199 | data_path = download_and_unzip(url, str(out_dir)) 200 | 201 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load( 202 | split="test" 203 | ) 204 | 205 | # Convert corpus and queries to lists 206 | corpus_lst = [val["title"] + " " + val["text"] for val in corpus.values()] 207 | queries_lst = list(queries.values()) 208 | 209 | if corpus_subsample is not None: 210 | corpus_lst = corpus_lst[:corpus_subsample] 211 | 212 | if queries_subsample is not None: 213 | queries_lst = queries_lst[:queries_subsample] 214 | 215 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 216 | t0 = time.monotonic() 217 | tokenized_corpus = bm25s.hf.batch_tokenize(tokenizer, corpus_lst) 218 | time_corpus_tok = time.monotonic() - t0 219 | 220 | t0 = time.monotonic() 221 | queries_tokenized = bm25s.hf.batch_tokenize(tokenizer, queries_lst) 222 | time_query_tok = time.monotonic() - t0 223 | 224 | print() 225 | print(f"Dataset: {dataset}\n") 226 | # print corpus and queries size 227 | print(f"Corpus size: {len(corpus_lst)}") 228 | print(f"Queries size: {len(queries_lst)}") 229 | print() 230 | 231 | # Initialize and index bm25-sparse 232 | bm25_sparse = bm25s.BM25(k1=1.5, b=0.75, method="atire", idf_method="lucene") 233 | start_time = time.monotonic() 234 | bm25_sparse.index(tokenized_corpus) 235 | bm25s_index_time = time.monotonic() - start_time 236 | print(f"bm25s index time: {bm25s_index_time:.4f}s") 237 | 238 | # Scoring with bm25-sparse 239 | start_time = time.monotonic() 240 | bm25_sparse_scores = [bm25_sparse.get_scores(q) for q in queries_tokenized] 241 | bm25s_score_time = time.monotonic() - start_time 242 | print(f"bm25s score time: {bm25s_score_time:.4f}s") 243 | 244 | # Initialize and index rank-bm25 245 | start_time = time.monotonic() 246 | model_pt = bm25_pt.BM25(tokenizer=tokenizer, device="cpu", k1=1.5, b=0.75) 247 | model_pt.index(corpus_lst) 248 | bm25_pt_index_time = time.monotonic() - start_time 249 | bm25_pt_index_time -= time_corpus_tok 250 | print(f"bm25-pt index time: {bm25_pt_index_time:.4f}s") 251 | 252 | # Scoring with rank-bm25 253 | start_time = time.monotonic() 254 | bm25_pt_scores = model_pt.score_batch(queries_lst) 255 | bm25_pt_scores = bm25_pt_scores.cpu().numpy() 256 | bm25_pt_score_time = time.monotonic() - start_time 257 | bm25_pt_score_time -= time_query_tok 258 | print(f"bm25-pt score time: {bm25_pt_score_time:.4f}s") 259 | 260 | # print difference in time 261 | print( 262 | f"Index Time: BM25S is {bm25_pt_index_time / bm25s_index_time:.2f}x faster than bm25-pt." 263 | ) 264 | print( 265 | f"Score Time: BM25S is {bm25_pt_score_time / bm25s_score_time:.2f}x faster than bm25-pt." 266 | ) 267 | 268 | # Check if scores are exactly the same 269 | bm25_sparse_scores = np.array(bm25_sparse_scores) 270 | bm25_pt_scores = np.array(bm25_pt_scores) 271 | 272 | error_msg = f"\nScores between bm25-sparse and rank-bm25 are not exactly the same on dataset {dataset}." 273 | almost_equal = np.allclose(bm25_sparse_scores, bm25_pt_scores, atol=1e-4) 274 | self.assertTrue(almost_equal, error_msg) 275 | 276 | general_info = { 277 | "date": time.strftime("%Y-%m-%d %H:%M:%S"), 278 | "num_jobs": 1, 279 | "dataset": dataset, 280 | "corpus_size": len(corpus_lst), 281 | "queries_size": len(queries_lst), 282 | "corpus_was_subsampled": corpus_subsample is not None, 283 | "queries_was_subsampled": queries_subsample is not None, 284 | } 285 | # Save metrics 286 | res = { 287 | "model": "bm25s", 288 | "index_time": bm25s_index_time, 289 | "score_time": bm25s_score_time, 290 | } 291 | res.update(general_info) 292 | save_scores(res, artifact_dir=artifact_dir) 293 | 294 | res = { 295 | "model": "bm25-pt", 296 | "score_time": bm25_pt_score_time, 297 | "index_time": bm25_pt_index_time, 298 | } 299 | res.update(general_info) 300 | save_scores(res, artifact_dir=artifact_dir) 301 | -------------------------------------------------------------------------------- /tests/comparison/test_bm25_pt.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .. import BM25TestCase 4 | 5 | class TestBM25PTQuick(BM25TestCase): 6 | def test_bm25_sparse_vs_rank_bm25_on_nfcorpus(self): 7 | self.compare_with_bm25_pt("nfcorpus", corpus_subsample=4000, queries_subsample=1000) 8 | 9 | def test_bm25_sparse_vs_rank_bm25_on_scifact(self): 10 | self.compare_with_bm25_pt("scifact", corpus_subsample=4000, queries_subsample=1000) 11 | 12 | def test_bm25_sparse_vs_rank_bm25_on_scidocs(self): 13 | self.compare_with_bm25_pt("scidocs", corpus_subsample=4000, queries_subsample=1000) 14 | 15 | # fiqa 16 | def test_bm25_sparse_vs_rank_bm25_on_fiqa(self): 17 | self.compare_with_bm25_pt("fiqa", corpus_subsample=4000, queries_subsample=1000) 18 | 19 | # arguana 20 | def test_bm25_sparse_vs_rank_bm25_on_arguana(self): 21 | self.compare_with_bm25_pt("arguana", corpus_subsample=4000, queries_subsample=1000) 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /tests/comparison/test_bm25s_indexing.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import unittest 5 | from pathlib import Path 6 | import warnings 7 | import logging 8 | 9 | import numpy as np 10 | from beir.datasets.data_loader import GenericDataLoader 11 | from beir.util import download_and_unzip 12 | import Stemmer 13 | 14 | import bm25s 15 | 16 | def check_scores_all_close(score1, score2, **kwargs): 17 | for key in score1.keys(): 18 | matrix1 = score1[key] 19 | matrix2 = score2[key] 20 | 21 | if matrix1.shape != matrix2.shape: 22 | return False 23 | if not np.allclose(matrix1, matrix2, **kwargs): 24 | return False 25 | 26 | return True 27 | 28 | class BM25SIndexing(unittest.TestCase): 29 | def test_indexing_by_corpus_type(self): 30 | warnings.filterwarnings("ignore", category=ResourceWarning) 31 | class Tokenized: 32 | def __init__(self, ids, vocab): 33 | self.ids = ids 34 | self.vocab = vocab 35 | 36 | dataset = "scifact" 37 | rel_save_dir = "datasets" 38 | # Download and prepare dataset 39 | base_url = ( 40 | "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip" 41 | ) 42 | url = base_url.format(dataset) 43 | out_dir = Path(__file__).parent / rel_save_dir 44 | data_path = download_and_unzip(url, str(out_dir)) 45 | 46 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load( 47 | split="test" 48 | ) 49 | 50 | corpus_ids, corpus_lst = [], [] 51 | for key, val in corpus.items(): 52 | corpus_ids.append(key) 53 | corpus_lst.append(val["title"] + " " + val["text"]) 54 | 55 | stemmer = Stemmer.Stemmer("english") 56 | corpus_tokens_lst = bm25s.tokenize( 57 | corpus_lst, 58 | stopwords="en", 59 | stemmer=stemmer, 60 | leave=False, 61 | return_ids=False, 62 | ) 63 | 64 | corpus_tokenized = bm25s.tokenize( 65 | corpus_lst, 66 | stopwords="en", 67 | stemmer=stemmer, 68 | leave=False, 69 | return_ids=True, 70 | ) 71 | 72 | bm25_tokens = bm25s.BM25(k1=0.9, b=0.4) 73 | bm25_tokens.index(corpus_tokens_lst) 74 | 75 | bm25_tuples = bm25s.BM25(k1=0.9, b=0.4) 76 | bm25_tuples.index((corpus_tokenized.ids, corpus_tokenized.vocab)) 77 | 78 | bm25_objects = bm25s.BM25(k1=0.9, b=0.4) 79 | bm25_objects.index( 80 | Tokenized(ids=corpus_tokenized.ids, vocab=corpus_tokenized.vocab) 81 | ) 82 | 83 | bm25_namedtuple = bm25s.BM25(k1=0.9, b=0.4) 84 | named_tuple = bm25s.tokenization.Tokenized( 85 | ids=corpus_tokenized.ids, vocab=corpus_tokenized.vocab 86 | ) 87 | bm25_namedtuple.index(named_tuple) 88 | 89 | # now, verify that the sparse matrix matches 90 | self.assertTrue( 91 | check_scores_all_close( 92 | bm25_tokens.scores, bm25_tuples.scores 93 | ), 94 | "Tokenized and Tuple indexing do not match", 95 | ) 96 | self.assertTrue( 97 | check_scores_all_close( 98 | bm25_tokens.scores, bm25_objects.scores 99 | ), 100 | "Tokenized and Object indexing do not match", 101 | ) 102 | self.assertTrue( 103 | check_scores_all_close( 104 | bm25_tokens.scores, bm25_namedtuple.scores 105 | ), 106 | "Tokenized and NamedTuple indexing do not match", 107 | ) 108 | -------------------------------------------------------------------------------- /tests/comparison/test_jsonl_corpus.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import warnings 3 | import unittest 4 | 5 | import numpy as np 6 | import Stemmer 7 | import beir.util 8 | from beir.datasets.data_loader import GenericDataLoader 9 | 10 | import bm25s 11 | from bm25s.utils.beir import BASE_URL 12 | 13 | 14 | class TestBM25PTQuick(unittest.TestCase): 15 | def test_bm25_sparse_vs_rank_bm25_on_nfcorpus(self): 16 | data_dir="datasets" 17 | index_dir = "bm25s_indices" 18 | dataset="scifact" 19 | split = "test" 20 | 21 | index_path = Path(index_dir) / dataset 22 | data_path = beir.util.download_and_unzip(BASE_URL.format(dataset), data_dir) 23 | 24 | warnings.filterwarnings("ignore", category=ResourceWarning) 25 | warnings.filterwarnings("ignore", category=UserWarning) 26 | 27 | # Download and prepare dataset 28 | corpus, queries, _ = GenericDataLoader(data_folder=data_path).load(split=split) 29 | 30 | corpus_ids, corpus_lst = [], [] 31 | for key, val in corpus.items(): 32 | corpus_ids.append(key) 33 | corpus_lst.append(val["title"] + " " + val["text"]) 34 | query_lst = list(queries.values()) 35 | 36 | stemmer = Stemmer.Stemmer("english") 37 | model = bm25s.BM25(corpus=corpus) 38 | corpus_tokens = bm25s.tokenize(corpus_lst, stemmer=stemmer) 39 | model.index(corpus_tokens, show_progress=False) 40 | 41 | # Save the model 42 | model.save(index_path) 43 | 44 | # Load the model 45 | q_tokens = bm25s.tokenize(query_lst, stemmer=stemmer) 46 | model1 = bm25s.BM25.load(index_path, mmap=False, load_corpus=True) 47 | model2 = bm25s.BM25.load(index_path, mmap=True, load_corpus=True) 48 | 49 | res1 = model1.retrieve(q_tokens, show_progress=False) 50 | res2 = model2.retrieve(q_tokens, show_progress=False) 51 | 52 | # make sure the results are the same 53 | self.assertTrue(np.all(res1.scores == res2.scores)) 54 | self.assertTrue(np.all(res1.documents == res2.documents)) 55 | -------------------------------------------------------------------------------- /tests/comparison/test_rank_bm25.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .. import BM25TestCase 4 | 5 | class TestRankBM25Quick(BM25TestCase): 6 | def test_bm25_sparse_vs_rank_bm25_on_nfcorpus(self): 7 | self.compare_with_rank_bm25("nfcorpus", corpus_subsample=2000, queries_subsample=200) 8 | 9 | def test_bm25_sparse_vs_rank_bm25_on_scifact(self): 10 | self.compare_with_rank_bm25("scifact", corpus_subsample=2000, queries_subsample=200) 11 | 12 | def test_bm25_sparse_vs_rank_bm25_on_scidocs(self): 13 | self.compare_with_rank_bm25("scidocs", corpus_subsample=2000, queries_subsample=200) 14 | 15 | # fiqa 16 | def test_bm25_sparse_vs_rank_bm25_on_fiqa(self): 17 | self.compare_with_rank_bm25("fiqa", corpus_subsample=2000, queries_subsample=200) 18 | 19 | # arguana 20 | def test_bm25_sparse_vs_rank_bm25_on_arguana(self): 21 | self.compare_with_rank_bm25("arguana", queries_subsample=100, corpus_subsample=1000) 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /tests/comparison/test_rank_bm25l.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .. import BM25TestCase 4 | 5 | 6 | class TestRankBM25Quick(BM25TestCase): 7 | def test_bm25_sparse_vs_rank_bm25_on_nfcorpus(self): 8 | self.compare_with_rank_bm25( 9 | "nfcorpus", corpus_subsample=2000, queries_subsample=200, method="bm25l" 10 | ) 11 | 12 | def test_bm25_sparse_vs_rank_bm25_on_scifact(self): 13 | self.compare_with_rank_bm25( 14 | "scifact", corpus_subsample=2000, queries_subsample=200, method="bm25l" 15 | ) 16 | 17 | def test_bm25_sparse_vs_rank_bm25_on_scidocs(self): 18 | self.compare_with_rank_bm25( 19 | "scidocs", corpus_subsample=2000, queries_subsample=200, method="bm25l" 20 | ) 21 | 22 | # fiqa 23 | def test_bm25_sparse_vs_rank_bm25_on_fiqa(self): 24 | self.compare_with_rank_bm25( 25 | "fiqa", corpus_subsample=2000, queries_subsample=200, method="bm25l" 26 | ) 27 | 28 | # arguana 29 | def test_bm25_sparse_vs_rank_bm25_on_arguana(self): 30 | self.compare_with_rank_bm25( 31 | "arguana", queries_subsample=100, corpus_subsample=1000, method="bm25l" 32 | ) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /tests/comparison/test_rank_bm25plus.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .. import BM25TestCase 4 | 5 | 6 | class TestRankBM25Quick(BM25TestCase): 7 | def test_bm25_sparse_vs_rank_bm25_on_nfcorpus(self): 8 | self.compare_with_rank_bm25( 9 | "nfcorpus", corpus_subsample=2000, queries_subsample=200, method="bm25+" 10 | ) 11 | 12 | def test_bm25_sparse_vs_rank_bm25_on_scifact(self): 13 | self.compare_with_rank_bm25( 14 | "scifact", corpus_subsample=2000, queries_subsample=200, method="bm25+" 15 | ) 16 | 17 | def test_bm25_sparse_vs_rank_bm25_on_scidocs(self): 18 | self.compare_with_rank_bm25( 19 | "scidocs", corpus_subsample=2000, queries_subsample=200, method="bm25+" 20 | ) 21 | 22 | # fiqa 23 | def test_bm25_sparse_vs_rank_bm25_on_fiqa(self): 24 | self.compare_with_rank_bm25( 25 | "fiqa", corpus_subsample=2000, queries_subsample=200, method="bm25+" 26 | ) 27 | 28 | # arguana 29 | def test_bm25_sparse_vs_rank_bm25_on_arguana(self): 30 | self.compare_with_rank_bm25( 31 | "arguana", queries_subsample=100, corpus_subsample=1000, method="bm25+" 32 | ) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /tests/comparison/test_utils_corpus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import unittest 4 | 5 | from tqdm import tqdm 6 | import beir.util 7 | 8 | from bm25s.utils.corpus import JsonlCorpus 9 | from bm25s.utils.beir import BASE_URL 10 | 11 | class TestTopKSingleQuery(unittest.TestCase): 12 | def test_utils_corpus(self): 13 | save_dir = "datasets" 14 | dataset = "scifact" 15 | data_path = beir.util.download_and_unzip(BASE_URL.format(dataset), save_dir) 16 | 17 | corpus_path = f"{data_path}/corpus.jsonl" 18 | 19 | nq = JsonlCorpus(corpus_path) 20 | 21 | # get all ids 22 | 23 | corpus_ids = [doc["_id"] for doc in tqdm(nq)] 24 | 25 | # alternatively, try opening the file and read the _ids as we go 26 | corpus_ids_2 = [] 27 | with open(corpus_path, "r") as f: 28 | for line in f: 29 | doc = json.loads(line) 30 | corpus_ids_2.append(doc["_id"]) 31 | 32 | self.assertListEqual(corpus_ids, corpus_ids_2) 33 | 34 | # check if jsonl corpus can be closed 35 | assert nq.file_obj is not None, "JsonlCorpus file_obj is None, expected file object" 36 | assert nq.mmap_obj is not None, "JsonlCorpus mmap_obj is None, expected mmap object" 37 | 38 | # now, we can close 39 | nq.close() 40 | 41 | assert nq.file_obj is None, "JsonlCorpus file_obj is not None, expected None" 42 | assert nq.mmap_obj is None, "JsonlCorpus mmap_obj is not None, expected None" 43 | 44 | # check if jsonl corpus can be loaded 45 | nq.load() 46 | 47 | assert nq.file_obj is not None, "JsonlCorpus file_obj is None, expected file object" 48 | assert nq.mmap_obj is not None, "JsonlCorpus mmap_obj is None, expected mmap object" 49 | 50 | corpus_ids = [doc["_id"] for doc in tqdm(nq)] 51 | self.assertListEqual(corpus_ids, corpus_ids_2) 52 | 53 | -------------------------------------------------------------------------------- /tests/comparison_full/test_bm25_pt.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .. import BM25TestCase 4 | 5 | class TestBM25PTQuick(BM25TestCase): 6 | def test_bm25_sparse_vs_rank_bm25_on_nfcorpus(self): 7 | self.compare_with_bm25_pt("nfcorpus") 8 | 9 | def test_bm25_sparse_vs_rank_bm25_on_scifact(self): 10 | self.compare_with_bm25_pt("scifact") 11 | 12 | def test_bm25_sparse_vs_rank_bm25_on_scidocs(self): 13 | self.compare_with_bm25_pt("scidocs") 14 | 15 | # fiqa 16 | def test_bm25_sparse_vs_rank_bm25_on_fiqa(self): 17 | self.compare_with_bm25_pt("fiqa") 18 | 19 | # arguana 20 | def test_bm25_sparse_vs_rank_bm25_on_arguana(self): 21 | self.compare_with_bm25_pt("arguana") 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /tests/comparison_full/test_rank_bm25.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .. import BM25TestCase 4 | 5 | class TestRankBM25Full(BM25TestCase): 6 | def test_bm25_sparse_vs_rank_bm25_on_nfcorpus(self): 7 | self.compare_with_rank_bm25("nfcorpus") 8 | 9 | def test_bm25_sparse_vs_rank_bm25_on_scifact(self): 10 | self.compare_with_rank_bm25("scifact") 11 | 12 | def test_bm25_sparse_vs_rank_bm25_on_scidocs(self): 13 | self.compare_with_rank_bm25("scidocs") 14 | 15 | # fiqa 16 | def test_bm25_sparse_vs_rank_bm25_on_fiqa(self): 17 | self.compare_with_rank_bm25("fiqa") 18 | 19 | # arguana 20 | def test_bm25_sparse_vs_rank_bm25_on_arguana(self): 21 | self.compare_with_rank_bm25("arguana") 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /tests/core/test_allow_empty.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import bm25s 4 | 5 | 6 | class TestBM25SAllowEmpty(unittest.TestCase): 7 | def test_simple(self): 8 | all_scores = [] 9 | 10 | for allow_empty in [(True, True), (True, False), (False, True), (False, False)]: 11 | corpus = ['foo', 'dog', 'baz', 'quick god', 'quick fox'] 12 | query = 'quick' 13 | tokenizer = bm25s.tokenization.Tokenizer(stopwords=["english",],) 14 | corpus_tokens = tokenizer.tokenize(corpus, show_progress=False, allow_empty=allow_empty[0], return_as="ids") 15 | retriever = bm25s.BM25(backend="numpy") 16 | retriever.index(corpus_tokens, show_progress=False) 17 | query_tokens = tokenizer.tokenize([query], show_progress=False, allow_empty=allow_empty[1], return_as="ids") 18 | 19 | results, scores = retriever.retrieve(query_tokens, k=len(corpus), show_progress=False, n_threads=1, sorted=True) 20 | all_scores.append(scores) 21 | 22 | # Check that the scores are same for both allow_empty=True and allow_empty=False 23 | # self.assertTrue(np.array_equal(all_scores[0], all_scores[1]), "Scores should be the same for allow_empty=True and allow_empty=False") 24 | # assert all equals 25 | for s in all_scores[1:]: 26 | self.assertTrue(np.array_equal(all_scores[0], s), f"Scores should be the same for allow_empty={allow_empty[0]} and allow_empty={s}") 27 | 28 | 29 | if __name__ == "__main__": 30 | unittest.main() -------------------------------------------------------------------------------- /tests/core/test_retrieve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | import unittest 5 | import tempfile 6 | 7 | import numpy as np 8 | import bm25s 9 | import Stemmer # optional: for stemming 10 | 11 | class TestBM25SLoadingSaving(unittest.TestCase): 12 | @classmethod 13 | def setUpClass(cls): 14 | 15 | # Create your corpus here 16 | corpus = [ 17 | "a cat is a feline and likes to purr", 18 | "a dog is the human's best friend and loves to play", 19 | "a bird is a beautiful animal that can fly", 20 | "a fish is a creature that lives in water and swims", 21 | ] 22 | 23 | # optional: create a stemmer 24 | stemmer = Stemmer.Stemmer("english") 25 | 26 | # Tokenize the corpus and only keep the ids (faster and saves memory) 27 | corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer) 28 | 29 | # Create the BM25 model and index the corpus 30 | retriever = bm25s.BM25(method='bm25+') 31 | retriever.index(corpus_tokens) 32 | 33 | # Save the retriever to temp dir 34 | cls.retriever = retriever 35 | cls.corpus = corpus 36 | cls.corpus_tokens = corpus_tokens 37 | cls.stemmer = stemmer 38 | 39 | def test_retrieve(self): 40 | ground_truth = np.array([[0, 2]]) 41 | 42 | # first, try with default mode 43 | query = "a cat is a feline, it's sometimes beautiful but cannot fly" 44 | query_tokens_obj = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=True) 45 | 46 | # retrieve the top 2 documents 47 | results = self.retriever.retrieve(query_tokens_obj, k=2).documents 48 | 49 | # assert that the retrieved indices are correct 50 | self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") 51 | 52 | # now, try tokenizing with text tokens 53 | query_tokens_texts = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=False) 54 | results = self.retriever.retrieve(query_tokens_texts, k=2).documents 55 | self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") 56 | 57 | # now, try to pass a tuple of tokens 58 | ids, vocab = query_tokens_obj 59 | query_tokens_tuple = (ids, vocab) 60 | results = self.retriever.retrieve(query_tokens_tuple, k=2).documents 61 | self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") 62 | 63 | # finally, try to pass a 2-tuple of tokens with text tokens to "try to trick the system" 64 | queries_as_tuple = (query_tokens_texts[0], query_tokens_texts[0]) 65 | # only retrieve 1 document 66 | ground_truth = np.array([[0], [0]]) 67 | results = self.retriever.retrieve(queries_as_tuple, k=1).documents 68 | self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") 69 | 70 | def test_retrieve_with_different_return_types(self): 71 | queries = [ 72 | "a cat is a feline, it's sometimes beautiful but cannot fly", 73 | "a dog is the human's best friend and loves to play" 74 | ] 75 | for method in ['bm25+', 'lucene', 'bm25l', 'atire', 'robertson']: 76 | all_docs = [] 77 | all_scores = [] 78 | for return_type in ['ids', 'tuple', 'string']: 79 | tokenizer = bm25s.tokenization.Tokenizer(lower=True,stopwords="en", stemmer=self.stemmer) 80 | corpus_tokens = tokenizer.tokenize(self.corpus, return_as=return_type, show_progress=False, allow_empty=True) 81 | query_tokens = tokenizer.tokenize(queries, return_as=return_type, show_progress=False, allow_empty=True) 82 | # Create the BM25 model and index the corpus 83 | retriever = bm25s.BM25(method=method) 84 | retriever.index(corpus_tokens) 85 | 86 | docs, scores = retriever.retrieve(query_tokens, k=2, sorted=False) 87 | all_docs.append(docs) 88 | all_scores.append(scores) 89 | 90 | # Check if the results are the same for both return types 91 | for doc in all_docs[1:]: 92 | self.assertTrue(np.array_equal(all_docs[0], doc), f"Expected {all_docs[0]}, got {doc}") 93 | # Check if the scores are the same for both return types 94 | for score in all_scores[1:]: 95 | self.assertTrue(np.array_equal(all_scores[0], score), f"Expected {all_scores[0]}, got {score}") 96 | 97 | 98 | def test_retrieve_with_weight_mask(self): 99 | 100 | 101 | # first, try with default mode 102 | query = "cat feline dog bird fish" # weights should be [2, 1, 1, 1], but after masking should be [2, 0, 0, 1] 103 | 104 | for dt in [np.float32, np.int32, np.bool_]: 105 | weight_mask = np.array([1, 0, 0, 1], dtype=dt) 106 | ground_truth = np.array([[0, 3]]) 107 | 108 | query_tokens_obj = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=True) 109 | 110 | # retrieve the top 2 documents 111 | results = self.retriever.retrieve(query_tokens_obj, k=2, weight_mask=weight_mask).documents 112 | 113 | # assert that the retrieved indices are correct 114 | self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") 115 | 116 | # now, try tokenizing with text tokens 117 | query_tokens_texts = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=False) 118 | results = self.retriever.retrieve(query_tokens_texts, k=2, weight_mask=weight_mask).documents 119 | self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") 120 | 121 | # now, try to pass a tuple of tokens 122 | ids, vocab = query_tokens_obj 123 | query_tokens_tuple = (ids, vocab) 124 | results = self.retriever.retrieve(query_tokens_tuple, k=2, weight_mask=weight_mask).documents 125 | self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") 126 | 127 | # finally, try to pass a 2-tuple of tokens with text tokens to "try to trick the system" 128 | queries_as_tuple = (query_tokens_texts[0], query_tokens_texts[0]) 129 | # only retrieve 1 document 130 | ground_truth = np.array([[0], [0]]) 131 | results = self.retriever.retrieve(queries_as_tuple, k=1, weight_mask=weight_mask).documents 132 | self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") 133 | 134 | 135 | def test_failure_of_bad_tuple(self): 136 | # try to pass a tuple of tokens with different lengths 137 | query = "a cat is a feline, it's sometimes beautiful but cannot fly" 138 | query_tokens_obj = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=True) 139 | query_tokens_texts = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=False) 140 | ids, vocab = query_tokens_obj 141 | query_tokens_tuple = (vocab, ids) 142 | 143 | with self.assertRaises(ValueError): 144 | self.retriever.retrieve(query_tokens_tuple, k=2) 145 | 146 | # now, test if there's vocab twice or ids twice 147 | query_tokens_tuple = (ids, ids) 148 | with self.assertRaises(ValueError): 149 | self.retriever.retrieve(query_tokens_tuple, k=2) 150 | 151 | # finally, test only passing vocab 152 | query_tokens_tuple = (vocab, ) 153 | with self.assertRaises(ValueError): 154 | self.retriever.retrieve(query_tokens_tuple, k=2) 155 | 156 | def test_value_error_for_very_small_corpus(self): 157 | query = "a cat is a feline, it's sometimes beautiful but cannot fly" 158 | query_tokens = bm25s.tokenize( 159 | [query], stopwords="en", 160 | stemmer=self.stemmer, return_ids=True 161 | ) 162 | corpus_size = len(self.corpus) 163 | for k in range(0, 10): 164 | if k > corpus_size: 165 | with self.assertRaises(ValueError) as context: 166 | self.retriever.retrieve(query_tokens, k=k) 167 | exception_str_should_include =\ 168 | "Please set with a smaller k or increase the size of corpus." 169 | self.assertIn( 170 | exception_str_should_include, 171 | str(context.exception), 172 | f"[k={k}] Expected ValueError mentioning (but did not)" 173 | f"; {exception_str_should_include}" 174 | ) 175 | else: 176 | results, scores = self.retriever.retrieve(query_tokens, k=k) 177 | self.assertEqual( 178 | int(results.size), k, 179 | f"[k={k}] The number of searched items" 180 | f" should be {k}; but it was {results.size}" 181 | ) 182 | self.assertEqual( 183 | int(scores.size), k, 184 | f"[k={k}] The number of searched items" 185 | f" should be {k}; but it was {scores.size}" 186 | ) 187 | 188 | @classmethod 189 | def tearDownClass(cls): 190 | pass -------------------------------------------------------------------------------- /tests/core/test_save_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | import unittest 5 | import tempfile 6 | import Stemmer # optional: for stemming 7 | import unittest.mock 8 | import json 9 | 10 | import bm25s 11 | from bm25s.utils import json_functions 12 | 13 | class TestBM25SLoadingSaving(unittest.TestCase): 14 | orjson_should_not_be_installed = False 15 | orjson_should_be_installed = True 16 | 17 | @classmethod 18 | def setUpClass(cls): 19 | # check that import orjson fails 20 | import bm25s 21 | 22 | # Create your corpus here 23 | corpus = [ 24 | "a cat is a feline and likes to purr", 25 | "a dog is the human's best friend and loves to play", 26 | "a bird is a beautiful animal that can fly", 27 | "a fish is a creature that lives in water and swims", 28 | "שלום חברים, איך אתם היום?", 29 | "El café está muy caliente", 30 | "今天的天气真好!", 31 | "Как дела?", 32 | "Türkçe öğreniyorum." 33 | ] 34 | 35 | # optional: create a stemmer 36 | stemmer = Stemmer.Stemmer("english") 37 | 38 | # Tokenize the corpus and only keep the ids (faster and saves memory) 39 | corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer) 40 | 41 | # Create the BM25 model and index the corpus 42 | retriever = bm25s.BM25(method='bm25+') 43 | retriever.index(corpus_tokens) 44 | 45 | # Save the retriever to temp dir 46 | cls.retriever = retriever 47 | cls.corpus = corpus 48 | cls.corpus_tokens = corpus_tokens 49 | cls.stemmer = stemmer 50 | cls.tmpdirname = tempfile.mkdtemp() 51 | 52 | def setUp(self): 53 | # verify that orjson is properly installed 54 | try: 55 | import orjson 56 | except ImportError: 57 | self.fail("orjson should be installed to run this test.") 58 | 59 | def test_a_save(self): 60 | # save the retriever to temp dir 61 | self.retriever.save( 62 | self.tmpdirname, 63 | data_name="data.index.csc.npy", 64 | indices_name="indices.index.csc.npy", 65 | indptr_name="indptr.index.csc.npy", 66 | vocab_name="vocab.json", 67 | nnoc_name="nonoccurrence_array.npy", 68 | params_name="params.json", 69 | ) 70 | 71 | # assert that the following files are saved 72 | fnames = [ 73 | "data.index.csc.npy", 74 | "indices.index.csc.npy", 75 | "indptr.index.csc.npy", 76 | "vocab.json", 77 | "nonoccurrence_array.npy", 78 | "params.json", 79 | ] 80 | 81 | for fname in fnames: 82 | error_msg = f"File {fname} not found in even though it should be saved by the .save() method" 83 | path_exists = os.path.exists(os.path.join(self.tmpdirname, fname)) 84 | self.assertTrue(path_exists, error_msg) 85 | 86 | def test_b_load(self): 87 | # load the retriever from temp dir 88 | r1 = self.retriever 89 | r2 = bm25s.BM25.load( 90 | self.tmpdirname, 91 | data_name="data.index.csc.npy", 92 | indices_name="indices.index.csc.npy", 93 | indptr_name="indptr.index.csc.npy", 94 | vocab_name="vocab.json", 95 | nnoc_name="nonoccurrence_array.npy", 96 | params_name="params.json", 97 | ) 98 | 99 | # for each of data, indices, indptr, vocab, nnoc, params 100 | # assert that the loaded object is the same as the original object 101 | # data, indices, indptr are stored in self.scores 102 | self.assertTrue((r1.scores['data'] == r2.scores['data']).all()) 103 | self.assertTrue((r1.scores['indices'] == r2.scores['indices']).all()) 104 | self.assertTrue((r1.scores['indptr'] == r2.scores['indptr']).all()) 105 | 106 | # vocab is stored in self.vocab 107 | self.assertEqual(r1.vocab_dict, r2.vocab_dict) 108 | 109 | # nnoc is stored in self.nnoc 110 | self.assertTrue((r1.nonoccurrence_array == r2.nonoccurrence_array).all()) 111 | 112 | @unittest.mock.patch("bm25s.utils.json_functions.dumps", json_functions.dumps_with_builtin) 113 | @unittest.mock.patch("bm25s.utils.json_functions.loads", json.loads) 114 | def test_c_save_no_orjson(self): 115 | self.assertEqual(json_functions.dumps_with_builtin, json_functions.dumps) 116 | self.assertEqual(json_functions.loads, json.loads) 117 | self.test_a_save() 118 | 119 | @unittest.mock.patch("bm25s.utils.json_functions.dumps", json_functions.dumps_with_builtin) 120 | @unittest.mock.patch("bm25s.utils.json_functions.loads", json.loads) 121 | def test_d_load_no_orjson(self): 122 | self.assertEqual(json_functions.dumps_with_builtin, json_functions.dumps) 123 | self.assertEqual(json_functions.loads, json.loads) 124 | self.test_b_load() 125 | 126 | @classmethod 127 | def tearDownClass(cls): 128 | # remove the temp dir with rmtree 129 | shutil.rmtree(cls.tmpdirname) 130 | 131 | 132 | class TestBM25SNonASCIILoadingSaving(unittest.TestCase): 133 | orjson_should_not_be_installed = False 134 | orjson_should_be_installed = True 135 | 136 | @classmethod 137 | def setUpClass(cls): 138 | # check that import orjson fails 139 | import bm25s 140 | cls.corpus = [ 141 | "a cat is a feline and likes to purr", 142 | "a dog is the human's best friend and loves to play", 143 | "a bird is a beautiful animal that can fly", 144 | "a fish is a creature that lives in water and swims", 145 | "שלום חברים, איך אתם היום?", 146 | "El café está muy caliente", 147 | "今天的天气真好!", 148 | "Как дела?", 149 | "Türkçe öğreniyorum.", 150 | 'שלום חברים' 151 | ] 152 | corpus_tokens = bm25s.tokenize(cls.corpus, stopwords="en") 153 | cls.retriever = bm25s.BM25(corpus=cls.corpus) 154 | cls.retriever.index(corpus_tokens) 155 | cls.tmpdirname = tempfile.mkdtemp() 156 | 157 | 158 | def setUp(self): 159 | # verify that orjson is properly installed 160 | try: 161 | import orjson 162 | except ImportError: 163 | self.fail("orjson should be installed to run this test.") 164 | 165 | def test_a_save_and_load(self): 166 | # both of these fail: UnicodeEncodeError: 'charmap' codec can't encode characters in position 2-6: character maps to 167 | self.retriever.save(self.tmpdirname, corpus=self.corpus) 168 | self.retriever.load(self.tmpdirname, load_corpus=True) 169 | 170 | @classmethod 171 | def tearDownClass(cls): 172 | # remove the temp dir with rmtree 173 | shutil.rmtree(cls.tmpdirname) 174 | 175 | 176 | class TestSaveAndReloadWithTokenizer(unittest.TestCase): 177 | def setUp(self): 178 | self.tmpdirname = tempfile.mkdtemp() 179 | 180 | def tearDown(self): 181 | shutil.rmtree(self.tmpdirname) 182 | 183 | def test_save_and_reload_with_tokenizer(self): 184 | import bm25s 185 | from bm25s.tokenization import Tokenizer 186 | 187 | corpus = [ 188 | "Welcome to bm25s, a library that implements BM25 in Python, allowing you to rank documents based on a query.", 189 | "BM25 is a widely used ranking function used for text retrieval tasks, and is a core component of search services like Elasticsearch.", 190 | "It is designed to be:", 191 | "Fast: bm25s is implemented in pure Python and leverage Scipy sparse matrices to store eagerly computed scores for all document tokens.", 192 | "This allows extremely fast scoring at query time, improving performance over popular libraries by orders of magnitude (see benchmarks below).", 193 | "Simple: bm25s is designed to be easy to use and understand.", 194 | "You can install it with pip and start using it in minutes.", 195 | "There is no dependencies on Java or Pytorch - all you need is Scipy and Numpy, and optional lightweight dependencies for stemming.", 196 | "Below, we compare bm25s with Elasticsearch in terms of speedup over rank-bm25, the most popular Python implementation of BM25.", 197 | "We measure the throughput in queries per second (QPS) on a few popular datasets from BEIR in a single-threaded setting.", 198 | "bm25s aims to offer a faster alternative for Python users who need efficient text retrieval.", 199 | "It leverages modern Python libraries and data structures for performance optimization.", 200 | "You can find more details in the documentation and example notebooks provided.", 201 | "Installation and usage guidelines are simple and accessible for developers of all skill levels.", 202 | "Try bm25s for a scalable and fast text ranking solution in your Python projects." 203 | ] 204 | 205 | # print(f"We have {len(corpus)} documents in the corpus.") 206 | 207 | tokenizer = Tokenizer(stemmer=None, stopwords=None, splitter=lambda x: x.split()) 208 | corpus_tokens = tokenizer.tokenize(corpus, return_as='tuple') 209 | 210 | retriever = bm25s.BM25(corpus=corpus) 211 | retriever.index(corpus_tokens) 212 | 213 | index_path = os.path.join(self.tmpdirname, "bm25s_index_readme") 214 | 215 | retriever.save(index_path) 216 | tokenizer.save_vocab(save_dir=index_path) 217 | 218 | reloaded_retriever = bm25s.BM25.load(index_path, load_corpus=True) 219 | reloaded_tokenizer = Tokenizer(stemmer=None, stopwords=None, splitter=lambda x: x.split()) 220 | reloaded_tokenizer.load_vocab(index_path) 221 | 222 | queries = ["widely used text ranking function"] 223 | 224 | query_tokens = reloaded_tokenizer.tokenize(queries, update_vocab=False) 225 | results, scores = reloaded_retriever.retrieve(query_tokens, k=2) 226 | 227 | doc = results[0,0] 228 | score = scores[0,0] 229 | 230 | assert doc['id'] == 1 231 | assert score > 3 and score < 4 -------------------------------------------------------------------------------- /tests/core/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | from typing import Generator 5 | import unittest 6 | import Stemmer 7 | import re 8 | 9 | import numpy as np 10 | 11 | from bm25s.tokenization import Tokenizer 12 | 13 | class TestTokenizer(unittest.TestCase): 14 | @classmethod 15 | def setUpClass(cls): 16 | # Define a sample corpus 17 | cls.corpus = [ 18 | "This is a test sentence.", 19 | "Another sentence for testing.", 20 | "Machine learning is fun!", 21 | "The quick brown fox jumps over the lazy dog.", 22 | ] 23 | 24 | cls.corpus_large = [] 25 | 26 | # load tests/data/nfcorpus.txt 27 | with open("tests/data/nfcorpus.txt", "r") as file: 28 | for line in file: 29 | cls.corpus_large.append(line.strip()) 30 | 31 | # Initialize a stemmer 32 | cls.stemmer = Stemmer.Stemmer("english") 33 | 34 | # temp dir 35 | cls.tmpdir = tempfile.mkdtemp() 36 | 37 | def setUp(self): 38 | # Initialize the Tokenizer with default settings 39 | self.tokenizer = Tokenizer(stemmer=self.stemmer) 40 | 41 | def test_tokenize_with_default_settings(self): 42 | """Tests the `tokenize` method with default settings.""" 43 | result = self.tokenizer.tokenize(self.corpus, update_vocab=True, return_as="ids", show_progress=True) 44 | self.assertIsInstance(result, list) 45 | for doc in result: 46 | self.assertIsInstance(doc, list) 47 | for token_id in doc: 48 | self.assertIsInstance(token_id, int) 49 | 50 | def test_tokenize_with_custom_splitter(self): 51 | """Tests the `tokenize` method and `__init__` method with a custom splitter.""" 52 | custom_splitter = lambda text: re.findall(r"\w+", text) 53 | tokenizer = Tokenizer(splitter=custom_splitter, stemmer=self.stemmer) 54 | result = tokenizer.tokenize(self.corpus, update_vocab=True, return_as="ids", show_progress=False) 55 | self.assertIsInstance(result, list) 56 | 57 | def test_tokenize_with_stopwords(self): 58 | """Tests the `tokenize` method and `__init__` method with stopwords filtering.""" 59 | stopwords = ["is", "a", "for"] 60 | tokenizer = Tokenizer(stopwords=stopwords, stemmer=self.stemmer) 61 | result = tokenizer.tokenize(self.corpus, update_vocab=True, return_as="string", show_progress=False) 62 | for doc in result: 63 | for token in doc: 64 | self.assertNotIn(token, stopwords) 65 | 66 | def test_tokenize_with_never_update_vocab(self): 67 | """Tests the `tokenize` method with the `update_vocab="never"` parameter.""" 68 | tokenizer = Tokenizer(stemmer=self.stemmer) 69 | tokenizer.tokenize(self.corpus, update_vocab="never", show_progress=False) 70 | vocab_size = len(tokenizer.get_vocab_dict()) 71 | self.assertEqual(vocab_size, 0) 72 | 73 | def test_invalid_splitter(self): 74 | """Tests the `__init__` method for handling an invalid `splitter` input.""" 75 | with self.assertRaises(ValueError): 76 | Tokenizer(splitter=123) # type: ignore 77 | 78 | def test_invalid_stemmer(self): 79 | """Tests the `__init__` method for handling an invalid `stemmer` input.""" 80 | with self.assertRaises(ValueError): 81 | Tokenizer(stemmer="not_callable") # type: ignore 82 | 83 | def test_tokenize_with_empty_vocab(self): 84 | """Tests the `tokenize` method with the `update_vocab="if_empty"` parameter.""" 85 | tokenizer = Tokenizer(stemmer=self.stemmer) 86 | tokenizer.tokenize(self.corpus, update_vocab="if_empty", show_progress=False) 87 | vocab_size = len(tokenizer.get_vocab_dict()) 88 | self.assertGreater(vocab_size, 0) 89 | 90 | def test_streaming_tokenize(self): 91 | """Tests the `streaming_tokenize` method directly for its functionality.""" 92 | stream = self.tokenizer.streaming_tokenize(self.corpus) 93 | assert isinstance(stream, Generator) 94 | for doc_ids in stream: 95 | self.assertIsInstance(doc_ids, list) 96 | for token_id in doc_ids: 97 | self.assertIsInstance(token_id, int) 98 | 99 | def test_get_vocab_dict(self): 100 | """Tests the `get_vocab_dict` method to ensure it returns the correct vocabulary dictionary.""" 101 | self.tokenizer.tokenize(self.corpus, update_vocab=True, show_progress=False) 102 | vocab = self.tokenizer.get_vocab_dict() 103 | self.assertIsInstance(vocab, dict) 104 | self.assertGreater(len(vocab), 0) 105 | 106 | def test_tokenize_return_types(self): 107 | """Tests the `tokenize` method with different `return_as` parameter values (`ids`, `string`, `tuple`).""" 108 | result_ids = self.tokenizer.tokenize(self.corpus, return_as="ids", show_progress=False) 109 | result_strings = self.tokenizer.tokenize(self.corpus, return_as="string", show_progress=False) 110 | result_tuple = self.tokenizer.tokenize(self.corpus, return_as="tuple", show_progress=False) 111 | 112 | self.assertIsInstance(result_ids, list) 113 | self.assertIsInstance(result_strings, list) 114 | self.assertIsInstance(result_tuple, tuple) 115 | 116 | def test_tokenize_with_invalid_return_type(self): 117 | """Tests the `tokenize` method for handling an invalid `return_as` parameter value.""" 118 | with self.assertRaises(ValueError): 119 | self.tokenizer.tokenize(self.corpus, return_as="invalid_type") 120 | 121 | def test_reset_vocab(self): 122 | """Tests the `reset_vocab` method to ensure it properly clears all vocabulary dictionaries.""" 123 | self.tokenizer.tokenize(self.corpus, update_vocab=True, show_progress=False) 124 | self.tokenizer.reset_vocab() 125 | vocab = self.tokenizer.get_vocab_dict() 126 | self.assertEqual(len(vocab), 0) 127 | 128 | def test_to_tokenized_tuple(self): 129 | """Tests the `to_tokenized_tuple` method to ensure it correctly converts token IDs to a named tuple.""" 130 | docs = self.tokenizer.tokenize(self.corpus, return_as="ids", show_progress=False) 131 | tokenized_tuple = self.tokenizer.to_tokenized_tuple(docs) # type: ignore 132 | self.assertIsInstance(tokenized_tuple, tuple) 133 | self.assertEqual(len(tokenized_tuple.ids), len(docs)) # type: ignore 134 | 135 | def test_decode_method(self): 136 | """Tests the `to_lists_of_strings` method to ensure it converts token IDs back to strings properly.""" 137 | docs = self.tokenizer.tokenize(self.corpus_large[:1000], return_as="ids", show_progress=False) 138 | strings = self.tokenizer.decode(docs) # type: ignore 139 | self.assertIsInstance(strings, list) 140 | for doc in strings: 141 | self.assertIsInstance(doc, list) 142 | for token in doc: 143 | self.assertIsInstance(token, str) 144 | 145 | # compare return_ids with decode 146 | def test_compare_class_with_functional(self): 147 | """Tests the `to_lists_of_strings` method to ensure it converts token IDs back to strings properly.""" 148 | docs = self.tokenizer.tokenize(self.corpus_large, return_as="ids", show_progress=False) 149 | strings = self.tokenizer.decode(docs) # type: ignore 150 | 151 | # now, do the same using bm25s.tokenize 152 | strings2 = self.tokenizer.tokenize(self.corpus_large, return_as="string") 153 | 154 | # compare the two 155 | self.assertEqual(strings, strings2) 156 | 157 | def test_save_load_vocab(self): 158 | """ 159 | Tests the save_vocab and load_vocab methods to ensure that the vocabulary is saved and loaded correctly. 160 | First, this test will tokenize a corpus and store the tokens for later comparison. Then, the vocabulary 161 | will be saved to a file, and the tokenizer will be re-initialized. Finally, the vocabulary will be loaded 162 | from the file, and the tokenization will be performed again. The tokens from the first tokenization and the 163 | second tokenization should be the same. 164 | """ 165 | corpus = self.corpus_large[:500] 166 | # Tokenize the corpus and store the tokens 167 | tokenizer = Tokenizer(stemmer=self.stemmer) 168 | tokens_original = tokenizer.tokenize(corpus, return_as="ids", update_vocab=True, show_progress=False) 169 | vocab = tokenizer.get_vocab_dict() 170 | 171 | # Save the vocabulary to a temp dir 172 | tokenizer.save_vocab(save_dir=self.tmpdir, vocab_name="vocab.tokenizer.json") 173 | 174 | # Re-initialize the tokenizer and load the vocabulary from the file 175 | tokenizer2 = Tokenizer(stemmer=self.stemmer) 176 | tokenizer2.load_vocab(save_dir=self.tmpdir, vocab_name="vocab.tokenizer.json") 177 | 178 | # Tokenize the corpus again 179 | tokens_new = tokenizer2.tokenize(corpus, return_as="ids", show_progress=False) 180 | vocab_new = tokenizer2.get_vocab_dict() 181 | 182 | # Compare the tokens from the first and second tokenization 183 | self.assertEqual(tokens_original, tokens_new) 184 | # Compare the vocabularies from the first and second tokenization 185 | self.assertEqual(vocab, vocab_new) 186 | 187 | def test_save_load_stopwords(self): 188 | """ 189 | Tests the save_stopwords and load_stopwords methods to ensure that the stopwords are saved and loaded correctly. 190 | First, this test will tokenize a corpus and store the tokens for later comparison. Then, the stopwords will be 191 | saved to a file, and the tokenizer will be re-initialized. Finally, the stopwords will be loaded from the file, 192 | and the tokenization will be performed again. The tokens from the first tokenization and the second tokenization 193 | should be the same. 194 | """ 195 | corpus = self.corpus_large[:500] 196 | # Tokenize the corpus and store the tokens 197 | tokenizer = Tokenizer(stemmer=self.stemmer, stopwords="english") 198 | tokens_original = tokenizer.tokenize(corpus, return_as="ids", update_vocab=True, show_progress=False) 199 | stopwords = tokenizer.stopwords 200 | 201 | # Save the stopwords to a temp dir 202 | tokenizer.save_stopwords(save_dir=self.tmpdir, stopwords_name="stopwords.tokenizer.json") 203 | 204 | # Re-initialize the tokenizer and load the stopwords from the file 205 | tokenizer2 = Tokenizer(stemmer=self.stemmer, stopwords=[]) 206 | 207 | # check if stopwords are empty 208 | self.assertEqual(tokenizer2.stopwords, []) 209 | 210 | tokenizer2.load_stopwords(save_dir=self.tmpdir, stopwords_name="stopwords.tokenizer.json") 211 | 212 | # Check if the stopwords are loaded correctly 213 | self.assertEqual(stopwords, tuple(tokenizer2.stopwords)) 214 | 215 | def test_empty_sentence_and_unknown_word(self): 216 | corpus = [ 217 | "a cat is a feline and likes to purr", 218 | "a dog is the human's best friend and loves to play", 219 | "a bird is a beautiful animal that can fly", 220 | "a fish is a creature that lives in water and swims", 221 | ] 222 | new_docs = ["cat", "", "potato"] 223 | tokenizer = Tokenizer(stopwords="en") 224 | corpus_tokens = tokenizer.tokenize(corpus) 225 | new_docs_tokens = tokenizer.tokenize(new_docs) 226 | 227 | self.assertTrue(np.all(new_docs_tokens == np.array([[1], [0], [0]]))) 228 | 229 | @classmethod 230 | def tearDownClass(cls): 231 | """Cleans up resources after all tests have run (not required in this test case).""" 232 | # delete temp dir 233 | shutil.rmtree(cls.tmpdir) 234 | 235 | 236 | if __name__ == "__main__": 237 | unittest.main() 238 | -------------------------------------------------------------------------------- /tests/core/test_tokenizer_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous tests for the tokenizer module. 3 | """ 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | import Stemmer 9 | 10 | import bm25s 11 | from bm25s.tokenization import Tokenizer 12 | 13 | 14 | class TestBM25SNewIds(unittest.TestCase): 15 | def test_empty_string(self): 16 | # Create an empty corpus 17 | corpus = ["", "", "", ""] 18 | # Create a list of queries 19 | queries = ["what is the meaning of life?"] 20 | 21 | # The tokenizer will return a list of list of tokens 22 | tokenizer = Tokenizer() 23 | corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple", allow_empty=True) 24 | 25 | self.assertEqual( 26 | corpus_tokens, 27 | bm25s.tokenization.Tokenized(ids=[[0], [0], [0], [0]], vocab={"": 0}), 28 | msg=f"Corpus tokens differ from expected: {corpus_tokens}", 29 | ) 30 | 31 | query_tokens = tokenizer.tokenize( 32 | queries, return_as="ids", update_vocab=False, allow_empty=True 33 | ) 34 | 35 | self.assertEqual( 36 | [[0]], 37 | query_tokens, 38 | msg=f"Query tokens differ from expected: {query_tokens}", 39 | ) 40 | 41 | retriever = bm25s.BM25() 42 | retriever.index(corpus_tokens) 43 | 44 | results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=2) 45 | self.assertTrue( 46 | np.all(results == np.array([["", ""]])), 47 | msg=f"Results differ from expected: {results}, {scores}", 48 | ) 49 | 50 | def test_new_ids(self): 51 | corpus = [ 52 | "a cat is a feline and likes to purr", 53 | "a dog is the human's best friend and loves to play", 54 | "a bird is a beautiful animal that can fly", 55 | "a fish is a creature that lives in water and swims", 56 | ] 57 | 58 | tokenizer = Tokenizer( 59 | stemmer=None, stopwords=None, splitter=lambda x: x.split() 60 | ) 61 | corpus_tokens = tokenizer.tokenize(corpus, allow_empty=False) 62 | 63 | bm25 = bm25s.BM25() 64 | bm25.index(corpus_tokens, create_empty_token=False) 65 | 66 | query = "What is a fly?" 67 | query_tokens = tokenizer.tokenize([query], update_vocab=True, allow_empty=False) 68 | self.assertListEqual([[27, 2, 0, 28]], query_tokens) 69 | 70 | results, scores = bm25.retrieve(query_tokens, k=3) 71 | self.assertTrue( 72 | np.all(np.array([[0, 2, 3]]) == results), 73 | msg=f"Results differ from expected: {results}, {scores}", 74 | ) 75 | 76 | def test_failing_after_adding_new_tokens_query(self): 77 | corpus = [ 78 | "a cat is a feline and likes to purr", 79 | "a dog is the human's best friend and loves to play", 80 | "a bird is a beautiful animal that can fly", 81 | "a fish is a creature that lives in water and swims", 82 | ] 83 | 84 | tokenizer = Tokenizer( 85 | stemmer=None, stopwords=None, splitter=lambda x: x.split() 86 | ) 87 | corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple", allow_empty=False) 88 | 89 | bm25 = bm25s.BM25() 90 | bm25.index(corpus_tokens, create_empty_token=False) 91 | 92 | query = "unknownword" 93 | query_tokens = tokenizer.tokenize([query], update_vocab=True, allow_empty=False) 94 | 95 | # assert a valueError is raised 96 | with self.assertRaises(ValueError): 97 | results, scores = bm25.retrieve(query_tokens, k=3) 98 | 99 | def test_only_unknown_token_query(self): 100 | corpus = [ 101 | "a cat is a feline and likes to purr", 102 | "a dog is the human's best friend and loves to play", 103 | "a bird is a beautiful animal that can fly", 104 | "a fish is a creature that lives in water and swims", 105 | ] 106 | 107 | tokenizer = Tokenizer( 108 | stemmer=None, stopwords=None, splitter=lambda x: x.split() 109 | ) 110 | corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple") 111 | 112 | bm25 = bm25s.BM25() 113 | bm25.index(corpus_tokens) 114 | 115 | query = "unknownword" 116 | query_tokens = tokenizer.tokenize([query], update_vocab=False) 117 | 118 | results, scores = bm25.retrieve(query_tokens, k=3) 119 | self.assertTrue(np.all(scores == 0.0)) 120 | 121 | def test_only_unknown_token_query_stemmed(self): 122 | corpus = [ 123 | "a cat is a feline and likes to purr", 124 | "a dog is the human's best friend and loves to play", 125 | "a bird is a beautiful animal that can fly", 126 | "a fish is a creature that lives in water and swims", 127 | ] 128 | 129 | stemmer = Stemmer.Stemmer("english") 130 | 131 | tokenizer = Tokenizer( 132 | stemmer=stemmer, stopwords=None, splitter=lambda x: x.split() 133 | ) 134 | corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple") 135 | 136 | bm25 = bm25s.BM25() 137 | bm25.index(corpus_tokens) 138 | 139 | query = "unknownword" 140 | query_tokens = tokenizer.tokenize([query], update_vocab=False) 141 | 142 | results, scores = bm25.retrieve(query_tokens, k=3) 143 | self.assertTrue(np.all(scores == 0.0)) 144 | 145 | def test_truncation_of_large_corpus(self): 146 | small_unit = [ 147 | "a cat is a feline and likes to purr", 148 | "a dog is the human's best friend and loves to play", 149 | "a bird is a beautiful animal that can fly", 150 | "a fish is a creature that lives in water and swims", 151 | # a line more than 10 tokens 152 | "a cat is a feline and likes to purr and the cat can jump very high with so careful but casual manner", 153 | ] 154 | corpus = [] 155 | for i in range(1000): 156 | corpus += small_unit 157 | tokenized = bm25s.tokenize(corpus, stopwords="en", return_ids=True) 158 | repr_tokenized = repr(tokenized) 159 | self.assertIn("... (total", repr_tokenized, 160 | msg="it should include the '...' message, for the indication of the truncation.") 161 | self.assertIn(", ...]", repr_tokenized, 162 | msg="it should include the '...' message, for the indication of the truncation.") 163 | 164 | def test_truncation_of_small_corpus(self): 165 | corpus = ["a cat is a feline"] 166 | tokenized = bm25s.tokenize(corpus, stopwords="en", return_ids=True) 167 | repr_tokenized = repr(tokenized) 168 | self.assertNotIn("... (total", repr_tokenized, 169 | msg="it should not include the '...' message when the corpus is small") 170 | self.assertNotIn(", ...]", repr_tokenized, 171 | msg="it should not include the '...' message when the doc is short") 172 | -------------------------------------------------------------------------------- /tests/core/test_topk.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | # Assuming JAX_IS_AVAILABLE is a global variable that we need to set for testing 5 | JAX_IS_AVAILABLE = False # Set to True if you want to test the JAX backend 6 | try: 7 | import jax 8 | import jax.numpy as jnp 9 | JAX_IS_AVAILABLE = True 10 | except ImportError: 11 | pass 12 | 13 | from bm25s.selection import topk 14 | 15 | 16 | class TestTopKSingleQuery(unittest.TestCase): 17 | def setUp(self): 18 | np.random.seed(42) 19 | self.k = 5 20 | self.scores = np.random.uniform(-10, 10, 2000) 21 | self.expected_scores = np.sort(self.scores)[-self.k:][::-1] 22 | self.expected_indices = np.argsort(self.scores)[-self.k:][::-1] 23 | 24 | def check_results(self, result_scores, result_indices, sorted=True): 25 | if sorted: 26 | np.testing.assert_allclose(result_scores, self.expected_scores) 27 | np.testing.assert_array_equal(result_indices, self.expected_indices) 28 | else: 29 | self.assertEqual(len(result_scores), self.k) 30 | self.assertEqual(len(result_indices), self.k) 31 | self.assertTrue(np.all(np.isin(result_scores, self.expected_scores))) 32 | self.assertTrue(np.all(np.isin(result_indices, self.expected_indices))) 33 | 34 | def test_topk_numpy_sorted(self): 35 | result_scores, result_indices = topk(self.scores, self.k, backend="numpy", sorted=True) 36 | self.check_results(result_scores, result_indices, sorted=True) 37 | 38 | def test_topk_numpy_unsorted(self): 39 | result_scores, result_indices = topk(self.scores, self.k, backend="numpy", sorted=False) 40 | self.check_results(result_scores, result_indices, sorted=False) 41 | 42 | @unittest.skipUnless(JAX_IS_AVAILABLE, "JAX is not available") 43 | def test_topk_jax_sorted(self): 44 | result_scores, result_indices = topk(jnp.array(self.scores), self.k, backend="jax", sorted=True) 45 | self.check_results(result_scores, result_indices, sorted=True) 46 | 47 | @unittest.skipUnless(JAX_IS_AVAILABLE, "JAX is not available") 48 | def test_topk_jax_unsorted(self): 49 | result_scores, result_indices = topk(jnp.array(self.scores), self.k, backend="jax", sorted=False) 50 | self.check_results(result_scores, result_indices, sorted=True) 51 | 52 | def test_topk_auto_backend(self): 53 | result_scores, result_indices = topk(self.scores, self.k, backend="auto", sorted=True) 54 | self.check_results(result_scores, result_indices, sorted=True) 55 | 56 | def test_jax_installed_but_unavailable(self): 57 | global JAX_IS_AVAILABLE 58 | original_jax_is_available = JAX_IS_AVAILABLE 59 | JAX_IS_AVAILABLE = False # Temporarily pretend JAX is not available 60 | 61 | result_scores, result_indices = topk(self.scores, self.k, backend="auto", sorted=True) 62 | self.check_results(result_scores, result_indices, sorted=True) 63 | 64 | JAX_IS_AVAILABLE = original_jax_is_available # Restore the original value 65 | 66 | 67 | if __name__ == '__main__': 68 | unittest.main() 69 | -------------------------------------------------------------------------------- /tests/core/test_utils_corpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | import unittest 5 | import tempfile 6 | import unittest.mock 7 | import Stemmer # optional: for stemming 8 | import json 9 | 10 | import bm25s 11 | from bm25s.utils import json_functions 12 | 13 | class TestUtilsCorpus(unittest.TestCase): 14 | def setUp(self): 15 | # let's test the functions 16 | # random jsonl file 17 | self.tmpdirname = tempfile.mkdtemp() 18 | file = os.path.join(self.tmpdirname, "file.jsonl") 19 | self.file = file 20 | # content is random uuids 21 | import uuid 22 | 23 | self.strings = [] 24 | 25 | with open(file, "w") as f: 26 | for i in range(500): 27 | s = str(json.dumps({"uuid": str(uuid.uuid4())})) + "\n" 28 | self.strings.append(s) 29 | f.write(s) 30 | 31 | # hide orjson from importable 32 | def test_load_and_save_mmindex(self): 33 | import bm25s 34 | 35 | try: 36 | import orjson 37 | except ImportError: 38 | self.fail("orjson is not installed") 39 | 40 | file = self.file 41 | mmindex = bm25s.utils.corpus.find_newline_positions(file) 42 | bm25s.utils.corpus.save_mmindex(mmindex, file) 43 | 44 | # read the first line 45 | mmindex = bm25s.utils.corpus.load_mmindex(file) 46 | 47 | for i in range(500): 48 | self.assertEqual(bm25s.utils.corpus.get_line(file, i, mmindex), self.strings[i]) 49 | 50 | @unittest.mock.patch("bm25s.utils.json_functions.dumps", json_functions.dumps_with_builtin) 51 | @unittest.mock.patch("bm25s.utils.json_functions.loads", json.loads) 52 | def test_load_and_save_mmindex_no_orjson(self): 53 | self.assertEqual(json_functions.dumps_with_builtin, json_functions.dumps) 54 | self.assertEqual(json_functions.loads, json.loads) 55 | self.test_load_and_save_mmindex() 56 | 57 | def tearDown(self): 58 | # remove the temp dir with rmtree 59 | shutil.rmtree(self.tmpdirname) -------------------------------------------------------------------------------- /tests/core/test_vocab_dict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | import unittest 5 | import numpy as np 6 | import tempfile 7 | import Stemmer # optional: for stemming 8 | import unittest.mock 9 | import json 10 | 11 | import bm25s 12 | 13 | class TestVocabDict(unittest.TestCase): 14 | def test_vocab_dict(self): 15 | 16 | # Create the BM25 model and index the corpus 17 | stemmer = Stemmer.Stemmer("english") 18 | corpus = [ 19 | "a cat is a feline and likes to purr", 20 | "a dog is the human's best friend and loves to play", 21 | "a bird is a beautiful animal that can fly", 22 | "a fish is a creature that lives in water and swims", 23 | ] 24 | # Note: allow_empty=False will ensure that "" is not in the vocab_dict 25 | corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer, allow_empty=False) 26 | 27 | # check that "" is not in the vocab 28 | self.assertFalse("" in corpus_tokens.vocab) 29 | 30 | # 1. index(corpus,create_empty_token=True) --> correct 31 | retriever = bm25s.BM25(method='bm25+') 32 | retriever.index(corpus_tokens, create_empty_token=True) 33 | self.assertTrue(retriever.vocab_dict is not None) 34 | self.assertTrue("" in retriever.vocab_dict) 35 | 36 | self.assertEqual(len(retriever.vocab_dict), len(retriever.unique_token_ids_set)) 37 | self.assertEqual(set(retriever.vocab_dict.values()), set(retriever.unique_token_ids_set)) 38 | 39 | # empty_sentence = ["", "", ""] 40 | # empty_sentence_tokens = bm25s.tokenize(empty_sentence, stopwords="en", stemmer=stemmer, allow_empty=True) 41 | # check that "" is in the vocab_dict 42 | 43 | # 2. index(corpus,create_empty_token=True) --> throwing an error 44 | # pass 45 | 46 | # 3. index(corpus,create_empty_token=False) --> correct 47 | corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer, allow_empty=False) 48 | 49 | retriever = bm25s.BM25(method='lucene') 50 | retriever.index(corpus_tokens, create_empty_token=False) 51 | self.assertTrue(retriever.vocab_dict is not None) 52 | self.assertFalse("" in retriever.vocab_dict) 53 | 54 | self.assertEqual(len(retriever.vocab_dict), len(retriever.unique_token_ids_set)) 55 | self.assertEqual(set(retriever.vocab_dict.values()), set(retriever.unique_token_ids_set)) 56 | 57 | # 4. index(corpus,create_empty_token=False) --> throwing an error 58 | retriever = bm25s.BM25(method='bm25+') 59 | retriever.index(corpus_tokens, create_empty_token=False) 60 | 61 | new_docs = ["cat", "", "potato"] 62 | 63 | tokenizer = bm25s.tokenization.Tokenizer(stopwords="en", stemmer=stemmer) 64 | corpus_tokens = tokenizer.tokenize(corpus, return_as='ids', allow_empty=True) 65 | new_docs_tokens = tokenizer.tokenize(new_docs, return_as='ids', allow_empty=True) 66 | 67 | # create_empty_token=True 68 | retriever = bm25s.BM25(method='bm25+') 69 | retriever.index(corpus_tokens, create_empty_token=True) 70 | retriever.retrieve(new_docs_tokens, k=1) 71 | 72 | # create_empty_token=False 73 | retriever = bm25s.BM25(method='bm25+') 74 | # assert that this will throw an error 75 | with self.assertRaises(IndexError): 76 | retriever.index(corpus_tokens, create_empty_token=False) 77 | retriever.retrieve(new_docs_tokens, k=1) -------------------------------------------------------------------------------- /tests/numba/test_numba_backend_retrieve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | import unittest 5 | import tempfile 6 | 7 | import numpy as np 8 | import bm25s 9 | import Stemmer # optional: for stemming 10 | 11 | class TestNumbaBackendRetrieve(unittest.TestCase): 12 | @classmethod 13 | def setUpClass(cls): 14 | 15 | # Create your corpus here 16 | corpus = [ 17 | "a cat is a feline and likes to purr", 18 | "a dog is the human's best friend and loves to play", 19 | "a bird is a beautiful animal that can fly", 20 | "a fish is a creature that lives in water and swims", 21 | ] 22 | 23 | # optional: create a stemmer 24 | stemmer = Stemmer.Stemmer("english") 25 | 26 | # Tokenize the corpus and only keep the ids (faster and saves memory) 27 | corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer) 28 | 29 | # Create the BM25 model and index the corpus 30 | retriever = bm25s.BM25(method='bm25+', backend="numba", corpus=corpus) 31 | retriever.index(corpus_tokens) 32 | 33 | # Save the retriever to temp dir 34 | cls.retriever = retriever 35 | cls.corpus = corpus 36 | cls.corpus_tokens = corpus_tokens 37 | cls.stemmer = stemmer 38 | cls.tmpdirname = tempfile.mkdtemp() 39 | 40 | def test_a_save(self): 41 | # save the retriever to temp dir 42 | self.retriever.save( 43 | self.tmpdirname, 44 | data_name="data.index.csc.npy", 45 | indices_name="indices.index.csc.npy", 46 | indptr_name="indptr.index.csc.npy", 47 | vocab_name="vocab.json", 48 | nnoc_name="nonoccurrence_array.npy", 49 | params_name="params.json", 50 | ) 51 | 52 | # assert that the following files are saved 53 | fnames = [ 54 | "data.index.csc.npy", 55 | "indices.index.csc.npy", 56 | "indptr.index.csc.npy", 57 | "vocab.json", 58 | "nonoccurrence_array.npy", 59 | "params.json", 60 | ] 61 | 62 | for fname in fnames: 63 | error_msg = f"File {fname} not found in even though it should be saved by the .save() method" 64 | path_exists = os.path.exists(os.path.join(self.tmpdirname, fname)) 65 | self.assertTrue(path_exists, error_msg) 66 | 67 | def test_b_retrieve_with_numba(self): 68 | # load the retriever from temp dir 69 | retriever = bm25s.BM25.load( 70 | self.tmpdirname, 71 | data_name="data.index.csc.npy", 72 | indices_name="indices.index.csc.npy", 73 | indptr_name="indptr.index.csc.npy", 74 | vocab_name="vocab.json", 75 | nnoc_name="nonoccurrence_array.npy", 76 | params_name="params.json", 77 | load_corpus=True, 78 | ) 79 | 80 | self.assertTrue(retriever.backend == "numba", "The backend should be 'numba'") 81 | 82 | reloaded_corpus_text = [c["text"] for c in retriever.corpus] 83 | self.assertTrue(reloaded_corpus_text == self.corpus, "The corpus should be the same as the original corpus") 84 | 85 | # now, let's retrieve the top-k results for a query 86 | query = ["my cat loves to purr", "a fish likes swimming"] 87 | query_tokens = bm25s.tokenize(query, stopwords="en", stemmer=self.stemmer) 88 | 89 | # retrieve the top-k results 90 | top_k = 2 91 | retrieved = retriever.retrieve(query_tokens, k=top_k, return_as="tuple") 92 | retrieved_docs = retriever.retrieve(query_tokens, k=top_k, return_as="documents") 93 | 94 | # now, let's retrieve the top-k results for a query using numpy 95 | retriever.backend = "numpy" 96 | retrieved_np = retriever.retrieve(query_tokens, k=top_k, return_as="tuple") 97 | retrieved_docs_np = retriever.retrieve(query_tokens, k=top_k, return_as="documents") 98 | # assert that the results are the same 99 | self.assertTrue(np.all(retrieved.scores == retrieved_np.scores), "The retrieved scores should be the same") 100 | self.assertTrue(np.all(retrieved.documents == retrieved_np.documents), "The retrieved documents should be the same") 101 | self.assertTrue(np.all(retrieved_docs == retrieved_docs_np), "The results should be the same") 102 | 103 | # finally, check when it's loaded with mmap 104 | def test_c_mmap_retrieve_with_numba(self): 105 | # load the retriever from temp dir 106 | retriever = bm25s.BM25.load( 107 | self.tmpdirname, 108 | data_name="data.index.csc.npy", 109 | indices_name="indices.index.csc.npy", 110 | indptr_name="indptr.index.csc.npy", 111 | vocab_name="vocab.json", 112 | nnoc_name="nonoccurrence_array.npy", 113 | params_name="params.json", 114 | load_corpus=True, 115 | mmap=True 116 | ) 117 | 118 | self.assertTrue(retriever.backend == "numba", "The backend should be 'numba'") 119 | 120 | reloaded_corpus_text = [c["text"] for c in retriever.corpus] 121 | self.assertTrue(reloaded_corpus_text == self.corpus, "The corpus should be the same as the original corpus") 122 | 123 | # now, let's retrieve the top-k results for a query 124 | query = ["my cat loves to purr", "a fish likes swimming"] 125 | query_tokens = bm25s.tokenize(query, stopwords="en", stemmer=self.stemmer) 126 | 127 | # retrieve the top-k results 128 | top_k = 2 129 | retrieved = retriever.retrieve(query_tokens, k=top_k, return_as="tuple") 130 | retrieved_docs = retriever.retrieve(query_tokens, k=top_k, return_as="documents") 131 | 132 | # now, let's retrieve the top-k results for a query using numpy 133 | retriever.backend = "numpy" 134 | retrieved_np = retriever.retrieve(query_tokens, k=top_k, return_as="tuple") 135 | retrieved_docs_np = retriever.retrieve(query_tokens, k=top_k, return_as="documents") 136 | # assert that the results are the same 137 | self.assertTrue(np.all(retrieved.scores == retrieved_np.scores), "The retrieved scores should be the same") 138 | self.assertTrue(np.all(retrieved.documents == retrieved_np.documents), "The retrieved documents should be the same") 139 | self.assertTrue(np.all(retrieved_docs == retrieved_docs_np), "The results should be the same") 140 | 141 | # test weight_mask in retrieve() 142 | def test_d_retrieve_with_weight_mask(self): 143 | for dt in [np.float32, np.int32, np.bool_]: 144 | weight_mask = np.array([1, 1, 0, 1], dtype=dt) 145 | # load the retriever from temp dir 146 | retriever = bm25s.BM25.load( 147 | self.tmpdirname, 148 | data_name="data.index.csc.npy", 149 | indices_name="indices.index.csc.npy", 150 | indptr_name="indptr.index.csc.npy", 151 | vocab_name="vocab.json", 152 | nnoc_name="nonoccurrence_array.npy", 153 | params_name="params.json", 154 | load_corpus=True, 155 | ) 156 | 157 | self.assertTrue(retriever.backend == "numba", "The backend should be 'numba'") 158 | 159 | # now, let's retrieve the top-k results for a query 160 | query = ["my cat loves to purr", "a fish likes swimming"] 161 | 162 | query_tokens = bm25s.tokenize(query, stopwords="en", stemmer=self.stemmer) 163 | 164 | # retrieve the top-k results 165 | top_k = 2 166 | retrieved = retriever.retrieve(query_tokens, k=top_k, return_as="tuple", weight_mask=weight_mask) 167 | 168 | @classmethod 169 | def tearDownClass(cls): 170 | # remove the temp dir with rmtree 171 | shutil.rmtree(cls.tmpdirname) -------------------------------------------------------------------------------- /tests/numba/test_topk_numba.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | # Assuming JAX_IS_AVAILABLE is a global variable that we need to set for testing 5 | JAX_IS_AVAILABLE = False # Set to True if you want to test the JAX backend 6 | try: 7 | import jax 8 | import jax.numpy as jnp 9 | JAX_IS_AVAILABLE = True 10 | except ImportError: 11 | pass 12 | 13 | from bm25s.numba.selection import topk 14 | 15 | 16 | class TestTopKSingleQuery(unittest.TestCase): 17 | def setUp(self): 18 | np.random.seed(42) 19 | self.k = 5 20 | self.scores = np.random.uniform(-10, 10, 2000) 21 | self.expected_scores = np.sort(self.scores)[-self.k:][::-1] 22 | self.expected_indices = np.argsort(self.scores)[-self.k:][::-1] 23 | 24 | def check_results(self, result_scores, result_indices, sorted=True): 25 | if sorted: 26 | np.testing.assert_allclose(result_scores, self.expected_scores) 27 | np.testing.assert_array_equal(result_indices, self.expected_indices) 28 | else: 29 | self.assertEqual(len(result_scores), self.k) 30 | self.assertEqual(len(result_indices), self.k) 31 | self.assertTrue(np.all(np.isin(result_scores, self.expected_scores))) 32 | self.assertTrue(np.all(np.isin(result_indices, self.expected_indices))) 33 | 34 | def test_topk_numba_sorted(self): 35 | result_scores, result_indices = topk(self.scores, self.k, backend="numba", sorted=True) 36 | self.check_results(result_scores, result_indices, sorted=True) 37 | 38 | def test_topk_numba_unsorted(self): 39 | result_scores, result_indices = topk(self.scores, self.k, backend="numba", sorted=False) 40 | self.check_results(result_scores, result_indices, sorted=False) 41 | 42 | if __name__ == '__main__': 43 | unittest.main() -------------------------------------------------------------------------------- /tests/requirements-comparison.txt: -------------------------------------------------------------------------------- 1 | beir 2 | PyStemmer 3 | torch 4 | numpy 5 | scipy 6 | tqdm 7 | transformers 8 | pyserini 9 | git+https://github.com/dorianbrown/rank_bm25@1abce6cb8bd4a4961f0958391b3eabb749483c01 10 | bm25_pt 11 | nltk -------------------------------------------------------------------------------- /tests/requirements-core.txt: -------------------------------------------------------------------------------- 1 | -e .[full] -------------------------------------------------------------------------------- /tests/stopwords/test_stopwords.py: -------------------------------------------------------------------------------- 1 | """ 2 | Testing for stopwords needs to be define. 3 | """ --------------------------------------------------------------------------------