├── .github └── workflows │ ├── docs.yml │ ├── release.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── benchmarks └── .gitkeep ├── docs ├── bbm25_haystack.html ├── bbm25_haystack │ ├── __about__.html │ ├── bbm25_retriever.html │ ├── bbm25_store.html │ └── filters.html ├── index.html └── search.js ├── pyproject.toml ├── scripts └── benchmark_beir.py ├── src └── bbm25_haystack │ ├── __about__.py │ ├── __init__.py │ ├── bbm25_retriever.py │ ├── bbm25_store.py │ ├── default.model │ └── filters.py └── tests ├── __init__.py ├── test_document_store.py └── test_retriever.py /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: website 2 | 3 | # build the documentation whenever there are new commits on main 4 | on: 5 | push: 6 | branches: 7 | - main 8 | # Alternative: only build for tags. 9 | # tags: 10 | # - '*' 11 | 12 | # security: restrict permissions for CI jobs. 13 | permissions: 14 | contents: read 15 | 16 | jobs: 17 | # Build the documentation and upload the static HTML files as an artifact. 18 | build: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: '3.9' 25 | 26 | # ADJUST THIS: install all dependencies (including pdoc) 27 | - run: pip install -e . 28 | - run: pip install pdoc 29 | 30 | # ADJUST THIS: build your documentation into docs/. 31 | # We use a custom build script for pdoc itself, ideally you just run `pdoc -o docs/ ...` here. 32 | - run: pdoc src/bbm25_haystack -o docs --docformat restructuredtext 33 | 34 | - uses: actions/upload-pages-artifact@v3 35 | with: 36 | path: docs/ 37 | 38 | # Deploy the artifact to GitHub pages. 39 | # This is a separate job so that only actions/deploy-pages has the necessary permissions. 40 | deploy: 41 | needs: build 42 | runs-on: ubuntu-latest 43 | permissions: 44 | pages: write 45 | id-token: write 46 | environment: 47 | name: github-pages 48 | url: ${{ steps.deployment.outputs.page_url }} 49 | steps: 50 | - id: deployment 51 | uses: actions/deploy-pages@v4 52 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v[0-9].[0-9]+.[0-9]+*" 7 | 8 | jobs: 9 | release-on-pypi: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | 16 | - name: Install Hatch 17 | run: pip install hatch 18 | 19 | - name: Build 20 | run: hatch build 21 | 22 | - name: Publish on PyPi 23 | env: 24 | HATCH_INDEX_USER: __token__ 25 | HATCH_INDEX_AUTH: ${{ secrets.PYPI_API_TOKEN }} 26 | run: hatch publish -y -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow comes from https://github.com/ofek/hatch-mypyc 2 | # https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml 3 | name: test 4 | 5 | on: 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | 11 | concurrency: 12 | group: test-${{ github.head_ref }} 13 | cancel-in-progress: true 14 | 15 | env: 16 | PYTHONUNBUFFERED: "1" 17 | FORCE_COLOR: "1" 18 | 19 | jobs: 20 | run: 21 | name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} 22 | runs-on: ${{ matrix.os }} 23 | strategy: 24 | fail-fast: false 25 | matrix: 26 | os: [ubuntu-latest, windows-latest, macos-latest] 27 | python-version: ['3.9', '3.10', '3.11', '3.12'] 28 | 29 | steps: 30 | - name: Support longpaths 31 | if: matrix.os == 'windows-latest' 32 | run: git config --system core.longpaths true 33 | 34 | - uses: actions/checkout@v3 35 | 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v4 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | 41 | - name: Install Hatch 42 | run: pip install --upgrade hatch 43 | 44 | - name: Lint 45 | if: matrix.python-version == '3.9' && runner.os == 'Linux' 46 | run: hatch run lint:all 47 | 48 | - name: Run tests 49 | run: hatch run cov 50 | 51 | - name: Upload coverage reports to Codecov 52 | uses: codecov/codecov-action@v4.0.1 53 | with: 54 | token: ${{ secrets.CODECOV_TOKEN }} 55 | slug: Guest400123064/bbm25-haystack 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # VS Code 163 | .vscode 164 | 165 | # Benchmarking datasets 166 | benchmarks/beir/* 167 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![test](https://github.com/Guest400123064/bbm25-haystack/actions/workflows/test.yml/badge.svg)](https://github.com/Guest400123064/bbm25-haystack/actions/workflows/test.yml) 2 | [![codecov](https://codecov.io/gh/Guest400123064/bbm25-haystack/graph/badge.svg?token=IGRIRBHZ3U)](https://codecov.io/gh/Guest400123064/bbm25-haystack) 3 | [![code style - Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 4 | [![types - Mypy](https://img.shields.io/badge/types-Mypy-blue.svg)](https://github.com/python/mypy) 5 | [![Python 3.9](https://img.shields.io/badge/python-3.9%20|%203.10%20|%203.11%20|%203.12-blue.svg)](https://www.python.org/downloads/release/python-390/) 6 | 7 | # Better BM25 In-Memory Document Store 8 | 9 | An in-memory document store is a great starting point for prototyping and debugging before migrating to production-grade stores like Elasticsearch. However, [the original implementation](https://github.com/deepset-ai/haystack/blob/0dbb98c0a017b499560521aa93186d0640aab659/haystack/document_stores/in_memory/document_store.py#L148) of BM25 retrieval recreates an inverse index for the entire document store __on every new search__. Furthermore, the tokenization method is primitive, only permitting splitters based on regular expressions, making localization and domain adaptation challenging. Therefore, this implementation is a slight upgrade to the default BM25 in-memory document store by implementing incremental index update and incorporation of [SentencePiece](https://github.com/google/sentencepiece) statistical sub-word tokenization. 10 | 11 | ## Installation 12 | 13 | ```bash 14 | $ pip install bbm25-haystack 15 | ``` 16 | 17 | Alternatively, you can clone the repository and build from source to be able to reflect changes to the source code: 18 | 19 | ```bash 20 | $ git clone https://github.com/Guest400123064/bbm25-haystack.git 21 | $ cd bbm25-haystack 22 | $ pip install -e . 23 | ``` 24 | 25 | ## Usage 26 | 27 | ### Quick Start 28 | 29 | Below is an example of how you can build a minimal search engine with the `bbm25_haystack` components on their own. They are also compatible with [Haystack pipelines](https://docs.haystack.deepset.ai/docs/creating-pipelines). 30 | 31 | ```python 32 | from haystack import Document 33 | from bbm25_haystack import BetterBM25DocumentStore, BetterBM25Retriever 34 | 35 | 36 | document_store = BetterBM25DocumentStore() 37 | document_store.write_documents([ 38 | Document(content="There are over 7,000 languages spoken around the world today."), 39 | Document(content="Elephants have been observed to behave in a way that indicates a high level of self-awareness, such as recognizing themselves in mirrors."), 40 | Document(content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bio-luminescent waves.") 41 | ]) 42 | 43 | retriever = BetterBM25Retriever(document_store) 44 | retriever.run(query="How many languages are spoken around the world today?") 45 | ``` 46 | 47 | ### API References 48 | 49 | You can find the full API references [here](https://guest400123064.github.io/bbm25-haystack/). In a hurry? Below are some most important document store parameters you might want explore: 50 | 51 | - `k, b, delta` - the [three BM25+ hyperparameters](https://en.wikipedia.org/wiki/Okapi_BM25). 52 | - `sp_file` - a path to a trained SentencePiece tokenizer `.model` file. The default tokenizer is directly copied from [LLaMA-2-7B-32K tokenizer](https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/tokenizer.model) with a vocab size of 32,000. 53 | - `n_grams` - default to 1, which means text (both query and document) are tokenized into uni-grams. If set to 2, the tokenizer also augment the list of uni-grams with bi-grams, and so on. If specified as tuple, e.g., (2, 3), the tokenizer only produce bi-grams and tri-grams, without any uni-gram. 54 | - `haystack_filter_logic` - see [below](#filtering-logic). 55 | 56 | The retriever parameters are largely the same as [`InMemoryBM25Retriever`](https://docs.haystack.deepset.ai/docs/inmemorybm25retriever). 57 | 58 | ## Filtering Logic 59 | 60 | The current document store uses [`document_matches_filter`](https://github.com/deepset-ai/haystack/blob/main/haystack/utils/filters.py) shipped with Haystack to perform filtering by default, which is the same as [`InMemoryDocumentStore`](https://docs.haystack.deepset.ai/docs/inmemorydocumentstore). 61 | 62 | However, there is also an alternative filtering logic shipped with this implementation (unstable at this point). To use this alternative logic, initialize the document store with `haystack_filter_logic=False`. Please find comments and implementation details in [`filters.py`](./src/bbm25_haystack/filters.py). TL;DR: 63 | 64 | - Comparison with `None`, i.e., missing values, involved will always return `False`, no matter missing the document attribute value or missing the filter value. 65 | - Comparison with `pandas.DataFrame` is always prohibited to reduce surprises. 66 | - No implicit `datetime` conversion from string values. 67 | - `in` and `not in` allows any `Iterable` as filter value, without the `list` constraint. 68 | - Allowing custom comparison functions for more flexibility. Note that the custom comparison function inputs are NEVER checked, i.e., no missing value check, no ``DataFrame`` check, etc. User should ensure the input values are valid and return value is always a boolean. The inputs are always supplied in the order of document value and then filter value. 69 | 70 | In this case, the negation logic needs to be considered again because `False` can now issue from both input nullity check and the actual comparisons. For instance, `in` and `not in` both yield non-matching upon missing values. But I think having input processing and comparisons separated makes the filtering behavior more transparent. 71 | 72 | ## Search Quality Evaluation 73 | 74 | This repo has [a simple script](./scripts/benchmark_beir.py) to help evaluate the search quality over [BEIR](https://github.com/beir-cellar/beir/tree/main) benchmark. You need to clone the repository (you can also manually download the script and place it under a folder named `scripts`) and you have to install additional dependencies to run the script. 75 | 76 | ```bash 77 | $ pip install beir 78 | ``` 79 | 80 | To run the script, you may want to specify the dataset name and BM25 hyperparameters. For example: 81 | 82 | ```bash 83 | $ python scripts/benchmark_beir.py --datasets scifact arguana --bm25-k1 1.2 --n-grams 2 --output eval.csv 84 | ``` 85 | 86 | It automatically downloads the benchmarking dataset to `benchmarks/beir`, where `benchmarks` is at the same level as `scripts`. You may also check the help page for more information. 87 | 88 | ```bash 89 | $ python scripts/benchmark_beir.py --help 90 | ``` 91 | 92 | New benchmarking scripts are expected to be added in the future. 93 | 94 | ## License 95 | 96 | `bbm25-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. 97 | -------------------------------------------------------------------------------- /benchmarks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guest400123064/bbm25-haystack/9906fa27ffc54f4fd92dfb5d717c15a12a69df0a/benchmarks/.gitkeep -------------------------------------------------------------------------------- /docs/bbm25_haystack/__about__.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | bbm25_haystack.__about__ API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 44 |
45 |
46 |

