├── .gitignore ├── LICENSE ├── README.md ├── apilist.txt ├── data └── state_of_the_union.txt ├── examples ├── langchain_integration.ipynb ├── overview.ipynb └── reranker_images.ipynb ├── pyproject.toml ├── rerankers ├── __init__.py ├── documents.py ├── integrations │ ├── __init__.py │ └── langchain.py ├── models │ ├── __init__.py │ ├── api_rankers.py │ ├── colbert_ranker.py │ ├── flashrank_ranker.py │ ├── llm_layerwise_ranker.py │ ├── llm_relevance_filter.py │ ├── monovlm_ranker.py │ ├── mxbai_v2.py │ ├── pylate_ranker.py │ ├── ranker.py │ ├── rankgpt_rankers.py │ ├── rankllm_ranker.py │ ├── t5ranker.py │ ├── transformer_ranker.py │ └── upr.py ├── reranker.py ├── results.py └── utils.py └── tests ├── consistency_notebooks ├── test_colbert.ipynb ├── test_crossenc.ipynb ├── test_inranker.ipynb └── test_t5.ipynb ├── test_crossenc.py └── test_results.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | .flashrank_cache 8 | 9 | # C extensions 10 | *.so 11 | 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | .pdm.toml 89 | __pypackages__/ 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | .envrc 100 | .conda/ 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | .ragatouille 106 | 107 | # mypy 108 | .mypy_cache/ 109 | .dmypy.json 110 | dmypy.json 111 | 112 | # Pyre type checker 113 | .pyre/ 114 | 115 | # pytype static type analyzer 116 | .pytype/ 117 | 118 | # Cython debug symbols 119 | cython_debug/ 120 | 121 | # data files 122 | *.tsv 123 | *.jsonl 124 | 125 | .mypy.ipynb_checkpoints 126 | .mkdocs.yml 127 | 128 | 129 | archive/ 130 | 131 | */.ragatouille 132 | 133 | local/ 134 | 135 | .vscode/ 136 | 137 | .devcontainer/ 138 | 139 | try*.ipynb 140 | 141 | data/*/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Thiago Laitz 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # rerankers 3 | 4 | ![Python Versions](https://img.shields.io/badge/Python-3.8_3.9_3.10_3.11-blue) 5 | [![Downloads](https://static.pepy.tech/badge/rerankers/month)](https://pepy.tech/project/rerankers) 6 | [![Twitter Follow](https://img.shields.io/twitter/follow/bclavie?style=social)](https://twitter.com/bclavie) 7 | 8 | 9 | _A lightweight unified API for various reranking models. Developed by [@bclavie](https://twitter.com/bclavie) as a member of [answer.ai](https://www.answer.ai)_ 10 | 11 | --- 12 | 13 | Welcome to `rerankers`! Our goal is to provide users with a simple API to use any reranking models. 14 | 15 | ## Recent Updates 16 | _A longer release history can be found in the [Release History](#release-history) section of this README._ 17 | 18 | - v0.10.0: Added support for PyLate ColBERT/late-interaction reranking and merged fixes for API rerankers, both thanks to community contributors! 19 | - v0.9.0: Added support for MXBai V2 reranker, which is a new version of the MXBai reranker based on Qwen and the current open-source state-of-the-art. 20 | - v0.7.0: Removing `pydantic` and `tqdm` dependencies, so `rerankers` is now dependency-free by default, avoiding any issues with Pydantic v1/v2! 21 | - v0.6.1: Added support for Pinecone's new rerankers via their API. 22 | - v0.6.0: `rerankers` goes multi-modal, with the support of the first MonoVLMRanker model, [MonoQwen2-VL-v0.1!](https://huggingface.co/lightonai/MonoQwen2-VL-v0.1)! + Many QoL fixes. 23 | 24 | ## Why `rerankers`? 25 | 26 | Rerankers are an important part of any retrieval architecture, but they're also often more obscure than other parts of the pipeline. 27 | 28 | Sometimes, it can be hard to even know which one to use. Every problem is different, and the best model for use X is not necessarily the same one as for use Y. 29 | 30 | Moreover, new reranking methods keep popping up: for example, RankGPT, using LLMs to rerank documents, appeared just last year, with very promising zero-shot benchmark results. 31 | 32 | All the different reranking approaches tend to be done in their own library, with varying levels of documentation. This results in an even higher barrier to entry. New users are required to swap between multiple unfamiliar input/output formats, all with their own quirks! 33 | 34 | `rerankers` seeks to address this problem by providing a simple API for all popular rerankers, no matter the architecture. 35 | 36 | `rerankers` aims to be: 37 | - 🪶 Lightweight. It ships with only the bare necessities as dependencies. 38 | - 📖 Easy-to-understand. There's just a handful of calls to learn, and you can then use the full range of provided reranking models. 39 | - 🔗 Easy-to-integrate. It should fit in just about any existing pipelines, with only a few lines of code! 40 | - 💪 Easy-to-expand. Any new reranking models can be added with very little knowledge of the codebase. All you need is a new class with a `rank()` function call mapping a (query, [documents]) input to a `RankedResults` output. 41 | - 🐛 Easy-to-debug. This is a beta release and there might be issues, but the codebase is conceived in such a way that most issues should be easy to track and fix ASAP. 42 | 43 | ## Get Started 44 | 45 | Installation is very simple. The core package ships with no dependencies, so as to avoid any conflict with your current environment. 46 | You may then install only the dependencies required by the models you want to try out: 47 | 48 | ```sh 49 | # Core package only, will require other dependencies already installed 50 | pip install rerankers 51 | 52 | # All transformers-based approaches (cross-encoders, t5, colbert) 53 | pip install "rerankers[transformers]" 54 | 55 | # RankGPT 56 | pip install "rerankers[gpt]" 57 | 58 | # API-based rerankers (Cohere, Jina, MixedBread, Pinecone, Isaacus) 59 | pip install "rerankers[api]" 60 | 61 | # FlashRank rerankers (ONNX-optimised, very fast on CPU) 62 | pip install "rerankers[flashrank]" 63 | 64 | # RankLLM rerankers (better RankGPT + support for local models such as RankZephyr and RankVicuna) 65 | # Note: RankLLM is only supported on Python 3.10+! This will not work with Python 3.9 66 | pip install "rerankers[rankllm]" 67 | 68 | # To support Multi-Modal rerankers such as MonoQwen2-VL and other MonoVLM models, which require flash-attention, peft, accelerate, and recent versions of `transformers` 69 | pip install "rerankers[monovlm]" 70 | 71 | 72 | # To support LLM-Layerwise rerankers (which need flash-attention installed) 73 | pip install "rerankers[llmlayerwise]" 74 | 75 | # All of the above 76 | pip install "rerankers[all]" 77 | ``` 78 | 79 | ## Usage 80 | 81 | Load any supported reranker in a single line, regardless of the architecture: 82 | ```python 83 | from rerankers import Reranker 84 | 85 | # Cross-encoder default. You can specify a 'lang' parameter to load a multilingual version! 86 | ranker = Reranker('cross-encoder') 87 | 88 | # Specific cross-encoder 89 | ranker = Reranker('mixedbread-ai/mxbai-rerank-large-v1', model_type='cross-encoder') 90 | 91 | # FlashRank default. You can specify a 'lang' parameter to load a multilingual version! 92 | ranker = Reranker('flashrank') 93 | 94 | # Specific flashrank model. 95 | ranker = Reranker('ce-esci-MiniLM-L12-v2', model_type='flashrank') 96 | 97 | # Default T5 Seq2Seq reranker 98 | ranker = Reranker("t5") 99 | 100 | # Specific T5 Seq2Seq reranker 101 | ranker = Reranker("unicamp-dl/InRanker-base", model_type = "t5") 102 | 103 | # API (Cohere) 104 | ranker = Reranker("cohere", lang='en' (or 'other'), api_key = API_KEY) 105 | 106 | # Custom Cohere model? No problem! 107 | ranker = Reranker("my_model_name", api_provider = "cohere", api_key = API_KEY) 108 | 109 | # API (Pinecone) 110 | ranker = Reranker("pinecone", api_key = API_KEY) 111 | 112 | # API (Jina) 113 | ranker = Reranker("jina", api_key = API_KEY) 114 | 115 | # API (Isaacus) 116 | ranker = Reranker("isaacus", api_key = API_KEY) 117 | 118 | # RankGPT4-turbo 119 | ranker = Reranker("rankgpt", api_key = API_KEY) 120 | 121 | # RankGPT3-turbo 122 | ranker = Reranker("rankgpt3", api_key = API_KEY) 123 | 124 | # RankGPT with another LLM provider 125 | ranker = Reranker("MY_LLM_NAME" (check litellm docs), model_type = "rankgpt", api_key = API_KEY) 126 | 127 | # RankLLM with default GPT (GPT-4o) 128 | ranker = Reranker("rankllm", api_key = API_KEY) 129 | 130 | # RankLLM with specified GPT models 131 | ranker = Reranker('gpt-4-turbo', model_type="rankllm", api_key = API_KEY) 132 | 133 | # ColBERTv2 reranker 134 | ranker = Reranker("colbert") 135 | 136 | # LLM Layerwise Reranker 137 | ranker = Reranker('llm-layerwise') 138 | 139 | # ... Or a non-default colbert model: 140 | ranker = Reranker(model_name_or_path, model_type = "colbert") 141 | 142 | ``` 143 | 144 | _Rerankers will always try to infer the model you're trying to use based on its name, but it's always safer to pass a `model_type` argument to it if you can!_ 145 | 146 | Then, regardless of which reranker is loaded, use the loaded model to rank a query against documents: 147 | 148 | ```python 149 | > results = ranker.rank(query="I love you", docs=["I hate you", "I really like you"], doc_ids=[0,1]) 150 | > results 151 | RankedResults(results=[Result(document=Document(text='I really like you', doc_id=1), score=-2.453125, rank=1), Result(document=Document(text='I hate you', doc_id=0), score=-4.14453125, rank=2)], query='I love you', has_scores=True) 152 | ``` 153 | 154 | You don't need to pass `doc_ids`! If not provided, they'll be auto-generated as integers corresponding to the index of a document in `docs`. 155 | 156 | 157 | You're free to pass metadata too, and it'll be stored with the documents. It'll also be accessible in the results object: 158 | 159 | ```python 160 | > results = ranker.rank(query="I love you", docs=["I hate you", "I really like you"], doc_ids=[0,1], metadata=[{'source': 'twitter'}, {'source': 'reddit'}]) 161 | > results 162 | RankedResults(results=[Result(document=Document(text='I really like you', doc_id=1, metadata={'source': 'twitter'}), score=-2.453125, rank=1), Result(document=Document(text='I hate you', doc_id=0, metadata={'source': 'reddit'}), score=-4.14453125, rank=2)], query='I love you', has_scores=True) 163 | ``` 164 | 165 | If you'd like your code to be a bit cleaner, you can also directly construct `Document` objects yourself, and pass those instead. In that case, you don't need to pass separate `doc_ids` and `metadata`: 166 | 167 | ```python 168 | > from rerankers import Document 169 | > docs = [Document(text="I really like you", doc_id=0, metadata={'source': 'twitter'}), Document(text="I hate you", doc_id=1, metadata={'source': 'reddit'})] 170 | > results = ranker.rank(query="I love you", docs=docs) 171 | > results 172 | RankedResults(results=[Result(document=Document(text='I really like you', doc_id=0, metadata={'source': 'twitter'}), score=-2.453125, rank=1), Result(document=Document(text='I hate you', doc_id=1, metadata={'source': 'reddit'}), score=-4.14453125, rank=2)], query='I love you', has_scores=True) 173 | ``` 174 | 175 | You can also use `rank_async`, which is essentially just a wrapper to turn `rank()` into a coroutine. The result will be the same: 176 | 177 | ```python 178 | > results = await ranker.rank_async(query="I love you", docs=["I hate you", "I really like you"], doc_ids=[0,1]) 179 | > results 180 | RankedResults(results=[Result(document=Document(text='I really like you', doc_id=1, metadata={'source': 'twitter'}), score=-2.453125, rank=1), Result(document=Document(text='I hate you', doc_id=0, metadata={'source': 'reddit'}), score=-4.14453125, rank=2)], query='I love you', has_scores=True) 181 | ``` 182 | 183 | All rerankers will return a `RankedResults` object, which is a Python object containing a list of `Result` objects and some other useful information, such as the original query. You can retrieve the top `k` results from it by running `top_k()`: 184 | 185 | ```python 186 | > results.top_k(1) 187 | [Result(Document(doc_id=1, text='I really like you', metadata={}), score=0.26170814, rank=1)] 188 | ``` 189 | 190 | The Result objects are transparent when trying to access the documents they store, as `Document` objects simply exist as an easy way to store IDs and metadata. If you want to access a given result's text or metadata, you can directly access it as a property: 191 | 192 | ```python 193 | > results.top_k(1)[0].text 194 | 'I really like you' 195 | ``` 196 | 197 | And that's all you need to know to get started quickly! Check out the overview notebook for more information on the API and the different models, or the langchain example to see how to integrate this in your langchain pipeline. 198 | 199 | 200 | ## Features 201 | 202 | Legend: 203 | - ✅ Supported 204 | - 🟠 Implemented, but not fully fledged 205 | - 📍 Not supported but intended to be in the future 206 | - ⭐ Same as above, but **important**. 207 | - ❌ Not supported & not currently planned 208 | 209 | Models: 210 | - ✅ Any standard SentenceTransformer or Transformers cross-encoder 211 | - ✅ RankGPT (Available both via the original RankGPT implementation and the improved RankLLM one) 212 | - ✅ T5-based pointwise rankers (InRanker, MonoT5...) 213 | - ✅ LLM-based pointwise rankers (BAAI/bge-reranker-v2.5-gemma2-lightweight, etc...) 214 | - ✅ Cohere, Jina, Voyage, MixedBread, Pinecone and Isaacus API rerankers 215 | - ✅ [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers (ONNX-optimised models, very fast on CPU) 216 | - ✅ ColBERT-based reranker - not a model initially designed for reranking, but does perform quite strongly in some cases. Implementation is lightweight, based only on transformers. 217 | - 🟠⭐ RankLLM/RankZephyr: supported by wrapping the [rank-llm library](https://github.com/castorini/rank_llm) library! Support for RankZephyr/RankVicuna is untested, but RankLLM + GPT models fully works! 218 | - ✅ 🆕 v0.6.0: MonoVLMRanker, multi-modal image reranker employing the MonoT5 method with a VLM backnd. 219 | - 📍 LiT5 220 | 221 | Features: 222 | - ✅ Metadata! 223 | - ✅ Reranking 224 | - ✅ Consistency notebooks to ensure performance on `scifact` matches the litterature for any given model implementation (Except RankGPT, where results are harder to reproduce). 225 | - ✅ ONNX runtime support --> Offered through [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) -- in line with the philosophy of the lib, we won't reinvent the wheel when @PrithivirajDamodaran is doing amazing work! 226 | - 📍 Training on Python >=3.10 (via interfacing with other libraries) 227 | - ❌(📍Maybe?) Training via rerankers directly 228 | 229 | ## Reference 230 | 231 | If rerankers has been useful to you in academic work, please do feel free to cite the work below! 232 | 233 | ``` 234 | @misc{clavié2024rerankers, 235 | title={rerankers: A Lightweight Python Library to Unify Ranking Methods}, 236 | author={Benjamin Clavié}, 237 | year={2024}, 238 | eprint={2408.17344}, 239 | archivePrefix={arXiv}, 240 | primaryClass={cs.IR}, 241 | url={https://arxiv.org/abs/2408.17344}, 242 | } 243 | ``` 244 | 245 | ## Release History 246 | 247 | - v0.5.*: ColBERT fixes (0.5.1) & Minor change making RankedResults subscribable, meaning results[0] will return the result for the first document (0.5.2), etc... ⚠️ This is sorted by **passed document order**, not by results, you should use `.top_k()` to get sorted results! 248 | - v0.5.0: Added support for the current state-of-the-art rerankers, BAAI's series of `BGE` layerwise LLM rerankers, based on [Gemma](https://huggingface.co/BAAI/bge-reranker-v2.5-gemma2-lightweight) and MiniCPM. These are different from RankGPT, as they're not listwise: the models are repurposed as "cross-encoders", and do output logit scores. 249 | - v0.4.0: ColBERT performance improvement! It should now be faster and result in stronger results following implementation of the JaColBERTv2.5 dynamic query length method. This version also now supports HuggingFace's Text-Embedding-Server (TEI) inference as an API reranker option, thanks to [@srisudarsan](https://github.com/srisudarsan). 250 | - v0.3.1: T5 bugfix and native default support for new Portuguese T5 rerankers. 251 | - v0.3.0: Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...) 252 | - v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API 253 | - v0.1.2: Voyage reranking API 254 | - v0.1.1: Langchain integration fixed! 255 | - v0.1.0: Initial release 256 | -------------------------------------------------------------------------------- /apilist.txt: -------------------------------------------------------------------------------- 1 | # rerankers Module Documentation 2 | 3 | ## rerankers.documents 4 | 5 | - `class Document` 6 | - `@validator('text') def validate_text(cls, v, values)` 7 | - `def __init__(self, text, doc_id, metadata, document_type, image_path, base64)` 8 | 9 | ## rerankers.integrations.langchain 10 | 11 | - `class RerankerLangChainCompressor` 12 | - `def compress_documents(self, documents, query, callbacks, **kwargs)` 13 | Rerank a list of documents relevant to a query. 14 | 15 | 16 | ## rerankers.models.api_rankers 17 | 18 | - `class APIRanker` 19 | - `def __init__(self, model, api_key, api_provider, verbose, url)` 20 | - `def rank(self, query, docs, doc_ids, metadata)` 21 | - `def score(self, query, doc)` 22 | 23 | ## rerankers.models.colbert_ranker 24 | 25 | > Code from HotchPotch's JQaRa repository: https://github.com/hotchpotch/JQaRA/blob/main/evaluator/reranker/colbert_reranker.py 26 | > Modifications include packaging into a BaseRanker, dynamic query/doc length and batch size handling. 27 | 28 | - `class ColBERTModel` 29 | - `def __init__(self, config)` 30 | - `def forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, output_attentions, output_hidden_states)` 31 | 32 | - `class ColBERTRanker` 33 | - `def __init__(self, model_name, batch_size, dtype, device, verbose, query_token, document_token, **kwargs)` 34 | - `def rank(self, query, docs, doc_ids, metadata)` 35 | - `def score(self, query, doc)` 36 | 37 | ## rerankers.models.flashrank_ranker 38 | 39 | - `class FlashRankRanker` 40 | - `def __init__(self, model_name_or_path, verbose, cache_dir)` 41 | - `def tokenize(self, inputs)` 42 | - `def rank(self, query, docs, doc_ids, metadata)` 43 | - `def score(self, query, doc)` 44 | 45 | ## rerankers.models.llm_layerwise_ranker 46 | 47 | - `class LLMLayerWiseRanker` 48 | - `def __init__(self, model_name_or_path, max_sequence_length, dtype, device, batch_size, verbose, prompt, cutoff_layers, compress_ratio, compress_layer, **kwargs)` 49 | - `@torch.inference_mode() def rank(self, query, docs, doc_ids, metadata, batch_size, max_sequence_length)` 50 | - `@torch.inference_mode() def score(self, query, doc)` 51 | 52 | ## rerankers.models.monovlm_ranker 53 | 54 | - `class MonoVLMRanker` 55 | - `def __init__(self, model_name_or_path, processor_name, dtype, device, batch_size, verbose, token_false, token_true, return_logits, prompt_template, **kwargs)` 56 | - `def rank(self, query, docs, doc_ids, metadata)` 57 | - `def score(self, query, doc)` 58 | 59 | ## rerankers.models.ranker 60 | 61 | - `class BaseRanker` 62 | - `@abstractmethod def __init__(self, model_name_or_path, verbose)` 63 | - `@abstractmethod def score(self, query, doc)` 64 | - `@abstractmethod def rank(self, query, docs, doc_ids)` 65 | End-to-end reranking of documents. 66 | 67 | - `def rank_async(self, query, docs, doc_ids)` 68 | - `def as_langchain_compressor(self, k)` 69 | 70 | ## rerankers.models.rankgpt_rankers 71 | 72 | > Full implementation is from the original RankGPT repository https://github.com/sunnweiwei/RankGPT under its Apache 2.0 License 73 | > 74 | > Changes made are: 75 | > - Truncating the file to only the relevant functions 76 | > - Using only LiteLLM 77 | > - make_item() added 78 | > - Packaging it onto RankGPTRanker 79 | 80 | - `class RankGPTRanker` 81 | - `def __init__(self, model, api_key, lang, verbose)` 82 | - `def rank(self, query, docs, doc_ids, metadata, rank_start, rank_end)` 83 | - `def score(self)` 84 | 85 | ## rerankers.models.rankllm_ranker 86 | 87 | - `class RankLLMRanker` 88 | - `def __init__(self, model, api_key, lang, verbose)` 89 | - `def rank(self, query, docs, doc_ids, metadata, rank_start, rank_end)` 90 | - `def score(self)` 91 | 92 | ## rerankers.models.t5ranker 93 | 94 | > Code for InRanker is taken from the excellent InRanker repo https://github.com/unicamp-dl/InRanker under its Apache 2.0 license. 95 | > The only change to the original implementation is the removal of InRanker's BaseRanker, replacing it with our own to support the unified API better. 96 | > The main purpose for adapting this code here rather than installing the InRanker library is to ensure greater version compatibility (InRanker requires Python >=3.10) 97 | 98 | - `class T5Ranker` 99 | - `def __init__(self, model_name_or_path, batch_size, dtype, device, verbose, token_false, token_true, return_logits, inputs_template, **kwargs)` 100 | Implementation of the key functions from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py 101 | Changes are detailed in the docstring for each relevant function. 102 | 103 | T5Ranker is a wrapper for using Seq2Seq models for ranking. 104 | Args: 105 | batch_size: The batch size to use when encoding. 106 | dtype: Data type for model weights. 107 | device: The device to use for inference ("cpu", "cuda", or "mps"). 108 | verbose: Verbosity level. 109 | silent: Whether to show progress bars. 110 | 111 | - `def rank(self, query, docs, doc_ids, metadata)` 112 | Ranks a list of documents based on their relevance to the query. 113 | 114 | - `def score(self, query, doc)` 115 | Scores a single document's relevance to a query. 116 | 117 | 118 | ## rerankers.models.transformer_ranker 119 | 120 | - `class TransformerRanker` 121 | - `def __init__(self, model_name_or_path, dtype, device, batch_size, verbose, **kwargs)` 122 | - `def tokenize(self, inputs)` 123 | - `@torch.inference_mode() def rank(self, query, docs, doc_ids, metadata, batch_size)` 124 | - `@torch.inference_mode() def score(self, query, doc)` 125 | 126 | ## rerankers.results 127 | 128 | - `class Result` 129 | - `@validator('rank', always=True) def check_score_or_rank_exists(cls, v, values)` 130 | - `def __getattr__(self, item)` 131 | 132 | - `class RankedResults` 133 | - `def __iter__(self)` 134 | Allows iteration over the results list. 135 | 136 | - `def __getitem__(self, index)` 137 | Allows indexing to access results directly. 138 | 139 | - `def results_count(self)` 140 | Returns the total number of results. 141 | 142 | - `def top_k(self, k)` 143 | Returns the top k results based on the score, if available, or rank. 144 | 145 | - `def get_score_by_docid(self, doc_id)` 146 | Fetches the score of a result by its doc_id using a more efficient approach. 147 | 148 | - `def get_result_by_docid(self, doc_id)` 149 | Fetches a result by its doc_id using a more efficient approach. 150 | 151 | 152 | ## rerankers.utils 153 | 154 | - `def prep_image_docs(docs, doc_ids, metadata)` 155 | Prepare image documents for processing. Can handle base64 encoded images or file paths. 156 | Similar to prep_docs but specialized for image documents. 157 | 158 | - `def get_chunks(iterable, chunk_size)` 159 | Implementation from https://github.com/unicamp-dl/InRanker/blob/main/inranker/base.py with extra typing and more descriptive names. 160 | This method is used to split a list l into chunks of batch size n. 161 | 162 | -------------------------------------------------------------------------------- /examples/langchain_integration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Welcome! This is a quick notebook to introduce you to using rerankers in Langchain, at the end of a retrieval pipeline. It's heavily inspired by existing langchain examples.\n", 8 | "\n", 9 | "First, let's define a helper function for printing docs:" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "def pretty_print_docs(docs):\n", 19 | " print(\n", 20 | " f\"\\n{'-' * 100}\\n\".join(\n", 21 | " [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n", 22 | " )\n", 23 | " )" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "Then, let's set up a normal document retrieval pipeline, using the common OpenAI embeddings + FAISS combo. If you want to run this example yourself and don't have faiss installed, you'll need to install it for this example! (the document is very small, so `faiss-cpu` is largely enough)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "Document 1:\n", 43 | "\n", 44 | "And so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. \n", 45 | "\n", 46 | "I understand. \n", 47 | "\n", 48 | "I remember when my Dad had to leave our home in Scranton, Pennsylvania to find work. I grew up in a family where if the price of food went up, you felt it. \n", 49 | "\n", 50 | "That’s why one of the first things I did as President was fight to pass the American Rescue Plan. \n", 51 | "\n", 52 | "Because people were hurting. We needed to act, and we did. \n", 53 | "\n", 54 | "Few pieces of legislation have done more in a critical moment in our history to lift us out of crisis. \n", 55 | "\n", 56 | "It fueled our efforts to vaccinate the nation and combat COVID-19. It delivered immediate economic relief for tens of millions of Americans. \n", 57 | "\n", 58 | "Helped put food on their table, keep a roof over their heads, and cut the cost of health insurance. \n", 59 | "\n", 60 | "And as my Dad used to say, it gave people a little breathing room.\n", 61 | "----------------------------------------------------------------------------------------------------\n", 62 | "Document 2:\n", 63 | "\n", 64 | "We got more than 130 countries to agree on a global minimum tax rate so companies can’t get out of paying their taxes at home by shipping jobs and factories overseas. \n", 65 | "\n", 66 | "That’s why I’ve proposed closing loopholes so the very wealthy don’t pay a lower tax rate than a teacher or a firefighter. \n", 67 | "\n", 68 | "So that’s my plan. It will grow the economy and lower costs for families. \n", 69 | "\n", 70 | "So what are we waiting for? Let’s get this done. And while you’re at it, confirm my nominees to the Federal Reserve, which plays a critical role in fighting inflation. \n", 71 | "\n", 72 | "My plan will not only lower costs to give families a fair shot, it will lower the deficit. \n", 73 | "\n", 74 | "The previous Administration not only ballooned the deficit with tax cuts for the very wealthy and corporations, it undermined the watchdogs whose job was to keep pandemic relief funds from being wasted. \n", 75 | "\n", 76 | "But in my administration, the watchdogs have been welcomed back.\n", 77 | "----------------------------------------------------------------------------------------------------\n", 78 | "Document 3:\n", 79 | "\n", 80 | "Tonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers. \n", 81 | "\n", 82 | "And as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up. \n", 83 | "\n", 84 | "That ends on my watch. \n", 85 | "\n", 86 | "Medicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect. \n", 87 | "\n", 88 | "We’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees. \n", 89 | "\n", 90 | "Let’s pass the Paycheck Fairness Act and paid leave. \n", 91 | "\n", 92 | "Raise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty. \n", 93 | "\n", 94 | "Let’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges.\n", 95 | "----------------------------------------------------------------------------------------------------\n", 96 | "Document 4:\n", 97 | "\n", 98 | "My plan will cut the cost in half for most families and help parents, including millions of women, who left the workforce during the pandemic because they couldn’t afford child care, to be able to get back to work. \n", 99 | "\n", 100 | "My plan doesn’t stop there. It also includes home and long-term care. More affordable housing. And Pre-K for every 3- and 4-year-old. \n", 101 | "\n", 102 | "All of these will lower costs. \n", 103 | "\n", 104 | "And under my plan, nobody earning less than $400,000 a year will pay an additional penny in new taxes. Nobody. \n", 105 | "\n", 106 | "The one thing all Americans agree on is that the tax system is not fair. We have to fix it. \n", 107 | "\n", 108 | "I’m not looking to punish anyone. But let’s make sure corporations and the wealthiest Americans start paying their fair share. \n", 109 | "\n", 110 | "Just last year, 55 Fortune 500 corporations earned $40 billion in profits and paid zero dollars in federal income tax. \n", 111 | "\n", 112 | "That’s simply not fair. That’s why I’ve proposed a 15% minimum tax rate for corporations.\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "# Vanilla retrieval\n", 118 | "from langchain_community.document_loaders import TextLoader\n", 119 | "from langchain_community.vectorstores import FAISS\n", 120 | "from langchain_openai import OpenAIEmbeddings\n", 121 | "from langchain_text_splitters import CharacterTextSplitter\n", 122 | "\n", 123 | "documents = TextLoader(\"../data/state_of_the_union.txt\").load()\n", 124 | "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", 125 | "texts = text_splitter.split_documents(documents)\n", 126 | "retriever = FAISS.from_documents(texts, OpenAIEmbeddings()).as_retriever()\n", 127 | "\n", 128 | "docs = retriever.get_relevant_documents(\n", 129 | " \"What did the president say about the minimum wage?\"\n", 130 | ")\n", 131 | "pretty_print_docs(docs)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "These results are interesting, but nothing about the actual new minimum wage pledge in the top two documents! Let's see if a re-ranker could help...\n", 139 | "\n", 140 | "First, let's load a reranker. The one you load doesn't actually matter -- they all behave exactly the same. For this example, we're using MixedBread's excellent [mixedbread-ai/mxbai-rerank-base-v1](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 3, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stderr", 150 | "output_type": "stream", 151 | "text": [ 152 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 153 | " from .autonotebook import tqdm as notebook_tqdm\n" 154 | ] 155 | }, 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "Warning: Model type could not be auto-mapped with the defaults list. Defaulting to TransformerRanker.\n", 161 | "If your model is NOT intended to be ran as a one-label cross-encoder, please reload it and specify the model_type! Otherwise, you may ignore this warning. You may specify `model_type='cross-encoder'` to suppress this warning in the future.\n", 162 | "Loading TransformerRanker model mixedbread-ai/mxbai-rerank-base-v1\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "# Load a reranker and convert it to a LangChain compressor\n", 168 | "from rerankers import Reranker\n", 169 | "\n", 170 | "ranker = Reranker(\"mixedbread-ai/mxbai-rerank-base-v1\", verbose=0)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "Converting it to a Langchain compressor is very straightforward, all you have to do is call `as_langchain_compressor`. You can pass a `k` argument to define how many documents it should retrieve, otherwise, `k` will default to 5." 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 4, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "compressor = ranker.as_langchain_compressor(k=3)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "You're all set! Let's just add it to our pipeline and retrieve+rerank documents:" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 5, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "Document 1:\n", 206 | "\n", 207 | "Tonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers. \n", 208 | "\n", 209 | "And as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up. \n", 210 | "\n", 211 | "That ends on my watch. \n", 212 | "\n", 213 | "Medicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect. \n", 214 | "\n", 215 | "We’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees. \n", 216 | "\n", 217 | "Let’s pass the Paycheck Fairness Act and paid leave. \n", 218 | "\n", 219 | "Raise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty. \n", 220 | "\n", 221 | "Let’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges.\n", 222 | "----------------------------------------------------------------------------------------------------\n", 223 | "Document 2:\n", 224 | "\n", 225 | "And so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. \n", 226 | "\n", 227 | "I understand. \n", 228 | "\n", 229 | "I remember when my Dad had to leave our home in Scranton, Pennsylvania to find work. I grew up in a family where if the price of food went up, you felt it. \n", 230 | "\n", 231 | "That’s why one of the first things I did as President was fight to pass the American Rescue Plan. \n", 232 | "\n", 233 | "Because people were hurting. We needed to act, and we did. \n", 234 | "\n", 235 | "Few pieces of legislation have done more in a critical moment in our history to lift us out of crisis. \n", 236 | "\n", 237 | "It fueled our efforts to vaccinate the nation and combat COVID-19. It delivered immediate economic relief for tens of millions of Americans. \n", 238 | "\n", 239 | "Helped put food on their table, keep a roof over their heads, and cut the cost of health insurance. \n", 240 | "\n", 241 | "And as my Dad used to say, it gave people a little breathing room.\n", 242 | "----------------------------------------------------------------------------------------------------\n", 243 | "Document 3:\n", 244 | "\n", 245 | "My plan will cut the cost in half for most families and help parents, including millions of women, who left the workforce during the pandemic because they couldn’t afford child care, to be able to get back to work. \n", 246 | "\n", 247 | "My plan doesn’t stop there. It also includes home and long-term care. More affordable housing. And Pre-K for every 3- and 4-year-old. \n", 248 | "\n", 249 | "All of these will lower costs. \n", 250 | "\n", 251 | "And under my plan, nobody earning less than $400,000 a year will pay an additional penny in new taxes. Nobody. \n", 252 | "\n", 253 | "The one thing all Americans agree on is that the tax system is not fair. We have to fix it. \n", 254 | "\n", 255 | "I’m not looking to punish anyone. But let’s make sure corporations and the wealthiest Americans start paying their fair share. \n", 256 | "\n", 257 | "Just last year, 55 Fortune 500 corporations earned $40 billion in profits and paid zero dollars in federal income tax. \n", 258 | "\n", 259 | "That’s simply not fair. That’s why I’ve proposed a 15% minimum tax rate for corporations.\n" 260 | ] 261 | } 262 | ], 263 | "source": [ 264 | "from langchain.retrievers import ContextualCompressionRetriever\n", 265 | "from langchain_openai import OpenAI\n", 266 | "\n", 267 | "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", 268 | "texts = text_splitter.split_documents(documents)\n", 269 | "retriever = FAISS.from_documents(texts, OpenAIEmbeddings()).as_retriever()\n", 270 | "\n", 271 | "compression_retriever = ContextualCompressionRetriever(\n", 272 | " base_compressor=compressor, base_retriever=retriever\n", 273 | ")\n", 274 | "\n", 275 | "\n", 276 | "compressed_docs = compression_retriever.get_relevant_documents(\n", 277 | " \"What did the president say about the minimum wage?\"\n", 278 | ")\n", 279 | "\n", 280 | "pretty_print_docs(compressed_docs)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "Here it is! There's not much more to show -- just load any reranker you want and try it out!\n", 288 | "\n", 289 | "Remember, not all rerankers work in the same way. It's important to experiment to find out which one works best for your data, and even to fine-tune them if you have the data to do so. The point of this library is to make it easy to try many different approaches to find the best one for your usecase!" 290 | ] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "mcol", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.10.13" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 2 314 | } 315 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [build-system] 3 | requires = ["setuptools"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [tool.setuptools] 7 | packages = [ 8 | "rerankers", 9 | "rerankers.models", 10 | "rerankers.integrations", 11 | ] 12 | 13 | [project] 14 | name = "rerankers" 15 | 16 | 17 | version = "0.10.0" 18 | 19 | description = "A unified API for various document re-ranking models." 20 | 21 | readme = "README.md" 22 | 23 | requires-python = ">=3.8" 24 | 25 | license = {file = "LICENSE"} 26 | 27 | keywords = ["reranking", "retrieval", "rag", "nlp"] 28 | 29 | authors = [ 30 | {name = "Ben Clavié", email = "bc@answer.ai" } 31 | ] 32 | maintainers = [ 33 | {name = "Ben Clavié", email = "bc@answer.ai" } 34 | ] 35 | 36 | classifiers = [ 37 | # Specify the Python versions you support here. In particular, ensure 38 | # that you indicate you support Python 3. These classifiers are *not* 39 | # checked by "pip install". See instead "requires-python" key in this file. 40 | "Programming Language :: Python :: 3", 41 | "Programming Language :: Python :: 3.8", 42 | "Programming Language :: Python :: 3.9", 43 | "Programming Language :: Python :: 3.10", 44 | "Programming Language :: Python :: 3.11", 45 | "Programming Language :: Python :: 3.12", 46 | "Programming Language :: Python :: 3 :: Only", 47 | ] 48 | 49 | dependencies = [] 50 | 51 | [project.optional-dependencies] 52 | all = [ 53 | "transformers>=4.45.0", 54 | "torch", 55 | "litellm", 56 | "requests", 57 | "sentencepiece", 58 | "protobuf", 59 | "flashrank", 60 | "flash-attn", 61 | "pillow", 62 | "accelerate>=0.26.0", 63 | "peft>=0.13.0", 64 | "nmslib-metabrainz; python_version >= '3.10'", 65 | "rank-llm; python_version >= '3.10'" 66 | ] 67 | transformers = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf"] 68 | api = ["requests"] 69 | gpt = ["litellm"] 70 | flashrank = ["flashrank"] 71 | llmlayerwise = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf", "flash-attn"] 72 | monovlm = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf", "flash-attn", "pillow", "accelerate>=0.26.0", "peft>=0.13.0"] 73 | rankllm = [ 74 | "nmslib-metabrainz; python_version >= '3.10'", 75 | "rank-llm; python_version >= '3.10'" 76 | ] 77 | pylate = ["pylate"] 78 | dev = ["ruff", "isort", "pytest", "ipyprogress", "ipython", "ranx", "ir_datasets", "srsly"] 79 | 80 | [project.urls] 81 | "Homepage" = "https://github.com/answerdotai/rerankers" -------------------------------------------------------------------------------- /rerankers/__init__.py: -------------------------------------------------------------------------------- 1 | from rerankers.reranker import Reranker 2 | from rerankers.documents import Document 3 | 4 | __all__ = ["Reranker", "Document"] 5 | __version__ = "0.10.0" -------------------------------------------------------------------------------- /rerankers/documents.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Literal 2 | 3 | 4 | class Document: 5 | def __init__( 6 | self, 7 | text: Optional[str] = None, 8 | doc_id: Optional[Union[str, int]] = None, 9 | metadata: Optional[dict] = None, 10 | document_type: Literal["text", "image"] = "text", 11 | image_path: Optional[str] = None, 12 | base64: Optional[str] = None, 13 | ): 14 | self.attributes = ["text", "base64", "image_path", "doc_id", "metadata", "document_type"] 15 | self.document_type = document_type 16 | self.text = text 17 | self.base64 = base64 18 | self.image_path = image_path 19 | self.doc_id = doc_id 20 | self.metadata = metadata if metadata is not None else {} 21 | 22 | # Validation 23 | if self.document_type == "text" and self.text is None: 24 | raise ValueError("text field is required when document_type is 'text'") 25 | 26 | def __repr__(self) -> str: 27 | fields = { 28 | "text": self.text, 29 | "doc_id": self.doc_id, 30 | "metadata": self.metadata, 31 | "document_type": self.document_type, 32 | "image_path": self.image_path, 33 | "base64": self.base64, 34 | } 35 | field_str = ", ".join(f"{k}={v!r}" for k, v in fields.items()) 36 | return f"{self.__class__.__name__}({field_str})" 37 | -------------------------------------------------------------------------------- /rerankers/integrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/rerankers/7bb252179daaf1968eb9d49629c9cf48b0f9b5f2/rerankers/integrations/__init__.py -------------------------------------------------------------------------------- /rerankers/integrations/langchain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Sequence 2 | 3 | from langchain.retrievers.document_compressors.base import BaseDocumentCompressor 4 | from langchain_core.callbacks.manager import Callbacks 5 | from langchain_core.documents import Document 6 | 7 | 8 | class RerankerLangChainCompressor(BaseDocumentCompressor): 9 | model: Any 10 | kwargs: dict = {} 11 | k: int = 5 12 | 13 | def compress_documents( 14 | self, 15 | documents: Sequence[Document], 16 | query: str, 17 | callbacks: Optional[Callbacks] = None, # noqa 18 | **kwargs, 19 | ) -> Any: 20 | """Rerank a list of documents relevant to a query.""" 21 | doc_list = list(documents) 22 | _docs = [d.page_content for d in doc_list] 23 | results = self.model.rank( 24 | query=query, 25 | docs=_docs, 26 | **self.kwargs, 27 | ) 28 | final_results = [] 29 | for r in results.top_k(kwargs.get("k", self.k)): 30 | doc = doc_list[r.doc_id] 31 | doc.metadata["relevance_score"] = r.score 32 | final_results.append(doc) 33 | return final_results 34 | -------------------------------------------------------------------------------- /rerankers/models/__init__.py: -------------------------------------------------------------------------------- 1 | AVAILABLE_RANKERS = {} 2 | 3 | try: 4 | from rerankers.models.transformer_ranker import TransformerRanker 5 | 6 | AVAILABLE_RANKERS["TransformerRanker"] = TransformerRanker 7 | except ImportError: 8 | pass 9 | try: 10 | from rerankers.models.api_rankers import APIRanker 11 | 12 | AVAILABLE_RANKERS["APIRanker"] = APIRanker 13 | except ImportError: 14 | pass 15 | try: 16 | from rerankers.models.rankgpt_rankers import RankGPTRanker 17 | 18 | AVAILABLE_RANKERS["RankGPTRanker"] = RankGPTRanker 19 | except ImportError: 20 | pass 21 | try: 22 | from rerankers.models.t5ranker import T5Ranker 23 | 24 | AVAILABLE_RANKERS["T5Ranker"] = T5Ranker 25 | except ImportError: 26 | pass 27 | 28 | try: 29 | from rerankers.models.colbert_ranker import ColBERTRanker 30 | 31 | AVAILABLE_RANKERS["ColBERTRanker"] = ColBERTRanker 32 | except ImportError: 33 | pass 34 | 35 | try: 36 | from rerankers.models.flashrank_ranker import FlashRankRanker 37 | 38 | AVAILABLE_RANKERS["FlashRankRanker"] = FlashRankRanker 39 | except ImportError: 40 | pass 41 | 42 | try: 43 | from rerankers.models.rankllm_ranker import RankLLMRanker 44 | 45 | AVAILABLE_RANKERS["RankLLMRanker"] = RankLLMRanker 46 | except ImportError: 47 | pass 48 | 49 | try: 50 | from rerankers.models.llm_layerwise_ranker import LLMLayerWiseRanker 51 | 52 | AVAILABLE_RANKERS["LLMLayerWiseRanker"] = LLMLayerWiseRanker 53 | except ImportError: 54 | pass 55 | 56 | try: 57 | from rerankers.models.monovlm_ranker import MonoVLMRanker 58 | AVAILABLE_RANKERS["MonoVLMRanker"] = MonoVLMRanker 59 | except ImportError: 60 | pass 61 | 62 | try: 63 | from rerankers.models.llm_relevance_filter import LLMRelevanceFilter 64 | AVAILABLE_RANKERS["LLMRelevanceFilter"] = LLMRelevanceFilter 65 | except ImportError: 66 | pass 67 | 68 | try: 69 | from rerankers.models.upr import UPRRanker 70 | AVAILABLE_RANKERS["UPRRanker"] = UPRRanker 71 | except ImportError: 72 | pass 73 | 74 | try: 75 | from rerankers.models.mxbai_v2 import MxBaiV2Ranker 76 | AVAILABLE_RANKERS["MxBaiV2Ranker"] = MxBaiV2Ranker 77 | except ImportError: 78 | pass 79 | 80 | try: 81 | from rerankers.models.pylate_ranker import PyLateRanker 82 | 83 | AVAILABLE_RANKERS["PyLateRanker"] = PyLateRanker 84 | except ImportError: 85 | pass 86 | -------------------------------------------------------------------------------- /rerankers/models/api_rankers.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Optional 2 | from rerankers.models.ranker import BaseRanker 3 | from rerankers.results import RankedResults, Result 4 | from rerankers.utils import prep_docs 5 | from rerankers.documents import Document 6 | from string import Template 7 | 8 | 9 | import requests 10 | import json 11 | 12 | 13 | URLS = { 14 | "cohere": "https://api.cohere.ai/v1/rerank", 15 | "jina": "https://api.jina.ai/v1/rerank", 16 | "isaacus": "https://api.isaacus.com/v1/rerankings", 17 | "voyage": "https://api.voyageai.com/v1/rerank", 18 | "mixedbread.ai": "https://api.mixedbread.ai/v1/reranking", 19 | "pinecone": "https://api.pinecone.io/rerank", 20 | } 21 | AUTHORIZATION_KEY_MAPPING = { 22 | "pinecone": "Api-Key" 23 | } 24 | API_VERSION_MAPPING = { 25 | "pinecone": {"X-Pinecone-API-Version": "2024-10"} 26 | } 27 | 28 | API_KEY_MAPPING = { 29 | "pinecone": Template("$api_key") 30 | } 31 | 32 | DOCUMENT_KEY_MAPPING = { 33 | "mixedbread.ai": "input", 34 | "text-embeddings-inference":"texts", 35 | "isaacus": "texts", 36 | } 37 | RETURN_DOCUMENTS_KEY_MAPPING = { 38 | "mixedbread.ai":"return_input", 39 | "text-embeddings-inference":"return_text" 40 | } 41 | RESULTS_KEY_MAPPING = { 42 | "voyage": "data", 43 | "mixedbread.ai": "data", 44 | "pinecone": "data", 45 | "text-embeddings-inference": None 46 | } 47 | SCORE_KEY_MAPPING = { 48 | "mixedbread.ai": "score", 49 | "pinecone": "score", 50 | "text-embeddings-inference":"score", 51 | "isaacus":"score", 52 | } 53 | 54 | class APIRanker(BaseRanker): 55 | def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1, url: str = None): 56 | self.api_provider = api_provider.lower() 57 | self.api_key = API_KEY_MAPPING.get(self.api_provider,Template("Bearer $api_key")).substitute(api_key=api_key) 58 | authorization_key = AUTHORIZATION_KEY_MAPPING.get(self.api_provider,"Authorization") 59 | api_version = API_VERSION_MAPPING.get(self.api_provider,None) 60 | self.model = model 61 | self.verbose = verbose 62 | self.ranking_type = "pointwise" 63 | self.headers = { 64 | "accept": "application/json", 65 | "content-type": "application/json", 66 | authorization_key: self.api_key, 67 | } 68 | if api_version: 69 | self.headers.update(api_version) 70 | self.url = url if url else URLS[self.api_provider] 71 | 72 | 73 | def _get_document_text(self, r: dict) -> str: 74 | if self.api_provider == "voyage": 75 | return r["document"] 76 | elif self.api_provider == "mixedbread.ai": 77 | return r["input"] 78 | elif self.api_provider == "text-embeddings-inference": 79 | return r["text"] 80 | else: 81 | return r["document"]["text"] 82 | 83 | def _get_score(self, r: dict) -> float: 84 | score_key = SCORE_KEY_MAPPING.get(self.api_provider,"relevance_score") 85 | return r[score_key] 86 | 87 | def _parse_response( 88 | self, response: dict, docs: List[Document], 89 | ) -> RankedResults: 90 | ranked_docs = [] 91 | results_key = RESULTS_KEY_MAPPING.get(self.api_provider,"results") 92 | 93 | for i, r in enumerate(response[results_key] if results_key else response): 94 | ranked_docs.append( 95 | Result( 96 | document=docs[r["index"]], 97 | score=self._get_score(r), 98 | rank=i + 1, 99 | ) 100 | ) 101 | 102 | return ranked_docs 103 | 104 | def rank( 105 | self, 106 | query: str, 107 | docs: Union[str, List[str], Document, List[Document]], 108 | doc_ids: Optional[Union[List[str], List[int]]] = None, 109 | metadata: Optional[List[dict]] = None, 110 | ) -> RankedResults: 111 | docs = prep_docs(docs, doc_ids, metadata) 112 | payload = self._format_payload(query, docs) 113 | response = requests.post(self.url, headers=self.headers, data=payload) 114 | results = self._parse_response(response.json(), docs) 115 | return RankedResults(results=results, query=query, has_scores=True) 116 | 117 | 118 | def _format_payload(self, query: str, docs: List[str]) -> str: 119 | top_key = ( 120 | "top_n" if self.api_provider not in ["voyage", "mixedbread.ai"] else "top_k" 121 | ) 122 | documents_key = DOCUMENT_KEY_MAPPING.get(self.api_provider,"documents") 123 | return_documents_key = RETURN_DOCUMENTS_KEY_MAPPING.get(self.api_provider,"return_documents") 124 | 125 | documents = ( 126 | [d.text for d in docs] if self.api_provider not in ["pinecone"] else [{"text": d.text} for d in docs] 127 | ) 128 | 129 | payload = { 130 | "model": self.model, 131 | "query": query, 132 | documents_key: documents, 133 | top_key: len(docs), 134 | return_documents_key: True, 135 | } 136 | return json.dumps(payload) 137 | 138 | def score(self, query: str, doc: str) -> float: 139 | payload = self._format_payload(query, [doc]) 140 | response = requests.post(self.url, headers=self.headers, data=payload) 141 | results = self._parse_response(response.json(), [doc]) 142 | return results[0].score 143 | -------------------------------------------------------------------------------- /rerankers/models/colbert_ranker.py: -------------------------------------------------------------------------------- 1 | """Code from HotchPotch's JQaRa repository: https://github.com/hotchpotch/JQaRA/blob/main/evaluator/reranker/colbert_reranker.py 2 | Modifications include packaging into a BaseRanker, dynamic query/doc length and batch size handling.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers import BertPreTrainedModel, BertModel, AutoModel, AutoTokenizer 7 | from typing import List, Optional, Union 8 | from math import ceil 9 | 10 | from rerankers.models.ranker import BaseRanker 11 | from rerankers.documents import Document 12 | from rerankers.results import RankedResults, Result 13 | from rerankers.utils import vprint, get_device, get_dtype, prep_docs 14 | 15 | 16 | def _insert_token( 17 | output: dict, 18 | insert_token_id: int, 19 | insert_position: int = 1, 20 | token_type_id: int = 0, 21 | attention_value: int = 1, 22 | ): 23 | """ 24 | Inserts a new token at a specified position into the sequences of a tokenized representation. 25 | 26 | This function takes a dictionary containing tokenized representations 27 | (e.g., 'input_ids', 'token_type_ids', 'attention_mask') as PyTorch tensors, 28 | and inserts a specified token into each sequence at the given position. 29 | This can be used to add special tokens or other modifications to tokenized inputs. 30 | 31 | Parameters: 32 | - output (dict): A dictionary containing the tokenized representations. Expected keys 33 | are 'input_ids', 'token_type_ids', and 'attention_mask'. Each key 34 | is associated with a PyTorch tensor. 35 | - insert_token_id (int): The token ID to be inserted into each sequence. 36 | - insert_position (int, optional): The position in the sequence where the new token 37 | should be inserted. Defaults to 1, which typically 38 | follows a special starting token like '[CLS]' or '[BOS]'. 39 | - token_type_id (int, optional): The token type ID to assign to the inserted token. 40 | Defaults to 0. 41 | - attention_value (int, optional): The attention mask value to assign to the inserted token. 42 | Defaults to 1. 43 | 44 | Returns: 45 | - updated_output (dict): A dictionary containing the updated tokenized representations, 46 | with the new token inserted at the specified position in each sequence. 47 | The structure and keys of the output dictionary are the same as the input. 48 | """ 49 | updated_output = {} 50 | for key in output: 51 | updated_tensor_list = [] 52 | for seqs in output[key]: 53 | if len(seqs.shape) == 1: 54 | seqs = seqs.unsqueeze(0) 55 | for seq in seqs: 56 | first_part = seq[:insert_position] 57 | second_part = seq[insert_position:] 58 | new_element = ( 59 | torch.tensor([insert_token_id]) 60 | if key == "input_ids" 61 | else torch.tensor([token_type_id]) 62 | ) 63 | if key == "attention_mask": 64 | new_element = torch.tensor([attention_value]) 65 | updated_seq = torch.cat((first_part, new_element, second_part), dim=0) 66 | updated_tensor_list.append(updated_seq) 67 | updated_output[key] = torch.stack(updated_tensor_list) 68 | return updated_output 69 | 70 | 71 | def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor): 72 | # calc max sim 73 | # base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py 74 | 75 | # Assert that all q_reps are at least as long as the query length 76 | assert ( 77 | q_reps.shape[1] >= q_mask.shape[1] 78 | ), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}" 79 | 80 | token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps) 81 | token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4) 82 | scores, _ = token_scores.max(-1) 83 | scores = scores.sum(1) / q_mask.sum(-1, keepdim=True) 84 | return scores 85 | 86 | 87 | class ColBERTModel(BertPreTrainedModel): 88 | def __init__(self, config, verbose: int): 89 | super().__init__(config) 90 | self.bert = BertModel(config) 91 | self.verbose = verbose 92 | # TODO: Load from artifact.metadata 93 | if "small" in config._name_or_path: 94 | linear_dim = 96 95 | else: 96 | linear_dim = 128 97 | vprint(f"Linear Dim set to: {linear_dim} for downcasting", self.verbose) 98 | self.linear = nn.Linear(config.hidden_size, linear_dim, bias=False) 99 | self.init_weights() 100 | 101 | def forward( 102 | self, 103 | input_ids=None, 104 | attention_mask=None, 105 | token_type_ids=None, 106 | position_ids=None, 107 | head_mask=None, 108 | inputs_embeds=None, 109 | encoder_hidden_states=None, 110 | encoder_attention_mask=None, 111 | output_attentions=None, 112 | output_hidden_states=None, 113 | ): 114 | outputs = self.bert( 115 | input_ids, 116 | attention_mask=attention_mask, 117 | token_type_ids=token_type_ids, 118 | position_ids=position_ids, 119 | head_mask=head_mask, 120 | inputs_embeds=inputs_embeds, 121 | encoder_hidden_states=encoder_hidden_states, 122 | encoder_attention_mask=encoder_attention_mask, 123 | output_attentions=output_attentions, 124 | output_hidden_states=True, # Always output hidden states 125 | ) 126 | 127 | sequence_output = outputs[0] 128 | 129 | return self.linear(sequence_output) 130 | 131 | def _encode(self, texts: list[str], insert_token_id: int, is_query: bool = False): 132 | encoding = self.tokenizer( 133 | texts, 134 | return_tensors="pt", 135 | padding=True, 136 | max_length=self.max_length - 1, # for insert token 137 | truncation=True, 138 | ) 139 | encoding = _insert_token(encoding, insert_token_id) # type: ignore 140 | 141 | if is_query: 142 | mask_token_id = self.tokenizer.mask_token_id 143 | 144 | new_encodings = {"input_ids": [], "attention_mask": []} 145 | 146 | for i, input_ids in enumerate(encoding["input_ids"]): 147 | original_length = ( 148 | (input_ids != self.tokenizer.pad_token_id).sum().item() 149 | ) 150 | 151 | # Calculate QLEN dynamically for each query 152 | if original_length % 32 <= 8: 153 | QLEN = original_length + 8 154 | else: 155 | QLEN = ceil(original_length / 32) * 32 156 | 157 | if original_length < QLEN: 158 | pad_length = QLEN - original_length 159 | padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length 160 | padded_attention_mask = ( 161 | encoding["attention_mask"][i].tolist() + [0] * pad_length 162 | ) 163 | else: 164 | padded_input_ids = input_ids[:QLEN].tolist() 165 | padded_attention_mask = encoding["attention_mask"][i][ 166 | :QLEN 167 | ].tolist() 168 | 169 | new_encodings["input_ids"].append(padded_input_ids) 170 | new_encodings["attention_mask"].append(padded_attention_mask) 171 | 172 | for key in new_encodings: 173 | new_encodings[key] = torch.tensor( 174 | new_encodings[key], device=self.device 175 | ) 176 | 177 | encoding = new_encodings 178 | 179 | encoding = {key: value.to(self.device) for key, value in encoding.items()} 180 | return encoding 181 | 182 | def _query_encode(self, query: list[str]): 183 | return self._encode(query, self.query_token_id, is_query=True) 184 | 185 | def _document_encode(self, documents: list[str]): 186 | return self._encode(documents, self.document_token_id) 187 | 188 | def _to_embs(self, encoding) -> torch.Tensor: 189 | with torch.inference_mode(): 190 | # embs = self.model(**encoding).last_hidden_state.squeeze(1) 191 | embs = self.model(**encoding) 192 | if self.normalize: 193 | embs = embs / embs.norm(dim=-1, keepdim=True) 194 | return embs 195 | 196 | def _rerank(self, query: str, documents: list[str]) -> list[float]: 197 | query_encoding = self._query_encode([query]) 198 | documents_encoding = self._document_encode(documents) 199 | query_embeddings = self._to_embs(query_encoding) 200 | document_embeddings = self._to_embs(documents_encoding) 201 | scores = ( 202 | _colbert_score( 203 | query_embeddings, 204 | document_embeddings, 205 | query_encoding["attention_mask"], 206 | documents_encoding["attention_mask"], 207 | ) 208 | .cpu() 209 | .tolist()[0] 210 | ) 211 | return scores 212 | 213 | 214 | class ColBERTRanker(BaseRanker): 215 | def __init__( 216 | self, 217 | model_name: str, 218 | batch_size: int = 32, 219 | dtype: Optional[Union[str, torch.dtype]] = None, 220 | device: Optional[Union[str, torch.device]] = None, 221 | verbose: int = 1, 222 | query_token: str = "[unused0]", 223 | document_token: str = "[unused1]", 224 | **kwargs, 225 | ): 226 | self.verbose = verbose 227 | self.device = get_device(device, self.verbose) 228 | self.dtype = get_dtype(dtype, self.device, self.verbose) 229 | self.batch_size = batch_size 230 | vprint( 231 | f"Loading model {model_name}, this might take a while...", 232 | self.verbose, 233 | ) 234 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) 235 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) 236 | model_kwargs = kwargs.get("model_kwargs", {}) 237 | self.model = ( 238 | ColBERTModel.from_pretrained( 239 | model_name, 240 | verbose=self.verbose, 241 | **model_kwargs 242 | ) 243 | .to(self.device) 244 | .to(self.dtype) 245 | ) 246 | self.model.eval() 247 | self.query_max_length = 32 # Lower bound 248 | self.doc_max_length = ( 249 | self.model.config.max_position_embeddings - 2 250 | ) # Upper bound 251 | self.query_token_id: int = self.tokenizer.convert_tokens_to_ids(query_token) # type: ignore 252 | self.document_token_id: int = self.tokenizer.convert_tokens_to_ids( 253 | document_token 254 | ) # type: ignore 255 | self.normalize = True 256 | 257 | def rank( 258 | self, 259 | query: str, 260 | docs: Union[Document, str, List[Document], List[str]], 261 | doc_ids: Optional[Union[List[str], List[int]]] = None, 262 | metadata: Optional[List[dict]] = None, 263 | ) -> RankedResults: 264 | docs = prep_docs(docs, doc_ids, metadata) 265 | 266 | scores = self._colbert_rank(query, [d.text for d in docs]) 267 | ranked_results = [ 268 | Result(document=doc, score=score, rank=idx + 1) 269 | for idx, (doc, score) in enumerate( 270 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) 271 | ) 272 | ] 273 | return RankedResults(results=ranked_results, query=query, has_scores=True) 274 | 275 | def score(self, query: str, doc: str) -> float: 276 | scores = self._colbert_rank(query, [doc]) 277 | return scores[0] if scores else 0.0 278 | 279 | @torch.inference_mode() 280 | def _colbert_rank( 281 | self, 282 | query: str, 283 | docs: List[str], 284 | ) -> List[float]: 285 | query_encoding = self._query_encode([query]) 286 | documents_encoding = self._document_encode(docs) 287 | query_embeddings = self._to_embs(query_encoding) 288 | document_embeddings = self._to_embs(documents_encoding) 289 | scores = ( 290 | _colbert_score( 291 | query_embeddings, 292 | document_embeddings, 293 | query_encoding["attention_mask"], 294 | documents_encoding["attention_mask"], 295 | ) 296 | .cpu() 297 | .tolist()[0] 298 | ) 299 | return scores 300 | 301 | def _query_encode(self, query: list[str]): 302 | return self._encode( 303 | query, self.query_token_id, max_length=self.doc_max_length, is_query=True 304 | ) 305 | 306 | def _document_encode(self, documents: list[str]): 307 | tokenized_doc_lengths = [ 308 | len( 309 | self.tokenizer.encode( 310 | doc, max_length=self.doc_max_length, truncation=True 311 | ) 312 | ) 313 | for doc in documents 314 | ] 315 | max_length = max(tokenized_doc_lengths) 316 | max_length = ( 317 | ceil(max_length / 32) * 32 318 | ) # Round up to the nearest multiple of 32 319 | max_length = max( 320 | max_length, self.query_max_length 321 | ) # Ensure not smaller than query_max_length 322 | max_length = int( 323 | min(max_length, self.doc_max_length) 324 | ) # Ensure not larger than doc_max_length 325 | return self._encode(documents, self.document_token_id, max_length) 326 | 327 | def _encode( 328 | self, 329 | texts: list[str], 330 | insert_token_id: int, 331 | max_length: int, 332 | is_query: bool = False, 333 | ): 334 | encoding = self.tokenizer( 335 | texts, 336 | return_tensors="pt", 337 | padding=True, 338 | max_length=max_length - 1, # for insert token 339 | truncation=True, 340 | ) 341 | encoding = _insert_token(encoding, insert_token_id) # type: ignore 342 | 343 | if is_query: 344 | mask_token_id = self.tokenizer.mask_token_id 345 | 346 | new_encodings = {"input_ids": [], "attention_mask": []} 347 | 348 | for i, input_ids in enumerate(encoding["input_ids"]): 349 | original_length = ( 350 | (input_ids != self.tokenizer.pad_token_id).sum().item() 351 | ) 352 | 353 | # Calculate QLEN dynamically for each query 354 | if original_length % 16 <= 8: 355 | QLEN = original_length + 8 356 | else: 357 | QLEN = ceil(original_length / 16) * 16 358 | 359 | if original_length < QLEN: 360 | pad_length = QLEN - original_length 361 | padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length 362 | padded_attention_mask = ( 363 | encoding["attention_mask"][i].tolist() + [0] * pad_length 364 | ) 365 | else: 366 | padded_input_ids = input_ids[:QLEN].tolist() 367 | padded_attention_mask = encoding["attention_mask"][i][ 368 | :QLEN 369 | ].tolist() 370 | 371 | new_encodings["input_ids"].append(padded_input_ids) 372 | new_encodings["attention_mask"].append(padded_attention_mask) 373 | 374 | for key in new_encodings: 375 | new_encodings[key] = torch.tensor( 376 | new_encodings[key], device=self.device 377 | ) 378 | 379 | encoding = new_encodings 380 | 381 | encoding = {key: value.to(self.device) for key, value in encoding.items()} 382 | return encoding 383 | 384 | def _to_embs(self, encoding) -> torch.Tensor: 385 | with torch.inference_mode(): 386 | batched_embs = [] 387 | for i in range(0, encoding["input_ids"].size(0), self.batch_size): 388 | batch_encoding = { 389 | key: val[i : i + self.batch_size] for key, val in encoding.items() 390 | } 391 | batch_embs = self.model(**batch_encoding) 392 | batched_embs.append(batch_embs) 393 | embs = torch.cat(batched_embs, dim=0) 394 | if self.normalize: 395 | embs = embs / embs.norm(dim=-1, keepdim=True) 396 | return embs 397 | -------------------------------------------------------------------------------- /rerankers/models/flashrank_ranker.py: -------------------------------------------------------------------------------- 1 | from rerankers.models.ranker import BaseRanker 2 | 3 | from flashrank import Ranker, RerankRequest 4 | 5 | 6 | from typing import Union, List, Optional, Tuple 7 | from rerankers.utils import vprint, prep_docs 8 | from rerankers.results import RankedResults, Result 9 | from rerankers.documents import Document 10 | 11 | 12 | class FlashRankRanker(BaseRanker): 13 | def __init__( 14 | self, 15 | model_name_or_path: str, 16 | verbose: int = 1, 17 | cache_dir: str = "./.flashrank_cache", 18 | **kwargs 19 | ): 20 | self.verbose = verbose 21 | vprint( 22 | f"Loading model FlashRank model {model_name_or_path}...", verbose=verbose 23 | ) 24 | self.model = Ranker(model_name=model_name_or_path, cache_dir=cache_dir) 25 | self.ranking_type = "pointwise" 26 | 27 | def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]): 28 | return self.tokenizer( 29 | inputs, return_tensors="pt", padding=True, truncation=True 30 | ).to(self.device) 31 | 32 | def rank( 33 | self, 34 | query: str, 35 | docs: Union[str, List[str], Document, List[Document]], 36 | doc_ids: Optional[Union[List[str], List[int]]] = None, 37 | metadata: Optional[List[dict]] = None, 38 | ) -> RankedResults: 39 | docs = prep_docs(docs, doc_ids, metadata) 40 | passages = [ 41 | {"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs) 42 | ] 43 | 44 | rerank_request = RerankRequest(query=query, passages=passages) 45 | flashrank_results = self.model.rerank(rerank_request) 46 | 47 | ranked_results = [ 48 | Result( 49 | document=docs[result["id"]], # Returns reranked documents. 50 | score=result["score"], 51 | rank=idx + 1, 52 | ) 53 | for idx, result in enumerate(flashrank_results) 54 | ] 55 | 56 | return RankedResults(results=ranked_results, query=query, has_scores=True) 57 | 58 | def score(self, query: str, doc: str) -> float: 59 | rerank_request = RerankRequest( 60 | query=query, passages=[{"id": "temp_id", "text": doc}] 61 | ) 62 | flashrank_result = self.model.rerank(rerank_request) 63 | score = flashrank_result[0]["score"] 64 | return score 65 | -------------------------------------------------------------------------------- /rerankers/models/llm_layerwise_ranker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from rerankers.models.ranker import BaseRanker 4 | from rerankers.documents import Document 5 | from typing import Union, List, Optional 6 | from rerankers.utils import vprint, get_device, get_dtype, prep_docs 7 | from rerankers.results import RankedResults, Result 8 | 9 | 10 | PROMPTS = { 11 | "BAAI/bge-reranker-v2.5-gemma2-lightweight": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.", 12 | "default": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.", 13 | } 14 | 15 | DEFAULT_PARAMS = { 16 | "default": {}, 17 | "BAAI/bge-multilingual-gemma2": {}, 18 | "BAAI/bge-reranker-v2-gemma": {}, 19 | "BAAI/bge-reranker-v2-minicpm-layerwise": {"cutoff_layers": [28]}, 20 | "BAAI/bge-reranker-v2.5-gemma2-lightweight": { 21 | "cutoff_layers": [28], 22 | "compress_ratio": 2, 23 | "compress_layer": [24, 40], 24 | }, 25 | } 26 | 27 | 28 | class LLMLayerWiseRanker(BaseRanker): 29 | def __init__( 30 | self, 31 | model_name_or_path: str = "BAAI/bge-reranker-v2.5-gemma2-lightweight", 32 | max_sequence_length: int = 512, 33 | dtype: Optional[Union[str, torch.dtype]] = None, 34 | device: Optional[Union[str, torch.device]] = None, 35 | batch_size: int = 16, 36 | verbose: int = 1, 37 | prompt: Optional[str] = None, 38 | cutoff_layers: Optional[List[int]] = None, 39 | compress_ratio: Optional[int] = None, 40 | compress_layer: Optional[List[int]] = None, 41 | **kwargs, 42 | ): 43 | self.verbose = verbose 44 | self.device = get_device(device, verbose=self.verbose) 45 | self.dtype = get_dtype(dtype, self.device, self.verbose) 46 | self.batch_size = batch_size 47 | 48 | vprint( 49 | f"Loading model {model_name_or_path}, this might take a while...", 50 | self.verbose, 51 | ) 52 | vprint(f"Using device {self.device}.", self.verbose) 53 | vprint(f"Using dtype {self.dtype}.", self.verbose) 54 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) 55 | tokenizer_trust_remote_code = tokenizer_kwargs.pop("trust_remote_code", True) 56 | self.tokenizer = AutoTokenizer.from_pretrained( 57 | model_name_or_path, 58 | trust_remote_code=tokenizer_trust_remote_code, 59 | **tokenizer_kwargs, 60 | ) 61 | self.max_sequence_length = max_sequence_length 62 | self.tokenizer.model_max_length = self.max_sequence_length 63 | self.tokenizer.padding_side = "right" 64 | model_kwargs = kwargs.get("model_kwargs", {}) 65 | model_trust_remote_code = model_kwargs.pop("trust_remote_code", True) 66 | 67 | self.model = AutoModelForCausalLM.from_pretrained( 68 | model_name_or_path, 69 | trust_remote_code=model_trust_remote_code, 70 | torch_dtype=self.dtype, 71 | **model_kwargs, 72 | ).to(self.device) 73 | self.model.eval() 74 | 75 | # Create params dict based on specified values or defaults 76 | params = {} 77 | if cutoff_layers is not None: 78 | params["cutoff_layers"] = cutoff_layers 79 | if compress_ratio is not None: 80 | params["compress_ratio"] = compress_ratio 81 | if compress_layer is not None: 82 | params["compress_layer"] = compress_layer 83 | if not params: 84 | params = DEFAULT_PARAMS.get(model_name_or_path, DEFAULT_PARAMS["default"]) 85 | self.params = params 86 | 87 | self.prompt = prompt 88 | if self.prompt is None: 89 | self.prompt = PROMPTS.get(model_name_or_path, PROMPTS["default"]) 90 | 91 | def _get_inputs(self, pairs, max_sequence_length: int): 92 | prompt = self.prompt 93 | sep = "\n" 94 | prompt_inputs = self.tokenizer( 95 | prompt, return_tensors=None, add_special_tokens=False 96 | )["input_ids"] 97 | sep_inputs = self.tokenizer(sep, return_tensors=None, add_special_tokens=False)[ 98 | "input_ids" 99 | ] 100 | inputs = [] 101 | for query, passage in pairs: 102 | query_inputs = self.tokenizer( 103 | f"A: {query}", 104 | return_tensors=None, 105 | add_special_tokens=False, 106 | max_length=max_sequence_length * 3 // 4, 107 | truncation=True, 108 | ) 109 | passage_inputs = self.tokenizer( 110 | f"B: {passage}", 111 | return_tensors=None, 112 | add_special_tokens=False, 113 | max_length=max_sequence_length, 114 | truncation=True, 115 | ) 116 | item = self.tokenizer.prepare_for_model( 117 | [self.tokenizer.bos_token_id] + query_inputs["input_ids"], 118 | sep_inputs + passage_inputs["input_ids"], 119 | truncation="only_second", 120 | max_length=max_sequence_length, 121 | padding=False, 122 | return_attention_mask=False, 123 | return_token_type_ids=False, 124 | add_special_tokens=False, 125 | ) 126 | item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs 127 | item["attention_mask"] = [1] * len(item["input_ids"]) 128 | inputs.append(item) 129 | 130 | return self.tokenizer.pad( 131 | inputs, 132 | padding=True, 133 | max_length=max_sequence_length + len(sep_inputs) + len(prompt_inputs), 134 | pad_to_multiple_of=8, 135 | return_tensors="pt", 136 | ) 137 | 138 | @torch.inference_mode() 139 | def rank( 140 | self, 141 | query: str, 142 | docs: Union[str, List[str], Document, List[Document]], 143 | doc_ids: Optional[Union[List[str], List[int]]] = None, 144 | metadata: Optional[List[dict]] = None, 145 | batch_size: Optional[int] = None, 146 | max_sequence_length: Optional[int] = None, 147 | ) -> RankedResults: 148 | docs = prep_docs(docs, doc_ids, metadata) 149 | pairs = [(query, doc.text) for doc in docs] 150 | 151 | # Override self.batch_size if explicitly set 152 | if batch_size is None: 153 | batch_size = self.batch_size 154 | 155 | # Same for max_sequence_length 156 | if max_sequence_length is None: 157 | max_sequence_length = self.max_sequence_length 158 | 159 | batched_pairs = [ 160 | pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size) 161 | ] 162 | scores = [] 163 | 164 | for batch in batched_pairs: 165 | inputs = self._get_inputs(batch, max_sequence_length=max_sequence_length) 166 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 167 | 168 | outputs = self.model(**inputs, return_dict=True, **self.params) 169 | all_scores = [ 170 | scores[:, -1] 171 | .view( 172 | -1, 173 | ) 174 | .float() 175 | for scores in outputs[0] 176 | ] 177 | batch_scores = all_scores[-1].cpu().numpy().tolist() 178 | 179 | scores.extend(batch_scores) 180 | 181 | ranked_results = [ 182 | Result(document=doc, score=score, rank=idx + 1) 183 | for idx, (doc, score) in enumerate( 184 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) 185 | ) 186 | ] 187 | return RankedResults(results=ranked_results, query=query, has_scores=True) 188 | 189 | @torch.inference_mode() 190 | def score(self, query: str, doc: str) -> float: 191 | inputs = self._get_inputs( 192 | [(query, doc)], max_sequence_length=self.max_sequence_length 193 | ) 194 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 195 | 196 | outputs = self.model(**inputs, return_dict=True, **self.params) 197 | all_scores = [ 198 | scores[:, -1] 199 | .view( 200 | -1, 201 | ) 202 | .float() 203 | for scores in outputs[0] 204 | ] 205 | score = all_scores[-1].item() 206 | 207 | return score 208 | -------------------------------------------------------------------------------- /rerankers/models/llm_relevance_filter.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union, Dict 2 | import warnings 3 | import re 4 | 5 | try: 6 | from litellm import completion 7 | except ImportError: 8 | pass 9 | 10 | from rerankers.models.ranker import BaseRanker 11 | from rerankers.documents import Document 12 | from rerankers.results import RankedResults, Result 13 | from rerankers.utils import prep_docs, vprint 14 | 15 | SUPPORTED_BACKENDS = ["litellm"] 16 | UNSUPPORTED_BACKENDS = ["gaspard", "claudette", "cosette"] 17 | 18 | SYSTEM = ( 19 | "You are a friendly AI assistant, working on document relevance filtering. Your task is " 20 | "to determine if a document is relevant to answering a given query. You must assign a binary " 21 | "RELEVANT or NOT_RELEVANT label to each document by carefully analysing them and the query." 22 | ) 23 | DEFAULT_PROMPT_TEMPLATE = """ 24 | Think carefully about whether the following documents would be useful to answer the query. 25 | For each document, explain your reasoning and then provide a binary decision (RELEVANT or NOT_RELEVANT). If a document is partially relevant, you will assign the RELEVANT label. 26 | 27 | The documents will be given to you in the following format: 28 | 29 | 30 | 31 | Text of the query. 32 | 33 | 34 | 35 | 36 | Text of the first document. 37 | 38 | 39 | Text of the second document. 40 | 41 | 42 | 43 | And you will respond in the following format: 44 | 45 | 46 | 47 | Your reasoning regarding the document's relevance. 48 | 49 | 50 | RELEVANT or NOT_RELEVANT 51 | 52 | 53 | 54 | 55 | Here is the query and documents: 56 | 57 | 58 | 59 | {query} 60 | 61 | 62 | 63 | {docu_inputs} 64 | 65 | 66 | 67 | Analyse the above documents and provide your responses using the provided format. You must assign either the RELEVANT or NOT_RELEVANT label, no other option is permitted.""" 68 | 69 | class LLMRelevanceFilter(BaseRanker): 70 | def __init__( 71 | self, 72 | model_name_or_path: str, 73 | backend: str = "litellm", 74 | prompt_template: Optional[str] = None, 75 | temperature: float = 0.0, 76 | verbose: int = 1, 77 | default_label: str = "RELEVANT", 78 | **kwargs 79 | ): 80 | """Initialize the LLM Relevance Filter. 81 | 82 | Args: 83 | model_name_or_path: Name of the model to use (e.g. "gpt-4") 84 | backend: One of "litellm", "gaspard", "claudette", "cosette" 85 | prompt_template: Optional custom prompt template. Must include {query} and {docu_inputs} placeholders. 86 | temperature: Temperature for LLM sampling (default 0.0 for deterministic outputs) 87 | verbose: Verbosity level 88 | **kwargs: Additional kwargs passed to the backend 89 | """ 90 | super().__init__(model_name_or_path, verbose) 91 | if backend not in SUPPORTED_BACKENDS: 92 | raise ValueError(f"Backend must be one of {SUPPORTED_BACKENDS}") 93 | 94 | if backend != "litellm": 95 | warnings.warn(f"Backend {backend} is experimental and may not work as expected") 96 | 97 | self.backend = backend 98 | self.model_name = model_name_or_path 99 | self.temperature = temperature 100 | self.prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE 101 | self.verbose = verbose 102 | self.additional_kwargs = kwargs 103 | self.default_label = default_label 104 | 105 | vprint(f"Initialized LLMRelevanceFilter with {backend} backend using model {model_name_or_path}", verbose) 106 | 107 | # Verify backend is available 108 | if backend == "litellm" and "completion" not in globals(): 109 | raise ImportError("litellm is required for the litellm backend. Install with pip install litellm") 110 | 111 | def _get_completion(self, prompt: str) -> str: 112 | """Get completion from the selected backend.""" 113 | if self.backend == "litellm": 114 | response = completion( 115 | model=self.model_name, 116 | messages=[{"role": "user", "content": prompt}], 117 | temperature=self.temperature, 118 | **self.additional_kwargs 119 | ) 120 | return response.choices[0].message.content.strip() 121 | else: 122 | raise NotImplementedError(f"Backend {self.backend} not yet implemented") 123 | 124 | def _parse_response(self, response: str) -> str: 125 | """ 126 | Parse an XML response to extract the answer from within the tags. 127 | If no answer is found, defaults to "NOT_RELEVANT". 128 | """ 129 | match = re.search(r'\s*(RELEVANT|NOT_RELEVANT)\s*', response, re.IGNORECASE) 130 | if match: 131 | return match.group(1).upper() 132 | print(response) 133 | print("MALFORMATTED RESPONSE!") 134 | return self.default_label 135 | 136 | def _format_doc_inputs(self, docs: List[str]) -> str: 137 | """ 138 | Format a list of document texts into an XML string with enumerated document IDs. 139 | Each document is wrapped in a ... block. 140 | """ 141 | formatted_docs = [] 142 | for i, text in enumerate(docs): 143 | formatted_docs.append(f"\n{text}\n") 144 | return "\n".join(formatted_docs) 145 | 146 | def score(self, query: str, doc: str) -> float: 147 | """Score a single document.""" 148 | # Format the single document as an XML input. 149 | doc_xml = self._format_doc_inputs([doc]) 150 | prompt = self.prompt_template.format(query=query, docu_inputs=doc_xml) 151 | response = self._get_completion(prompt) 152 | print(response) 153 | answer = self._parse_response(response) 154 | print(answer) 155 | return 1.0 if answer == "RELEVANT" else 0.0 156 | 157 | def rank( 158 | self, 159 | query: str, 160 | docs: Union[str, List[str], Document, List[Document]], 161 | doc_ids: Optional[Union[List[str], List[int]]] = None, 162 | metadata: Optional[List[dict]] = None, 163 | ) -> RankedResults: 164 | """Rank a list of documents based on relevance to query.""" 165 | docs = prep_docs(docs, doc_ids, metadata) 166 | doc_texts = [doc.text for doc in docs] 167 | # Format all document texts into one XML block. 168 | docs_xml = self._format_doc_inputs(doc_texts) 169 | prompt = self.prompt_template.format(query=query, docu_inputs=docs_xml) 170 | response = self._get_completion(prompt) 171 | print(response) 172 | 173 | pattern = re.compile(r'(.*?)', re.DOTALL) 174 | matches = pattern.findall(response) 175 | doc_scores = {} 176 | for doc_id, content in matches: 177 | ans = self._parse_response(content) 178 | doc_scores[int(doc_id)] = 1.0 if ans == "RELEVANT" else 0.0 179 | 180 | # Preserve original order while sorting by score descending. 181 | scores_with_index = [] 182 | for i, doc in enumerate(docs): 183 | score = doc_scores.get(i, 0.0) 184 | scores_with_index.append((score, i, doc)) 185 | 186 | scores_with_index.sort(key=lambda x: (-x[0], x[1])) 187 | 188 | ranked_results = [ 189 | Result(document=doc, score=score, rank=idx + 1) 190 | for idx, (score, _, doc) in enumerate(scores_with_index) 191 | ] 192 | 193 | return RankedResults(results=ranked_results, query=query, has_scores=True) 194 | -------------------------------------------------------------------------------- /rerankers/models/monovlm_ranker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import base64 4 | import io 5 | # TODO: Support more than Qwen 6 | from transformers import AutoProcessor, Qwen2VLForConditionalGeneration 7 | from rerankers.models.ranker import BaseRanker 8 | from rerankers.documents import Document 9 | from typing import Union, List, Optional 10 | from rerankers.utils import vprint, get_device, get_dtype, prep_image_docs 11 | from rerankers.results import RankedResults, Result 12 | 13 | PREDICTION_TOKENS = { 14 | "default": ["False", "True"], 15 | "lightonai/MonoQwen2-VL-v0.1": ["False", "True"] 16 | } 17 | 18 | def _get_output_tokens(model_name_or_path, token_false: str, token_true: str): 19 | if token_false == "auto": 20 | if model_name_or_path in PREDICTION_TOKENS: 21 | token_false = PREDICTION_TOKENS[model_name_or_path][0] 22 | else: 23 | token_false = PREDICTION_TOKENS["default"][0] 24 | print( 25 | f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_false to `{token_false}`." 26 | ) 27 | if token_true == "auto": 28 | if model_name_or_path in PREDICTION_TOKENS: 29 | token_true = PREDICTION_TOKENS[model_name_or_path][1] 30 | else: 31 | token_true = PREDICTION_TOKENS["default"][1] 32 | print( 33 | f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_true to `{token_true}`." 34 | ) 35 | 36 | return token_false, token_true 37 | 38 | class MonoVLMRanker(BaseRanker): 39 | def __init__( 40 | self, 41 | model_name_or_path: str, 42 | processor_name: Optional[str] = None, 43 | dtype: Optional[Union[str, torch.dtype]] = 'bf16', 44 | device: Optional[Union[str, torch.device]] = None, 45 | batch_size: int = 1, 46 | verbose: int = 1, 47 | token_false: str = "auto", 48 | token_true: str = "auto", 49 | return_logits: bool = False, 50 | prompt_template: str = "Assert the relevance of the previous image document to the following query, answer True or False. The query is: {query}", 51 | **kwargs 52 | ): 53 | self.verbose = verbose 54 | self.device = get_device(device, verbose=self.verbose) 55 | if self.device == 'mps': 56 | print("WARNING: MPS is not supported by MonoVLMRanker due to PyTorch limitations. Falling back to CPU.") 57 | self.device = 'cpu' 58 | print(dtype) 59 | self.dtype = get_dtype(dtype, self.device, self.verbose) 60 | self.batch_size = batch_size 61 | self.return_logits = return_logits 62 | self.prompt_template = prompt_template 63 | 64 | vprint(f"Loading model {model_name_or_path}, this might take a while...", self.verbose) 65 | vprint(f"Using device {self.device}.", self.verbose) 66 | vprint(f"Using dtype {self.dtype}.", self.verbose) 67 | 68 | processor_name = processor_name or "Qwen/Qwen2-VL-2B-Instruct" 69 | processor_kwargs = kwargs.get("processor_kwargs", {}) 70 | model_kwargs = kwargs.get("model_kwargs", {}) 71 | attention_implementation = kwargs.get("attention_implementation", "flash_attention_2") 72 | self.processor = AutoProcessor.from_pretrained(processor_name, **processor_kwargs) 73 | self.model = Qwen2VLForConditionalGeneration.from_pretrained( 74 | model_name_or_path, 75 | device_map=self.device, 76 | torch_dtype=self.dtype, 77 | attn_implementation=attention_implementation, 78 | **model_kwargs 79 | ) 80 | self.model.eval() 81 | 82 | token_false, token_true = _get_output_tokens( 83 | model_name_or_path=model_name_or_path, 84 | token_false=token_false, 85 | token_true=token_true, 86 | ) 87 | self.token_false_id = self.processor.tokenizer.convert_tokens_to_ids(token_false) 88 | self.token_true_id = self.processor.tokenizer.convert_tokens_to_ids(token_true) 89 | 90 | vprint(f"VLM true token set to {token_true}", self.verbose) 91 | vprint(f"VLM false token set to {token_false}", self.verbose) 92 | 93 | @torch.inference_mode() 94 | def _get_scores(self, query: str, docs: List[Document]) -> List[float]: 95 | scores = [] 96 | for doc in docs: 97 | if doc.document_type != "image" or not doc.base64: 98 | raise ValueError("MonoVLMRanker requires image documents with base64 data") 99 | 100 | # Convert base64 to PIL Image 101 | image_io = io.BytesIO(base64.b64decode(doc.base64)) 102 | image_io.seek(0) # Reset file pointer to start 103 | image = Image.open(image_io).convert('RGB') 104 | 105 | # Prepare prompt 106 | prompt = self.prompt_template.format(query=query) 107 | messages = [ 108 | { 109 | "role": "user", 110 | "content": [ 111 | {"type": "image", "image": image}, 112 | {"type": "text", "text": prompt}, 113 | ], 114 | } 115 | ] 116 | 117 | # Process inputs 118 | text = self.processor.apply_chat_template( 119 | messages, 120 | tokenize=False, 121 | add_generation_prompt=True 122 | ) 123 | inputs = self.processor( 124 | text=text, 125 | images=image, 126 | return_tensors="pt" 127 | ).to(self.device).to(self.dtype) 128 | 129 | # Get model outputs 130 | outputs = self.model(**inputs) 131 | logits = outputs.logits[:, -1, :] 132 | 133 | # Calculate scores 134 | relevant_logits = logits[:, [self.token_false_id, self.token_true_id]] 135 | if self.return_logits: 136 | score = relevant_logits[0, 1].cpu().item() # True logit 137 | else: 138 | probs = torch.softmax(relevant_logits, dim=-1) 139 | score = probs[0, 1].cpu().item() # True probability 140 | 141 | scores.append(score) 142 | 143 | return scores 144 | 145 | def rank( 146 | self, 147 | query: str, 148 | docs: Union[str, List[str], Document, List[Document]], 149 | doc_ids: Optional[Union[List[str], List[int]]] = None, 150 | metadata: Optional[List[dict]] = None, 151 | ) -> RankedResults: 152 | docs = prep_image_docs(docs, doc_ids, metadata) 153 | scores = self._get_scores(query, docs) 154 | ranked_results = [ 155 | Result(document=doc, score=score, rank=idx + 1) 156 | for idx, (doc, score) in enumerate( 157 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) 158 | ) 159 | ] 160 | return RankedResults(results=ranked_results, query=query, has_scores=True) 161 | 162 | def score(self, query: str, doc: Union[str, Document]) -> float: 163 | scores = self._get_scores(query, [doc]) 164 | return scores[0] 165 | -------------------------------------------------------------------------------- /rerankers/models/mxbai_v2.py: -------------------------------------------------------------------------------- 1 | """ 2 | MXBai V2 Reranker implementation 3 | 4 | Parts of the code borrowed/adapted from the Apache 2.0 licensed original codebase: https://github.com/mixedbread-ai/mxbai-rerank 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import torch 10 | from typing import Union, List, Optional 11 | 12 | # Rerankers base imports 13 | from rerankers.models.ranker import BaseRanker 14 | from rerankers.documents import Document 15 | from rerankers.results import RankedResults, Result 16 | from rerankers.utils import vprint, prep_docs, get_device, get_dtype 17 | 18 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 19 | 20 | 21 | # Default prompt templates and tokens 22 | SEPS = { 23 | "mixedbread-ai/mxbai-rerank-large-v2": "\n", 24 | "mixedbread-ai/mxbai-rerank-base-v2": "\n", 25 | "default": "\n" 26 | } 27 | 28 | INSTRUCTION_PROMPT = { 29 | "mixedbread-ai/mxbai-rerank-large-v2": "instruction: {instruction}", 30 | "mixedbread-ai/mxbai-rerank-base-v2": "instruction: {instruction}", 31 | "default": "instruction: {instruction}" 32 | } 33 | 34 | QUERY_PROMPT = { 35 | "mixedbread-ai/mxbai-rerank-large-v2": "query: {query}", 36 | "mixedbread-ai/mxbai-rerank-base-v2": "query: {query}", 37 | "default": "query: {query}" 38 | } 39 | 40 | DOC_PROMPT = { 41 | "mixedbread-ai/mxbai-rerank-large-v2": "document: {document}", 42 | "mixedbread-ai/mxbai-rerank-base-v2": "document: {document}", 43 | "default": "document: {document}" 44 | } 45 | 46 | TASK_PROMPT = { 47 | "mixedbread-ai/mxbai-rerank-large-v2": """You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant). 48 | Relevance:""", 49 | "mixedbread-ai/mxbai-rerank-base-v2": """You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant). 50 | Relevance:""", 51 | "default": """You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant). 52 | Relevance:""" 53 | } 54 | 55 | CHAT_TEMPLATE = { 56 | "mixedbread-ai/mxbai-rerank-large-v2": { 57 | "prefix": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n", 58 | "suffix": "<|im_end|>\n<|im_start|>assistant\n", 59 | }, 60 | "mixedbread-ai/mxbai-rerank-base-v2": { 61 | "prefix": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n", 62 | "suffix": "<|im_end|>\n<|im_start|>assistant\n", 63 | }, 64 | "default": { 65 | "prefix": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n", 66 | "suffix": "<|im_end|>\n<|im_start|>assistant\n", 67 | } 68 | } 69 | 70 | POS_TOKEN = { 71 | "mixedbread-ai/mxbai-rerank-large-v2": "1", 72 | "mixedbread-ai/mxbai-rerank-base-v2": "1", 73 | "default": "1" 74 | } 75 | 76 | NEG_TOKEN = { 77 | "mixedbread-ai/mxbai-rerank-large-v2": "0", 78 | "mixedbread-ai/mxbai-rerank-base-v2": "0", 79 | "default": "0" 80 | } 81 | 82 | 83 | def _ensure_multiple_of_8(x: int, max_value: Optional[int] = None) -> int: 84 | """Make x a multiple of 8, respecting optional max_value""" 85 | if max_value is not None: 86 | max_value = max_value - max_value % 8 87 | x = min(x, max_value) 88 | return x - x % 8 89 | 90 | 91 | class MxBaiV2Ranker(BaseRanker): 92 | """ 93 | A reranker that uses MxBai models for yes/no-based relevance classification. 94 | 95 | This ranker uses causal language models from the MxBai family to determine 96 | document relevance by predicting binary relevance scores (0/1). 97 | """ 98 | 99 | def __init__( 100 | self, 101 | model_name_or_path: str = "mixedbread-ai/mxbai-rerank-base-v2", 102 | device: Optional[Union[str, torch.device]] = None, 103 | dtype: Optional[Union[str, torch.dtype]] = None, 104 | batch_size: int = 16, 105 | verbose: int = 1, 106 | max_length: int = 8192, 107 | **kwargs 108 | ): 109 | """ 110 | Initialize the MxBai reranker. 111 | 112 | Args: 113 | model_name_or_path: Path or name of the MxBai model. 114 | device: Device to use (e.g. 'cpu', 'cuda:0', or 'auto'). 115 | dtype: Torch dtype or 'auto'. 116 | batch_size: Batch size for processing multiple documents. 117 | verbose: Verbosity level. 118 | max_length: Maximum token length for inputs. 119 | **kwargs: Additional kwargs for model and tokenizer. 120 | """ 121 | super().__init__(model_name_or_path=model_name_or_path, verbose=verbose) 122 | self.verbose = verbose 123 | self.device = get_device(device, verbose=self.verbose) 124 | self.dtype = get_dtype(dtype, self.device, self.verbose) 125 | self.batch_size = batch_size 126 | self.max_length = max_length 127 | self.cfg = AutoConfig.from_pretrained(model_name_or_path) 128 | 129 | vprint(f"Loading MxBai model from {model_name_or_path}", self.verbose) 130 | vprint(f"Device: {self.device}, Dtype: {self.dtype}", self.verbose) 131 | 132 | # Extract model kwargs 133 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) 134 | model_kwargs = kwargs.get("model_kwargs", {}) 135 | 136 | # Try to use flash attention if available 137 | try: 138 | import flash_attn 139 | attn_impl = "flash_attention_2" 140 | self.dtype = 'bfloat16' 141 | vprint("Flash attention is available. Setting dtype to bfloat16.", self.verbose) 142 | except ImportError: 143 | attn_impl = None 144 | 145 | # Load model and tokenizer 146 | self.model = AutoModelForCausalLM.from_pretrained( 147 | model_name_or_path, 148 | attn_implementation=attn_impl, 149 | torch_dtype=self.dtype if isinstance(self.dtype, torch.dtype) else "auto", 150 | device_map=str(self.device), 151 | **model_kwargs, 152 | ) 153 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs) 154 | self.tokenizer.padding_side = "left" 155 | 156 | # Get model-specific templates and tokens 157 | model_key = model_name_or_path.split("/")[-1] 158 | self._setup_templates(model_key) 159 | 160 | # Pre-tokenize static prompts for efficiency 161 | self._prepare_tokenized_templates() 162 | 163 | # Switch to eval mode 164 | self.model.eval() 165 | vprint("MxBaiV2Ranker ready.", self.verbose) 166 | 167 | def _setup_templates(self, model_key: str): 168 | """Set up the templates and tokens for the specific model""" 169 | # Helper function to get template with fallback to default 170 | def get_template(template_dict, key): 171 | return template_dict.get(key, template_dict["default"]) 172 | 173 | # Set up all templates 174 | self.task_prompt = get_template(TASK_PROMPT, model_key) 175 | self.chat_template = get_template(CHAT_TEMPLATE, model_key) 176 | self.query_prompt = get_template(QUERY_PROMPT, model_key) 177 | self.doc_prompt = get_template(DOC_PROMPT, model_key) 178 | self.instruction_prompt = get_template(INSTRUCTION_PROMPT, model_key) 179 | self.pos_token = get_template(POS_TOKEN, model_key) 180 | self.neg_token = get_template(NEG_TOKEN, model_key) 181 | self.sep = get_template(SEPS, model_key) 182 | 183 | if not any(model_key in template_dict for template_dict in [TASK_PROMPT, CHAT_TEMPLATE, QUERY_PROMPT, DOC_PROMPT, INSTRUCTION_PROMPT, POS_TOKEN, NEG_TOKEN, SEPS]): 184 | vprint("Model name did not have all necessary instructions. Using default prompt formats which might not be suitable for this model!", self.verbose) 185 | 186 | def _prepare_tokenized_templates(self): 187 | """Pre-tokenize templates for efficiency""" 188 | # Get token IDs for positive and negative tokens 189 | self.pos_id = self.tokenizer(self.pos_token, return_tensors=None, add_special_tokens=False)["input_ids"][0] 190 | self.neg_id = self.tokenizer(self.neg_token, return_tensors=None, add_special_tokens=False)["input_ids"][0] 191 | 192 | # Pre-tokenize chat template parts 193 | self.prefix_ids = self.tokenizer(self.chat_template["prefix"], return_tensors=None, add_special_tokens=False)["input_ids"] 194 | self.suffix_ids = self.tokenizer(self.chat_template["suffix"], return_tensors=None, add_special_tokens=False)["input_ids"] 195 | 196 | # Pre-tokenize task prompt and separator 197 | self.task_prompt_ids = self.tokenizer(self.task_prompt, return_tensors=None, add_special_tokens=False)["input_ids"] 198 | self.sep_ids = self.tokenizer(self.sep, return_tensors=None, add_special_tokens=False)["input_ids"] 199 | 200 | # Calculate total length of static tokens 201 | self.static_tokens_length = ( 202 | len(self.prefix_ids) + 203 | len(self.task_prompt_ids) + 204 | len(self.suffix_ids) + 205 | len(self.sep_ids) 206 | ) 207 | 208 | # Set model max length 209 | self.model_max_length = self.cfg.max_position_embeddings 210 | 211 | # Adjust max_length to account for static tokens 212 | if self.max_length + self.static_tokens_length > self.model_max_length: 213 | self.max_length = self.model_max_length - self.static_tokens_length 214 | 215 | # Ensure padding length is a multiple of 8 for efficiency 216 | self.padding_length = _ensure_multiple_of_8( 217 | max(self.model_max_length, self.max_length + self.static_tokens_length), 218 | max_value=self.model_max_length 219 | ) 220 | 221 | def _create_full_input_ids(self, content_ids: List[int]) -> List[int]: 222 | """ 223 | Create the full input by combining content with template parts. 224 | 225 | Args: 226 | content_ids: Token IDs for the query-document content 227 | 228 | Returns: 229 | List of token IDs for the complete input 230 | """ 231 | return ( 232 | self.prefix_ids + 233 | content_ids + 234 | self.sep_ids + 235 | self.task_prompt_ids + 236 | self.suffix_ids 237 | ) 238 | 239 | def _prepare_batch( 240 | self, 241 | queries: List[str], 242 | documents: List[str], 243 | instruction: Optional[str] = None, 244 | ) -> dict: 245 | """ 246 | Prepare a batch of query-document pairs for the model. 247 | 248 | Args: 249 | queries: List of query strings 250 | documents: List of document strings 251 | instruction: Optional instruction to prepend 252 | 253 | Returns: 254 | Dictionary with input_ids and attention_mask tensors 255 | """ 256 | batch_inputs = [] 257 | 258 | for query, document in zip(queries, documents): 259 | # Format query with template 260 | query_text = self.query_prompt.format(query=query) 261 | 262 | # Add instruction if provided 263 | if instruction: 264 | instruction_text = self.instruction_prompt.format(instruction=instruction) 265 | query_text = instruction_text + self.sep + query_text 266 | 267 | # Tokenize query with length limit 268 | query_ids = self.tokenizer( 269 | query_text, 270 | return_tensors=None, 271 | add_special_tokens=False, 272 | max_length=self.max_length * 3 // 4, # Use 3/4 of tokens for query 273 | truncation=True, 274 | )["input_ids"] 275 | 276 | # Calculate remaining tokens for document 277 | available_tokens = self.model_max_length - len(query_ids) - self.static_tokens_length 278 | doc_max_length = min(available_tokens, self.max_length // 4) # Use 1/4 of tokens for document 279 | 280 | # Tokenize document 281 | doc_text = self.doc_prompt.format(document=document) 282 | doc_ids = self.tokenizer( 283 | doc_text, 284 | return_tensors=None, 285 | add_special_tokens=False, 286 | max_length=doc_max_length, 287 | truncation=True, 288 | )["input_ids"] 289 | 290 | # Combine query and document 291 | combined = self.tokenizer.prepare_for_model( 292 | query_ids, 293 | self.sep_ids + doc_ids, 294 | truncation="only_second", 295 | max_length=self.max_length, 296 | padding=False, 297 | return_attention_mask=False, 298 | return_token_type_ids=False, 299 | add_special_tokens=False, 300 | ) 301 | 302 | # Create full input with template 303 | full_input_ids = self._create_full_input_ids(combined["input_ids"]) 304 | 305 | # Add to batch 306 | batch_inputs.append({ 307 | "input_ids": full_input_ids, 308 | "attention_mask": [1] * len(full_input_ids), 309 | }) 310 | 311 | # Pad all inputs to the same length 312 | padded_batch = self.tokenizer.pad( 313 | batch_inputs, 314 | padding="longest", 315 | max_length=self.padding_length, 316 | pad_to_multiple_of=8, 317 | return_tensors="pt", 318 | ) 319 | 320 | return padded_batch 321 | 322 | @torch.inference_mode() 323 | def _predict( 324 | self, 325 | queries: List[str], 326 | documents: List[str], 327 | instruction: Optional[str] = None, 328 | ) -> torch.Tensor: 329 | """ 330 | Get relevance scores for query-document pairs. 331 | 332 | Args: 333 | queries: List of query strings 334 | documents: List of document strings 335 | instruction: Optional instruction to prepend 336 | 337 | Returns: 338 | Tensor of relevance scores 339 | """ 340 | # Prepare inputs 341 | inputs = self._prepare_batch(queries, documents, instruction=instruction) 342 | 343 | # Move to device 344 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 345 | 346 | # Run model 347 | outputs = self.model( 348 | input_ids=inputs["input_ids"], 349 | attention_mask=inputs["attention_mask"], 350 | output_hidden_states=True, 351 | ) 352 | 353 | score = outputs.logits[:, -1, self.pos_id] - outputs.logits[:, -1, self.neg_id] 354 | 355 | # Return scores as CPU tensor 356 | return score.detach().cpu().float() 357 | 358 | @torch.inference_mode() 359 | def rank( 360 | self, 361 | query: str, 362 | docs: Union[str, List[str], Document, List[Document]], 363 | doc_ids: Optional[Union[List[str], List[int]]] = None, 364 | metadata: Optional[List[dict]] = None, 365 | batch_size: Optional[int] = None, 366 | instruction: Optional[str] = None, 367 | ) -> RankedResults: 368 | """ 369 | Rank documents by relevance to the query. 370 | 371 | Args: 372 | query: Query string 373 | docs: Documents to rank 374 | doc_ids: Optional document IDs 375 | metadata: Optional document metadata 376 | batch_size: Optional batch size override 377 | instruction: Optional instruction to prepend 378 | 379 | Returns: 380 | RankedResults with documents sorted by relevance 381 | """ 382 | # Prepare documents 383 | docs = prep_docs(docs, doc_ids, metadata) 384 | 385 | # Use default batch size if not specified 386 | if batch_size is None: 387 | batch_size = self.batch_size 388 | 389 | all_docs = [] 390 | all_scores = [] 391 | 392 | # Process in batches 393 | for i in range(0, len(docs), batch_size): 394 | batch_docs = docs[i:i + batch_size] 395 | batch_scores = self._predict( 396 | queries=[query] * len(batch_docs), 397 | documents=[d.text for d in batch_docs], 398 | instruction=instruction, 399 | ) 400 | 401 | all_docs.extend(batch_docs) 402 | all_scores.extend(batch_scores.tolist()) 403 | 404 | # Sort by descending score 405 | scored_docs = sorted(zip(all_docs, all_scores), key=lambda x: x[1], reverse=True) 406 | 407 | # Create ranked results 408 | results = [ 409 | Result(document=doc, score=score, rank=idx + 1) 410 | for idx, (doc, score) in enumerate(scored_docs) 411 | ] 412 | 413 | return RankedResults(results=results, query=query, has_scores=True) 414 | 415 | @torch.inference_mode() 416 | def score( 417 | self, 418 | query: str, 419 | doc: Union[str, Document], 420 | instruction: Optional[str] = None, 421 | ) -> float: 422 | """ 423 | Score a single query-document pair. 424 | 425 | Args: 426 | query: Query string 427 | doc: Document to score 428 | instruction: Optional instruction to prepend 429 | 430 | Returns: 431 | Relevance score as float 432 | """ 433 | # Extract text if document is a Document object 434 | doc_text = doc.text if isinstance(doc, Document) else doc 435 | 436 | # Get score 437 | scores = self._predict( 438 | queries=[query], 439 | documents=[doc_text], 440 | instruction=instruction, 441 | ) 442 | 443 | # Return as float 444 | return float(scores[0]) -------------------------------------------------------------------------------- /rerankers/models/pylate_ranker.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | 5 | from pylate import models, rank 6 | from rerankers.documents import Document 7 | from rerankers.models.ranker import BaseRanker 8 | from rerankers.results import RankedResults, Result 9 | from rerankers.utils import get_device, get_dtype, prep_docs, vprint 10 | 11 | 12 | class PyLateRanker(BaseRanker): 13 | def __init__( 14 | self, 15 | model_name: str, 16 | batch_size: int = 32, 17 | dtype: Optional[Union[str, torch.dtype]] = None, 18 | device: Optional[Union[str, torch.device]] = None, 19 | verbose: int = 1, 20 | **kwargs, 21 | ): 22 | self.verbose = verbose 23 | self.device = get_device(device, self.verbose) 24 | self.dtype = get_dtype(dtype, self.device, self.verbose) 25 | self.batch_size = batch_size 26 | vprint( 27 | f"Loading model {model_name}, this might take a while...", 28 | self.verbose, 29 | ) 30 | kwargs = kwargs.get("kwargs", {}) 31 | kwargs["device"] = self.device 32 | model_kwargs = kwargs.get("model_kwargs", {}) 33 | model_kwargs["torch_dtype"] = self.dtype 34 | self.model = models.ColBERT( 35 | model_name_or_path=model_name, 36 | model_kwargs=model_kwargs, 37 | **kwargs, 38 | ) 39 | 40 | def rank( 41 | self, 42 | query: str, 43 | docs: Union[Document, str, List[Document], List[str]], 44 | doc_ids: Optional[Union[List[str], List[int]]] = None, 45 | metadata: Optional[List[dict]] = None, 46 | ) -> RankedResults: 47 | docs = prep_docs(docs, doc_ids, metadata) 48 | documents_embeddings = self.model.encode( 49 | [[d.text for d in docs]], 50 | is_query=False, 51 | ) 52 | 53 | query_embeddings = self.model.encode( 54 | [query], 55 | is_query=True, 56 | ) 57 | scores = rank.rerank( 58 | documents_ids=[doc_ids], 59 | queries_embeddings=query_embeddings, 60 | documents_embeddings=documents_embeddings, 61 | ) 62 | 63 | ranked_results = [ 64 | Result( 65 | document=doc, 66 | score=score["score"] / len(query_embeddings[0]), 67 | rank=idx + 1, 68 | ) 69 | for idx, (doc, score) in enumerate(zip(docs, scores[0])) 70 | ] 71 | return RankedResults(results=ranked_results, query=query, has_scores=True) 72 | 73 | def score(self, query: str, doc: str) -> float: 74 | document_embeddings = self.model.encode( 75 | doc, 76 | is_query=False, 77 | ) 78 | 79 | query_embeddings = self.model.encode( 80 | query, 81 | is_query=True, 82 | ) 83 | # This is shamefull, I really need to provide a scoring method with padding inside 84 | scores = rank.rerank( 85 | documents_ids=["0"], 86 | queries_embeddings=query_embeddings, 87 | documents_embeddings=document_embeddings, 88 | ) 89 | return scores[0][0]["score"] if scores else 0.0 90 | -------------------------------------------------------------------------------- /rerankers/models/ranker.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from asyncio import get_event_loop 3 | from functools import partial 4 | from typing import List, Optional, Union 5 | from rerankers.results import RankedResults 6 | from rerankers.documents import Document 7 | 8 | 9 | class BaseRanker(ABC): 10 | @abstractmethod 11 | def __init__(self, model_name_or_path: str, verbose: int): 12 | pass 13 | 14 | @abstractmethod 15 | def score(self, query: str, doc: str) -> float: 16 | pass 17 | 18 | @abstractmethod 19 | def rank( 20 | self, 21 | query: str, 22 | docs: Union[str, List[str], Document, List[Document]], 23 | doc_ids: Optional[Union[List[str], List[int]]] = None, 24 | ) -> RankedResults: 25 | """ 26 | End-to-end reranking of documents. 27 | """ 28 | pass 29 | 30 | async def rank_async( 31 | self, 32 | query: str, 33 | docs: List[str], 34 | doc_ids: Optional[Union[List[str], str]] = None, 35 | ) -> RankedResults: 36 | 37 | 38 | loop = get_event_loop() 39 | return await loop.run_in_executor(None, partial(self.rank, query, docs, doc_ids)) 40 | 41 | def as_langchain_compressor(self, k: int = 10): 42 | try: 43 | from rerankers.integrations.langchain import RerankerLangChainCompressor 44 | 45 | return RerankerLangChainCompressor(model=self, k=k) 46 | except ImportError: 47 | print( 48 | "You need to install langchain and langchain_core to export a reranker as a LangChainCompressor!" 49 | ) 50 | print( 51 | 'Please run `pip install "rerankers[langchain]"` to get all the required dependencies.' 52 | ) 53 | -------------------------------------------------------------------------------- /rerankers/models/rankgpt_rankers.py: -------------------------------------------------------------------------------- 1 | """Full implementation is from the original RankGPT repository https://github.com/sunnweiwei/RankGPT under its Apache 2.0 License 2 | 3 | Changes made are: 4 | - Truncating the file to only the relevant functions 5 | - Using only LiteLLM 6 | - make_item() added 7 | - Packaging it onto RankGPTRanker""" 8 | 9 | import copy 10 | from typing import Optional, Union, List, Dict 11 | from litellm import completion 12 | from rerankers.models.ranker import BaseRanker 13 | from rerankers.documents import Document 14 | from rerankers.results import RankedResults, Result 15 | from rerankers.utils import vprint, prep_docs 16 | 17 | 18 | def get_prefix_prompt(query, num): 19 | return [ 20 | { 21 | "role": "system", 22 | "content": "You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.", 23 | }, 24 | { 25 | "role": "user", 26 | "content": f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}.", 27 | }, 28 | {"role": "assistant", "content": "Okay, please provide the passages."}, 29 | ] 30 | 31 | 32 | def get_post_prompt(query, num): 33 | return f"Search Query: {query}. \nRank the {num} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only response the ranking results, do not say any word or explain." 34 | 35 | 36 | def create_permutation_instruction( 37 | item=None, 38 | rank_start=0, 39 | rank_end=100, 40 | lang: str = "en", 41 | ): 42 | query = item["query"] 43 | num = len(item["hits"][rank_start:rank_end]) 44 | 45 | max_length = 300 46 | 47 | messages = get_prefix_prompt(query, num) 48 | rank = 0 49 | for hit in item["hits"][rank_start:rank_end]: 50 | rank += 1 51 | content = hit["content"] 52 | content = content.replace("Title: Content: ", "") 53 | content = content.strip() 54 | # For Japanese should cut by character: content = content[:int(max_length)] 55 | if lang in ["zh", "ja"]: 56 | content = content[: int(max_length)] 57 | else: 58 | content = " ".join(content.split()[: int(max_length)]) 59 | messages.append({"role": "user", "content": f"[{rank}] {content}"}) 60 | messages.append({"role": "assistant", "content": f"Received passage [{rank}]."}) 61 | messages.append({"role": "user", "content": get_post_prompt(query, num)}) 62 | 63 | return messages 64 | 65 | 66 | def clean_response(response: str): 67 | new_response = "" 68 | for c in response: 69 | if not c.isdigit(): 70 | new_response += " " 71 | else: 72 | new_response += c 73 | new_response = new_response.strip() 74 | return new_response 75 | 76 | 77 | def remove_duplicate(response): 78 | new_response = [] 79 | for c in response: 80 | if c not in new_response: 81 | new_response.append(c) 82 | return new_response 83 | 84 | 85 | def receive_permutation(item, permutation, rank_start=0, rank_end=100): 86 | response = clean_response(permutation) 87 | response = [int(x) - 1 for x in response.split()] 88 | response = remove_duplicate(response) 89 | cut_range = copy.deepcopy(item["hits"][rank_start:rank_end]) 90 | original_rank = [tt for tt in range(len(cut_range))] 91 | response = [ss for ss in response if ss in original_rank] 92 | response = response + [tt for tt in original_rank if tt not in response] 93 | for j, x in enumerate(response): 94 | item["hits"][j + rank_start] = copy.deepcopy(cut_range[x]) 95 | if "rank" in item["hits"][j + rank_start]: 96 | item["hits"][j + rank_start]["rank"] = cut_range[j]["rank"] 97 | if "score" in item["hits"][j + rank_start]: 98 | item["hits"][j + rank_start]["score"] = cut_range[j]["score"] 99 | return item 100 | 101 | 102 | def make_item( 103 | query: str, docs: List[str] 104 | ) -> Dict[str, Union[List[Dict[str, str]], str]]: 105 | return { 106 | "query": query, 107 | "hits": [{"content": doc} for doc in docs], 108 | } 109 | 110 | 111 | class RankGPTRanker(BaseRanker): 112 | def __init__( 113 | self, model: str, api_key: str, lang: str = "en", verbose: int = 1 114 | ) -> "RankGPTRanker": 115 | self.api_key = api_key 116 | self.model = model 117 | self.verbose = verbose 118 | self.lang = lang 119 | 120 | def _query_llm(self, messages: List[Dict[str, str]]) -> str: 121 | response = completion( 122 | api_key=self.api_key, model=self.model, messages=messages, temperature=0 123 | ) 124 | return response.choices[0].message.content 125 | 126 | def rank( 127 | self, 128 | query: str, 129 | docs: Union[str, List[str], Document, List[Document]], 130 | doc_ids: Optional[Union[List[str], List[int]]] = None, 131 | metadata: Optional[List[dict]] = None, 132 | rank_start: int = 0, 133 | rank_end: int = 0, 134 | ) -> RankedResults: 135 | docs = prep_docs(docs, doc_ids, metadata) 136 | 137 | item = make_item(query, [d.text for d in docs]) 138 | messages = create_permutation_instruction( 139 | item=item, 140 | rank_start=rank_start, 141 | rank_end=rank_end, 142 | lang=self.lang, 143 | ) 144 | vprint(f"Querying model {self.model} with via LiteLLM...", self.verbose) 145 | permutation = self._query_llm(messages) 146 | item = receive_permutation( 147 | item, permutation, rank_start=rank_start, rank_end=rank_end 148 | ) 149 | ranked_docs = [] 150 | for idx, doc in enumerate(item["hits"]): 151 | ranked_docs.append( 152 | Result( 153 | document=list(filter(lambda x: x.text == doc["content"], docs))[0], 154 | rank=idx + 1, 155 | ) 156 | ) 157 | ranked_results = RankedResults( 158 | results=ranked_docs, query=query, has_scores=False 159 | ) 160 | return ranked_results 161 | 162 | def score(self): 163 | print("Listwise ranking models like RankGPT-4 cannot output scores!") 164 | return None 165 | -------------------------------------------------------------------------------- /rerankers/models/rankllm_ranker.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | from rerankers.models.ranker import BaseRanker 3 | from rerankers.documents import Document 4 | from rerankers.results import RankedResults, Result 5 | from rerankers.utils import prep_docs 6 | 7 | # from rerankers import Reranker 8 | 9 | from rank_llm.rerank.reranker import Reranker as rankllm_Reranker 10 | from rank_llm.rerank import PromptMode, get_azure_openai_args, get_genai_api_key, get_openai_api_key 11 | from rank_llm.data import Candidate, Query, Request 12 | 13 | 14 | class RankLLMRanker(BaseRanker): 15 | def __init__( 16 | self, 17 | model: str = "rank_zephyr", 18 | api_key: Optional[str] = None, 19 | lang: str = "en", 20 | verbose: int = 1, 21 | # RankLLM specific arguments 22 | window_size: int = 20, 23 | context_size: int = 4096, 24 | prompt_mode: PromptMode = PromptMode.RANK_GPT, 25 | num_few_shot_examples: int = 0, 26 | few_shot_file: Optional[str] = None, 27 | num_gpus: int = 1, 28 | variable_passages: bool = False, 29 | use_logits: bool = False, 30 | use_alpha: bool = False, 31 | stride: int = 10, 32 | use_azure_openai: bool = False, 33 | ) -> "RankLLMRanker": 34 | self.api_key = api_key 35 | self.model = model 36 | self.verbose = verbose 37 | self.lang = lang 38 | 39 | # RankLLM-specific parameters 40 | self.window_size = window_size 41 | self.context_size = context_size 42 | self.prompt_mode = prompt_mode 43 | self.num_few_shot_examples = num_few_shot_examples 44 | self.few_shot_file = few_shot_file 45 | self.num_gpus = num_gpus 46 | self.variable_passages = variable_passages 47 | self.use_logits = use_logits 48 | self.use_alpha = use_alpha 49 | self.stride = stride 50 | self.use_azure_openai = use_azure_openai 51 | 52 | kwargs = { 53 | "model_path": self.model, 54 | "default_model_coordinator": None, 55 | "context_size": self.context_size, 56 | "prompt_mode": self.prompt_mode, 57 | "num_gpus": self.num_gpus, 58 | "use_logits": self.use_logits, 59 | "use_alpha": self.use_alpha, 60 | "num_few_shot_examples": self.num_few_shot_examples, 61 | "few_shot_file": self.few_shot_file, 62 | "variable_passages": self.variable_passages, 63 | "interactive": False, 64 | "window_size": self.window_size, 65 | "stride": self.stride, 66 | "use_azure_openai": self.use_azure_openai, 67 | } 68 | model_coordinator = rankllm_Reranker.create_model_coordinator(**kwargs) 69 | self.reranker = rankllm_Reranker(model_coordinator) 70 | 71 | def rank( 72 | self, 73 | query: str, 74 | docs: Union[str, List[str], Document, List[Document]], 75 | doc_ids: Optional[Union[List[str], List[int]]] = None, 76 | metadata: Optional[List[dict]] = None, 77 | rank_start: int = 0, 78 | rank_end: int = 0, 79 | ) -> RankedResults: 80 | docs = prep_docs(docs, doc_ids, metadata) 81 | 82 | request = Request( 83 | query=Query(text=query, qid=1), 84 | candidates=[ 85 | Candidate(doc={"text": doc.text}, docid=doc_idx, score=1) 86 | for doc_idx, doc in enumerate(docs) 87 | ], 88 | ) 89 | 90 | rankllm_results = self.reranker.rerank( 91 | request, 92 | rank_end=len(docs) if rank_end == 0 else rank_end, 93 | window_size=min(20, len(docs)), 94 | step=10, 95 | ) 96 | 97 | ranked_docs = [] 98 | 99 | for rank, result in enumerate(rankllm_results.candidates, start=rank_start): 100 | ranked_docs.append( 101 | Result( 102 | document=docs[result.docid], 103 | rank=rank, 104 | ) 105 | ) 106 | 107 | return RankedResults(results=ranked_docs, query=query, has_scores=False) 108 | 109 | def score(self): 110 | print("Listwise ranking models like RankLLM cannot output scores!") 111 | return None 112 | -------------------------------------------------------------------------------- /rerankers/models/t5ranker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for InRanker is taken from the excellent InRanker repo https://github.com/unicamp-dl/InRanker under its Apache 2.0 license. 3 | The only change to the original implementation is the removal of InRanker's BaseRanker, replacing it with our own to support the unified API better. 4 | The main purpose for adapting this code here rather than installing the InRanker library is to ensure greater version compatibility (InRanker requires Python >=3.10) 5 | """ 6 | 7 | from typing import List, Optional, Union 8 | from math import ceil 9 | 10 | import torch 11 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 12 | from rerankers.models.ranker import BaseRanker 13 | from rerankers.documents import Document 14 | 15 | 16 | from rerankers.results import RankedResults, Result 17 | from rerankers.utils import ( 18 | vprint, 19 | get_device, 20 | get_dtype, 21 | prep_docs, 22 | get_chunks, 23 | ) 24 | try: 25 | from tqdm.auto import tqdm 26 | except ImportError: 27 | def tqdm(iterable, *args, **kwargs): 28 | return iterable 29 | 30 | PREDICTION_TOKENS = { 31 | "default": ["▁false", "▁true"], 32 | "castorini/monot5-base-msmarco": ["▁false", "▁true"], 33 | "castorini/monot5-base-msmarco-10k": ["▁false", "▁true"], 34 | "castorini/monot5-large-msmarco": ["▁false", "▁true"], 35 | "castorini/monot5-large-msmarco-10k": ["▁false", "▁true"], 36 | "castorini/monot5-base-med-msmarco": ["▁false", "▁true"], 37 | "castorini/monot5-3b-med-msmarco": ["▁false", "▁true"], 38 | "castorini/monot5-3b-msmarco-10k": ["▁false", "▁true"], 39 | "unicamp-dl/InRanker-small": ["▁false", "▁true"], 40 | "unicamp-dl/InRanker-base": ["▁false", "▁true"], 41 | "unicamp-dl/InRanker-3B": ["▁false", "▁true"], 42 | "unicamp-dl/mt5-base-en-msmarco": ["▁no", "▁yes"], 43 | "unicamp-dl/ptt5-base-pt-msmarco-10k-v2": ["▁não", "▁sim"], 44 | "unicamp-dl/ptt5-base-pt-msmarco-100k-v2": ["▁não", "▁sim"], 45 | "unicamp-dl/ptt5-base-en-pt-msmarco-100k-v2": ["▁não", "▁sim"], 46 | "unicamp-dl/mt5-base-en-pt-msmarco-v2": ["▁no", "▁yes"], 47 | "unicamp-dl/mt5-base-mmarco-v2": ["▁no", "▁yes"], 48 | "unicamp-dl/mt5-base-en-pt-msmarco-v1": ["▁no", "▁yes"], 49 | "unicamp-dl/mt5-base-mmarco-v1": ["▁no", "▁yes"], 50 | "unicamp-dl/ptt5-base-pt-msmarco-10k-v1": ["▁não", "▁sim"], 51 | "unicamp-dl/ptt5-base-pt-msmarco-100k-v1": ["▁não", "▁sim"], 52 | "unicamp-dl/ptt5-base-en-pt-msmarco-10k-v1": ["▁não", "▁sim"], 53 | "unicamp-dl/mt5-3B-mmarco-en-pt": ["▁", "▁true"], 54 | "unicamp-dl/mt5-13b-mmarco-100k": ["▁", "▁true"], 55 | "unicamp-dl/monoptt5-small": ["▁Não", "▁Sim"], 56 | "unicamp-dl/monoptt5-base": ["▁Não", "▁Sim"], 57 | "unicamp-dl/monoptt5-large": ["▁Não", "▁Sim"], 58 | "unicamp-dl/monoptt5-3b": ["▁Não", "▁Sim"], 59 | "Dundalia/TWOLAR-large": [6136, 1176], 60 | "Dundalia/TWOLAR-xl": [6136, 1176], 61 | } 62 | 63 | 64 | def _get_output_tokens(model_name_or_path, token_false: str, token_true: str): 65 | if token_false == "auto": 66 | if model_name_or_path in PREDICTION_TOKENS: 67 | token_false = PREDICTION_TOKENS[model_name_or_path][0] 68 | else: 69 | token_false = PREDICTION_TOKENS["default"][0] 70 | print( 71 | f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_false to `{token_false}`." 72 | ) 73 | if token_true == "auto": 74 | if model_name_or_path in PREDICTION_TOKENS: 75 | token_true = PREDICTION_TOKENS[model_name_or_path][1] 76 | else: 77 | token_true = PREDICTION_TOKENS["default"][1] 78 | print( 79 | f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_true to `{token_true}`." 80 | ) 81 | 82 | return token_false, token_true 83 | 84 | 85 | class T5Ranker(BaseRanker): 86 | def __init__( 87 | self, 88 | model_name_or_path: str, 89 | batch_size: int = 32, 90 | dtype: Optional[Union[str, torch.dtype]] = None, 91 | device: Optional[Union[str, torch.device]] = None, 92 | verbose: int = 1, 93 | token_false: str = "auto", 94 | token_true: str = "auto", 95 | return_logits: bool = False, 96 | inputs_template: str = "Query: {query} Document: {text} Relevant:", 97 | **kwargs, 98 | ): 99 | """ 100 | Implementation of the key functions from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py 101 | Changes are detailed in the docstring for each relevant function. 102 | 103 | T5Ranker is a wrapper for using Seq2Seq models for ranking. 104 | Args: 105 | batch_size: The batch size to use when encoding. 106 | dtype: Data type for model weights. 107 | device: The device to use for inference ("cpu", "cuda", or "mps"). 108 | verbose: Verbosity level. 109 | silent: Whether to show progress bars. 110 | """ 111 | self.verbose = verbose 112 | self.device = get_device(device, self.verbose, no_mps=True) 113 | self.dtype = get_dtype(dtype, self.device, self.verbose) 114 | self.batch_size = batch_size 115 | vprint( 116 | f"Loading model {model_name_or_path}, this might take a while...", 117 | self.verbose, 118 | ) 119 | vprint(f"Using device {self.device}.", self.verbose) 120 | vprint(f"Using dtype {self.dtype}.", self.verbose) 121 | model_kwargs = kwargs.get("model_kwargs", {}) 122 | self.model = AutoModelForSeq2SeqLM.from_pretrained( 123 | model_name_or_path, 124 | torch_dtype=self.dtype, 125 | **model_kwargs, 126 | ).to(self.device) 127 | self.model.eval() 128 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) 129 | self.tokenizer = AutoTokenizer.from_pretrained( 130 | model_name_or_path, 131 | **tokenizer_kwargs, 132 | ) 133 | 134 | token_false, token_true = _get_output_tokens( 135 | model_name_or_path=model_name_or_path, 136 | token_false=token_false, 137 | token_true=token_true, 138 | ) 139 | if isinstance(token_false, int): 140 | self.token_false_id = token_false 141 | else: 142 | self.token_false_id = self.tokenizer.convert_tokens_to_ids(token_false) 143 | if isinstance(token_true, int): 144 | self.token_true_id = token_true 145 | else: 146 | self.token_true_id = self.tokenizer.convert_tokens_to_ids(token_true) 147 | vprint(f"T5 true token set to {token_true}", self.verbose) 148 | vprint(f"T5 false token set to {token_false}", self.verbose) 149 | 150 | self.return_logits = return_logits 151 | if self.return_logits: 152 | vprint( 153 | f"Returning raw logits for `{token_true}` as scores...", self.verbose 154 | ) 155 | else: 156 | vprint("Returning normalised scores...", self.verbose) 157 | self.inputs_template = inputs_template 158 | vprint(f"Inputs template set to {inputs_template}", self.verbose) 159 | 160 | def rank( 161 | self, 162 | query: str, 163 | docs: Union[str, List[str], Document, List[Document]], 164 | doc_ids: Optional[Union[List[str], List[int]]] = None, 165 | metadata: Optional[List[dict]] = None, 166 | ) -> RankedResults: 167 | """ 168 | Ranks a list of documents based on their relevance to the query. 169 | """ 170 | docs = prep_docs(docs, doc_ids, metadata) 171 | scores = self._get_scores(query, [d.text for d in docs]) 172 | ranked_results = [ 173 | Result(document=doc, score=score, rank=idx + 1) 174 | for idx, (doc, score) in enumerate( 175 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) 176 | ) 177 | ] 178 | return RankedResults(results=ranked_results, query=query, has_scores=True) 179 | 180 | def score(self, query: str, doc: str) -> float: 181 | """ 182 | Scores a single document's relevance to a query. 183 | """ 184 | scores = self._get_scores(query, [doc]) 185 | return scores[0] if scores else 0.0 186 | 187 | @torch.inference_mode() 188 | def _get_scores( 189 | self, 190 | query: str, 191 | docs: List[str], 192 | max_length: int = 512, 193 | batch_size: Optional[int] = None, 194 | ) -> List[float]: 195 | """ 196 | Implementation from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py. 197 | Lightly modified so only the positive logits are returned and renamed the chunking function. 198 | 199 | Given a query and a list of documents, return a list of scores. 200 | Args: 201 | query: The query string. 202 | docs: A list of document strings. 203 | max_length: The maximum length of the input sequence. 204 | """ 205 | if self.return_logits: 206 | logits = [] 207 | else: 208 | scores = [] 209 | if batch_size is None: 210 | batch_size = self.batch_size 211 | for batch in tqdm( 212 | get_chunks(docs, batch_size), 213 | disable=not self.verbose, 214 | desc="Scoring...", 215 | total=ceil(len(docs) / batch_size), 216 | ): 217 | queries_documents = [ 218 | self.inputs_template.format(query=query, text=text) 219 | for text in batch 220 | ] 221 | tokenized = self.tokenizer( 222 | queries_documents, 223 | padding=True, 224 | truncation="longest_first", 225 | return_tensors="pt", 226 | max_length=max_length, 227 | ).to(self.device) 228 | input_ids = tokenized["input_ids"] 229 | attention_mask = tokenized["attention_mask"] 230 | 231 | _, batch_scores = self._greedy_decode( 232 | model=self.model, 233 | input_ids=input_ids, 234 | length=1, 235 | attention_mask=attention_mask, 236 | return_last_logits=True, 237 | ) 238 | batch_scores = batch_scores[ 239 | :, [self.token_false_id, self.token_true_id] 240 | ].cpu() 241 | if self.return_logits: 242 | logits.extend(batch_scores[:, 1].tolist()) 243 | else: 244 | batch_scores = torch.log_softmax(batch_scores, dim=-1) 245 | batch_scores = torch.exp(batch_scores[:, 1]) 246 | scores.extend(batch_scores.tolist()) 247 | 248 | if self.return_logits: 249 | return logits 250 | return scores 251 | 252 | @torch.inference_mode() 253 | def _greedy_decode( 254 | self, 255 | model, 256 | input_ids: torch.Tensor, 257 | length: int, 258 | attention_mask: torch.Tensor = None, 259 | return_last_logits: bool = True, 260 | ): 261 | """Implementation from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py""" 262 | decode_ids = torch.full( 263 | (input_ids.size(0), 1), 264 | model.config.decoder_start_token_id, 265 | dtype=torch.long, 266 | ).to(input_ids.device) 267 | encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask) 268 | next_token_logits = None 269 | for _ in range(length): 270 | try: 271 | model_inputs = model.prepare_inputs_for_generation( 272 | decode_ids, 273 | encoder_outputs=encoder_outputs, 274 | past=None, 275 | attention_mask=attention_mask, 276 | use_cache=True, 277 | ) 278 | outputs = model(**model_inputs) 279 | except TypeError: 280 | # Newer transformers versions have deprecated `past` 281 | # Our aim is to maintain pipeline compatibility for as many people as possible 282 | # So currently, we maintain a forking path with this error. Might need to do it more elegantly later on (TODO). 283 | model_inputs = model.prepare_inputs_for_generation( 284 | decode_ids, 285 | encoder_outputs=encoder_outputs, 286 | attention_mask=attention_mask, 287 | use_cache=True, 288 | ) 289 | outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size) 290 | next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size) 291 | decode_ids = torch.cat( 292 | [decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1 293 | ) 294 | if return_last_logits: 295 | return decode_ids, next_token_logits 296 | return decode_ids 297 | -------------------------------------------------------------------------------- /rerankers/models/transformer_ranker.py: -------------------------------------------------------------------------------- 1 | from rerankers.models.ranker import BaseRanker 2 | from rerankers.documents import Document 3 | 4 | 5 | import torch 6 | from typing import Union, List, Optional, Tuple 7 | from transformers import ( 8 | AutoModelForSequenceClassification, 9 | AutoTokenizer, 10 | ) 11 | from rerankers.utils import ( 12 | vprint, 13 | get_device, 14 | get_dtype, 15 | prep_docs, 16 | ) 17 | from rerankers.results import RankedResults, Result 18 | 19 | 20 | class TransformerRanker(BaseRanker): 21 | def __init__( 22 | self, 23 | model_name_or_path: str, 24 | dtype: Optional[Union[str, torch.dtype]] = None, 25 | device: Optional[Union[str, torch.device]] = None, 26 | batch_size: int = 16, 27 | verbose: int = 1, 28 | **kwargs, 29 | ): 30 | self.verbose = verbose 31 | self.device = get_device(device, verbose=self.verbose) 32 | self.dtype = get_dtype(dtype, self.device, self.verbose) 33 | self.is_monobert = "monobert" in model_name_or_path.lower() 34 | model_kwargs = kwargs.get("model_kwargs", {}) 35 | self.model = AutoModelForSequenceClassification.from_pretrained( 36 | model_name_or_path, 37 | torch_dtype=self.dtype, 38 | **model_kwargs, 39 | ).to(self.device) 40 | vprint(f"Loaded model {model_name_or_path}", self.verbose) 41 | vprint(f"Using device {self.device}.", self.verbose) 42 | vprint(f"Using dtype {self.dtype}.", self.verbose) 43 | self.model.eval() 44 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) 45 | self.tokenizer = AutoTokenizer.from_pretrained( 46 | model_name_or_path, 47 | **tokenizer_kwargs, 48 | ) 49 | self.ranking_type = "pointwise" 50 | self.batch_size = batch_size 51 | 52 | def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]): 53 | return self.tokenizer( 54 | inputs, return_tensors="pt", padding=True, truncation=True 55 | ).to(self.device) 56 | 57 | @torch.inference_mode() 58 | def rank( 59 | self, 60 | query: str, 61 | docs: Union[str, List[str], Document, List[Document]], 62 | doc_ids: Optional[Union[List[str], List[int]]] = None, 63 | metadata: Optional[List[dict]] = None, 64 | batch_size: Optional[int] = None, 65 | ) -> RankedResults: 66 | docs = prep_docs(docs, doc_ids, metadata) 67 | inputs = [(query, doc.text) for doc in docs] 68 | 69 | # Override self.batch_size if explicitely set 70 | if batch_size is None: 71 | batch_size = self.batch_size 72 | batched_inputs = [ 73 | inputs[i : i + batch_size] for i in range(0, len(inputs), batch_size) 74 | ] 75 | scores = [] 76 | for batch in batched_inputs: 77 | tokenized_inputs = self.tokenize(batch) 78 | batch_scores = self.model(**tokenized_inputs).logits.squeeze() 79 | if self.dtype != torch.float32: 80 | batch_scores = batch_scores.float() 81 | batch_scores = batch_scores.detach().cpu().numpy().tolist() 82 | if isinstance(batch_scores, float): # Handling the case of single score 83 | scores.append(batch_scores) 84 | else: 85 | scores.extend(batch_scores) 86 | if self.is_monobert: scores = [x[1] - x[0] for x in scores] 87 | if len(scores) == 1: 88 | return RankedResults(results=[Result(document=docs[0], score=scores[0])], query=query, has_scores=True) 89 | else: 90 | ranked_results = [ 91 | Result(document=doc, score=score, rank=idx + 1) 92 | for idx, (doc, score) in enumerate( 93 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) 94 | ) 95 | ] 96 | return RankedResults(results=ranked_results, query=query, has_scores=True) 97 | 98 | @torch.inference_mode() 99 | def score(self, query: str, doc: str) -> float: 100 | inputs = self.tokenize((query, doc)) 101 | outputs = self.model(**inputs) 102 | score = outputs.logits.squeeze().detach().cpu().numpy().astype(float) 103 | return score 104 | -------------------------------------------------------------------------------- /rerankers/models/upr.py: -------------------------------------------------------------------------------- 1 | # models/upr.py 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from transformers import T5ForConditionalGeneration, T5Tokenizer 6 | from math import ceil 7 | from typing import List, Optional, Union 8 | 9 | from rerankers.models.ranker import BaseRanker 10 | from rerankers.documents import Document 11 | from rerankers.results import RankedResults, Result 12 | from rerankers.utils import ( 13 | vprint, 14 | get_device, 15 | get_dtype, 16 | prep_docs, 17 | get_chunks, 18 | ) 19 | 20 | 21 | class UPRRanker(BaseRanker): 22 | """ 23 | UPR (Unsupervised Passage Reranker) replicates the negative log-likelihood 24 | approach from the authors' code. The doc is passed as the encoder input, 25 | and the query is the decoder label. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | model_name_or_path: str, 31 | verbose: int = 1, 32 | device: Optional[Union[str, torch.device]] = None, 33 | dtype: Optional[Union[str, torch.dtype]] = None, 34 | batch_size: int = 16, 35 | verbalizer_head: str = "Passage:", 36 | verbalizer: str = "Please write a question based on this passage.", 37 | max_input_length: int = 512, 38 | max_query_length: int = 128, 39 | **kwargs 40 | ): 41 | """ 42 | Args: 43 | model_name_or_path: A T5 checkpoint name or path (e.g., 't5-large', 'google/t5-xxl-lm-adapt', etc.) 44 | verbose: Verbosity level. 45 | device: "cuda", "cpu", or None for auto. 46 | dtype: e.g. "float32", "float16", "bf16", or a torch.dtype. 47 | batch_size: How many documents to process at once. 48 | verbalizer_head: Prefixed to the doc text to mimic the 'Passage: ' from the original code. 49 | verbalizer: A short instruction appended to the doc text. The original UPR default is 50 | "Please write a question based on this passage." 51 | max_input_length: Maximum tokens for the encoder side (document). 52 | max_query_length: Maximum tokens for the decoder side (query). 53 | """ 54 | self.verbose = verbose 55 | self.device = get_device(device, self.verbose) 56 | self.dtype = get_dtype(dtype, self.device, self.verbose) 57 | self.batch_size = batch_size 58 | self.verbalizer_head = verbalizer_head 59 | self.verbalizer = verbalizer 60 | self.max_input_length = max_input_length 61 | self.max_query_length = max_query_length 62 | 63 | vprint(f"[UPR] Loading T5 model: {model_name_or_path}", self.verbose) 64 | vprint(f"[UPR] device={self.device}, dtype={self.dtype}, batch_size={batch_size}", self.verbose) 65 | 66 | # Load T5 67 | model_kwargs = kwargs.get("model_kwargs", {}) 68 | self.model = T5ForConditionalGeneration.from_pretrained( 69 | model_name_or_path, torch_dtype=self.dtype, **model_kwargs 70 | ).to(self.device) 71 | self.model.eval() 72 | 73 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) 74 | self.tokenizer = T5Tokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs) 75 | 76 | def score(self, query: str, doc: str) -> float: 77 | """ 78 | Score a single document. Negative log-likelihood of 'query' given 'doc'. 79 | Higher means more relevant (score = -NLL). 80 | """ 81 | scores = self._get_scores(query, [doc]) 82 | return scores[0] if scores else 0.0 83 | 84 | def rank( 85 | self, 86 | query: str, 87 | docs: Union[str, List[str], Document, List[Document]], 88 | doc_ids: Optional[Union[List[str], List[int]]] = None, 89 | metadata: Optional[List[dict]] = None, 90 | ) -> RankedResults: 91 | """ 92 | Ranks a list of documents by the negative log-likelihood of the query given the doc. 93 | """ 94 | # Convert user inputs into a list of Document objects 95 | docs = prep_docs(docs, doc_ids, metadata) 96 | 97 | # Score them 98 | doc_texts = [d.text for d in docs] 99 | scores = self._get_scores(query, doc_texts) 100 | 101 | # Sort in descending order of score 102 | ranked_results = [ 103 | Result(document=doc, score=score, rank=idx + 1) 104 | for idx, (doc, score) in enumerate( 105 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) 106 | ) 107 | ] 108 | return RankedResults(results=ranked_results, query=query, has_scores=True) 109 | 110 | @torch.inference_mode() 111 | def _get_scores(self, query: str, docs: List[str]) -> List[float]: 112 | """ 113 | Batched negative log-likelihood scoring: 114 | score = - sum_{tokens in query} [ log P(token | doc) ]. 115 | """ 116 | all_scores = [] 117 | # Create mini-batches of docs 118 | for batch in get_chunks(docs, self.batch_size): 119 | # 1) Build the T5 encoder inputs for the doc 120 | # (mimicking "Passage: {doc_text}. Please write a question..." from the original code) 121 | encoder_texts = [ 122 | f"{self.verbalizer_head} {doc_text}. {self.verbalizer}" 123 | for doc_text in batch 124 | ] 125 | 126 | encoder_enc = self.tokenizer( 127 | encoder_texts, 128 | padding=True, 129 | truncation=True, 130 | max_length=self.max_input_length, 131 | return_tensors="pt", 132 | ).to(self.device) 133 | 134 | # 2) Build the T5 decoder labels for the query 135 | # (the question is now the *label*, exactly as in original UPR). 136 | decoder_enc = self.tokenizer( 137 | [query] * len(batch), 138 | padding=True, 139 | truncation=True, 140 | max_length=self.max_query_length, 141 | return_tensors="pt", 142 | ).to(self.device) 143 | 144 | # 3) forward pass with `labels=...` so that T5 returns cross-entropy 145 | # but we want the per-token log-likelihood to replicate the approach exactly. 146 | logits = self.model( 147 | input_ids=encoder_enc.input_ids, 148 | attention_mask=encoder_enc.attention_mask, 149 | labels=decoder_enc.input_ids, 150 | ).logits # shape: [batch, seq_len, vocab_size] 151 | 152 | # 4) Compute log-softmax for each token 153 | log_probs = F.log_softmax(logits, dim=-1) # [batch, seq_len, vocab_size] 154 | 155 | # 5) Gather the probabilities at each gold label token => negative log-likelihood 156 | # next_token = decoder_enc.input_ids[..., 1:] 157 | # but T5 shifts internally. We'll simply do gather on 158 | # the label tokens and sum up, replicating the original. 159 | labels = decoder_enc.input_ids.unsqueeze(-1) # [batch, seq_len, 1] 160 | token_log_probs = log_probs.gather(-1, labels).squeeze(-1) # [batch, seq_len] 161 | 162 | # T5 shifts internally, so the first token is the "start token." The original UPR code 163 | # just sums everything. We'll do the same. 164 | nll = -token_log_probs # [batch, seq_len] 165 | sum_nll = torch.sum(nll, dim=1) # sum over query tokens 166 | 167 | # final score = - sum_nll (which is +ve if the NLL is large) 168 | # we want "best doc" to have the largest score => doc that yields the *lowest* NLL 169 | batch_scores = (-sum_nll).tolist() 170 | 171 | all_scores.extend(batch_scores) 172 | 173 | return all_scores 174 | -------------------------------------------------------------------------------- /rerankers/reranker.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import warnings 3 | from rerankers.models import AVAILABLE_RANKERS 4 | from rerankers.models.ranker import BaseRanker 5 | from rerankers.utils import vprint 6 | 7 | DEFAULTS = { 8 | "jina": {"en": "jina-reranker-v1-base-en"}, 9 | "isaacus": {"en": "kanon-universal-classifier"}, 10 | "pinecone": {"en": "pinecone-rerank-v0"}, 11 | "cohere": {"en": "rerank-english-v3.0", "other": "rerank-multilingual-v3.0"}, 12 | "voyage": {"en": "rerank-lite-1"}, 13 | "mixedbread.ai": {"en": "mixedbread-ai/mxbai-rerank-large-v1"}, 14 | "cross-encoder": { 15 | "en": "mixedbread-ai/mxbai-rerank-base-v1", 16 | "fr": "antoinelouis/crossencoder-camembert-base-mmarcoFR", 17 | "zh": "BAAI/bge-reranker-base", 18 | "other": "corrius/cross-encoder-mmarco-mMiniLMv2-L12-H384-v1", 19 | }, 20 | "t5": {"en": "unicamp-dl/InRanker-base", "other": "unicamp-dl/mt5-base-mmarco-v2"}, 21 | "lit5": {"en": "castorini/LiT5-Distill-base"}, 22 | "rankgpt": {"en": "gpt-4-turbo-preview", "other": "gpt-4-turbo-preview"}, 23 | "rankgpt3": {"en": "gpt-3.5-turbo", "other": "gpt-3.5-turbo"}, 24 | "rankgpt4": {"en": "gpt-4", "other": "gpt-4"}, 25 | "rankllm": {"en": "rank_zephyr", "other": "rank_zephyr"}, 26 | "colbert": { 27 | "en": "colbert-ir/colbertv2.0", 28 | "fr": "bclavie/FraColBERTv2", 29 | "ja": "bclavie/JaColBERTv2", 30 | "es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES", 31 | }, 32 | "flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"}, 33 | "text-embeddings-inference": {"other": "BAAI/bge-reranker-base"}, 34 | "llm-layerwise": { 35 | "en": "BAAI/bge-reranker-v2.5-gemma2-lightweight", 36 | "other": "BAAI/bge-reranker-v2.5-gemma2-lightweight", 37 | }, 38 | "monovlm": { 39 | "en": "lightonai/MonoQwen2-VL-v0.1", 40 | "other": "lightonai/MonoQwen2-VL-v0.1" 41 | }, 42 | "llm-relevance-filter": { 43 | "en": "gpt-4-turbo-preview", 44 | "other": "gpt-4-turbo-preview" 45 | }, 46 | "upr": {"en": "google/t5-large-lm-adapt"}, 47 | "mxbaiv2": {"en": "mixedbread-ai/mxbai-rerank-base-v2"}, 48 | "pylate": { 49 | "en": "lightonai/GTE-ModernColBERT-v1", 50 | "other": "lightonai/GTE-ModernColBERT-v1", 51 | }, 52 | } 53 | 54 | DEPS_MAPPING = { 55 | "TransformerRanker": "transformers", 56 | "T5Ranker": "transformers", 57 | "LiT5Ranker": "lit5", 58 | "RankGPTRanker": "gpt", 59 | "APIRanker": "api", 60 | "ColBERTRanker": "transformers", 61 | "FlashRankRanker": "flashrank", 62 | "RankLLMRanker": "rankllm", 63 | "LLMLayerWiseRanker": "transformers", 64 | "MonoVLMRanker": "transformers", 65 | "LLMRelevanceFilter": "litellm", 66 | "UPRRanker": "transformers", 67 | "MxBaiV2Ranker": "transformers", 68 | "PyLateRanker": "pylate", 69 | } 70 | 71 | PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "pinecone", "isaacus", "text-embeddings-inference"] 72 | 73 | def _get_api_provider(model_name: str, model_type: Optional[str] = None) -> Optional[str]: 74 | # If an explicit model_type is provided and it isn't one of the known API providers, 75 | # then we skip auto-detection of an API provider. 76 | if model_type is not None and model_type not in PROVIDERS: 77 | return None 78 | if (model_type in PROVIDERS) or any(provider in model_name for provider in PROVIDERS): 79 | return model_type if model_type in PROVIDERS else next( 80 | (provider for provider in PROVIDERS if provider in model_name), None 81 | ) 82 | # Check if the model_name is a key in DEFAULTS to set the provider correctly 83 | return next( 84 | ( 85 | provider 86 | for provider in PROVIDERS 87 | if model_name in DEFAULTS and any(provider in values for values in DEFAULTS[model_name].values()) 88 | ), 89 | None, 90 | ) 91 | 92 | def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) -> str: 93 | if explicit_model_type: 94 | model_mapping = { 95 | "cohere": "APIRanker", 96 | "pinecone": "APIRanker", 97 | "jina": "APIRanker", 98 | "isaacus": "APIRanker", 99 | "voyage": "APIRanker", 100 | "text-embeddings-inference": "APIRanker", 101 | "rankgpt": "RankGPTRanker", 102 | "lit5": "LiT5Ranker", 103 | "t5": "T5Ranker", 104 | "colbert": "ColBERTRanker", 105 | "cross-encoder": "TransformerRanker", 106 | "flashrank": "FlashRankRanker", 107 | "rankllm": "RankLLMRanker", 108 | "llm-layerwise": "LLMLayerWiseRanker", 109 | "monovlm": "MonoVLMRanker", 110 | "llm-relevance-filter": "LLMRelevanceFilter", 111 | "upr": "UPRRanker", 112 | "mxbaiv2": "MxBaiV2Ranker", 113 | "pylate": "PyLateRanker", 114 | } 115 | return model_mapping.get(explicit_model_type, explicit_model_type) 116 | else: 117 | model_name = model_name.lower() 118 | model_mapping = { 119 | "lit5": "LiT5Ranker", 120 | "t5": "T5Ranker", 121 | "inranker": "T5Ranker", 122 | "rankllm": "RankLLMRanker", 123 | "rankgpt": "RankGPTRanker", 124 | "gpt": "RankGPTRanker", 125 | "colbert": "ColBERTRanker", 126 | "cohere": "APIRanker", 127 | "pinecone": "APIRanker", 128 | "jina": "APIRanker", 129 | "isaacus": "APIRanker", 130 | "voyage": "APIRanker", 131 | "text-embeddings-inference": "APIRanker", 132 | "ms-marco-minilm-l-12-v2": "FlashRankRanker", 133 | "ms-marco-multibert-l-12": "FlashRankRanker", 134 | "vicuna": "RankLLMRanker", 135 | "zephyr": "RankLLMRanker", 136 | "bge-reranker-v2.5-gemma2-lightweight": "LLMLayerWiseRanker", 137 | "monovlm": "MonoVLMRanker", 138 | "monoqwen2-vl": "MonoVLMRanker", 139 | "llm-relevance-filter": "LLMRelevanceFilter", 140 | "upr": "UPRRanker", 141 | "mxbaiv2": "MxBaiV2Ranker", 142 | "mxbai-rerank-base-v2": "MxBaiV2Ranker", 143 | "mxbai-rerank-large-v2": "MxBaiV2Ranker", 144 | "pylate": "PyLateRanker", 145 | } 146 | for key, value in model_mapping.items(): 147 | if key in model_name: 148 | if key == "gpt": 149 | warnings.warn( 150 | "The key 'gpt' currently defaults to the rough rankGPT implementation. From version 0.0.5 onwards, 'gpt' will default to RankLLM instead. Please specify the 'rankgpt' `model_type` if you want to keep the current behaviour", 151 | DeprecationWarning, 152 | ) 153 | return value 154 | if ( 155 | any( 156 | keyword in model_name 157 | for keyword in ["minilm", "bert", "cross-encoders/"] 158 | ) 159 | and "/" in model_name 160 | ): 161 | return "TransformerRanker" 162 | print( 163 | "Warning: Model type could not be auto-mapped with the defaults list. Defaulting to TransformerRanker." 164 | ) 165 | print( 166 | "If your model is NOT intended to be ran as a one-label cross-encoder, please reload it and specify the model_type!", 167 | "Otherwise, you may ignore this warning. You may specify `model_type='cross-encoder'` to suppress this warning in the future.", 168 | ) 169 | return "TransformerRanker" 170 | 171 | def _get_defaults( 172 | model_name: str, 173 | model_type: Optional[str] = None, 174 | lang: str = "en", 175 | verbose: int = 1, 176 | ) -> str: 177 | if model_name in DEFAULTS.keys(): 178 | print(f"Loading default {model_name} model for language {lang}") 179 | try: 180 | model_name = DEFAULTS[model_name][lang] 181 | except KeyError: 182 | if "other" not in DEFAULTS[model_name]: 183 | print( 184 | f"Model family {model_name} does not have a default for language {lang}" 185 | ) 186 | print( 187 | "Aborting now... Please retry with another model family or by specifying a model" 188 | ) 189 | return None, None 190 | model_name = DEFAULTS[model_name]["other"] 191 | model_type = _get_model_type(model_name, model_type) 192 | vprint(f"Default Model: {model_name}", verbose) 193 | 194 | return model_name, model_type 195 | 196 | def Reranker( 197 | model_name: str, 198 | lang: str = "en", 199 | model_type: Optional[str] = None, 200 | verbose: int = 1, 201 | **kwargs, 202 | ) -> Optional[BaseRanker]: 203 | original_model_name = model_name 204 | api_provider = _get_api_provider(model_name, model_type) 205 | if api_provider or model_name.lower() in PROVIDERS: 206 | if model_name.lower() in PROVIDERS: 207 | api_provider = model_name.lower() 208 | # Only override model_type to APIRanker if it hasn't been explicitly set 209 | if model_type is None: 210 | model_type = "APIRanker" 211 | model_name = ( 212 | DEFAULTS[api_provider][lang] 213 | if lang in DEFAULTS[api_provider] 214 | else DEFAULTS[api_provider]["other"] 215 | ) 216 | print( 217 | f"Auto-updated model_name to {model_name} for API provider {api_provider}" 218 | ) 219 | else: 220 | if model_type is None: 221 | model_type = "APIRanker" 222 | else: 223 | if original_model_name in DEFAULTS.keys(): 224 | model_name, model_type = _get_defaults(original_model_name, model_type, lang, verbose) 225 | if model_name is None: 226 | return None 227 | api_provider = _get_api_provider(model_name, model_type) 228 | if api_provider and model_type is None: 229 | model_type = "APIRanker" 230 | 231 | if api_provider: 232 | kwargs["api_provider"] = api_provider 233 | 234 | model_type = _get_model_type(model_name, model_type) 235 | 236 | try: 237 | vprint(f"Loading {model_type} model {model_name} (this message can be suppressed by setting verbose=0)", verbose) 238 | return AVAILABLE_RANKERS[model_type](model_name, verbose=verbose, **kwargs) 239 | except KeyError: 240 | print( 241 | f"You don't have the necessary dependencies installed to use {model_type}." 242 | ) 243 | print( 244 | f'Please install the necessary dependencies for {model_type} by running `pip install "rerankers[{DEPS_MAPPING[model_type]}]"`', 245 | 'or `pip install "rerankers[all]"` to install the dependencies for all reranker types.', 246 | ) 247 | return None 248 | -------------------------------------------------------------------------------- /rerankers/results.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from rerankers.documents import Document 4 | 5 | 6 | class Result: 7 | def __init__(self, document: Document, score: Optional[float] = None, rank: Optional[int] = None): 8 | self.document = document 9 | self.score = score 10 | self.rank = rank 11 | 12 | if rank is None and score is None: 13 | raise ValueError("Either score or rank must be provided.") 14 | 15 | def __getattr__(self, item): 16 | if hasattr(self.document, item): 17 | return getattr(self.document, item) 18 | elif item in ["document", "score", "rank"]: 19 | return getattr(self, item) 20 | elif item in self.document.attributes: 21 | return getattr(self.document, item) 22 | elif item in self.document.metadata: 23 | return self.document.metadata[item] 24 | raise AttributeError( 25 | f"'{self.__class__.__name__}' object has no attribute '{item}'" 26 | ) 27 | 28 | def __repr__(self) -> str: 29 | fields = { 30 | "document": self.document, 31 | "score": self.score, 32 | "rank": self.rank, 33 | } 34 | field_str = ", ".join(f"{k}={v!r}" for k, v in fields.items()) 35 | return f"{self.__class__.__name__}({field_str})" 36 | 37 | 38 | class RankedResults: 39 | def __init__(self, results: List[Result], query: str, has_scores: bool = False): 40 | self.results = results 41 | self.query = query 42 | self.has_scores = has_scores 43 | 44 | def __iter__(self): 45 | """Allows iteration over the results list.""" 46 | return iter(self.results) 47 | 48 | def __getitem__(self, index): 49 | """Allows indexing to access results directly.""" 50 | return self.results[index] 51 | 52 | def results_count(self) -> int: 53 | """Returns the total number of results.""" 54 | return len(self.results) 55 | 56 | def top_k(self, k: int) -> List[Result]: 57 | """Returns the top k results based on the score, if available, or rank.""" 58 | if self.has_scores: 59 | return sorted( 60 | self.results, 61 | key=lambda x: x.score if x.score is not None else float("-inf"), 62 | reverse=True, 63 | )[:k] 64 | else: 65 | return sorted( 66 | self.results, 67 | key=lambda x: x.rank if x.rank is not None else float("inf"), 68 | )[:k] 69 | 70 | def get_score_by_docid(self, doc_id: Union[int, str]) -> Optional[float]: 71 | """Fetches the score of a result by its doc_id using a more efficient approach.""" 72 | result = next((r for r in self.results if r.document.doc_id == doc_id), None) 73 | return result.score if result else None 74 | 75 | def get_result_by_docid(self, doc_id: Union[int, str]) -> Optional[Result]: 76 | """Fetches a result by its doc_id using a more efficient approach.""" 77 | result = next((r for r in self.results if r.document.doc_id == doc_id), None) 78 | return result if result else None 79 | 80 | def __repr__(self) -> str: 81 | fields = { 82 | "results": self.results, 83 | "query": self.query, 84 | "has_scores": self.has_scores, 85 | } 86 | field_str = ", ".join(f"{k}={v!r}" for k, v in fields.items()) 87 | return f"{self.__class__.__name__}({field_str})" 88 | -------------------------------------------------------------------------------- /rerankers/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import binascii 3 | from typing import Union, Optional, List, Iterable 4 | try: 5 | import io 6 | from PIL import Image 7 | except ImportError: 8 | pass 9 | from rerankers.documents import Document 10 | 11 | def vprint(txt: str, verbose: int) -> None: 12 | if verbose > 0: 13 | print(txt) 14 | 15 | 16 | try: 17 | import torch 18 | 19 | def get_dtype( 20 | dtype: Optional[Union[str, torch.dtype]], 21 | device: Optional[Union[str, torch.device]], 22 | verbose: int = 1, 23 | ) -> torch.dtype: 24 | if dtype is None: 25 | vprint("No dtype set", verbose) 26 | # if device == "cpu": 27 | # vprint("Device set to `cpu`, setting dtype to `float32`", verbose) 28 | # dtype = torch.float32 29 | if not isinstance(dtype, torch.dtype): 30 | if dtype == "fp16" or dtype == "float16": 31 | dtype = torch.float16 32 | elif dtype == "bf16" or dtype == "bfloat16": 33 | dtype = torch.bfloat16 34 | else: 35 | dtype = torch.float32 36 | vprint(f"Using dtype {dtype}", verbose) 37 | return dtype 38 | 39 | def get_device( 40 | device: Optional[Union[str, torch.device]], 41 | verbose: int = 1, 42 | no_mps: bool = False, 43 | ) -> Union[str, torch.device]: 44 | if not device: 45 | vprint("No device set", verbose) 46 | if torch.cuda.is_available(): 47 | device = "cuda" 48 | elif torch.backends.mps.is_available() and not no_mps: 49 | device = "mps" 50 | else: 51 | device = "cpu" 52 | vprint(f"Using device {device}", verbose) 53 | return device 54 | 55 | except ImportError: 56 | pass 57 | 58 | 59 | def make_documents( 60 | docs: List[str], 61 | doc_ids: Optional[Union[List[str], List[int]]] = None, 62 | ): 63 | if doc_ids is None: 64 | doc_ids = list(range(len(docs))) 65 | return [Document(doc, doc_id=doc_ids[i]) for i, doc in enumerate(docs)] 66 | 67 | 68 | def prep_docs( 69 | docs: Union[str, List[str], Document, List[Document]], 70 | doc_ids: Optional[Union[List[str], List[int]]] = None, 71 | metadata: Optional[List[dict]] = None, 72 | ): 73 | if isinstance(docs, Document) or ( 74 | isinstance(docs, List) and isinstance(docs[0], Document) 75 | ): 76 | if isinstance(docs, Document): 77 | docs = [docs] 78 | if doc_ids is not None: 79 | if docs[0].doc_id is not None: 80 | print( 81 | "Overriding doc_ids passed within the Document objects with explicitly passed doc_ids!" 82 | ) 83 | print( 84 | "This is not the preferred way of doing so, please double-check your code." 85 | ) 86 | for i, doc in enumerate(docs): 87 | doc.doc_id = doc_ids[i] 88 | 89 | elif doc_ids is None: 90 | doc_ids = [doc.doc_id for doc in docs] 91 | if doc_ids[0] is None: 92 | print( 93 | "'None' doc_ids detected, reverting to auto-generated integer ids..." 94 | ) 95 | doc_ids = list(range(len(docs))) 96 | 97 | if metadata is not None: 98 | if docs[0].meatadata is not None: 99 | print( 100 | "Overriding doc_ids passed within the Document objects with explicitly passed doc_ids!" 101 | ) 102 | print( 103 | "This is not the preferred way of doing so, please double-check your code." 104 | ) 105 | for i, doc in enumerate(docs): 106 | doc.metadata = metadata[i] 107 | 108 | return docs 109 | 110 | if isinstance(docs, str): 111 | docs = [docs] 112 | if doc_ids is None: 113 | doc_ids = list(range(len(docs))) 114 | if metadata is None: 115 | metadata = [{} for _ in docs] 116 | 117 | return [ 118 | Document(doc, doc_id=doc_ids[i], metadata=metadata[i]) 119 | for i, doc in enumerate(docs) 120 | ] 121 | 122 | 123 | def prep_image_docs( 124 | docs: Union[str, List[str], Document, List[Document]], 125 | doc_ids: Optional[Union[List[str], List[int]]] = None, 126 | metadata: Optional[List[dict]] = None, 127 | ) -> List[Document]: 128 | """ 129 | Prepare image documents for processing. Can handle base64 encoded images or file paths. 130 | Similar to prep_docs but specialized for image documents. 131 | """ 132 | # If already Document objects, handle similarly to prep_docs 133 | if isinstance(docs, Document) or ( 134 | isinstance(docs, List) and isinstance(docs[0], Document) 135 | ): 136 | if isinstance(docs, Document): 137 | docs = [docs] 138 | # Validate all docs are image type 139 | for doc in docs: 140 | if doc.document_type != "image": 141 | raise ValueError("All documents must be of type 'image'") 142 | return prep_docs(docs, doc_ids, metadata) 143 | 144 | # Handle string inputs (paths or base64) 145 | if isinstance(docs, str): 146 | docs = [docs] 147 | 148 | processed_docs = [] 149 | for doc in docs: 150 | # Check if input is base64 by attempting to decode 151 | try: 152 | # Try to decode and verify it's an image 153 | decoded = base64.b64decode(doc) 154 | try: 155 | Image.open(io.BytesIO(decoded)).verify() 156 | b64 = doc 157 | image_path = None 158 | except: 159 | raise binascii.Error("Invalid image data") 160 | except binascii.Error: 161 | # If decode fails, treat as file path 162 | try: 163 | image_path = doc 164 | with open(doc, 'rb') as img_file: 165 | b64 = base64.b64encode(img_file.read()).decode('utf-8') 166 | except Exception as e: 167 | raise ValueError(f"Could not process image input {doc}: {str(e)}") 168 | 169 | processed_docs.append( 170 | Document( 171 | document_type="image", 172 | base64=b64, 173 | image_path=image_path 174 | ) 175 | ) 176 | 177 | # Handle doc_ids and metadata 178 | if doc_ids is None: 179 | doc_ids = list(range(len(processed_docs))) 180 | if metadata is None: 181 | metadata = [{} for _ in processed_docs] 182 | 183 | # Set doc_ids and metadata 184 | for i, doc in enumerate(processed_docs): 185 | doc.doc_id = doc_ids[i] 186 | doc.metadata = metadata[i] 187 | 188 | 189 | return processed_docs 190 | 191 | 192 | 193 | 194 | def get_chunks(iterable: Iterable, chunk_size: int): # noqa: E741 195 | """ 196 | Implementation from https://github.com/unicamp-dl/InRanker/blob/main/inranker/base.py with extra typing and more descriptive names. 197 | This method is used to split a list l into chunks of batch size n. 198 | """ 199 | for i in range(0, len(iterable), chunk_size): 200 | yield iterable[i : i + chunk_size] 201 | -------------------------------------------------------------------------------- /tests/consistency_notebooks/test_colbert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from ranx import Qrels, Run" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "qrels = Qrels.from_ir_datasets(\"beir/scifact/test\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stderr", 28 | "output_type": "stream", 29 | "text": [ 30 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 31 | " from .autonotebook import tqdm as notebook_tqdm\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "from rerankers import Reranker" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "{'_id': '4983',\n", 48 | " 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',\n", 49 | " 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and preterm infants at term showed marked differences in white matter fiber organization. The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural development in cerebral white matter in living infants.',\n", 50 | " 'metadata': {}}" 51 | ] 52 | }, 53 | "execution_count": 4, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "import srsly\n", 60 | "\n", 61 | "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n", 62 | "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n", 63 | "\n", 64 | "corpus[0]" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "Loading default colbert model for language en\n", 77 | "Loading ColBERTRanker model colbert-ir/colbertv2.0\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "ranker = Reranker('colbert', device='cuda', verbose=0)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 6, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "top100 = srsly.read_json('./data/scifact/scifact_top_100.json')\n", 92 | "\n" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 7, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 8, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | " 75%|███████▌ | 226/300 [03:36<01:14, 1.01s/it]" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "qrels_dict = dict(qrels)\n", 119 | "queries = [q for q in queries if q['_id'] in qrels_dict]\n", 120 | "from tqdm import tqdm\n", 121 | "\n", 122 | "scores = {}\n", 123 | "for q in tqdm(queries):\n", 124 | " doc_ids = top100[q['_id']]\n", 125 | " docs = [corpus_map[x] for x in doc_ids]\n", 126 | " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "scores_dict = {}\n", 136 | "for q_id, ranked_results in scores.items():\n", 137 | " top_10_results = ranked_results.top_k(10)\n", 138 | " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n", 139 | "run = Run(scores_dict)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Score is within 0.01 NDCG@10 of the reported score!\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "from ranx import evaluate\n", 157 | "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n", 158 | "litterature_result = 0.693 # From ColBERTv2 Paper https://arxiv.org/abs/2112.01488\n", 159 | "if abs(evaluation_score - litterature_result) > 0.01:\n", 160 | " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n", 161 | "else:\n", 162 | " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n" 163 | ] 164 | } 165 | ], 166 | "metadata": { 167 | "kernelspec": { 168 | "display_name": "rerankers", 169 | "language": "python", 170 | "name": "python3" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3", 182 | "version": "3.10.13" 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 2 187 | } 188 | -------------------------------------------------------------------------------- /tests/consistency_notebooks/test_crossenc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from ranx import Qrels, Run\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "qrels = Qrels.from_ir_datasets(\"beir/scifact/test\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stderr", 28 | "output_type": "stream", 29 | "text": [ 30 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 31 | " from .autonotebook import tqdm as notebook_tqdm\n" 32 | ] 33 | }, 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "T5Ranker\n", 39 | "{'TransformerRanker': , 'APIRanker': , 'RankGPTRanker': , 'T5Ranker': }\n", 40 | "No dtype set\n", 41 | "Using dtype torch.float16\n", 42 | "Loading model castorini/monot5-base-msmarco-10k, this might take a while...\n", 43 | "Using device cuda.\n", 44 | "Using dtype torch.float16.\n" 45 | ] 46 | }, 47 | { 48 | "name": "stderr", 49 | "output_type": "stream", 50 | "text": [ 51 | "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n" 52 | ] 53 | }, 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "T5 true token set to ▁true\n", 59 | "T5 false token set to ▁false\n", 60 | "Returning normalised scores...\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "from rerankers import Reranker\n", 66 | "ranker = Reranker('castorini/monot5-base-msmarco-10k', device='cuda', batch_size=128)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "{'_id': '4983',\n", 78 | " 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',\n", 79 | " 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and preterm infants at term showed marked differences in white matter fiber organization. The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural development in cerebral white matter in living infants.',\n", 80 | " 'metadata': {}}" 81 | ] 82 | }, 83 | "execution_count": 4, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "import srsly\n", 90 | "\n", 91 | "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n", 92 | "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n", 93 | "\n", 94 | "corpus[0]" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "Warning: Model type could not be auto-mapped. Defaulting to TransformerRanker.\n", 107 | "If your model is NOT intended to be ran as a one-label cross-encoder, please reload it and specify the model_type!\n", 108 | "TransformerRanker\n", 109 | "{'TransformerRanker': , 'APIRanker': , 'RankGPTRanker': , 'T5Ranker': }\n", 110 | "No dtype set\n", 111 | "Using dtype torch.float16\n" 112 | ] 113 | }, 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "Loaded model mixedbread-ai/mxbai-rerank-base-v1\n", 119 | "Using device cuda.\n", 120 | "Using dtype torch.float16.\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "ranker = Reranker('mixedbread-ai/mxbai-rerank-base-v1', device='cuda')" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 8, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stderr", 153 | "output_type": "stream", 154 | "text": [ 155 | "100%|██████████| 300/300 [02:55<00:00, 1.71it/s]\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "qrels_dict = dict(qrels)\n", 161 | "queries = [q for q in queries if q['_id'] in qrels_dict]\n", 162 | "from tqdm import tqdm\n", 163 | "\n", 164 | "scores = {}\n", 165 | "for q in tqdm(queries):\n", 166 | " doc_ids = top100[q['_id']]\n", 167 | " docs = [corpus_map[x] for x in doc_ids]\n", 168 | " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 9, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "scores_dict = {}\n", 178 | "for q_id, ranked_results in scores.items():\n", 179 | " top_10_results = ranked_results.top_k(10)\n", 180 | " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 10, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "run = Run(scores_dict)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 11, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/plain": [ 200 | "300" 201 | ] 202 | }, 203 | "execution_count": 11, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | } 207 | ], 208 | "source": [ 209 | "len(run)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 12, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "Score is within 0.01NDCG@10 of the reported score!\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "from ranx import evaluate\n", 227 | "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n", 228 | "litterature_result = 0.724 # from MXBAI https://docs.google.com/spreadsheets/d/15ELkSMFv-oHa5TRiIjDvhIstH9dlc3pnZeO-iGz4Ld4/edit#gid=0\n", 229 | "if abs(evaluation_score - litterature_result) > 0.01:\n", 230 | " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n", 231 | "else:\n", 232 | " print(f\"Score is within 0.01NDCG@10 of the reported score!\")\n" 233 | ] 234 | } 235 | ], 236 | "metadata": { 237 | "kernelspec": { 238 | "display_name": "rerankers", 239 | "language": "python", 240 | "name": "python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 3 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython3", 252 | "version": "3.10.13" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 2 257 | } 258 | -------------------------------------------------------------------------------- /tests/consistency_notebooks/test_inranker.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from ranx import Qrels, Run\n", 10 | "\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "qrels = Qrels.from_ir_datasets(\"beir/scifact/test\")" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 32 | " from .autonotebook import tqdm as notebook_tqdm\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "from rerankers import Reranker" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "{'_id': '4983',\n", 49 | " 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',\n", 50 | " 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and preterm infants at term showed marked differences in white matter fiber organization. The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural development in cerebral white matter in living infants.',\n", 51 | " 'metadata': {}}" 52 | ] 53 | }, 54 | "execution_count": 4, 55 | "metadata": {}, 56 | "output_type": "execute_result" 57 | } 58 | ], 59 | "source": [ 60 | "import srsly\n", 61 | "\n", 62 | "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n", 63 | "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n", 64 | "\n", 65 | "corpus[0]" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "T5Ranker\n", 78 | "{'TransformerRanker': , 'APIRanker': , 'RankGPTRanker': , 'T5Ranker': }\n" 79 | ] 80 | }, 81 | { 82 | "name": "stderr", 83 | "output_type": "stream", 84 | "text": [ 85 | "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n" 86 | ] 87 | } 88 | ], 89 | "source": [ 90 | "ranker = Reranker('unicamp-dl/InRanker-base', device='cuda', batch_size=32, verbose=0)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 6, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n", 100 | "\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 7, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 8, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stderr", 119 | "output_type": "stream", 120 | "text": [ 121 | "100%|██████████| 300/300 [05:18<00:00, 1.06s/it]\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "qrels_dict = dict(qrels)\n", 127 | "queries = [q for q in queries if q['_id'] in qrels_dict]\n", 128 | "from tqdm import tqdm\n", 129 | "\n", 130 | "scores = {}\n", 131 | "for q in tqdm(queries):\n", 132 | " doc_ids = top100[q['_id']]\n", 133 | " docs = [corpus_map[x] for x in doc_ids]\n", 134 | " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 9, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "scores_dict = {}\n", 144 | "for q_id, ranked_results in scores.items():\n", 145 | " top_10_results = ranked_results.top_k(10)\n", 146 | " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n", 147 | "run = Run(scores_dict)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 12, 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "name": "stdout", 157 | "output_type": "stream", 158 | "text": [ 159 | "Score is within 0.01 NDCG@10 of the reported score!\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "from ranx import evaluate\n", 165 | "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n", 166 | "litterature_result = 0.7618 # From InRanker Paper https://arxiv.org/pdf/2401.06910.pdf\n", 167 | "if abs(evaluation_score - litterature_result) > 0.01:\n", 168 | " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n", 169 | "else:\n", 170 | " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n" 171 | ] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "rerankers", 177 | "language": "python", 178 | "name": "python3" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "python", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.10.13" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 2 195 | } 196 | -------------------------------------------------------------------------------- /tests/consistency_notebooks/test_t5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from ranx import Qrels, Run" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "qrels = Qrels.from_ir_datasets(\"beir/scifact/test\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stderr", 28 | "output_type": "stream", 29 | "text": [ 30 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 31 | " from .autonotebook import tqdm as notebook_tqdm\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "from rerankers import Reranker" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "{'_id': '4983',\n", 48 | " 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',\n", 49 | " 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and preterm infants at term showed marked differences in white matter fiber organization. The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural development in cerebral white matter in living infants.',\n", 50 | " 'metadata': {}}" 51 | ] 52 | }, 53 | "execution_count": 4, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "import srsly\n", 60 | "\n", 61 | "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n", 62 | "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n", 63 | "\n", 64 | "corpus[0]" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "T5Ranker\n", 77 | "{'TransformerRanker': , 'APIRanker': , 'T5Ranker': }\n" 78 | ] 79 | }, 80 | { 81 | "name": "stderr", 82 | "output_type": "stream", 83 | "text": [ 84 | "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "ranker = Reranker('castorini/monot5-base-msmarco-10k', device='cuda', batch_size=128, verbose=0)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 6, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n", 99 | "\n" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 7, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 8, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stderr", 118 | "output_type": "stream", 119 | "text": [ 120 | "100%|██████████| 300/300 [04:44<00:00, 1.06it/s]\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "qrels_dict = dict(qrels)\n", 126 | "queries = [q for q in queries if q['_id'] in qrels_dict]\n", 127 | "from tqdm import tqdm\n", 128 | "\n", 129 | "scores = {}\n", 130 | "for q in tqdm(queries):\n", 131 | " doc_ids = top100[q['_id']]\n", 132 | " docs = [corpus_map[x] for x in doc_ids]\n", 133 | " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 17, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "scores_dict = {}\n", 143 | "for q_id, ranked_results in scores.items():\n", 144 | " top_10_results = ranked_results.top_k(10)\n", 145 | " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n", 146 | "run = Run(scores_dict)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 18, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "Score 0.731 is within 0.005 NDCG@10 of the reported score!\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "from ranx import evaluate\n", 164 | "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n", 165 | "litterature_result = 0.734 # From RankGPT Paper https://arxiv.org/pdf/2304.09542.pdf\n", 166 | "if abs(evaluation_score - litterature_result) > 0.01:\n", 167 | " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the reported score.\")\n", 168 | "else:\n", 169 | " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n" 170 | ] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "rerankers", 176 | "language": "python", 177 | "name": "python3" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.9.18" 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 2 194 | } 195 | -------------------------------------------------------------------------------- /tests/test_crossenc.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | import torch 3 | from rerankers import Reranker 4 | from rerankers.models.transformer_ranker import TransformerRanker 5 | from rerankers.results import Result, RankedResults 6 | from rerankers.documents import Document 7 | 8 | @patch("rerankers.models.transformer_ranker.TransformerRanker.rank") 9 | def test_transformer_ranker_rank(mock_rank): 10 | query = "Gone with the wind is an absolute masterpiece" 11 | docs = [ 12 | "Gone with the wind is a masterclass in bad storytelling.", 13 | "Gone with the wind is an all-time classic", 14 | ] 15 | expected_results = RankedResults( 16 | results=[ 17 | Result( 18 | document=Document( 19 | doc_id=1, text="Gone with the wind is an all-time classic" 20 | ), 21 | score=1.6181640625, 22 | rank=1, 23 | ), 24 | Result( 25 | document=Document( 26 | doc_id=0, 27 | text="Gone with the wind is a masterclass in bad storytelling.", 28 | ), 29 | score=0.88427734375, 30 | rank=2, 31 | ), 32 | ], 33 | query=query, 34 | has_scores=True, 35 | ) 36 | mock_rank.return_value = expected_results 37 | ranker = TransformerRanker("mixedbread-ai/mxbai-rerank-xsmall-v1") 38 | results = ranker.rank(query=query, docs=docs) 39 | assert results == expected_results 40 | -------------------------------------------------------------------------------- /tests/test_results.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rerankers.results import Result, RankedResults 3 | from rerankers.documents import Document 4 | 5 | 6 | def test_ranked_results_functions(): 7 | results = RankedResults( 8 | results=[ 9 | Result(document=Document(doc_id=0, text="Doc 0"), score=0.9, rank=2), 10 | Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1), 11 | ], 12 | query="Test Query", 13 | has_scores=True, 14 | ) 15 | assert results.results_count() == 2 16 | top_k = results.top_k(1) 17 | assert len(top_k) == 1 18 | assert top_k[0].doc_id == 1 19 | assert results.get_score_by_docid(0) == 0.9 20 | 21 | 22 | def test_result_attributes(): 23 | result = Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1) 24 | assert result.doc_id == 1 25 | assert result.text == "Doc 1" 26 | assert result.score == 0.95 27 | assert result.rank == 1 28 | 29 | 30 | def test_result_validation_error(): 31 | with pytest.raises(ValueError) as excinfo: 32 | Result(document=Document(doc_id=2, text="Doc 2")) 33 | assert "Either score or rank must be provided." in str(excinfo.value) 34 | --------------------------------------------------------------------------------