47 | bbm25_haystack.__about__

48 | 49 | 50 | 51 | 52 | 53 | 54 |
1# SPDX-FileCopyrightText: 2024-present Yuxuan Wang <wangy49@seas.upenn.edu>
 55 | 2#
 56 | 3# SPDX-License-Identifier: Apache-2.0
 57 | 4
 58 | 5__version__ = "0.2.0"
 59 | 
60 | 61 | 62 |
63 |
64 | 246 | -------------------------------------------------------------------------------- /docs/bbm25_haystack/filters.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | bbm25_haystack.filters API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 53 |
54 |
55 |

56 | bbm25_haystack.filters

57 | 58 | 59 | 60 | 61 | 62 | 63 |
  1# SPDX-FileCopyrightText: 2024-present Yuxuan Wang <wangy49@seas.upenn.edu>
 64 |   2#
 65 |   3# SPDX-License-Identifier: Apache-2.0
 66 |   4from collections.abc import Iterable
 67 |   5from functools import wraps
 68 |   6from typing import Any, Callable, Final, Optional
 69 |   7
 70 |   8import pandas as pd
 71 |   9from haystack.dataclasses import Document
 72 |  10from haystack.errors import FilterError
 73 |  11
 74 |  12
 75 |  13def apply_filters_to_document(
 76 |  14    filters: Optional[dict[str, Any]], document: Document
 77 |  15) -> bool:
 78 |  16    """
 79 |  17    Apply filters to a document.
 80 |  18
 81 |  19    :param filters: The filters to apply to the document.
 82 |  20    :type filters: dict[str, Any]
 83 |  21    :param document: The document to apply the filters to.
 84 |  22    :type document: Document
 85 |  23
 86 |  24    :return: True if the document passes the filters.
 87 |  25    :rtype: bool
 88 |  26    """
 89 |  27    if filters is None or not filters:
 90 |  28        return True
 91 |  29    return _run_comparison_condition(filters, document)
 92 |  30
 93 |  31
 94 |  32def _get_document_field(document: Document, field: str) -> Optional[Any]:
 95 |  33    """
 96 |  34    Get the value of a field in a document.
 97 |  35
 98 |  36    If the field is not found within the document then, instead of
 99 |  37    raising an error, `None` is returned. Note that here we do not
100 |  38    implicitly add 'meta' prefix for fields that are not a direct
101 |  39    attribute of the document, not supporting legacy behavior anymore.
102 |  40
103 |  41    :param document: The document to get the field value from.
104 |  42    :type document: Document
105 |  43    :param field: The field to get the value of.
106 |  44    :type field: str
107 |  45
108 |  46    :return: The value of the field in the document.
109 |  47    :rtype: Optional[Any]
110 |  48    """
111 |  49    if "." not in field:
112 |  50        return getattr(document, field)
113 |  51
114 |  52    attr = document.meta
115 |  53    for f in field.split(".")[1:]:
116 |  54        attr = attr.get(f)
117 |  55        if attr is None:
118 |  56            return None
119 |  57    return attr
120 |  58
121 |  59
122 |  60def _run_logical_condition(condition: dict[str, Any], document: Document) -> bool:
123 |  61    if "operator" not in condition:
124 |  62        msg = "Logical condition must have an 'operator' key."
125 |  63        raise FilterError(msg)
126 |  64    if "conditions" not in condition:
127 |  65        msg = "Logical condition must have a 'conditions' key."
128 |  66        raise FilterError(msg)
129 |  67
130 |  68    conditions = condition["conditions"]
131 |  69    reducer = LOGICAL_OPERATORS[condition["operator"]]
132 |  70
133 |  71    return reducer(document, conditions)
134 |  72
135 |  73
136 |  74def _run_comparison_condition(condition: dict[str, Any], document: Document) -> bool:
137 |  75    if "field" not in condition:
138 |  76        return _run_logical_condition(condition, document)
139 |  77
140 |  78    if "operator" not in condition:
141 |  79        msg = "Comparison condition must have an 'operator' key."
142 |  80        raise FilterError(msg)
143 |  81    if "value" not in condition:
144 |  82        msg = "Comparison condition must have a 'value' key."
145 |  83        raise FilterError(msg)
146 |  84
147 |  85    field: str = condition["field"]
148 |  86    value: Any = condition["value"]
149 |  87    comparator = COMPARISON_OPERATORS[condition["operator"]]
150 |  88
151 |  89    return comparator(_get_document_field(document, field), value)
152 |  90
153 |  91
154 |  92def _and(document: Document, conditions: list[dict[str, Any]]) -> bool:
155 |  93    """
156 |  94    Return True if all conditions are met.
157 |  95
158 |  96    :param document: The document to check the conditions against.
159 |  97    :type document: Document
160 |  98    :param conditions: The conditions to check against the document.
161 |  99    :type conditions: list[dict[str, Any]]
162 | 100
163 | 101    :return: True if not all conditions are met.
164 | 102    :rtype: bool
165 | 103    """
166 | 104    return all(
167 | 105        _run_comparison_condition(condition, document) for condition in conditions
168 | 106    )
169 | 107
170 | 108
171 | 109def _or(document: Document, conditions: list[dict[str, Any]]) -> bool:
172 | 110    """
173 | 111    Return True if any condition is met.
174 | 112
175 | 113    :param document: The document to check the conditions against.
176 | 114    :type document: Document
177 | 115    :param conditions: The conditions to check against the document.
178 | 116    :type conditions: list[dict[str, Any]]
179 | 117
180 | 118    :return: True if not all conditions are met.
181 | 119    :rtype: bool
182 | 120    """
183 | 121    return any(_run_comparison_condition(cond, document) for cond in conditions)
184 | 122
185 | 123
186 | 124def _not(document: Document, conditions: list[dict[str, Any]]) -> bool:
187 | 125    """
188 | 126    Return True if not all conditions are met.
189 | 127
190 | 128    The 'NOT' operator is under-specified when supplied with a
191 | 129    set of conditions instead of a single condition. Because we
192 | 130    can have the semantics of 'at least one False' versus
193 | 131    'all False'. Here we choose to comply with the official
194 | 132    implementation of Haystack (the 'at least one False' semantics).
195 | 133
196 | 134    :param document: The document to check the conditions against.
197 | 135    :type document: Document
198 | 136    :param conditions: The conditions to check against the document.
199 | 137    :type conditions: list[dict[str, Any]]
200 | 138
201 | 139    :return: True if not all conditions are met.
202 | 140    :rtype: bool
203 | 141    """
204 | 142    return not _and(document, conditions)
205 | 143
206 | 144
207 | 145def _check_comparator_inputs(
208 | 146    comparator: Callable[[Any, Any], bool]
209 | 147) -> Callable[[Any, Any], bool]:
210 | 148    """
211 | 149    A decorator to check and preprocess input attribute values.
212 | 150
213 | 151    ALL COMPARISON OPERATORS SHOULD BE WRAPPED WITH THIS DECORATOR.
214 | 152    because a `False` may be returned by both input validation and
215 | 153    the actual comparison. This decorator ensures that the comparison
216 | 154    function is only called if the input values are valid.
217 | 155
218 | 156    :param comparator: The comparator function to wrap.
219 | 157    :type comparator: Callable[[Any, Any], bool]
220 | 158
221 | 159    :return: The wrapped comparator function.
222 | 160    :rtype: Callable[[Any, Any], bool]
223 | 161    """
224 | 162
225 | 163    @wraps(comparator)
226 | 164    def _wrapper(dv: Any, fv: Any) -> bool:
227 | 165
228 | 166        # I think allowing comparison between DataFrames would
229 | 167        # be a really bad idea because it would create unexpected
230 | 168        # behavior, but I am open to discussion on this.
231 | 169        if isinstance(dv, pd.DataFrame) or isinstance(fv, pd.DataFrame):
232 | 170            msg = (
233 | 171                "Cannot compare DataFrames. Please convert them to "
234 | 172                "simpler data structures before comparing."
235 | 173            )
236 | 174            raise FilterError(msg)
237 | 175
238 | 176        # I think comparison between missing values is ambiguous,
239 | 177        # but again, I am open to discussion on this. Here I choose
240 | 178        # to return False if either value is None because from a
241 | 179        # logical perspective, we really cannot say anything about
242 | 180        # the comparison between a missing value and a non-missing.
243 | 181        if dv is None or fv is None:
244 | 182            return False
245 | 183
246 | 184        try:
247 | 185            return comparator(dv, fv)
248 | 186        except TypeError as exc:
249 | 187            msg = (
250 | 188                f"Cannot compare document value of {type(dv)} type "
251 | 189                f"with filter value of {type(fv)} type."
252 | 190            )
253 | 191            raise FilterError(msg) from exc
254 | 192
255 | 193    return _wrapper
256 | 194
257 | 195
258 | 196@_check_comparator_inputs
259 | 197def _eq(dv: Any, fv: Any) -> bool:
260 | 198    """
261 | 199    Conservative implementation of equal comparison.
262 | 200
263 | 201    There are two major differences between this implementation
264 | 202    and the default Haystack filter implementation:
265 | 203        - If both values are None, we return False, instead of True.
266 | 204        - If any value is a DataFrame, we raise an error, instead
267 | 205            of converting them to JSON.
268 | 206    """
269 | 207    return dv == fv
270 | 208
271 | 209
272 | 210@_check_comparator_inputs
273 | 211def _ne(dv: Any, fv: Any) -> bool:
274 | 212    return not _eq(dv, fv)
275 | 213
276 | 214
277 | 215@_check_comparator_inputs
278 | 216def _gt(dv: Any, fv: Any) -> bool:
279 | 217    """
280 | 218    A more liberal implementation with less surprises.
281 | 219
282 | 220    Simply compare the two values with default Python comparison.
283 | 221    We do not perform any conversion here to have the behavior
284 | 222    more predictable. If we want to compare the dates, we should
285 | 223    just convert the document value and filter value explicitly
286 | 224    to dates before comparing them.
287 | 225    """
288 | 226    return dv > fv
289 | 227
290 | 228
291 | 229@_check_comparator_inputs
292 | 230def _lt(dv: Any, fv: Any) -> bool:
293 | 231    return dv < fv
294 | 232
295 | 233
296 | 234@_check_comparator_inputs
297 | 235def _gte(dv: Any, fv: Any) -> bool:
298 | 236    return _gt(dv, fv) or _eq(dv, fv)
299 | 237
300 | 238
301 | 239@_check_comparator_inputs
302 | 240def _lte(dv: Any, fv: Any) -> bool:
303 | 241    return _lt(dv, fv) or _eq(dv, fv)
304 | 242
305 | 243
306 | 244@_check_comparator_inputs
307 | 245def _in(dv: Any, fv: Any) -> bool:
308 | 246    """
309 | 247    Allowing iterable filter values not just lists.
310 | 248
311 | 249    This implementation permits a larger set of filter values
312 | 250    such as tuples, sets, and other iterable objects.
313 | 251    """
314 | 252    if not isinstance(fv, Iterable):
315 | 253        msg = "Filter value must be an iterable for 'in' comparison."
316 | 254        raise FilterError(msg)
317 | 255
318 | 256    return any(_eq(dv, v) for v in fv)
319 | 257
320 | 258
321 | 259@_check_comparator_inputs
322 | 260def _nin(dv: Any, fv: Any) -> bool:
323 | 261    return not _in(dv, fv)
324 | 262
325 | 263
326 | 264LOGICAL_OPERATORS: Final = {"NOT": _not, "AND": _and, "OR": _or}
327 | 265
328 | 266COMPARISON_OPERATORS: Final = {
329 | 267    "==": _eq,
330 | 268    "!=": _ne,
331 | 269    ">": _gt,
332 | 270    "<": _lt,
333 | 271    ">=": _gte,
334 | 272    "<=": _lte,
335 | 273    "in": _in,
336 | 274    "not in": _nin,
337 | 275}
338 | 
339 | 340 | 341 |
342 |
343 | 344 |
345 | 346 | def 347 | apply_filters_to_document( filters: Optional[dict[str, Any]], document: haystack.dataclasses.document.Document) -> bool: 348 | 349 | 350 | 351 |
352 | 353 |
14def apply_filters_to_document(
354 | 15    filters: Optional[dict[str, Any]], document: Document
355 | 16) -> bool:
356 | 17    """
357 | 18    Apply filters to a document.
358 | 19
359 | 20    :param filters: The filters to apply to the document.
360 | 21    :type filters: dict[str, Any]
361 | 22    :param document: The document to apply the filters to.
362 | 23    :type document: Document
363 | 24
364 | 25    :return: True if the document passes the filters.
365 | 26    :rtype: bool
366 | 27    """
367 | 28    if filters is None or not filters:
368 | 29        return True
369 | 30    return _run_comparison_condition(filters, document)
370 | 
371 | 372 | 373 |

Apply filters to a document.

374 | 375 |
Parameters
376 | 377 |
    378 |
  • filters: The filters to apply to the document.
  • 379 |
  • document: The document to apply the filters to.
  • 380 |
381 | 382 |
Returns
383 | 384 |
385 |

True if the document passes the filters.

386 |
387 |
388 | 389 | 390 |
391 |
392 |
393 | LOGICAL_OPERATORS: Final = 394 | {'NOT': <function _not>, 'AND': <function _and>, 'OR': <function _or>} 395 | 396 | 397 |
398 | 399 | 400 | 401 | 402 |
403 |
404 |
405 | COMPARISON_OPERATORS: Final = 406 | 407 | {'==': <function _eq>, '!=': <function _ne>, '>': <function _gt>, '<': <function _lt>, '>=': <function _gte>, '<=': <function _lte>, 'in': <function _in>, 'not in': <function _nin>} 408 | 409 | 410 |
411 | 412 | 413 | 414 | 415 |
416 |
417 | 599 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "bbm25-haystack" 7 | dynamic = ["version"] 8 | description = 'Haystack 2.x In-memory BM25 Document Store with Enhanced Efficiency' 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | license = "Apache-2.0" 12 | keywords = ["Document Search", "BM25", "LLM Agent", "RAG", "Haystack"] 13 | authors = [ 14 | { name = "Guest400123064", email = "wangy49@seas.upenn.edu" }, 15 | ] 16 | classifiers = [ 17 | "Development Status :: 5 - Production/Stable", 18 | "Programming Language :: Python", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Programming Language :: Python :: Implementation :: CPython", 24 | "Programming Language :: Python :: Implementation :: PyPy", 25 | ] 26 | dependencies = [ 27 | "haystack-ai", 28 | "sentencepiece", 29 | ] 30 | 31 | [project.urls] 32 | Documentation = "https://github.com/Guest400123064/bbm25-haystack#readme" 33 | Issues = "https://github.com/Guest400123064/bbm25-haystack/issues" 34 | Source = "https://github.com/Guest400123064/bbm25-haystack" 35 | 36 | [tool.hatch.version] 37 | path = "src/bbm25_haystack/__about__.py" 38 | 39 | [tool.hatch.envs.default] 40 | dependencies = [ 41 | "coverage[toml]>=6.5", 42 | "pytest", 43 | "pytest-cov", 44 | "hypothesis", 45 | ] 46 | [tool.hatch.envs.default.scripts] 47 | test = "pytest {args:tests}" 48 | test-cov = "coverage run -m pytest {args:tests}" 49 | cov-report = [ 50 | "- coverage combine", 51 | "coverage xml", 52 | ] 53 | cov = [ 54 | "test-cov", 55 | "cov-report", 56 | ] 57 | 58 | [[tool.hatch.envs.all.matrix]] 59 | python = ["3.9", "3.10", "3.11", "3.12"] 60 | 61 | [tool.hatch.envs.lint] 62 | detached = true 63 | dependencies = [ 64 | "black>=23.1.0", 65 | "mypy>=1.0.0", 66 | "ruff>=0.0.243", 67 | ] 68 | [tool.hatch.envs.lint.scripts] 69 | typing = "mypy --install-types --non-interactive {args:src/bbm25_haystack tests}" 70 | style = [ 71 | "ruff {args:check .}", 72 | "black --check --diff {args:.}", 73 | ] 74 | fmt = [ 75 | "black {args:.}", 76 | "ruff {args:check .} --fix", 77 | "style", 78 | ] 79 | all = [ 80 | "style", 81 | "typing", 82 | ] 83 | 84 | [tool.hatch.metadata] 85 | allow-direct-references = true 86 | 87 | [tool.black] 88 | target-version = ["py39"] 89 | line-length = 85 90 | skip-string-normalization = true 91 | 92 | [tool.ruff] 93 | target-version = "py39" 94 | line-length = 85 95 | select = [ 96 | "A", 97 | "ARG", 98 | "B", 99 | "C", 100 | "DTZ", 101 | "E", 102 | "EM", 103 | "F", 104 | "FBT", 105 | "I", 106 | "ICN", 107 | "ISC", 108 | "N", 109 | "PLC", 110 | "PLE", 111 | "PLR", 112 | "PLW", 113 | "Q", 114 | "RUF", 115 | "S", 116 | "T", 117 | "TID", 118 | "UP", 119 | "W", 120 | "YTT", 121 | ] 122 | ignore = [ 123 | # Allow non-abstract empty methods in abstract base classes 124 | "B027", 125 | # Allow boolean positional values in function calls, like `dict.get(... True)` 126 | "FBT003", 127 | # Ignore checks for possible passwords 128 | "S105", "S106", "S107", 129 | # Ignore complexity 130 | "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", 131 | # Ignore usage of `lambda` expression 132 | "E731", 133 | ] 134 | unfixable = [ 135 | # Don't touch unused imports 136 | "F401", 137 | ] 138 | 139 | [tool.ruff.isort] 140 | known-first-party = ["bbm25_haystack"] 141 | 142 | [tool.ruff.flake8-tidy-imports] 143 | ban-relative-imports = "all" 144 | 145 | [tool.ruff.per-file-ignores] 146 | # Tests can use magic values, assertions, and relative imports 147 | "tests/**/*" = ["PLR2004", "S101", "TID252"] 148 | 149 | [tool.coverage.run] 150 | source_pkgs = ["bbm25_haystack", "tests"] 151 | branch = true 152 | parallel = true 153 | omit = [ 154 | "src/bbm25_haystack/__about__.py", 155 | ] 156 | 157 | [tool.coverage.paths] 158 | bbm25_haystack = ["src/bbm25_haystack", "*/bbm25-haystack/src/bbm25_haystack"] 159 | tests = ["tests", "*/bbm25-haystack/tests"] 160 | 161 | [tool.coverage.report] 162 | exclude_lines = [ 163 | "no cov", 164 | "if __name__ == .__main__.:", 165 | "if TYPE_CHECKING:", 166 | ] 167 | 168 | [tool.pytest.ini_options] 169 | minversion = "6.0" 170 | markers = [ 171 | "unit: unit tests", 172 | "integration: integration tests" 173 | ] 174 | 175 | [[tool.mypy.overrides]] 176 | module = [ 177 | "haystack.*", 178 | "pytest.*" 179 | ] 180 | ignore_missing_imports = true 181 | -------------------------------------------------------------------------------- /scripts/benchmark_beir.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Yuxuan Wang 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import argparse 5 | import logging 6 | import os 7 | import pathlib 8 | from collections import deque 9 | 10 | import pandas as pd 11 | import tqdm 12 | from beir import LoggingHandler, util 13 | from beir.datasets.data_loader import GenericDataLoader 14 | from beir.retrieval.evaluation import EvaluateRetrieval 15 | from beir.retrieval.search import BaseSearch 16 | from haystack import Document 17 | 18 | from bbm25_haystack import BetterBM25DocumentStore 19 | 20 | DIR_PROJ = pathlib.Path(__file__).resolve().parent.parent 21 | DIR_BEIR_RAW = DIR_PROJ / "benchmarks" / "beir" 22 | 23 | URL_BEIR = ( 24 | "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{name}.zip" 25 | ) 26 | 27 | DATASETS = [ 28 | # General IR (in-domain) 29 | "msmarco", 30 | # Bio-medical IR 31 | "trec-covid", 32 | "nfcorpus", 33 | # Question answering 34 | "nq", 35 | "hotpotqa", 36 | "fiqa", 37 | # Citation prediction 38 | "scidocs", 39 | # Argument retrieval 40 | "arguana", 41 | "webis-touche2020", 42 | # Duplicate question retrieval 43 | "quora", 44 | "cqadupstack", 45 | # Fact checking 46 | "scifact", 47 | "fever", 48 | "climate-fever", 49 | # Entity retrieval 50 | "dbpedia-entity", 51 | ] 52 | 53 | logging.basicConfig( 54 | format="%(asctime)s - %(message)s", 55 | datefmt="%Y-%m-%d %H:%M:%S", 56 | level=logging.INFO, 57 | handlers=[LoggingHandler()], 58 | ) 59 | 60 | 61 | class BEIRWrapper(BaseSearch): 62 | """Wrapper for the BetterBM25DocumentStore to be compatible with BEIR.""" 63 | 64 | def __init__(self, store: BetterBM25DocumentStore) -> None: 65 | self._store = store 66 | self._indexed = False 67 | 68 | def index(self, corpus: dict[str, dict[str, str]]) -> None: 69 | """Index the corpus for retrieval.""" 70 | 71 | for idx, raw in tqdm.tqdm(corpus.items(), desc="Indexing corpus"): 72 | raw_title = raw.get("title", "") 73 | raw_text = raw.get("text", "") 74 | 75 | content = f"title: {raw_title}; text: {raw_text}" 76 | document = Document(idx, content=content) 77 | self._store.write_documents([document]) 78 | self._indexed = True 79 | 80 | def search( 81 | self, 82 | corpus: dict[str, dict[str]], 83 | queries: dict[str, str], 84 | top_k: int = 10, 85 | *args, 86 | **kwargs, 87 | ) -> dict[str, dict[str, float]]: 88 | """Search the corpus for relevant documents.""" 89 | 90 | _ = args 91 | _ = kwargs 92 | 93 | if not self._indexed: 94 | self.index(corpus) 95 | 96 | results = {} 97 | for idx, qry in tqdm.tqdm(queries.items(), desc="Searching queries"): 98 | result = self._store._retrieval(qry, top_k=top_k) 99 | results[idx] = {doc.id: scr for doc, scr in result if doc.id != idx} 100 | return results 101 | 102 | 103 | def download_dataset_from_beir(name: str) -> bool: 104 | """Download a dataset maintained by the UKP Lab.""" 105 | 106 | if os.path.isdir(DIR_BEIR_RAW / name): 107 | logging.info(f"Dataset {name} already downloaded. Skipping...") 108 | return True 109 | 110 | try: 111 | logging.info(f"Downloading dataset {name} from BEIR to {DIR_BEIR_RAW}...") 112 | util.download_and_unzip(URL_BEIR.format(name=name), DIR_BEIR_RAW) 113 | except Exception as exc: 114 | logging.warn(f"Failed to download dataset {name} from BEIR: {exc}") 115 | return False 116 | 117 | logging.info(f"Dataset {name} downloaded successfully.") 118 | return True 119 | 120 | 121 | def evaluate_retriever(args: argparse.Namespace) -> None: 122 | """Evaluate the retrieval performance of a query encoder over 123 | the BEIR benchmark.""" 124 | 125 | queue = deque() # [ local_save_dir_name... ] 126 | for name in args.datasets or DATASETS: 127 | download_dataset_from_beir(name) 128 | 129 | if name != "cqadupstack": 130 | queue.append(name) 131 | continue 132 | 133 | # Special handling for the CQADupStack dataset because the dataset has 134 | # subdirectories for each topic; so we need to flatten the directory. 135 | for sub_name in os.listdir(DIR_BEIR_RAW / "cqadupstack"): 136 | sub_name_alt = str(os.path.join("cqadupstack", sub_name)) 137 | queue.append(sub_name_alt) 138 | 139 | records = [] 140 | while queue: 141 | ds_name = queue.popleft() 142 | dir_raw = DIR_BEIR_RAW / ds_name 143 | 144 | store = BetterBM25DocumentStore( 145 | k=args.bm25_k1, 146 | b=args.bm25_b, 147 | delta=args.bm25_delta, 148 | sp_file=args.sp_file, 149 | n_grams=args.n_grams, 150 | ) 151 | model = BEIRWrapper(store) 152 | retriever = EvaluateRetrieval(model) 153 | 154 | corpus, queries, qrels = GenericDataLoader(dir_raw).load(split=args.split) 155 | results = retriever.retrieve(corpus, queries) 156 | 157 | logging.info(f"Evaluating retriever over {ds_name}...") 158 | 159 | record = {} 160 | for metric in retriever.evaluate(qrels, results, k_values=args.k_values): 161 | record.update(metric) 162 | 163 | record.update( 164 | { 165 | "datetime": pd.Timestamp.now(), 166 | "dataset": ds_name.replace("/", "-"), 167 | } 168 | ) 169 | record.update(store.to_dict().get("init_parameters")) 170 | records.append(record) 171 | 172 | records = pd.DataFrame(records) 173 | records.to_csv(args.output, index=False) 174 | 175 | 176 | def get_args() -> argparse.Namespace: 177 | """Get command line arguments for evaluating retrieval performance.""" 178 | 179 | parser = argparse.ArgumentParser( 180 | prog="benchmark_beir.py", 181 | description="Evaluate retrieval performance over the BEIR benchmark.", 182 | epilog="Email wangy49@seas.upenn.edu for questions.", 183 | ) 184 | 185 | parser.add_argument( 186 | "--datasets", 187 | type=str, 188 | nargs="+", 189 | required=False, 190 | default=None, 191 | choices=DATASETS, 192 | help=( 193 | "Dataset names to evaluate on. All datasets will be used " 194 | "if not specified (default: None)" 195 | ), 196 | ) 197 | parser.add_argument( 198 | "--bm25-k1", 199 | type=float, 200 | required=False, 201 | default=1.5, 202 | help="The BM25+ k1 parameter; default to 1.5", 203 | ) 204 | parser.add_argument( 205 | "--bm25-b", 206 | type=float, 207 | default=0.75, 208 | required=False, 209 | help="The BM25+ b parameter; default to 0.75", 210 | ) 211 | parser.add_argument( 212 | "--bm25-delta", 213 | type=float, 214 | default=1.0, 215 | required=False, 216 | help="The BM25+ delta parameter; default to 1.0", 217 | ) 218 | parser.add_argument( 219 | "--sp-file", 220 | type=str, 221 | default=None, 222 | required=False, 223 | help="Path to the SentencePiece model file; default to None (LLaMA2)", 224 | ) 225 | parser.add_argument( 226 | "--n-grams", 227 | type=int, 228 | default=1, 229 | required=False, 230 | help="The n-gram size up to n for tokenizations; default to 1", 231 | ) 232 | parser.add_argument( 233 | "--split", 234 | type=str, 235 | default="test", 236 | required=False, 237 | choices=["train", "dev", "test"], 238 | help="Dataset split to evaluate on (default: 'test')", 239 | ) 240 | parser.add_argument( 241 | "--output", 242 | type=str, 243 | default="beir_evaluation_results.csv", 244 | help="Path to the evaluation result", 245 | ) 246 | parser.add_argument( 247 | "--k-values", 248 | type=int, 249 | nargs="+", 250 | required=False, 251 | default=[10, 100], 252 | help="Top-k values for evaluation (default: [10, 100])", 253 | ) 254 | 255 | args = parser.parse_args() 256 | return args 257 | 258 | 259 | if __name__ == "__main__": 260 | evaluate_retriever(get_args()) 261 | -------------------------------------------------------------------------------- /src/bbm25_haystack/__about__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Guest400123064 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | __version__ = "0.2.2" 6 | -------------------------------------------------------------------------------- /src/bbm25_haystack/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Guest400123064 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from bbm25_haystack.bbm25_retriever import BetterBM25Retriever 5 | from bbm25_haystack.bbm25_store import BetterBM25DocumentStore 6 | from bbm25_haystack.filters import apply_filters_to_document 7 | 8 | __all__ = [ 9 | "BetterBM25DocumentStore", 10 | "BetterBM25Retriever", 11 | "apply_filters_to_document", 12 | ] 13 | -------------------------------------------------------------------------------- /src/bbm25_haystack/bbm25_retriever.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Guest400123064 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Any, Optional 5 | 6 | from haystack import ( 7 | DeserializationError, 8 | Document, 9 | component, 10 | default_from_dict, 11 | default_to_dict, 12 | ) 13 | 14 | from bbm25_haystack.bbm25_store import BetterBM25DocumentStore 15 | 16 | 17 | def _validate_search_params(filters: Optional[dict[str, Any]], top_k: int) -> None: 18 | """ 19 | Validate the search parameters. 20 | 21 | :param filters: Haystack filters, a dictionary with filters to 22 | narrow down the search space. The filters are applied 23 | **before** similarity retrieval. 24 | :type filters: ``Optional[dict[str, Any]]`` 25 | :param top_k: The maximum number of documents to return. 26 | :type top_k: ``int`` 27 | 28 | :raises ValueError: If the specified top_k is not > 0. 29 | :raises TypeError: If filters is not a dictionary. 30 | """ 31 | if not isinstance(top_k, int): 32 | msg = f"'top_k' must be an integer; got '{type(top_k)}' instead." 33 | raise TypeError(msg) 34 | 35 | if top_k <= 0: 36 | msg = f"'top_k' must be > 0; got '{top_k}' instead." 37 | raise ValueError(msg) 38 | 39 | if filters is not None and (not isinstance(filters, dict)): 40 | msg = f"'filters' must be a dictionary; got '{type(filters)}' instead." 41 | raise TypeError(msg) 42 | 43 | 44 | @component 45 | class BetterBM25Retriever: 46 | """ 47 | A component for retrieving documents from a ``BetterBM25DocumentStore``. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | document_store: BetterBM25DocumentStore, 53 | *, 54 | filters: Optional[dict[str, Any]] = None, 55 | top_k: int = 10, 56 | set_score: bool = True, 57 | ) -> None: 58 | """ 59 | Create a ``BetterBM25Retriever`` component. 60 | 61 | :param document_store: A ``BetterBM25DocumentStore`` instance. 62 | :type document_store: ``BetterBM25DocumentStore`` 63 | :param filters: Haystack filters, a dictionary with filters to 64 | narrow down the search space. The filters are applied 65 | **before** similarity retrieval. 66 | :type filters: ``Optional[dict[str, Any]]`` 67 | :param top_k: The maximum number of documents to return. 68 | :type top_k: ``int`` 69 | :param set_score: Whether to set the similarity scores to returned 70 | documents under ``Document.score`` attribute. This is useful in 71 | hybrid retrieval setting where you may want to merge results. 72 | Note that returned documents are **copies** so that the original 73 | instances in the document store are not modified. 74 | :type set_score: ``bool`` 75 | 76 | :raises ValueError: If the ``filters`` or ``top_k`` is invalid. 77 | :raises TypeError: If the ``document_store`` is not an instance of 78 | ``BetterBM25DocumentStore``. 79 | """ 80 | _validate_search_params(filters, top_k) 81 | 82 | self.filters = filters 83 | self.top_k = top_k 84 | self.set_score = set_score 85 | 86 | if not isinstance(document_store, BetterBM25DocumentStore): 87 | msg = "'document_store' must be of type 'BetterBM25DocumentStore'" 88 | raise TypeError(msg) 89 | 90 | self.document_store = document_store 91 | 92 | @component.output_types(documents=list[Document]) 93 | def run( 94 | self, 95 | query: str, 96 | *, 97 | filters: Optional[dict[str, Any]] = None, 98 | top_k: Optional[int] = None, 99 | ) -> dict[str, list[Document]]: 100 | """ 101 | Run the Retriever on the given query. This method always return 102 | copies of the documents retrieved from the document store. 103 | 104 | :param query: The text search term. 105 | :type query: ``str`` 106 | :param filters: Haystack filters, a dictionary with filters to 107 | narrow down the search space. The filters are applied 108 | **before** similarity retrieval. Use the value provided during 109 | initialization if not specified. 110 | :type filters: ``Optional[dict[str, Any]]`` 111 | :param top_k: The maximum number of documents to return. Use the 112 | value provided during initialization if not specified. 113 | :type top_k: ``Optional[int]`` 114 | 115 | :return: The retrieved documents in a dictionary with key "documents". 116 | """ 117 | filters = filters or self.filters 118 | top_k = top_k or self.top_k 119 | 120 | _validate_search_params(filters, top_k) 121 | 122 | sim = self.document_store._retrieval(query, filters=filters, top_k=top_k) 123 | 124 | ret = [] 125 | for doc, scr in sim: 126 | data = doc.to_dict() 127 | if self.set_score: 128 | data["score"] = scr 129 | ret.append(Document.from_dict(data)) 130 | 131 | return {"documents": ret} 132 | 133 | def to_dict(self) -> dict[str, Any]: 134 | """Serializes the component to a dictionary.""" 135 | return default_to_dict( 136 | self, 137 | filters=self.filters, 138 | top_k=self.top_k, 139 | document_store=self.document_store.to_dict(), 140 | set_score=self.set_score, 141 | ) 142 | 143 | @classmethod 144 | def from_dict(cls, data: dict[str, Any]) -> "BetterBM25Retriever": 145 | """Deserializes the retriever from a dictionary.""" 146 | doc_store_params = data["init_parameters"].get("document_store") 147 | if doc_store_params is None: 148 | msg = "Missing 'document_store' in serialization data" 149 | raise DeserializationError(msg) 150 | 151 | if doc_store_params.get("type") is None: 152 | msg = "Missing 'type' in document store's serialization data" 153 | raise DeserializationError(msg) 154 | 155 | data["init_parameters"]["document_store"] = ( 156 | BetterBM25DocumentStore.from_dict(doc_store_params) 157 | ) 158 | return default_from_dict(cls, data) 159 | -------------------------------------------------------------------------------- /src/bbm25_haystack/bbm25_store.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Guest400123064 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import heapq 5 | import math 6 | import os 7 | from collections import Counter, deque 8 | from collections.abc import Iterable 9 | from itertools import chain 10 | from typing import Any, Final, Optional, Union 11 | 12 | import pandas as pd 13 | from haystack import Document, default_from_dict, default_to_dict, logging 14 | from haystack.document_stores.errors import ( 15 | DuplicateDocumentError, 16 | MissingDocumentError, 17 | ) 18 | from haystack.document_stores.types import DuplicatePolicy 19 | from haystack.utils.filters import document_matches_filter 20 | from sentencepiece import SentencePieceProcessor # type: ignore 21 | 22 | from bbm25_haystack.filters import apply_filters_to_document 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def _n_grams(seq: Iterable[str], n: int): 28 | """ 29 | Returns a sliding window (of width n) over data from the 30 | iterable. This solution is adapted from the StackOverflow 31 | answer [here](https://stackoverflow.com/a/6822773/13403958). 32 | 33 | :param seq: Input token sequence. 34 | :type seq: ``Iterable[str]`` 35 | :param n: Window size. 36 | :type n: ``int`` 37 | 38 | :return: The n-gram window generator. 39 | :rtype: ``Generator[tuple[str], None, None]`` 40 | """ 41 | it = iter(seq) 42 | wd = deque((next(it, None) for _ in range(n)), maxlen=n) 43 | 44 | yield tuple(wd) 45 | for el in it: 46 | wd.append(el) 47 | yield tuple(wd) 48 | 49 | 50 | class BetterBM25DocumentStore: 51 | """ 52 | An in-memory BM25 document store intended to improve the default 53 | ``InMemoryDocumentStore`` shipped with Haystack. 54 | """ 55 | 56 | _default_sp_file: Final = os.path.join( 57 | os.path.dirname(os.path.abspath(__file__)), "default.model" 58 | ) 59 | 60 | def __init__( 61 | self, 62 | *, 63 | k: float = 1.5, 64 | b: float = 0.75, 65 | delta: float = 1.0, 66 | sp_file: Optional[str] = None, 67 | n_grams: Union[int, tuple[int, int]] = 1, 68 | haystack_filter_logic: bool = True, 69 | ) -> None: 70 | """ 71 | Creates a new ``BetterBM25DocumentStore`` instance. 72 | 73 | :param k: k1 parameter in BM25+ formula. 74 | :type k: ``Optional[float]`` 75 | :param b: b parameter in BM25+ formula. 76 | :type b: ``Optional[float]`` 77 | :param delta: delta parameter in BM25+ formula. 78 | :type delta: ``Optional[float]`` 79 | :param sp_file: ``SentencePiece`` tokenizer ``.model`` file to 80 | use. A default from LLaMA-2-32K is used if not provided. 81 | :type sp_file: ``Optional[str]`` 82 | :param n_grams: The n-gram window size. Can be a range of n-grams 83 | to include in text representation. If a single integer is 84 | provided, it will be treated as the maximum n-gram window size, 85 | which is equivalent to ``(1, n_grams)``. 86 | :type n_grams: ``Optional[Union[int, tuple[int, int]]]`` 87 | :param haystack_filter_logic: Whether to use the Haystack filter 88 | logic or the one implemented in this store. 89 | :type haystack_filter_logic: ``Optional[bool]`` 90 | """ 91 | self._k = k 92 | self._b = b 93 | 94 | # Adjust the delta value so that we can bring the ``(k1 + 1)`` 95 | # term out of the 'term frequency' term in BM25+ formula and 96 | # delete it; this will not affect the ranking. 97 | self._delta = delta / (self._k + 1.0) 98 | 99 | self._parse_sp_file(sp_file=sp_file) 100 | self._parse_n_grams(n_grams=n_grams) 101 | 102 | self._haystack_filter_logic = haystack_filter_logic 103 | self._filter_func = ( 104 | document_matches_filter 105 | if self._haystack_filter_logic 106 | else apply_filters_to_document 107 | ) 108 | 109 | self._avg_doc_len: float = 0.0 110 | self._freq_doc: Counter = Counter() 111 | self._index: dict[str, tuple[Document, dict[tuple[str], int], int]] = {} 112 | 113 | def _parse_sp_file(self, sp_file: Optional[str]) -> None: 114 | self._sp_file = sp_file 115 | 116 | if sp_file is None: 117 | self._sp_inst = SentencePieceProcessor(model_file=self._default_sp_file) 118 | return 119 | 120 | if not os.path.exists(sp_file) or not os.path.isfile(sp_file): 121 | msg = ( 122 | f"Tokenizer model file '{sp_file}' not accessible; " 123 | f"fallback to default {self._default_sp_file}." 124 | ) 125 | logger.warn(msg) 126 | self._sp_inst = SentencePieceProcessor(model_file=self._default_sp_file) 127 | return 128 | 129 | try: 130 | self._sp_inst = SentencePieceProcessor(model_file=sp_file) 131 | except Exception as exc: 132 | msg = ( 133 | f"Failed to load tokenizer model file '{sp_file}': {exc}; " 134 | f"fallback to default {self._default_sp_file}." 135 | ) 136 | logger.error(msg) 137 | self._sp_inst = SentencePieceProcessor(model_file=self._default_sp_file) 138 | 139 | def _parse_n_grams(self, n_grams: Optional[Union[int, tuple[int, int]]]) -> None: 140 | self._n_grams = n_grams 141 | 142 | if isinstance(n_grams, int): 143 | self._n_grams_min = 1 144 | self._n_grams_max = n_grams 145 | return 146 | 147 | if isinstance(n_grams, tuple): 148 | self._n_grams_min, self._n_grams_max = n_grams 149 | if not all(isinstance(n, int) for n in n_grams): 150 | msg = f"Invalid n-gram window size: {n_grams}." 151 | raise ValueError(msg) 152 | return 153 | 154 | msg = f"Invalid n-gram window size: {n_grams}; expected int or tuple." 155 | raise ValueError(msg) 156 | 157 | def _tokenize(self, texts: Union[str, list[str]]) -> list[list[tuple[str]]]: 158 | """ 159 | Tokenize input text using SentencePiece model. 160 | 161 | The input text can either be a single string or a list of strings, 162 | such as a single user query or a group of raw document. The tokenized 163 | text will be augmented into set of n-grams based. 164 | 165 | :param texts: Input text to tokenize, queries or documents. 166 | :type texts: ``Union[str, list[str]]`` 167 | 168 | :return: Tokenized and n-gram augmented texts. 169 | :rtype: ``list[list[tuple[str]]]`` 170 | """ 171 | 172 | def _augment_to_n_grams(tokens: list[str]) -> list[tuple[str]]: 173 | it = ( 174 | _n_grams(tokens, n) 175 | for n in range(self._n_grams_min, self._n_grams_max + 1) 176 | ) 177 | return list(chain(*it)) 178 | 179 | if isinstance(texts, str): 180 | texts = [texts] 181 | return [ 182 | _augment_to_n_grams(tokens) 183 | for tokens in self._sp_inst.encode(texts, out_type=str) 184 | ] 185 | 186 | def _compute_bm25plus( 187 | self, 188 | query: str, 189 | documents: list[Document], 190 | ) -> list[tuple[Document, float]]: 191 | """ 192 | Calculate the BM25+ score for all documents in this index. 193 | 194 | :param query: Query to calculate the BM25+ score for. 195 | :type query: ``str`` 196 | :param documents: Filtered pool of documents retrieve from. 197 | :type documents: ``list[Document]`` 198 | 199 | :return: Documents and corresponding BM25+ scores. 200 | :rtype: ``list[tuple[Document, float]]`` 201 | """ 202 | cnt = lambda ng: self._freq_doc.get(ng, 0) 203 | idf = { 204 | ng: math.log( 205 | 1 + (self.count_documents() - cnt(ng) + 0.5) / (cnt(ng) + 0.5) 206 | ) 207 | for ng in self._tokenize(query)[0] 208 | } 209 | 210 | sim = [] 211 | for doc in documents: 212 | _, freq, doc_len = self._index[doc.id] 213 | doc_len_scaled = doc_len / self._avg_doc_len 214 | 215 | scr = 0.0 216 | for token, idf_val in idf.items(): 217 | freq_term = freq.get(token, 0.0) 218 | freq_damp = self._k * (1 + self._b * (doc_len_scaled - 1)) 219 | 220 | tf_val = freq_term / (freq_term + freq_damp) + self._delta 221 | scr += idf_val * tf_val 222 | 223 | sim.append((doc, scr)) 224 | 225 | return sim 226 | 227 | def _retrieval( 228 | self, 229 | query: str, 230 | *, 231 | filters: Optional[dict[str, Any]] = None, 232 | top_k: Optional[int] = None, 233 | ) -> list[tuple[Document, float]]: 234 | """ 235 | Retrieve documents from the store using the given query. 236 | 237 | :param query: Query to search for. 238 | :type query: ``str`` 239 | :param filters: Filters to apply to the document list. 240 | :type filters: ``Optional[dict[str, Any]]`` 241 | :param top_k: Number of documents to return. 242 | :type top_k: ``int`` 243 | 244 | :return: Top ``k`` documents and corresponding BM25+ scores. 245 | :rtype: ``list[tuple[Document, float]]`` 246 | """ 247 | documents = self.filter_documents(filters) 248 | if not documents: 249 | return [] 250 | 251 | sim = self._compute_bm25plus(query, documents) 252 | if top_k is None: 253 | return sorted(sim, key=lambda x: x[1], reverse=True) 254 | return heapq.nlargest(top_k, sim, key=lambda x: x[1]) 255 | 256 | def count_documents(self) -> int: 257 | """ 258 | Returns how many documents are present in this store. 259 | 260 | :return: Number of documents in the store. 261 | :rtype: ``int`` 262 | """ 263 | return len(self._index.keys()) 264 | 265 | def filter_documents( 266 | self, filters: Optional[dict[str, Any]] = None 267 | ) -> list[Document]: 268 | """ 269 | Filter documents in the store using the given filters. 270 | 271 | :param filters: Filters to apply to the document list. 272 | :type filters: ``Optional[dict[str, Any]]`` 273 | 274 | :return: List of documents that match the given filters. 275 | :rtype: ``list[Document]`` 276 | """ 277 | if filters is None or not filters: 278 | return [doc for doc, _, _ in self._index.values()] 279 | return [ 280 | doc 281 | for doc, _, _ in self._index.values() 282 | if self._filter_func(filters, doc) 283 | ] 284 | 285 | def write_documents( 286 | self, 287 | documents: list[Document], 288 | policy: DuplicatePolicy = DuplicatePolicy.NONE, 289 | ) -> int: 290 | """ 291 | Writes (or overwrites) documents into the store. 292 | 293 | :param documents: List of documents to write. 294 | :type documents: ``list[Document]`` 295 | :param policy: Documents with the same ``Document.id`` count as 296 | duplicates. When duplicates are met, the store can: 297 | - ``SKIP``: keep the existing document and ignore the new one. 298 | - ``OVERWRITE``: remove the old document and write the new one. 299 | - ``FAIL``: an error is raised (default behavior if not specified) 300 | :type policy: ``Optional[DuplicatePolicy]`` 301 | 302 | :raises ValueError: Exception trigger on invalid duplicate policy. 303 | :raises DuplicateDocumentError: Exception trigger on duplicate 304 | document if ``policy=DuplicatePolicy.FAIL`` 305 | 306 | :return: Number of documents written. 307 | :rtype: ``int`` 308 | """ 309 | if policy not in DuplicatePolicy: 310 | msg = f"Invalid duplicate policy: {policy}." 311 | raise ValueError(msg) 312 | 313 | if policy == DuplicatePolicy.NONE: 314 | policy = DuplicatePolicy.FAIL 315 | 316 | n_written = 0 317 | for doc in documents: 318 | if not isinstance(doc, Document): 319 | msg = f"Expected document type, got '{doc}' of type '{type(doc)}'." 320 | raise ValueError(msg) 321 | 322 | if doc.id in self._index.keys(): 323 | if policy == DuplicatePolicy.SKIP: 324 | continue 325 | elif policy == DuplicatePolicy.FAIL: 326 | msg = f"Document with ID '{doc.id}' already exists in the store." 327 | raise DuplicateDocumentError(msg) 328 | 329 | # Overwrite if exists; delete first to keep the statistics consistent 330 | logger.debug( 331 | f"Document '{doc.id}' already exists in the store, overwriting." 332 | ) 333 | self.delete_documents([doc.id]) 334 | 335 | content = doc.content or "" 336 | if content == "" and isinstance(doc.dataframe, pd.DataFrame): 337 | content = doc.dataframe.astype(str).to_csv(index=False) 338 | 339 | tokens = self._tokenize(content)[0] 340 | 341 | self._index[doc.id] = (doc, Counter(tokens), len(tokens)) 342 | self._freq_doc.update(set(tokens)) 343 | self._avg_doc_len = ( 344 | len(tokens) + self._avg_doc_len * self.count_documents() 345 | ) / (self.count_documents() + 1) 346 | 347 | logger.debug(f"Document '{doc.id}' written to store.") 348 | n_written += 1 349 | 350 | return n_written 351 | 352 | def delete_documents(self, document_ids: list[str]) -> int: 353 | """ 354 | Deletes all documents with a matching ID. 355 | 356 | :param document_ids: List of ``object_id`` to delete 357 | :type document_ids: ``list[str]`` 358 | 359 | :raises MissingDocumentError: Triggered on document not found. 360 | 361 | :return: Number of documents deleted. 362 | :rtype: ``int`` 363 | """ 364 | n_removal = 0 365 | for doc_id in document_ids: 366 | try: 367 | _, freq, doc_len = self._index.pop(doc_id) 368 | self._freq_doc.subtract(Counter(freq.keys())) 369 | try: 370 | self._avg_doc_len = ( 371 | self._avg_doc_len * (self.count_documents() + 1) - doc_len 372 | ) / self.count_documents() 373 | except ZeroDivisionError: 374 | self._avg_doc_len = 0 375 | 376 | logger.debug(f"Document '{doc_id}' deleted from store.") 377 | n_removal += 1 378 | except KeyError as exc: 379 | msg = f"Document with ID '{doc_id}' not found, cannot delete it." 380 | raise MissingDocumentError(msg) from exc 381 | 382 | return n_removal 383 | 384 | def to_dict(self) -> dict[str, Any]: 385 | """Serializes this store to a dictionary.""" 386 | return default_to_dict( 387 | self, 388 | k=self._k, 389 | b=self._b, 390 | delta=self._delta * (self._k + 1.0), # Because we scaled it on init 391 | sp_file=self._sp_file, 392 | n_grams=self._n_grams, 393 | haystack_filter_logic=self._haystack_filter_logic, 394 | ) 395 | 396 | @classmethod 397 | def from_dict(cls, data: dict[str, Any]) -> "BetterBM25DocumentStore": 398 | """Deserializes the store from a dictionary.""" 399 | return default_from_dict(cls, data) 400 | -------------------------------------------------------------------------------- /src/bbm25_haystack/default.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guest400123064/bbm25-haystack/9906fa27ffc54f4fd92dfb5d717c15a12a69df0a/src/bbm25_haystack/default.model -------------------------------------------------------------------------------- /src/bbm25_haystack/filters.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Guest400123064 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from collections.abc import Iterable 5 | from functools import wraps 6 | from typing import Any, Callable, Final, Optional 7 | 8 | import pandas as pd 9 | from haystack.dataclasses import Document 10 | from haystack.errors import FilterError 11 | 12 | 13 | def apply_filters_to_document( 14 | filters: Optional[dict[str, Any]], document: Document 15 | ) -> bool: 16 | """ 17 | Apply filters to a document. Differences with the official 18 | Haystack implementation: 19 | 20 | - Comparison with ``None``, i.e., missing values, involved will 21 | always return ``False``, no matter missing the document 22 | attribute value or missing the filter value. 23 | - Comparison with ``pandas.DataFrame`` is always prohibited to 24 | reduce surprises. 25 | - No implicit ``datetime`` conversion from string values. 26 | - ``in`` and ``not in`` allows any ``Iterable`` as filter value, 27 | without the ``list`` constraint. 28 | - Allowing custom comparison functions for more flexibility. Note 29 | that the custom comparison function inputs are NEVER checked, 30 | i.e., no missing value check, no ``DataFrame`` check, etc. User 31 | should ensure the input values are valid and return value is 32 | always a boolean. The inputs are always supplied in the order 33 | of document value and then filter value. 34 | 35 | :param filters: The filters to apply to the document. 36 | :type filters: ``dict[str, Any]`` 37 | :param document: The document to apply the filters to. 38 | :type document: ``Document`` 39 | 40 | :return: ``True`` if the document passes the filters. 41 | :rtype: ``bool`` 42 | """ 43 | if filters is None or not filters: 44 | return True 45 | return _run_comparison_condition(filters, document) 46 | 47 | 48 | def _get_document_field(document: Document, field: str) -> Optional[Any]: 49 | """ 50 | Get the value of a field in a document. 51 | 52 | If the field is not found within the document then, instead of 53 | raising an error, ``None`` is returned. Note that here we do not 54 | implicitly add ``'meta'`` prefix for fields that are not a direct 55 | attribute of the document, not supporting legacy behavior anymore. 56 | 57 | :param document: The document to get the field value from. 58 | :type document: ``Document`` 59 | :param field: The field to get the value of. 60 | :type field: ``str`` 61 | 62 | :return: The value of the field in the document. 63 | :rtype: ``Optional[Any]`` 64 | """ 65 | if "." not in field: 66 | return getattr(document, field) 67 | 68 | attr = document.meta 69 | for f in field.split(".")[1:]: 70 | attr = attr.get(f) 71 | if attr is None: 72 | return None 73 | return attr 74 | 75 | 76 | def _run_logical_condition(condition: dict[str, Any], document: Document) -> bool: 77 | if "operator" not in condition: 78 | msg = "Logical condition must have an 'operator' key." 79 | raise FilterError(msg) 80 | if "conditions" not in condition: 81 | msg = "Logical condition must have a 'conditions' key." 82 | raise FilterError(msg) 83 | 84 | conditions = condition["conditions"] 85 | reducer = LOGICAL_OPERATORS[condition["operator"]] 86 | 87 | return reducer(document, conditions) 88 | 89 | 90 | def _run_comparison_condition(condition: dict[str, Any], document: Document) -> bool: 91 | if "field" not in condition: 92 | return _run_logical_condition(condition, document) 93 | 94 | if "operator" not in condition: 95 | msg = "Comparison condition must have an 'operator' key." 96 | raise FilterError(msg) 97 | if "value" not in condition: 98 | msg = "Comparison condition must have a 'value' key." 99 | raise FilterError(msg) 100 | 101 | field: str = condition["field"] 102 | value: Any = condition["value"] 103 | 104 | # TODO: We may want to check if the supplied comparator is valid 105 | if callable(condition["operator"]): 106 | comparator = condition["operator"] 107 | else: 108 | comparator = COMPARISON_OPERATORS[condition["operator"]] 109 | 110 | return comparator(_get_document_field(document, field), value) 111 | 112 | 113 | def _and(document: Document, conditions: list[dict[str, Any]]) -> bool: 114 | """ 115 | Return True if all conditions are met. 116 | 117 | :param document: The document to check the conditions against. 118 | :type document: Document 119 | :param conditions: The conditions to check against the document. 120 | :type conditions: ``list[dict[str, Any]]`` 121 | 122 | :return: True if not all conditions are met. 123 | :rtype: ``bool`` 124 | """ 125 | return all( 126 | _run_comparison_condition(condition, document) for condition in conditions 127 | ) 128 | 129 | 130 | def _or(document: Document, conditions: list[dict[str, Any]]) -> bool: 131 | """ 132 | Return True if any condition is met. 133 | 134 | :param document: The document to check the conditions against. 135 | :type document: Document 136 | :param conditions: The conditions to check against the document. 137 | :type conditions: ``list[dict[str, Any]]`` 138 | 139 | :return: True if not all conditions are met. 140 | :rtype: ``bool`` 141 | """ 142 | return any(_run_comparison_condition(cond, document) for cond in conditions) 143 | 144 | 145 | def _not(document: Document, conditions: list[dict[str, Any]]) -> bool: 146 | """ 147 | Return True if not all conditions are met. 148 | 149 | The 'NOT' operator is under-specified when supplied with a 150 | set of conditions instead of a single condition. Because we 151 | can have the semantics of 'at least one False' versus 152 | 'all False'. Here we choose to comply with the official 153 | implementation of Haystack (the 'at least one False' semantics). 154 | 155 | :param document: The document to check the conditions against. 156 | :type document: ``Document`` 157 | :param conditions: The conditions to check against the document. 158 | :type conditions: ``list[dict[str, Any]]`` 159 | 160 | :return: True if not all conditions are met. 161 | :rtype: ``bool`` 162 | """ 163 | return not _and(document, conditions) 164 | 165 | 166 | def _check_comparator_inputs( 167 | comparator: Callable[[Any, Any], bool] 168 | ) -> Callable[[Any, Any], bool]: 169 | """ 170 | A decorator to check and preprocess input attribute values. 171 | 172 | ALL COMPARISON OPERATORS SHOULD BE WRAPPED WITH THIS DECORATOR. 173 | because a `False` may be returned by both input validation and 174 | the actual comparison. This decorator ensures that the comparison 175 | function is only called if the input values are valid. 176 | 177 | :param comparator: The comparator function to wrap. 178 | :type comparator: ``Callable[[Any, Any], bool]`` 179 | 180 | :return: The wrapped comparator function. 181 | :rtype: ``Callable[[Any, Any], bool]`` 182 | """ 183 | 184 | @wraps(comparator) 185 | def _wrapper(dv: Any, fv: Any) -> bool: 186 | 187 | # I think allowing comparison between DataFrames would 188 | # be a really bad idea because it would create unexpected 189 | # behavior, but I am open to discussion on this. 190 | if isinstance(dv, pd.DataFrame) or isinstance(fv, pd.DataFrame): 191 | msg = ( 192 | "Cannot compare DataFrames. Please convert them to " 193 | "simpler data structures before comparing." 194 | ) 195 | raise FilterError(msg) 196 | 197 | # I think comparison between missing values is ambiguous, 198 | # but again, I am open to discussion on this. Here I choose 199 | # to return False if either value is None because from a 200 | # logical perspective, we really cannot say anything about 201 | # the comparison between a missing value and a non-missing. 202 | if dv is None or fv is None: 203 | return False 204 | 205 | try: 206 | return comparator(dv, fv) 207 | except TypeError as exc: 208 | msg = ( 209 | f"Cannot compare document value of {type(dv)} type " 210 | f"with filter value of {type(fv)} type." 211 | ) 212 | raise FilterError(msg) from exc 213 | 214 | return _wrapper 215 | 216 | 217 | @_check_comparator_inputs 218 | def _eq(dv: Any, fv: Any) -> bool: 219 | """ 220 | Conservative implementation of equal comparison. 221 | 222 | There are two major differences between this implementation 223 | and the default Haystack filter implementation: 224 | - If both values are None, we return False, instead of True. 225 | - If any value is a DataFrame, we raise an error, instead 226 | of converting them to JSON. 227 | """ 228 | return dv == fv 229 | 230 | 231 | @_check_comparator_inputs 232 | def _ne(dv: Any, fv: Any) -> bool: 233 | return not _eq(dv, fv) 234 | 235 | 236 | @_check_comparator_inputs 237 | def _gt(dv: Any, fv: Any) -> bool: 238 | """ 239 | A more liberal implementation with less surprises. 240 | 241 | Simply compare the two values with default Python comparison. 242 | We do not perform any conversion here to have the behavior 243 | more predictable. If we want to compare the dates, we should 244 | just convert the document value and filter value explicitly 245 | to dates before comparing them. 246 | """ 247 | return dv > fv 248 | 249 | 250 | @_check_comparator_inputs 251 | def _lt(dv: Any, fv: Any) -> bool: 252 | return dv < fv 253 | 254 | 255 | @_check_comparator_inputs 256 | def _gte(dv: Any, fv: Any) -> bool: 257 | return _gt(dv, fv) or _eq(dv, fv) 258 | 259 | 260 | @_check_comparator_inputs 261 | def _lte(dv: Any, fv: Any) -> bool: 262 | return _lt(dv, fv) or _eq(dv, fv) 263 | 264 | 265 | @_check_comparator_inputs 266 | def _in(dv: Any, fv: Any) -> bool: 267 | """ 268 | Allowing iterable filter values not just lists. 269 | 270 | This implementation permits a larger set of filter values 271 | such as tuples, sets, and other iterable objects. 272 | """ 273 | if not isinstance(fv, Iterable): 274 | msg = "Filter value must be an iterable for 'in' comparison." 275 | raise FilterError(msg) 276 | 277 | return any(_eq(dv, v) for v in fv) 278 | 279 | 280 | @_check_comparator_inputs 281 | def _nin(dv: Any, fv: Any) -> bool: 282 | return not _in(dv, fv) 283 | 284 | 285 | LOGICAL_OPERATORS: Final = {"NOT": _not, "AND": _and, "OR": _or} 286 | 287 | COMPARISON_OPERATORS: Final = { 288 | "==": _eq, 289 | "!=": _ne, 290 | ">": _gt, 291 | "<": _lt, 292 | ">=": _gte, 293 | "<=": _lte, 294 | "in": _in, 295 | "not in": _nin, 296 | } 297 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2023-present John Doe 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | -------------------------------------------------------------------------------- /tests/test_document_store.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Guest400123064 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import pandas as pd 5 | import pytest 6 | from haystack import Document 7 | from haystack.document_stores.errors import ( 8 | DuplicateDocumentError, 9 | MissingDocumentError, 10 | ) 11 | from haystack.document_stores.types import ( 12 | DocumentStore, 13 | DuplicatePolicy, 14 | ) 15 | from haystack.errors import FilterError 16 | from haystack.testing.document_store import ( 17 | DocumentStoreBaseTests, 18 | ) 19 | 20 | from bbm25_haystack.bbm25_store import BetterBM25DocumentStore 21 | 22 | 23 | @pytest.mark.integration 24 | class TestDocumentStore(DocumentStoreBaseTests): 25 | """Common test cases will be provided by `DocumentStoreBaseTests`.""" 26 | 27 | @pytest.fixture 28 | def document_store(self) -> BetterBM25DocumentStore: 29 | return BetterBM25DocumentStore() 30 | 31 | @pytest.fixture 32 | def document_store_bbm25_filter(self) -> BetterBM25DocumentStore: 33 | return BetterBM25DocumentStore(haystack_filter_logic=False) 34 | 35 | def test_write_documents(self, document_store: DocumentStore): 36 | docs = [Document(id="1")] 37 | assert document_store.write_documents(docs) == 1 38 | with pytest.raises(DuplicateDocumentError): 39 | document_store.write_documents(docs, DuplicatePolicy.FAIL) 40 | 41 | document_store.write_documents( 42 | [Document(id="1"), Document(id="2")], DuplicatePolicy.OVERWRITE 43 | ) 44 | assert document_store.count_documents() == 2 45 | 46 | def test_delete_documents_empty_document_store(self, document_store): 47 | """ 48 | This is different from the original implementation. 49 | 50 | One expects a MissingDocumentError to be raised when deleting a 51 | non-existing document, which is more intuitive. 52 | """ 53 | with pytest.raises(MissingDocumentError): 54 | document_store.delete_documents(["non_existing_id"]) 55 | 56 | def test_delete_documents_non_existing_document(self, document_store): 57 | """ 58 | This is different from the original implementation. 59 | 60 | One expects a MissingDocumentError to be raised when deleting a 61 | non-existing document, which is more intuitive. 62 | """ 63 | document_store.write_documents([Document(id="42")]) 64 | with pytest.raises(MissingDocumentError): 65 | document_store.delete_documents(["non_existing_id"]) 66 | 67 | assert document_store.count_documents() == 1 68 | 69 | def test_bm25_retrieval(self, document_store): 70 | docs = [ 71 | Document(content="Hello world"), 72 | Document(content="Haystack supports multiple languages"), 73 | ] 74 | document_store.write_documents(docs) 75 | 76 | results = document_store._retrieval(query="What languages?", top_k=1) 77 | 78 | assert len(results) == 1 79 | assert results[0][0].content == "Haystack supports multiple languages" 80 | 81 | # Override a few filter test cases to account for new comparison logic 82 | # Specifically, we alter the expected behavior when comparison involves 83 | # None, DataFrame, and Iterables. 84 | def test_comparison_equal_with_none_bbm25_filter( 85 | self, document_store_bbm25_filter, filterable_docs 86 | ): 87 | document_store_bbm25_filter.write_documents(filterable_docs) 88 | result = document_store_bbm25_filter.filter_documents( 89 | filters={"field": "meta.number", "operator": "==", "value": None} 90 | ) 91 | self.assert_documents_are_equal(result, []) 92 | 93 | def test_comparison_not_equal_with_none_bbm25_filter( 94 | self, document_store_bbm25_filter, filterable_docs 95 | ): 96 | document_store_bbm25_filter.write_documents(filterable_docs) 97 | result = document_store_bbm25_filter.filter_documents( 98 | filters={"field": "meta.number", "operator": "!=", "value": None} 99 | ) 100 | self.assert_documents_are_equal(result, []) 101 | 102 | def test_comparison_not_equal_bbm25_filter( 103 | self, document_store_bbm25_filter, filterable_docs 104 | ): 105 | """Comparison with missing values will always return False. 106 | So the ground truth is that we should only return documents 107 | with a non-missing value.""" 108 | document_store_bbm25_filter.write_documents(filterable_docs) 109 | result = document_store_bbm25_filter.filter_documents( 110 | {"field": "meta.number", "operator": "!=", "value": 100} 111 | ) 112 | self.assert_documents_are_equal( 113 | result, 114 | [ 115 | d 116 | for d in filterable_docs 117 | if d.meta.get("number") != 100 and "number" in d.meta 118 | ], 119 | ) 120 | 121 | def test_comparison_not_in_bbm25_filter( 122 | self, document_store_bbm25_filter, filterable_docs 123 | ): 124 | """Similar to the test above.""" 125 | document_store_bbm25_filter.write_documents(filterable_docs) 126 | result = document_store_bbm25_filter.filter_documents( 127 | {"field": "meta.number", "operator": "not in", "value": [9, 10]} 128 | ) 129 | self.assert_documents_are_equal( 130 | result, 131 | [ 132 | d 133 | for d in filterable_docs 134 | if d.meta.get("number") not in [9, 10] and "number" in d.meta 135 | ], 136 | ) 137 | 138 | def test_comparison_equal_with_dataframe_bbm25_filter( 139 | self, document_store_bbm25_filter, filterable_docs 140 | ): 141 | document_store_bbm25_filter.write_documents(filterable_docs) 142 | with pytest.raises(FilterError): 143 | _ = document_store_bbm25_filter.filter_documents( 144 | filters={ 145 | "field": "dataframe", 146 | "operator": "==", 147 | "value": pd.DataFrame([1]), 148 | } 149 | ) 150 | 151 | def test_comparison_not_equal_with_dataframe_bbm25_filter( 152 | self, document_store_bbm25_filter, filterable_docs 153 | ): 154 | document_store_bbm25_filter.write_documents(filterable_docs) 155 | with pytest.raises(FilterError): 156 | _ = document_store_bbm25_filter.filter_documents( 157 | filters={ 158 | "field": "dataframe", 159 | "operator": "==", 160 | "value": pd.DataFrame([1]), 161 | } 162 | ) 163 | -------------------------------------------------------------------------------- /tests/test_retriever.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2024-present Guest400123064 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Any 5 | 6 | import pytest 7 | from haystack import DeserializationError, Pipeline 8 | from haystack.dataclasses import Document 9 | from haystack.testing.factory import document_store_class 10 | 11 | from bbm25_haystack.bbm25_retriever import BetterBM25Retriever 12 | from bbm25_haystack.bbm25_store import BetterBM25DocumentStore 13 | 14 | 15 | @pytest.fixture() 16 | def mock_docs(): 17 | return [ 18 | Document(content="Javascript is a popular programming language"), 19 | Document(content="Java is a popular programming language"), 20 | Document(content="Python is a popular programming language"), 21 | Document(content="Ruby is a popular programming language"), 22 | Document(content="PHP is a popular programming language"), 23 | ] 24 | 25 | 26 | class TestRetriever: 27 | def test_init_default(self): 28 | retriever = BetterBM25Retriever(BetterBM25DocumentStore()) 29 | assert retriever.filters is None 30 | assert retriever.top_k == 10 31 | 32 | def test_init_with_parameters(self): 33 | retriever = BetterBM25Retriever( 34 | BetterBM25DocumentStore(), filters={"name": "test.txt"}, top_k=5 35 | ) 36 | assert retriever.filters == {"name": "test.txt"} 37 | assert retriever.top_k == 5 38 | 39 | def test_init_with_invalid_top_k_parameter(self): 40 | with pytest.raises(ValueError): 41 | BetterBM25Retriever(BetterBM25DocumentStore(), top_k=-2) 42 | 43 | with pytest.raises(TypeError): 44 | BetterBM25Retriever(BetterBM25DocumentStore(), top_k="2") 45 | 46 | def test_init_with_invalid_filters_parameter(self): 47 | with pytest.raises(TypeError): 48 | BetterBM25Retriever(BetterBM25DocumentStore(), filters="invalid") 49 | 50 | def test_to_dict(self): 51 | store_class = document_store_class( 52 | "MyFakeStore", bases=(BetterBM25DocumentStore,) 53 | ) 54 | document_store = store_class() 55 | document_store.to_dict = lambda: { 56 | "type": "MyFakeStore", 57 | "init_parameters": {}, 58 | } 59 | component = BetterBM25Retriever(document_store=document_store) 60 | 61 | data = component.to_dict() 62 | assert data == { 63 | "type": "bbm25_haystack.bbm25_retriever.BetterBM25Retriever", 64 | "init_parameters": { 65 | "document_store": { 66 | "type": "MyFakeStore", 67 | "init_parameters": {}, 68 | }, 69 | "filters": None, 70 | "top_k": 10, 71 | "set_score": True, 72 | }, 73 | } 74 | 75 | def test_to_dict_with_custom_init_parameters(self): 76 | ds = BetterBM25DocumentStore() 77 | serialized_ds = ds.to_dict() 78 | 79 | component = BetterBM25Retriever( 80 | document_store=BetterBM25DocumentStore(), 81 | filters={"name": "test.txt"}, 82 | top_k=5, 83 | set_score=False, 84 | ) 85 | data = component.to_dict() 86 | assert data == { 87 | "type": "bbm25_haystack.bbm25_retriever.BetterBM25Retriever", 88 | "init_parameters": { 89 | "document_store": serialized_ds, 90 | "filters": {"name": "test.txt"}, 91 | "top_k": 5, 92 | "set_score": False, 93 | }, 94 | } 95 | 96 | def test_from_dict(self): 97 | data = { 98 | "type": "bbm25_haystack.bbm25_retriever.BetterBM25Retriever", 99 | "init_parameters": { 100 | "document_store": { 101 | "type": "bbm25_haystack.bbm25_store.BetterBM25DocumentStore", 102 | "init_parameters": {}, 103 | }, 104 | "filters": {"name": "test.txt"}, 105 | "top_k": 5, 106 | }, 107 | } 108 | component = BetterBM25Retriever.from_dict(data) 109 | assert isinstance(component.document_store, BetterBM25DocumentStore) 110 | assert component.filters == {"name": "test.txt"} 111 | assert component.top_k == 5 112 | 113 | def test_from_dict_without_docstore(self): 114 | data = {"type": "BetterBM25Retriever", "init_parameters": {}} 115 | with pytest.raises( 116 | DeserializationError, 117 | match="Missing 'document_store' in serialization data", 118 | ): 119 | BetterBM25Retriever.from_dict(data) 120 | 121 | def test_from_dict_without_docstore_type(self): 122 | data = { 123 | "type": "BetterBM25Retriever", 124 | "init_parameters": {"document_store": {"init_parameters": {}}}, 125 | } 126 | with pytest.raises( 127 | DeserializationError, 128 | match="Missing 'type' in document store's serialization data", 129 | ): 130 | BetterBM25Retriever.from_dict(data) 131 | 132 | def test_from_dict_nonexisting_docstore(self): 133 | data = { 134 | "type": "bbm25_haystack.BetterBM25Retriever", 135 | "init_parameters": { 136 | "document_store": { 137 | "type": "Nonexisting.Docstore", 138 | "init_parameters": {}, 139 | } 140 | }, 141 | } 142 | with pytest.raises(DeserializationError): 143 | BetterBM25Retriever.from_dict(data) 144 | 145 | def test_retriever_valid_run(self, mock_docs): 146 | ds = BetterBM25DocumentStore() 147 | ds.write_documents(mock_docs) 148 | 149 | retriever = BetterBM25Retriever(ds, top_k=5) 150 | result = retriever.run(query="PHP") 151 | 152 | assert "documents" in result 153 | assert len(result["documents"]) == 5 154 | assert ( 155 | result["documents"][0].content == "PHP is a popular programming language" 156 | ) 157 | 158 | def test_invalid_run_wrong_store_type(self): 159 | store_class = document_store_class("SomeOtherDocumentStore") 160 | with pytest.raises( 161 | TypeError, 162 | match="'document_store' must be of type 'BetterBM25DocumentStore'", 163 | ): 164 | BetterBM25Retriever(store_class()) 165 | 166 | @pytest.mark.integration 167 | @pytest.mark.parametrize( 168 | "query, query_result", 169 | [ 170 | ("Javascript", "Javascript is a popular programming language"), 171 | ("Java", "Java is a popular programming language"), 172 | ], 173 | ) 174 | def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): 175 | ds = BetterBM25DocumentStore() 176 | ds.write_documents(mock_docs) 177 | retriever = BetterBM25Retriever(ds) 178 | 179 | pipeline = Pipeline() 180 | pipeline.add_component("retriever", retriever) 181 | result: dict[str, Any] = pipeline.run(data={"retriever": {"query": query}}) 182 | 183 | assert result 184 | assert "retriever" in result 185 | results_docs = result["retriever"]["documents"] 186 | assert results_docs 187 | assert results_docs[0].content == query_result 188 | 189 | @pytest.mark.integration 190 | @pytest.mark.parametrize( 191 | "query, query_result, top_k", 192 | [ 193 | ("Javascript", "Javascript is a popular programming language", 1), 194 | ("Java", "Java is a popular programming language", 2), 195 | ("Ruby", "Ruby is a popular programming language", 3), 196 | ], 197 | ) 198 | def test_run_with_pipeline_and_top_k( 199 | self, mock_docs, query: str, query_result: str, top_k: int 200 | ): 201 | ds = BetterBM25DocumentStore() 202 | ds.write_documents(mock_docs) 203 | retriever = BetterBM25Retriever(ds) 204 | 205 | pipeline = Pipeline() 206 | pipeline.add_component("retriever", retriever) 207 | result: dict[str, Any] = pipeline.run( 208 | data={"retriever": {"query": query, "top_k": top_k}} 209 | ) 210 | 211 | assert result 212 | assert "retriever" in result 213 | results_docs = result["retriever"]["documents"] 214 | assert results_docs 215 | assert len(results_docs) == top_k 216 | assert results_docs[0].content == query_result 217 | --------------------------------------------------------------------------------