├── .gitignore
├── LICENSE
├── README.md
├── apilist.txt
├── data
└── state_of_the_union.txt
├── examples
├── langchain_integration.ipynb
├── overview.ipynb
└── reranker_images.ipynb
├── pyproject.toml
├── rerankers
├── __init__.py
├── documents.py
├── integrations
│ ├── __init__.py
│ └── langchain.py
├── models
│ ├── __init__.py
│ ├── api_rankers.py
│ ├── colbert_ranker.py
│ ├── flashrank_ranker.py
│ ├── llm_layerwise_ranker.py
│ ├── llm_relevance_filter.py
│ ├── monovlm_ranker.py
│ ├── mxbai_v2.py
│ ├── pylate_ranker.py
│ ├── ranker.py
│ ├── rankgpt_rankers.py
│ ├── rankllm_ranker.py
│ ├── t5ranker.py
│ ├── transformer_ranker.py
│ └── upr.py
├── reranker.py
├── results.py
└── utils.py
└── tests
├── consistency_notebooks
├── test_colbert.ipynb
├── test_crossenc.ipynb
├── test_inranker.ipynb
└── test_t5.ipynb
├── test_crossenc.py
└── test_results.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | .flashrank_cache
8 |
9 | # C extensions
10 | *.so
11 |
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 | .pdm.toml
89 | __pypackages__/
90 |
91 | # Environments
92 | .env
93 | .venv
94 | env/
95 | venv/
96 | ENV/
97 | env.bak/
98 | venv.bak/
99 | .envrc
100 | .conda/
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | .ragatouille
106 |
107 | # mypy
108 | .mypy_cache/
109 | .dmypy.json
110 | dmypy.json
111 |
112 | # Pyre type checker
113 | .pyre/
114 |
115 | # pytype static type analyzer
116 | .pytype/
117 |
118 | # Cython debug symbols
119 | cython_debug/
120 |
121 | # data files
122 | *.tsv
123 | *.jsonl
124 |
125 | .mypy.ipynb_checkpoints
126 | .mkdocs.yml
127 |
128 |
129 | archive/
130 |
131 | */.ragatouille
132 |
133 | local/
134 |
135 | .vscode/
136 |
137 | .devcontainer/
138 |
139 | try*.ipynb
140 |
141 | data/*/
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Thiago Laitz
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # rerankers
3 |
4 | 
5 | [](https://pepy.tech/project/rerankers)
6 | [](https://twitter.com/bclavie)
7 |
8 |
9 | _A lightweight unified API for various reranking models. Developed by [@bclavie](https://twitter.com/bclavie) as a member of [answer.ai](https://www.answer.ai)_
10 |
11 | ---
12 |
13 | Welcome to `rerankers`! Our goal is to provide users with a simple API to use any reranking models.
14 |
15 | ## Recent Updates
16 | _A longer release history can be found in the [Release History](#release-history) section of this README._
17 |
18 | - v0.10.0: Added support for PyLate ColBERT/late-interaction reranking and merged fixes for API rerankers, both thanks to community contributors!
19 | - v0.9.0: Added support for MXBai V2 reranker, which is a new version of the MXBai reranker based on Qwen and the current open-source state-of-the-art.
20 | - v0.7.0: Removing `pydantic` and `tqdm` dependencies, so `rerankers` is now dependency-free by default, avoiding any issues with Pydantic v1/v2!
21 | - v0.6.1: Added support for Pinecone's new rerankers via their API.
22 | - v0.6.0: `rerankers` goes multi-modal, with the support of the first MonoVLMRanker model, [MonoQwen2-VL-v0.1!](https://huggingface.co/lightonai/MonoQwen2-VL-v0.1)! + Many QoL fixes.
23 |
24 | ## Why `rerankers`?
25 |
26 | Rerankers are an important part of any retrieval architecture, but they're also often more obscure than other parts of the pipeline.
27 |
28 | Sometimes, it can be hard to even know which one to use. Every problem is different, and the best model for use X is not necessarily the same one as for use Y.
29 |
30 | Moreover, new reranking methods keep popping up: for example, RankGPT, using LLMs to rerank documents, appeared just last year, with very promising zero-shot benchmark results.
31 |
32 | All the different reranking approaches tend to be done in their own library, with varying levels of documentation. This results in an even higher barrier to entry. New users are required to swap between multiple unfamiliar input/output formats, all with their own quirks!
33 |
34 | `rerankers` seeks to address this problem by providing a simple API for all popular rerankers, no matter the architecture.
35 |
36 | `rerankers` aims to be:
37 | - 🪶 Lightweight. It ships with only the bare necessities as dependencies.
38 | - 📖 Easy-to-understand. There's just a handful of calls to learn, and you can then use the full range of provided reranking models.
39 | - 🔗 Easy-to-integrate. It should fit in just about any existing pipelines, with only a few lines of code!
40 | - 💪 Easy-to-expand. Any new reranking models can be added with very little knowledge of the codebase. All you need is a new class with a `rank()` function call mapping a (query, [documents]) input to a `RankedResults` output.
41 | - 🐛 Easy-to-debug. This is a beta release and there might be issues, but the codebase is conceived in such a way that most issues should be easy to track and fix ASAP.
42 |
43 | ## Get Started
44 |
45 | Installation is very simple. The core package ships with no dependencies, so as to avoid any conflict with your current environment.
46 | You may then install only the dependencies required by the models you want to try out:
47 |
48 | ```sh
49 | # Core package only, will require other dependencies already installed
50 | pip install rerankers
51 |
52 | # All transformers-based approaches (cross-encoders, t5, colbert)
53 | pip install "rerankers[transformers]"
54 |
55 | # RankGPT
56 | pip install "rerankers[gpt]"
57 |
58 | # API-based rerankers (Cohere, Jina, MixedBread, Pinecone, Isaacus)
59 | pip install "rerankers[api]"
60 |
61 | # FlashRank rerankers (ONNX-optimised, very fast on CPU)
62 | pip install "rerankers[flashrank]"
63 |
64 | # RankLLM rerankers (better RankGPT + support for local models such as RankZephyr and RankVicuna)
65 | # Note: RankLLM is only supported on Python 3.10+! This will not work with Python 3.9
66 | pip install "rerankers[rankllm]"
67 |
68 | # To support Multi-Modal rerankers such as MonoQwen2-VL and other MonoVLM models, which require flash-attention, peft, accelerate, and recent versions of `transformers`
69 | pip install "rerankers[monovlm]"
70 |
71 |
72 | # To support LLM-Layerwise rerankers (which need flash-attention installed)
73 | pip install "rerankers[llmlayerwise]"
74 |
75 | # All of the above
76 | pip install "rerankers[all]"
77 | ```
78 |
79 | ## Usage
80 |
81 | Load any supported reranker in a single line, regardless of the architecture:
82 | ```python
83 | from rerankers import Reranker
84 |
85 | # Cross-encoder default. You can specify a 'lang' parameter to load a multilingual version!
86 | ranker = Reranker('cross-encoder')
87 |
88 | # Specific cross-encoder
89 | ranker = Reranker('mixedbread-ai/mxbai-rerank-large-v1', model_type='cross-encoder')
90 |
91 | # FlashRank default. You can specify a 'lang' parameter to load a multilingual version!
92 | ranker = Reranker('flashrank')
93 |
94 | # Specific flashrank model.
95 | ranker = Reranker('ce-esci-MiniLM-L12-v2', model_type='flashrank')
96 |
97 | # Default T5 Seq2Seq reranker
98 | ranker = Reranker("t5")
99 |
100 | # Specific T5 Seq2Seq reranker
101 | ranker = Reranker("unicamp-dl/InRanker-base", model_type = "t5")
102 |
103 | # API (Cohere)
104 | ranker = Reranker("cohere", lang='en' (or 'other'), api_key = API_KEY)
105 |
106 | # Custom Cohere model? No problem!
107 | ranker = Reranker("my_model_name", api_provider = "cohere", api_key = API_KEY)
108 |
109 | # API (Pinecone)
110 | ranker = Reranker("pinecone", api_key = API_KEY)
111 |
112 | # API (Jina)
113 | ranker = Reranker("jina", api_key = API_KEY)
114 |
115 | # API (Isaacus)
116 | ranker = Reranker("isaacus", api_key = API_KEY)
117 |
118 | # RankGPT4-turbo
119 | ranker = Reranker("rankgpt", api_key = API_KEY)
120 |
121 | # RankGPT3-turbo
122 | ranker = Reranker("rankgpt3", api_key = API_KEY)
123 |
124 | # RankGPT with another LLM provider
125 | ranker = Reranker("MY_LLM_NAME" (check litellm docs), model_type = "rankgpt", api_key = API_KEY)
126 |
127 | # RankLLM with default GPT (GPT-4o)
128 | ranker = Reranker("rankllm", api_key = API_KEY)
129 |
130 | # RankLLM with specified GPT models
131 | ranker = Reranker('gpt-4-turbo', model_type="rankllm", api_key = API_KEY)
132 |
133 | # ColBERTv2 reranker
134 | ranker = Reranker("colbert")
135 |
136 | # LLM Layerwise Reranker
137 | ranker = Reranker('llm-layerwise')
138 |
139 | # ... Or a non-default colbert model:
140 | ranker = Reranker(model_name_or_path, model_type = "colbert")
141 |
142 | ```
143 |
144 | _Rerankers will always try to infer the model you're trying to use based on its name, but it's always safer to pass a `model_type` argument to it if you can!_
145 |
146 | Then, regardless of which reranker is loaded, use the loaded model to rank a query against documents:
147 |
148 | ```python
149 | > results = ranker.rank(query="I love you", docs=["I hate you", "I really like you"], doc_ids=[0,1])
150 | > results
151 | RankedResults(results=[Result(document=Document(text='I really like you', doc_id=1), score=-2.453125, rank=1), Result(document=Document(text='I hate you', doc_id=0), score=-4.14453125, rank=2)], query='I love you', has_scores=True)
152 | ```
153 |
154 | You don't need to pass `doc_ids`! If not provided, they'll be auto-generated as integers corresponding to the index of a document in `docs`.
155 |
156 |
157 | You're free to pass metadata too, and it'll be stored with the documents. It'll also be accessible in the results object:
158 |
159 | ```python
160 | > results = ranker.rank(query="I love you", docs=["I hate you", "I really like you"], doc_ids=[0,1], metadata=[{'source': 'twitter'}, {'source': 'reddit'}])
161 | > results
162 | RankedResults(results=[Result(document=Document(text='I really like you', doc_id=1, metadata={'source': 'twitter'}), score=-2.453125, rank=1), Result(document=Document(text='I hate you', doc_id=0, metadata={'source': 'reddit'}), score=-4.14453125, rank=2)], query='I love you', has_scores=True)
163 | ```
164 |
165 | If you'd like your code to be a bit cleaner, you can also directly construct `Document` objects yourself, and pass those instead. In that case, you don't need to pass separate `doc_ids` and `metadata`:
166 |
167 | ```python
168 | > from rerankers import Document
169 | > docs = [Document(text="I really like you", doc_id=0, metadata={'source': 'twitter'}), Document(text="I hate you", doc_id=1, metadata={'source': 'reddit'})]
170 | > results = ranker.rank(query="I love you", docs=docs)
171 | > results
172 | RankedResults(results=[Result(document=Document(text='I really like you', doc_id=0, metadata={'source': 'twitter'}), score=-2.453125, rank=1), Result(document=Document(text='I hate you', doc_id=1, metadata={'source': 'reddit'}), score=-4.14453125, rank=2)], query='I love you', has_scores=True)
173 | ```
174 |
175 | You can also use `rank_async`, which is essentially just a wrapper to turn `rank()` into a coroutine. The result will be the same:
176 |
177 | ```python
178 | > results = await ranker.rank_async(query="I love you", docs=["I hate you", "I really like you"], doc_ids=[0,1])
179 | > results
180 | RankedResults(results=[Result(document=Document(text='I really like you', doc_id=1, metadata={'source': 'twitter'}), score=-2.453125, rank=1), Result(document=Document(text='I hate you', doc_id=0, metadata={'source': 'reddit'}), score=-4.14453125, rank=2)], query='I love you', has_scores=True)
181 | ```
182 |
183 | All rerankers will return a `RankedResults` object, which is a Python object containing a list of `Result` objects and some other useful information, such as the original query. You can retrieve the top `k` results from it by running `top_k()`:
184 |
185 | ```python
186 | > results.top_k(1)
187 | [Result(Document(doc_id=1, text='I really like you', metadata={}), score=0.26170814, rank=1)]
188 | ```
189 |
190 | The Result objects are transparent when trying to access the documents they store, as `Document` objects simply exist as an easy way to store IDs and metadata. If you want to access a given result's text or metadata, you can directly access it as a property:
191 |
192 | ```python
193 | > results.top_k(1)[0].text
194 | 'I really like you'
195 | ```
196 |
197 | And that's all you need to know to get started quickly! Check out the overview notebook for more information on the API and the different models, or the langchain example to see how to integrate this in your langchain pipeline.
198 |
199 |
200 | ## Features
201 |
202 | Legend:
203 | - ✅ Supported
204 | - 🟠 Implemented, but not fully fledged
205 | - 📍 Not supported but intended to be in the future
206 | - ⭐ Same as above, but **important**.
207 | - ❌ Not supported & not currently planned
208 |
209 | Models:
210 | - ✅ Any standard SentenceTransformer or Transformers cross-encoder
211 | - ✅ RankGPT (Available both via the original RankGPT implementation and the improved RankLLM one)
212 | - ✅ T5-based pointwise rankers (InRanker, MonoT5...)
213 | - ✅ LLM-based pointwise rankers (BAAI/bge-reranker-v2.5-gemma2-lightweight, etc...)
214 | - ✅ Cohere, Jina, Voyage, MixedBread, Pinecone and Isaacus API rerankers
215 | - ✅ [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers (ONNX-optimised models, very fast on CPU)
216 | - ✅ ColBERT-based reranker - not a model initially designed for reranking, but does perform quite strongly in some cases. Implementation is lightweight, based only on transformers.
217 | - 🟠⭐ RankLLM/RankZephyr: supported by wrapping the [rank-llm library](https://github.com/castorini/rank_llm) library! Support for RankZephyr/RankVicuna is untested, but RankLLM + GPT models fully works!
218 | - ✅ 🆕 v0.6.0: MonoVLMRanker, multi-modal image reranker employing the MonoT5 method with a VLM backnd.
219 | - 📍 LiT5
220 |
221 | Features:
222 | - ✅ Metadata!
223 | - ✅ Reranking
224 | - ✅ Consistency notebooks to ensure performance on `scifact` matches the litterature for any given model implementation (Except RankGPT, where results are harder to reproduce).
225 | - ✅ ONNX runtime support --> Offered through [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) -- in line with the philosophy of the lib, we won't reinvent the wheel when @PrithivirajDamodaran is doing amazing work!
226 | - 📍 Training on Python >=3.10 (via interfacing with other libraries)
227 | - ❌(📍Maybe?) Training via rerankers directly
228 |
229 | ## Reference
230 |
231 | If rerankers has been useful to you in academic work, please do feel free to cite the work below!
232 |
233 | ```
234 | @misc{clavié2024rerankers,
235 | title={rerankers: A Lightweight Python Library to Unify Ranking Methods},
236 | author={Benjamin Clavié},
237 | year={2024},
238 | eprint={2408.17344},
239 | archivePrefix={arXiv},
240 | primaryClass={cs.IR},
241 | url={https://arxiv.org/abs/2408.17344},
242 | }
243 | ```
244 |
245 | ## Release History
246 |
247 | - v0.5.*: ColBERT fixes (0.5.1) & Minor change making RankedResults subscribable, meaning results[0] will return the result for the first document (0.5.2), etc... ⚠️ This is sorted by **passed document order**, not by results, you should use `.top_k()` to get sorted results!
248 | - v0.5.0: Added support for the current state-of-the-art rerankers, BAAI's series of `BGE` layerwise LLM rerankers, based on [Gemma](https://huggingface.co/BAAI/bge-reranker-v2.5-gemma2-lightweight) and MiniCPM. These are different from RankGPT, as they're not listwise: the models are repurposed as "cross-encoders", and do output logit scores.
249 | - v0.4.0: ColBERT performance improvement! It should now be faster and result in stronger results following implementation of the JaColBERTv2.5 dynamic query length method. This version also now supports HuggingFace's Text-Embedding-Server (TEI) inference as an API reranker option, thanks to [@srisudarsan](https://github.com/srisudarsan).
250 | - v0.3.1: T5 bugfix and native default support for new Portuguese T5 rerankers.
251 | - v0.3.0: Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
252 | - v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API
253 | - v0.1.2: Voyage reranking API
254 | - v0.1.1: Langchain integration fixed!
255 | - v0.1.0: Initial release
256 |
--------------------------------------------------------------------------------
/apilist.txt:
--------------------------------------------------------------------------------
1 | # rerankers Module Documentation
2 |
3 | ## rerankers.documents
4 |
5 | - `class Document`
6 | - `@validator('text') def validate_text(cls, v, values)`
7 | - `def __init__(self, text, doc_id, metadata, document_type, image_path, base64)`
8 |
9 | ## rerankers.integrations.langchain
10 |
11 | - `class RerankerLangChainCompressor`
12 | - `def compress_documents(self, documents, query, callbacks, **kwargs)`
13 | Rerank a list of documents relevant to a query.
14 |
15 |
16 | ## rerankers.models.api_rankers
17 |
18 | - `class APIRanker`
19 | - `def __init__(self, model, api_key, api_provider, verbose, url)`
20 | - `def rank(self, query, docs, doc_ids, metadata)`
21 | - `def score(self, query, doc)`
22 |
23 | ## rerankers.models.colbert_ranker
24 |
25 | > Code from HotchPotch's JQaRa repository: https://github.com/hotchpotch/JQaRA/blob/main/evaluator/reranker/colbert_reranker.py
26 | > Modifications include packaging into a BaseRanker, dynamic query/doc length and batch size handling.
27 |
28 | - `class ColBERTModel`
29 | - `def __init__(self, config)`
30 | - `def forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, output_attentions, output_hidden_states)`
31 |
32 | - `class ColBERTRanker`
33 | - `def __init__(self, model_name, batch_size, dtype, device, verbose, query_token, document_token, **kwargs)`
34 | - `def rank(self, query, docs, doc_ids, metadata)`
35 | - `def score(self, query, doc)`
36 |
37 | ## rerankers.models.flashrank_ranker
38 |
39 | - `class FlashRankRanker`
40 | - `def __init__(self, model_name_or_path, verbose, cache_dir)`
41 | - `def tokenize(self, inputs)`
42 | - `def rank(self, query, docs, doc_ids, metadata)`
43 | - `def score(self, query, doc)`
44 |
45 | ## rerankers.models.llm_layerwise_ranker
46 |
47 | - `class LLMLayerWiseRanker`
48 | - `def __init__(self, model_name_or_path, max_sequence_length, dtype, device, batch_size, verbose, prompt, cutoff_layers, compress_ratio, compress_layer, **kwargs)`
49 | - `@torch.inference_mode() def rank(self, query, docs, doc_ids, metadata, batch_size, max_sequence_length)`
50 | - `@torch.inference_mode() def score(self, query, doc)`
51 |
52 | ## rerankers.models.monovlm_ranker
53 |
54 | - `class MonoVLMRanker`
55 | - `def __init__(self, model_name_or_path, processor_name, dtype, device, batch_size, verbose, token_false, token_true, return_logits, prompt_template, **kwargs)`
56 | - `def rank(self, query, docs, doc_ids, metadata)`
57 | - `def score(self, query, doc)`
58 |
59 | ## rerankers.models.ranker
60 |
61 | - `class BaseRanker`
62 | - `@abstractmethod def __init__(self, model_name_or_path, verbose)`
63 | - `@abstractmethod def score(self, query, doc)`
64 | - `@abstractmethod def rank(self, query, docs, doc_ids)`
65 | End-to-end reranking of documents.
66 |
67 | - `def rank_async(self, query, docs, doc_ids)`
68 | - `def as_langchain_compressor(self, k)`
69 |
70 | ## rerankers.models.rankgpt_rankers
71 |
72 | > Full implementation is from the original RankGPT repository https://github.com/sunnweiwei/RankGPT under its Apache 2.0 License
73 | >
74 | > Changes made are:
75 | > - Truncating the file to only the relevant functions
76 | > - Using only LiteLLM
77 | > - make_item() added
78 | > - Packaging it onto RankGPTRanker
79 |
80 | - `class RankGPTRanker`
81 | - `def __init__(self, model, api_key, lang, verbose)`
82 | - `def rank(self, query, docs, doc_ids, metadata, rank_start, rank_end)`
83 | - `def score(self)`
84 |
85 | ## rerankers.models.rankllm_ranker
86 |
87 | - `class RankLLMRanker`
88 | - `def __init__(self, model, api_key, lang, verbose)`
89 | - `def rank(self, query, docs, doc_ids, metadata, rank_start, rank_end)`
90 | - `def score(self)`
91 |
92 | ## rerankers.models.t5ranker
93 |
94 | > Code for InRanker is taken from the excellent InRanker repo https://github.com/unicamp-dl/InRanker under its Apache 2.0 license.
95 | > The only change to the original implementation is the removal of InRanker's BaseRanker, replacing it with our own to support the unified API better.
96 | > The main purpose for adapting this code here rather than installing the InRanker library is to ensure greater version compatibility (InRanker requires Python >=3.10)
97 |
98 | - `class T5Ranker`
99 | - `def __init__(self, model_name_or_path, batch_size, dtype, device, verbose, token_false, token_true, return_logits, inputs_template, **kwargs)`
100 | Implementation of the key functions from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py
101 | Changes are detailed in the docstring for each relevant function.
102 |
103 | T5Ranker is a wrapper for using Seq2Seq models for ranking.
104 | Args:
105 | batch_size: The batch size to use when encoding.
106 | dtype: Data type for model weights.
107 | device: The device to use for inference ("cpu", "cuda", or "mps").
108 | verbose: Verbosity level.
109 | silent: Whether to show progress bars.
110 |
111 | - `def rank(self, query, docs, doc_ids, metadata)`
112 | Ranks a list of documents based on their relevance to the query.
113 |
114 | - `def score(self, query, doc)`
115 | Scores a single document's relevance to a query.
116 |
117 |
118 | ## rerankers.models.transformer_ranker
119 |
120 | - `class TransformerRanker`
121 | - `def __init__(self, model_name_or_path, dtype, device, batch_size, verbose, **kwargs)`
122 | - `def tokenize(self, inputs)`
123 | - `@torch.inference_mode() def rank(self, query, docs, doc_ids, metadata, batch_size)`
124 | - `@torch.inference_mode() def score(self, query, doc)`
125 |
126 | ## rerankers.results
127 |
128 | - `class Result`
129 | - `@validator('rank', always=True) def check_score_or_rank_exists(cls, v, values)`
130 | - `def __getattr__(self, item)`
131 |
132 | - `class RankedResults`
133 | - `def __iter__(self)`
134 | Allows iteration over the results list.
135 |
136 | - `def __getitem__(self, index)`
137 | Allows indexing to access results directly.
138 |
139 | - `def results_count(self)`
140 | Returns the total number of results.
141 |
142 | - `def top_k(self, k)`
143 | Returns the top k results based on the score, if available, or rank.
144 |
145 | - `def get_score_by_docid(self, doc_id)`
146 | Fetches the score of a result by its doc_id using a more efficient approach.
147 |
148 | - `def get_result_by_docid(self, doc_id)`
149 | Fetches a result by its doc_id using a more efficient approach.
150 |
151 |
152 | ## rerankers.utils
153 |
154 | - `def prep_image_docs(docs, doc_ids, metadata)`
155 | Prepare image documents for processing. Can handle base64 encoded images or file paths.
156 | Similar to prep_docs but specialized for image documents.
157 |
158 | - `def get_chunks(iterable, chunk_size)`
159 | Implementation from https://github.com/unicamp-dl/InRanker/blob/main/inranker/base.py with extra typing and more descriptive names.
160 | This method is used to split a list l into chunks of batch size n.
161 |
162 |
--------------------------------------------------------------------------------
/examples/langchain_integration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Welcome! This is a quick notebook to introduce you to using rerankers in Langchain, at the end of a retrieval pipeline. It's heavily inspired by existing langchain examples.\n",
8 | "\n",
9 | "First, let's define a helper function for printing docs:"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "def pretty_print_docs(docs):\n",
19 | " print(\n",
20 | " f\"\\n{'-' * 100}\\n\".join(\n",
21 | " [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n",
22 | " )\n",
23 | " )"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "Then, let's set up a normal document retrieval pipeline, using the common OpenAI embeddings + FAISS combo. If you want to run this example yourself and don't have faiss installed, you'll need to install it for this example! (the document is very small, so `faiss-cpu` is largely enough)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 2,
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "name": "stdout",
40 | "output_type": "stream",
41 | "text": [
42 | "Document 1:\n",
43 | "\n",
44 | "And so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. \n",
45 | "\n",
46 | "I understand. \n",
47 | "\n",
48 | "I remember when my Dad had to leave our home in Scranton, Pennsylvania to find work. I grew up in a family where if the price of food went up, you felt it. \n",
49 | "\n",
50 | "That’s why one of the first things I did as President was fight to pass the American Rescue Plan. \n",
51 | "\n",
52 | "Because people were hurting. We needed to act, and we did. \n",
53 | "\n",
54 | "Few pieces of legislation have done more in a critical moment in our history to lift us out of crisis. \n",
55 | "\n",
56 | "It fueled our efforts to vaccinate the nation and combat COVID-19. It delivered immediate economic relief for tens of millions of Americans. \n",
57 | "\n",
58 | "Helped put food on their table, keep a roof over their heads, and cut the cost of health insurance. \n",
59 | "\n",
60 | "And as my Dad used to say, it gave people a little breathing room.\n",
61 | "----------------------------------------------------------------------------------------------------\n",
62 | "Document 2:\n",
63 | "\n",
64 | "We got more than 130 countries to agree on a global minimum tax rate so companies can’t get out of paying their taxes at home by shipping jobs and factories overseas. \n",
65 | "\n",
66 | "That’s why I’ve proposed closing loopholes so the very wealthy don’t pay a lower tax rate than a teacher or a firefighter. \n",
67 | "\n",
68 | "So that’s my plan. It will grow the economy and lower costs for families. \n",
69 | "\n",
70 | "So what are we waiting for? Let’s get this done. And while you’re at it, confirm my nominees to the Federal Reserve, which plays a critical role in fighting inflation. \n",
71 | "\n",
72 | "My plan will not only lower costs to give families a fair shot, it will lower the deficit. \n",
73 | "\n",
74 | "The previous Administration not only ballooned the deficit with tax cuts for the very wealthy and corporations, it undermined the watchdogs whose job was to keep pandemic relief funds from being wasted. \n",
75 | "\n",
76 | "But in my administration, the watchdogs have been welcomed back.\n",
77 | "----------------------------------------------------------------------------------------------------\n",
78 | "Document 3:\n",
79 | "\n",
80 | "Tonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers. \n",
81 | "\n",
82 | "And as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up. \n",
83 | "\n",
84 | "That ends on my watch. \n",
85 | "\n",
86 | "Medicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect. \n",
87 | "\n",
88 | "We’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees. \n",
89 | "\n",
90 | "Let’s pass the Paycheck Fairness Act and paid leave. \n",
91 | "\n",
92 | "Raise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty. \n",
93 | "\n",
94 | "Let’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges.\n",
95 | "----------------------------------------------------------------------------------------------------\n",
96 | "Document 4:\n",
97 | "\n",
98 | "My plan will cut the cost in half for most families and help parents, including millions of women, who left the workforce during the pandemic because they couldn’t afford child care, to be able to get back to work. \n",
99 | "\n",
100 | "My plan doesn’t stop there. It also includes home and long-term care. More affordable housing. And Pre-K for every 3- and 4-year-old. \n",
101 | "\n",
102 | "All of these will lower costs. \n",
103 | "\n",
104 | "And under my plan, nobody earning less than $400,000 a year will pay an additional penny in new taxes. Nobody. \n",
105 | "\n",
106 | "The one thing all Americans agree on is that the tax system is not fair. We have to fix it. \n",
107 | "\n",
108 | "I’m not looking to punish anyone. But let’s make sure corporations and the wealthiest Americans start paying their fair share. \n",
109 | "\n",
110 | "Just last year, 55 Fortune 500 corporations earned $40 billion in profits and paid zero dollars in federal income tax. \n",
111 | "\n",
112 | "That’s simply not fair. That’s why I’ve proposed a 15% minimum tax rate for corporations.\n"
113 | ]
114 | }
115 | ],
116 | "source": [
117 | "# Vanilla retrieval\n",
118 | "from langchain_community.document_loaders import TextLoader\n",
119 | "from langchain_community.vectorstores import FAISS\n",
120 | "from langchain_openai import OpenAIEmbeddings\n",
121 | "from langchain_text_splitters import CharacterTextSplitter\n",
122 | "\n",
123 | "documents = TextLoader(\"../data/state_of_the_union.txt\").load()\n",
124 | "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
125 | "texts = text_splitter.split_documents(documents)\n",
126 | "retriever = FAISS.from_documents(texts, OpenAIEmbeddings()).as_retriever()\n",
127 | "\n",
128 | "docs = retriever.get_relevant_documents(\n",
129 | " \"What did the president say about the minimum wage?\"\n",
130 | ")\n",
131 | "pretty_print_docs(docs)"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | "These results are interesting, but nothing about the actual new minimum wage pledge in the top two documents! Let's see if a re-ranker could help...\n",
139 | "\n",
140 | "First, let's load a reranker. The one you load doesn't actually matter -- they all behave exactly the same. For this example, we're using MixedBread's excellent [mixedbread-ai/mxbai-rerank-base-v1](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 3,
146 | "metadata": {},
147 | "outputs": [
148 | {
149 | "name": "stderr",
150 | "output_type": "stream",
151 | "text": [
152 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
153 | " from .autonotebook import tqdm as notebook_tqdm\n"
154 | ]
155 | },
156 | {
157 | "name": "stdout",
158 | "output_type": "stream",
159 | "text": [
160 | "Warning: Model type could not be auto-mapped with the defaults list. Defaulting to TransformerRanker.\n",
161 | "If your model is NOT intended to be ran as a one-label cross-encoder, please reload it and specify the model_type! Otherwise, you may ignore this warning. You may specify `model_type='cross-encoder'` to suppress this warning in the future.\n",
162 | "Loading TransformerRanker model mixedbread-ai/mxbai-rerank-base-v1\n"
163 | ]
164 | }
165 | ],
166 | "source": [
167 | "# Load a reranker and convert it to a LangChain compressor\n",
168 | "from rerankers import Reranker\n",
169 | "\n",
170 | "ranker = Reranker(\"mixedbread-ai/mxbai-rerank-base-v1\", verbose=0)"
171 | ]
172 | },
173 | {
174 | "cell_type": "markdown",
175 | "metadata": {},
176 | "source": [
177 | "Converting it to a Langchain compressor is very straightforward, all you have to do is call `as_langchain_compressor`. You can pass a `k` argument to define how many documents it should retrieve, otherwise, `k` will default to 5."
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 4,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": [
186 | "compressor = ranker.as_langchain_compressor(k=3)"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {},
192 | "source": [
193 | "You're all set! Let's just add it to our pipeline and retrieve+rerank documents:"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 5,
199 | "metadata": {},
200 | "outputs": [
201 | {
202 | "name": "stdout",
203 | "output_type": "stream",
204 | "text": [
205 | "Document 1:\n",
206 | "\n",
207 | "Tonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers. \n",
208 | "\n",
209 | "And as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up. \n",
210 | "\n",
211 | "That ends on my watch. \n",
212 | "\n",
213 | "Medicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect. \n",
214 | "\n",
215 | "We’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees. \n",
216 | "\n",
217 | "Let’s pass the Paycheck Fairness Act and paid leave. \n",
218 | "\n",
219 | "Raise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty. \n",
220 | "\n",
221 | "Let’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges.\n",
222 | "----------------------------------------------------------------------------------------------------\n",
223 | "Document 2:\n",
224 | "\n",
225 | "And so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. \n",
226 | "\n",
227 | "I understand. \n",
228 | "\n",
229 | "I remember when my Dad had to leave our home in Scranton, Pennsylvania to find work. I grew up in a family where if the price of food went up, you felt it. \n",
230 | "\n",
231 | "That’s why one of the first things I did as President was fight to pass the American Rescue Plan. \n",
232 | "\n",
233 | "Because people were hurting. We needed to act, and we did. \n",
234 | "\n",
235 | "Few pieces of legislation have done more in a critical moment in our history to lift us out of crisis. \n",
236 | "\n",
237 | "It fueled our efforts to vaccinate the nation and combat COVID-19. It delivered immediate economic relief for tens of millions of Americans. \n",
238 | "\n",
239 | "Helped put food on their table, keep a roof over their heads, and cut the cost of health insurance. \n",
240 | "\n",
241 | "And as my Dad used to say, it gave people a little breathing room.\n",
242 | "----------------------------------------------------------------------------------------------------\n",
243 | "Document 3:\n",
244 | "\n",
245 | "My plan will cut the cost in half for most families and help parents, including millions of women, who left the workforce during the pandemic because they couldn’t afford child care, to be able to get back to work. \n",
246 | "\n",
247 | "My plan doesn’t stop there. It also includes home and long-term care. More affordable housing. And Pre-K for every 3- and 4-year-old. \n",
248 | "\n",
249 | "All of these will lower costs. \n",
250 | "\n",
251 | "And under my plan, nobody earning less than $400,000 a year will pay an additional penny in new taxes. Nobody. \n",
252 | "\n",
253 | "The one thing all Americans agree on is that the tax system is not fair. We have to fix it. \n",
254 | "\n",
255 | "I’m not looking to punish anyone. But let’s make sure corporations and the wealthiest Americans start paying their fair share. \n",
256 | "\n",
257 | "Just last year, 55 Fortune 500 corporations earned $40 billion in profits and paid zero dollars in federal income tax. \n",
258 | "\n",
259 | "That’s simply not fair. That’s why I’ve proposed a 15% minimum tax rate for corporations.\n"
260 | ]
261 | }
262 | ],
263 | "source": [
264 | "from langchain.retrievers import ContextualCompressionRetriever\n",
265 | "from langchain_openai import OpenAI\n",
266 | "\n",
267 | "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
268 | "texts = text_splitter.split_documents(documents)\n",
269 | "retriever = FAISS.from_documents(texts, OpenAIEmbeddings()).as_retriever()\n",
270 | "\n",
271 | "compression_retriever = ContextualCompressionRetriever(\n",
272 | " base_compressor=compressor, base_retriever=retriever\n",
273 | ")\n",
274 | "\n",
275 | "\n",
276 | "compressed_docs = compression_retriever.get_relevant_documents(\n",
277 | " \"What did the president say about the minimum wage?\"\n",
278 | ")\n",
279 | "\n",
280 | "pretty_print_docs(compressed_docs)"
281 | ]
282 | },
283 | {
284 | "cell_type": "markdown",
285 | "metadata": {},
286 | "source": [
287 | "Here it is! There's not much more to show -- just load any reranker you want and try it out!\n",
288 | "\n",
289 | "Remember, not all rerankers work in the same way. It's important to experiment to find out which one works best for your data, and even to fine-tune them if you have the data to do so. The point of this library is to make it easy to try many different approaches to find the best one for your usecase!"
290 | ]
291 | }
292 | ],
293 | "metadata": {
294 | "kernelspec": {
295 | "display_name": "mcol",
296 | "language": "python",
297 | "name": "python3"
298 | },
299 | "language_info": {
300 | "codemirror_mode": {
301 | "name": "ipython",
302 | "version": 3
303 | },
304 | "file_extension": ".py",
305 | "mimetype": "text/x-python",
306 | "name": "python",
307 | "nbconvert_exporter": "python",
308 | "pygments_lexer": "ipython3",
309 | "version": "3.10.13"
310 | }
311 | },
312 | "nbformat": 4,
313 | "nbformat_minor": 2
314 | }
315 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 |
2 | [build-system]
3 | requires = ["setuptools"]
4 | build-backend = "setuptools.build_meta"
5 |
6 | [tool.setuptools]
7 | packages = [
8 | "rerankers",
9 | "rerankers.models",
10 | "rerankers.integrations",
11 | ]
12 |
13 | [project]
14 | name = "rerankers"
15 |
16 |
17 | version = "0.10.0"
18 |
19 | description = "A unified API for various document re-ranking models."
20 |
21 | readme = "README.md"
22 |
23 | requires-python = ">=3.8"
24 |
25 | license = {file = "LICENSE"}
26 |
27 | keywords = ["reranking", "retrieval", "rag", "nlp"]
28 |
29 | authors = [
30 | {name = "Ben Clavié", email = "bc@answer.ai" }
31 | ]
32 | maintainers = [
33 | {name = "Ben Clavié", email = "bc@answer.ai" }
34 | ]
35 |
36 | classifiers = [
37 | # Specify the Python versions you support here. In particular, ensure
38 | # that you indicate you support Python 3. These classifiers are *not*
39 | # checked by "pip install". See instead "requires-python" key in this file.
40 | "Programming Language :: Python :: 3",
41 | "Programming Language :: Python :: 3.8",
42 | "Programming Language :: Python :: 3.9",
43 | "Programming Language :: Python :: 3.10",
44 | "Programming Language :: Python :: 3.11",
45 | "Programming Language :: Python :: 3.12",
46 | "Programming Language :: Python :: 3 :: Only",
47 | ]
48 |
49 | dependencies = []
50 |
51 | [project.optional-dependencies]
52 | all = [
53 | "transformers>=4.45.0",
54 | "torch",
55 | "litellm",
56 | "requests",
57 | "sentencepiece",
58 | "protobuf",
59 | "flashrank",
60 | "flash-attn",
61 | "pillow",
62 | "accelerate>=0.26.0",
63 | "peft>=0.13.0",
64 | "nmslib-metabrainz; python_version >= '3.10'",
65 | "rank-llm; python_version >= '3.10'"
66 | ]
67 | transformers = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf"]
68 | api = ["requests"]
69 | gpt = ["litellm"]
70 | flashrank = ["flashrank"]
71 | llmlayerwise = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf", "flash-attn"]
72 | monovlm = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf", "flash-attn", "pillow", "accelerate>=0.26.0", "peft>=0.13.0"]
73 | rankllm = [
74 | "nmslib-metabrainz; python_version >= '3.10'",
75 | "rank-llm; python_version >= '3.10'"
76 | ]
77 | pylate = ["pylate"]
78 | dev = ["ruff", "isort", "pytest", "ipyprogress", "ipython", "ranx", "ir_datasets", "srsly"]
79 |
80 | [project.urls]
81 | "Homepage" = "https://github.com/answerdotai/rerankers"
--------------------------------------------------------------------------------
/rerankers/__init__.py:
--------------------------------------------------------------------------------
1 | from rerankers.reranker import Reranker
2 | from rerankers.documents import Document
3 |
4 | __all__ = ["Reranker", "Document"]
5 | __version__ = "0.10.0"
--------------------------------------------------------------------------------
/rerankers/documents.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Literal
2 |
3 |
4 | class Document:
5 | def __init__(
6 | self,
7 | text: Optional[str] = None,
8 | doc_id: Optional[Union[str, int]] = None,
9 | metadata: Optional[dict] = None,
10 | document_type: Literal["text", "image"] = "text",
11 | image_path: Optional[str] = None,
12 | base64: Optional[str] = None,
13 | ):
14 | self.attributes = ["text", "base64", "image_path", "doc_id", "metadata", "document_type"]
15 | self.document_type = document_type
16 | self.text = text
17 | self.base64 = base64
18 | self.image_path = image_path
19 | self.doc_id = doc_id
20 | self.metadata = metadata if metadata is not None else {}
21 |
22 | # Validation
23 | if self.document_type == "text" and self.text is None:
24 | raise ValueError("text field is required when document_type is 'text'")
25 |
26 | def __repr__(self) -> str:
27 | fields = {
28 | "text": self.text,
29 | "doc_id": self.doc_id,
30 | "metadata": self.metadata,
31 | "document_type": self.document_type,
32 | "image_path": self.image_path,
33 | "base64": self.base64,
34 | }
35 | field_str = ", ".join(f"{k}={v!r}" for k, v in fields.items())
36 | return f"{self.__class__.__name__}({field_str})"
37 |
--------------------------------------------------------------------------------
/rerankers/integrations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnswerDotAI/rerankers/7bb252179daaf1968eb9d49629c9cf48b0f9b5f2/rerankers/integrations/__init__.py
--------------------------------------------------------------------------------
/rerankers/integrations/langchain.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional, Sequence
2 |
3 | from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
4 | from langchain_core.callbacks.manager import Callbacks
5 | from langchain_core.documents import Document
6 |
7 |
8 | class RerankerLangChainCompressor(BaseDocumentCompressor):
9 | model: Any
10 | kwargs: dict = {}
11 | k: int = 5
12 |
13 | def compress_documents(
14 | self,
15 | documents: Sequence[Document],
16 | query: str,
17 | callbacks: Optional[Callbacks] = None, # noqa
18 | **kwargs,
19 | ) -> Any:
20 | """Rerank a list of documents relevant to a query."""
21 | doc_list = list(documents)
22 | _docs = [d.page_content for d in doc_list]
23 | results = self.model.rank(
24 | query=query,
25 | docs=_docs,
26 | **self.kwargs,
27 | )
28 | final_results = []
29 | for r in results.top_k(kwargs.get("k", self.k)):
30 | doc = doc_list[r.doc_id]
31 | doc.metadata["relevance_score"] = r.score
32 | final_results.append(doc)
33 | return final_results
34 |
--------------------------------------------------------------------------------
/rerankers/models/__init__.py:
--------------------------------------------------------------------------------
1 | AVAILABLE_RANKERS = {}
2 |
3 | try:
4 | from rerankers.models.transformer_ranker import TransformerRanker
5 |
6 | AVAILABLE_RANKERS["TransformerRanker"] = TransformerRanker
7 | except ImportError:
8 | pass
9 | try:
10 | from rerankers.models.api_rankers import APIRanker
11 |
12 | AVAILABLE_RANKERS["APIRanker"] = APIRanker
13 | except ImportError:
14 | pass
15 | try:
16 | from rerankers.models.rankgpt_rankers import RankGPTRanker
17 |
18 | AVAILABLE_RANKERS["RankGPTRanker"] = RankGPTRanker
19 | except ImportError:
20 | pass
21 | try:
22 | from rerankers.models.t5ranker import T5Ranker
23 |
24 | AVAILABLE_RANKERS["T5Ranker"] = T5Ranker
25 | except ImportError:
26 | pass
27 |
28 | try:
29 | from rerankers.models.colbert_ranker import ColBERTRanker
30 |
31 | AVAILABLE_RANKERS["ColBERTRanker"] = ColBERTRanker
32 | except ImportError:
33 | pass
34 |
35 | try:
36 | from rerankers.models.flashrank_ranker import FlashRankRanker
37 |
38 | AVAILABLE_RANKERS["FlashRankRanker"] = FlashRankRanker
39 | except ImportError:
40 | pass
41 |
42 | try:
43 | from rerankers.models.rankllm_ranker import RankLLMRanker
44 |
45 | AVAILABLE_RANKERS["RankLLMRanker"] = RankLLMRanker
46 | except ImportError:
47 | pass
48 |
49 | try:
50 | from rerankers.models.llm_layerwise_ranker import LLMLayerWiseRanker
51 |
52 | AVAILABLE_RANKERS["LLMLayerWiseRanker"] = LLMLayerWiseRanker
53 | except ImportError:
54 | pass
55 |
56 | try:
57 | from rerankers.models.monovlm_ranker import MonoVLMRanker
58 | AVAILABLE_RANKERS["MonoVLMRanker"] = MonoVLMRanker
59 | except ImportError:
60 | pass
61 |
62 | try:
63 | from rerankers.models.llm_relevance_filter import LLMRelevanceFilter
64 | AVAILABLE_RANKERS["LLMRelevanceFilter"] = LLMRelevanceFilter
65 | except ImportError:
66 | pass
67 |
68 | try:
69 | from rerankers.models.upr import UPRRanker
70 | AVAILABLE_RANKERS["UPRRanker"] = UPRRanker
71 | except ImportError:
72 | pass
73 |
74 | try:
75 | from rerankers.models.mxbai_v2 import MxBaiV2Ranker
76 | AVAILABLE_RANKERS["MxBaiV2Ranker"] = MxBaiV2Ranker
77 | except ImportError:
78 | pass
79 |
80 | try:
81 | from rerankers.models.pylate_ranker import PyLateRanker
82 |
83 | AVAILABLE_RANKERS["PyLateRanker"] = PyLateRanker
84 | except ImportError:
85 | pass
86 |
--------------------------------------------------------------------------------
/rerankers/models/api_rankers.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Optional
2 | from rerankers.models.ranker import BaseRanker
3 | from rerankers.results import RankedResults, Result
4 | from rerankers.utils import prep_docs
5 | from rerankers.documents import Document
6 | from string import Template
7 |
8 |
9 | import requests
10 | import json
11 |
12 |
13 | URLS = {
14 | "cohere": "https://api.cohere.ai/v1/rerank",
15 | "jina": "https://api.jina.ai/v1/rerank",
16 | "isaacus": "https://api.isaacus.com/v1/rerankings",
17 | "voyage": "https://api.voyageai.com/v1/rerank",
18 | "mixedbread.ai": "https://api.mixedbread.ai/v1/reranking",
19 | "pinecone": "https://api.pinecone.io/rerank",
20 | }
21 | AUTHORIZATION_KEY_MAPPING = {
22 | "pinecone": "Api-Key"
23 | }
24 | API_VERSION_MAPPING = {
25 | "pinecone": {"X-Pinecone-API-Version": "2024-10"}
26 | }
27 |
28 | API_KEY_MAPPING = {
29 | "pinecone": Template("$api_key")
30 | }
31 |
32 | DOCUMENT_KEY_MAPPING = {
33 | "mixedbread.ai": "input",
34 | "text-embeddings-inference":"texts",
35 | "isaacus": "texts",
36 | }
37 | RETURN_DOCUMENTS_KEY_MAPPING = {
38 | "mixedbread.ai":"return_input",
39 | "text-embeddings-inference":"return_text"
40 | }
41 | RESULTS_KEY_MAPPING = {
42 | "voyage": "data",
43 | "mixedbread.ai": "data",
44 | "pinecone": "data",
45 | "text-embeddings-inference": None
46 | }
47 | SCORE_KEY_MAPPING = {
48 | "mixedbread.ai": "score",
49 | "pinecone": "score",
50 | "text-embeddings-inference":"score",
51 | "isaacus":"score",
52 | }
53 |
54 | class APIRanker(BaseRanker):
55 | def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1, url: str = None):
56 | self.api_provider = api_provider.lower()
57 | self.api_key = API_KEY_MAPPING.get(self.api_provider,Template("Bearer $api_key")).substitute(api_key=api_key)
58 | authorization_key = AUTHORIZATION_KEY_MAPPING.get(self.api_provider,"Authorization")
59 | api_version = API_VERSION_MAPPING.get(self.api_provider,None)
60 | self.model = model
61 | self.verbose = verbose
62 | self.ranking_type = "pointwise"
63 | self.headers = {
64 | "accept": "application/json",
65 | "content-type": "application/json",
66 | authorization_key: self.api_key,
67 | }
68 | if api_version:
69 | self.headers.update(api_version)
70 | self.url = url if url else URLS[self.api_provider]
71 |
72 |
73 | def _get_document_text(self, r: dict) -> str:
74 | if self.api_provider == "voyage":
75 | return r["document"]
76 | elif self.api_provider == "mixedbread.ai":
77 | return r["input"]
78 | elif self.api_provider == "text-embeddings-inference":
79 | return r["text"]
80 | else:
81 | return r["document"]["text"]
82 |
83 | def _get_score(self, r: dict) -> float:
84 | score_key = SCORE_KEY_MAPPING.get(self.api_provider,"relevance_score")
85 | return r[score_key]
86 |
87 | def _parse_response(
88 | self, response: dict, docs: List[Document],
89 | ) -> RankedResults:
90 | ranked_docs = []
91 | results_key = RESULTS_KEY_MAPPING.get(self.api_provider,"results")
92 |
93 | for i, r in enumerate(response[results_key] if results_key else response):
94 | ranked_docs.append(
95 | Result(
96 | document=docs[r["index"]],
97 | score=self._get_score(r),
98 | rank=i + 1,
99 | )
100 | )
101 |
102 | return ranked_docs
103 |
104 | def rank(
105 | self,
106 | query: str,
107 | docs: Union[str, List[str], Document, List[Document]],
108 | doc_ids: Optional[Union[List[str], List[int]]] = None,
109 | metadata: Optional[List[dict]] = None,
110 | ) -> RankedResults:
111 | docs = prep_docs(docs, doc_ids, metadata)
112 | payload = self._format_payload(query, docs)
113 | response = requests.post(self.url, headers=self.headers, data=payload)
114 | results = self._parse_response(response.json(), docs)
115 | return RankedResults(results=results, query=query, has_scores=True)
116 |
117 |
118 | def _format_payload(self, query: str, docs: List[str]) -> str:
119 | top_key = (
120 | "top_n" if self.api_provider not in ["voyage", "mixedbread.ai"] else "top_k"
121 | )
122 | documents_key = DOCUMENT_KEY_MAPPING.get(self.api_provider,"documents")
123 | return_documents_key = RETURN_DOCUMENTS_KEY_MAPPING.get(self.api_provider,"return_documents")
124 |
125 | documents = (
126 | [d.text for d in docs] if self.api_provider not in ["pinecone"] else [{"text": d.text} for d in docs]
127 | )
128 |
129 | payload = {
130 | "model": self.model,
131 | "query": query,
132 | documents_key: documents,
133 | top_key: len(docs),
134 | return_documents_key: True,
135 | }
136 | return json.dumps(payload)
137 |
138 | def score(self, query: str, doc: str) -> float:
139 | payload = self._format_payload(query, [doc])
140 | response = requests.post(self.url, headers=self.headers, data=payload)
141 | results = self._parse_response(response.json(), [doc])
142 | return results[0].score
143 |
--------------------------------------------------------------------------------
/rerankers/models/colbert_ranker.py:
--------------------------------------------------------------------------------
1 | """Code from HotchPotch's JQaRa repository: https://github.com/hotchpotch/JQaRA/blob/main/evaluator/reranker/colbert_reranker.py
2 | Modifications include packaging into a BaseRanker, dynamic query/doc length and batch size handling."""
3 |
4 | import torch
5 | import torch.nn as nn
6 | from transformers import BertPreTrainedModel, BertModel, AutoModel, AutoTokenizer
7 | from typing import List, Optional, Union
8 | from math import ceil
9 |
10 | from rerankers.models.ranker import BaseRanker
11 | from rerankers.documents import Document
12 | from rerankers.results import RankedResults, Result
13 | from rerankers.utils import vprint, get_device, get_dtype, prep_docs
14 |
15 |
16 | def _insert_token(
17 | output: dict,
18 | insert_token_id: int,
19 | insert_position: int = 1,
20 | token_type_id: int = 0,
21 | attention_value: int = 1,
22 | ):
23 | """
24 | Inserts a new token at a specified position into the sequences of a tokenized representation.
25 |
26 | This function takes a dictionary containing tokenized representations
27 | (e.g., 'input_ids', 'token_type_ids', 'attention_mask') as PyTorch tensors,
28 | and inserts a specified token into each sequence at the given position.
29 | This can be used to add special tokens or other modifications to tokenized inputs.
30 |
31 | Parameters:
32 | - output (dict): A dictionary containing the tokenized representations. Expected keys
33 | are 'input_ids', 'token_type_ids', and 'attention_mask'. Each key
34 | is associated with a PyTorch tensor.
35 | - insert_token_id (int): The token ID to be inserted into each sequence.
36 | - insert_position (int, optional): The position in the sequence where the new token
37 | should be inserted. Defaults to 1, which typically
38 | follows a special starting token like '[CLS]' or '[BOS]'.
39 | - token_type_id (int, optional): The token type ID to assign to the inserted token.
40 | Defaults to 0.
41 | - attention_value (int, optional): The attention mask value to assign to the inserted token.
42 | Defaults to 1.
43 |
44 | Returns:
45 | - updated_output (dict): A dictionary containing the updated tokenized representations,
46 | with the new token inserted at the specified position in each sequence.
47 | The structure and keys of the output dictionary are the same as the input.
48 | """
49 | updated_output = {}
50 | for key in output:
51 | updated_tensor_list = []
52 | for seqs in output[key]:
53 | if len(seqs.shape) == 1:
54 | seqs = seqs.unsqueeze(0)
55 | for seq in seqs:
56 | first_part = seq[:insert_position]
57 | second_part = seq[insert_position:]
58 | new_element = (
59 | torch.tensor([insert_token_id])
60 | if key == "input_ids"
61 | else torch.tensor([token_type_id])
62 | )
63 | if key == "attention_mask":
64 | new_element = torch.tensor([attention_value])
65 | updated_seq = torch.cat((first_part, new_element, second_part), dim=0)
66 | updated_tensor_list.append(updated_seq)
67 | updated_output[key] = torch.stack(updated_tensor_list)
68 | return updated_output
69 |
70 |
71 | def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor):
72 | # calc max sim
73 | # base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py
74 |
75 | # Assert that all q_reps are at least as long as the query length
76 | assert (
77 | q_reps.shape[1] >= q_mask.shape[1]
78 | ), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}"
79 |
80 | token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps)
81 | token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4)
82 | scores, _ = token_scores.max(-1)
83 | scores = scores.sum(1) / q_mask.sum(-1, keepdim=True)
84 | return scores
85 |
86 |
87 | class ColBERTModel(BertPreTrainedModel):
88 | def __init__(self, config, verbose: int):
89 | super().__init__(config)
90 | self.bert = BertModel(config)
91 | self.verbose = verbose
92 | # TODO: Load from artifact.metadata
93 | if "small" in config._name_or_path:
94 | linear_dim = 96
95 | else:
96 | linear_dim = 128
97 | vprint(f"Linear Dim set to: {linear_dim} for downcasting", self.verbose)
98 | self.linear = nn.Linear(config.hidden_size, linear_dim, bias=False)
99 | self.init_weights()
100 |
101 | def forward(
102 | self,
103 | input_ids=None,
104 | attention_mask=None,
105 | token_type_ids=None,
106 | position_ids=None,
107 | head_mask=None,
108 | inputs_embeds=None,
109 | encoder_hidden_states=None,
110 | encoder_attention_mask=None,
111 | output_attentions=None,
112 | output_hidden_states=None,
113 | ):
114 | outputs = self.bert(
115 | input_ids,
116 | attention_mask=attention_mask,
117 | token_type_ids=token_type_ids,
118 | position_ids=position_ids,
119 | head_mask=head_mask,
120 | inputs_embeds=inputs_embeds,
121 | encoder_hidden_states=encoder_hidden_states,
122 | encoder_attention_mask=encoder_attention_mask,
123 | output_attentions=output_attentions,
124 | output_hidden_states=True, # Always output hidden states
125 | )
126 |
127 | sequence_output = outputs[0]
128 |
129 | return self.linear(sequence_output)
130 |
131 | def _encode(self, texts: list[str], insert_token_id: int, is_query: bool = False):
132 | encoding = self.tokenizer(
133 | texts,
134 | return_tensors="pt",
135 | padding=True,
136 | max_length=self.max_length - 1, # for insert token
137 | truncation=True,
138 | )
139 | encoding = _insert_token(encoding, insert_token_id) # type: ignore
140 |
141 | if is_query:
142 | mask_token_id = self.tokenizer.mask_token_id
143 |
144 | new_encodings = {"input_ids": [], "attention_mask": []}
145 |
146 | for i, input_ids in enumerate(encoding["input_ids"]):
147 | original_length = (
148 | (input_ids != self.tokenizer.pad_token_id).sum().item()
149 | )
150 |
151 | # Calculate QLEN dynamically for each query
152 | if original_length % 32 <= 8:
153 | QLEN = original_length + 8
154 | else:
155 | QLEN = ceil(original_length / 32) * 32
156 |
157 | if original_length < QLEN:
158 | pad_length = QLEN - original_length
159 | padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length
160 | padded_attention_mask = (
161 | encoding["attention_mask"][i].tolist() + [0] * pad_length
162 | )
163 | else:
164 | padded_input_ids = input_ids[:QLEN].tolist()
165 | padded_attention_mask = encoding["attention_mask"][i][
166 | :QLEN
167 | ].tolist()
168 |
169 | new_encodings["input_ids"].append(padded_input_ids)
170 | new_encodings["attention_mask"].append(padded_attention_mask)
171 |
172 | for key in new_encodings:
173 | new_encodings[key] = torch.tensor(
174 | new_encodings[key], device=self.device
175 | )
176 |
177 | encoding = new_encodings
178 |
179 | encoding = {key: value.to(self.device) for key, value in encoding.items()}
180 | return encoding
181 |
182 | def _query_encode(self, query: list[str]):
183 | return self._encode(query, self.query_token_id, is_query=True)
184 |
185 | def _document_encode(self, documents: list[str]):
186 | return self._encode(documents, self.document_token_id)
187 |
188 | def _to_embs(self, encoding) -> torch.Tensor:
189 | with torch.inference_mode():
190 | # embs = self.model(**encoding).last_hidden_state.squeeze(1)
191 | embs = self.model(**encoding)
192 | if self.normalize:
193 | embs = embs / embs.norm(dim=-1, keepdim=True)
194 | return embs
195 |
196 | def _rerank(self, query: str, documents: list[str]) -> list[float]:
197 | query_encoding = self._query_encode([query])
198 | documents_encoding = self._document_encode(documents)
199 | query_embeddings = self._to_embs(query_encoding)
200 | document_embeddings = self._to_embs(documents_encoding)
201 | scores = (
202 | _colbert_score(
203 | query_embeddings,
204 | document_embeddings,
205 | query_encoding["attention_mask"],
206 | documents_encoding["attention_mask"],
207 | )
208 | .cpu()
209 | .tolist()[0]
210 | )
211 | return scores
212 |
213 |
214 | class ColBERTRanker(BaseRanker):
215 | def __init__(
216 | self,
217 | model_name: str,
218 | batch_size: int = 32,
219 | dtype: Optional[Union[str, torch.dtype]] = None,
220 | device: Optional[Union[str, torch.device]] = None,
221 | verbose: int = 1,
222 | query_token: str = "[unused0]",
223 | document_token: str = "[unused1]",
224 | **kwargs,
225 | ):
226 | self.verbose = verbose
227 | self.device = get_device(device, self.verbose)
228 | self.dtype = get_dtype(dtype, self.device, self.verbose)
229 | self.batch_size = batch_size
230 | vprint(
231 | f"Loading model {model_name}, this might take a while...",
232 | self.verbose,
233 | )
234 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
235 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
236 | model_kwargs = kwargs.get("model_kwargs", {})
237 | self.model = (
238 | ColBERTModel.from_pretrained(
239 | model_name,
240 | verbose=self.verbose,
241 | **model_kwargs
242 | )
243 | .to(self.device)
244 | .to(self.dtype)
245 | )
246 | self.model.eval()
247 | self.query_max_length = 32 # Lower bound
248 | self.doc_max_length = (
249 | self.model.config.max_position_embeddings - 2
250 | ) # Upper bound
251 | self.query_token_id: int = self.tokenizer.convert_tokens_to_ids(query_token) # type: ignore
252 | self.document_token_id: int = self.tokenizer.convert_tokens_to_ids(
253 | document_token
254 | ) # type: ignore
255 | self.normalize = True
256 |
257 | def rank(
258 | self,
259 | query: str,
260 | docs: Union[Document, str, List[Document], List[str]],
261 | doc_ids: Optional[Union[List[str], List[int]]] = None,
262 | metadata: Optional[List[dict]] = None,
263 | ) -> RankedResults:
264 | docs = prep_docs(docs, doc_ids, metadata)
265 |
266 | scores = self._colbert_rank(query, [d.text for d in docs])
267 | ranked_results = [
268 | Result(document=doc, score=score, rank=idx + 1)
269 | for idx, (doc, score) in enumerate(
270 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
271 | )
272 | ]
273 | return RankedResults(results=ranked_results, query=query, has_scores=True)
274 |
275 | def score(self, query: str, doc: str) -> float:
276 | scores = self._colbert_rank(query, [doc])
277 | return scores[0] if scores else 0.0
278 |
279 | @torch.inference_mode()
280 | def _colbert_rank(
281 | self,
282 | query: str,
283 | docs: List[str],
284 | ) -> List[float]:
285 | query_encoding = self._query_encode([query])
286 | documents_encoding = self._document_encode(docs)
287 | query_embeddings = self._to_embs(query_encoding)
288 | document_embeddings = self._to_embs(documents_encoding)
289 | scores = (
290 | _colbert_score(
291 | query_embeddings,
292 | document_embeddings,
293 | query_encoding["attention_mask"],
294 | documents_encoding["attention_mask"],
295 | )
296 | .cpu()
297 | .tolist()[0]
298 | )
299 | return scores
300 |
301 | def _query_encode(self, query: list[str]):
302 | return self._encode(
303 | query, self.query_token_id, max_length=self.doc_max_length, is_query=True
304 | )
305 |
306 | def _document_encode(self, documents: list[str]):
307 | tokenized_doc_lengths = [
308 | len(
309 | self.tokenizer.encode(
310 | doc, max_length=self.doc_max_length, truncation=True
311 | )
312 | )
313 | for doc in documents
314 | ]
315 | max_length = max(tokenized_doc_lengths)
316 | max_length = (
317 | ceil(max_length / 32) * 32
318 | ) # Round up to the nearest multiple of 32
319 | max_length = max(
320 | max_length, self.query_max_length
321 | ) # Ensure not smaller than query_max_length
322 | max_length = int(
323 | min(max_length, self.doc_max_length)
324 | ) # Ensure not larger than doc_max_length
325 | return self._encode(documents, self.document_token_id, max_length)
326 |
327 | def _encode(
328 | self,
329 | texts: list[str],
330 | insert_token_id: int,
331 | max_length: int,
332 | is_query: bool = False,
333 | ):
334 | encoding = self.tokenizer(
335 | texts,
336 | return_tensors="pt",
337 | padding=True,
338 | max_length=max_length - 1, # for insert token
339 | truncation=True,
340 | )
341 | encoding = _insert_token(encoding, insert_token_id) # type: ignore
342 |
343 | if is_query:
344 | mask_token_id = self.tokenizer.mask_token_id
345 |
346 | new_encodings = {"input_ids": [], "attention_mask": []}
347 |
348 | for i, input_ids in enumerate(encoding["input_ids"]):
349 | original_length = (
350 | (input_ids != self.tokenizer.pad_token_id).sum().item()
351 | )
352 |
353 | # Calculate QLEN dynamically for each query
354 | if original_length % 16 <= 8:
355 | QLEN = original_length + 8
356 | else:
357 | QLEN = ceil(original_length / 16) * 16
358 |
359 | if original_length < QLEN:
360 | pad_length = QLEN - original_length
361 | padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length
362 | padded_attention_mask = (
363 | encoding["attention_mask"][i].tolist() + [0] * pad_length
364 | )
365 | else:
366 | padded_input_ids = input_ids[:QLEN].tolist()
367 | padded_attention_mask = encoding["attention_mask"][i][
368 | :QLEN
369 | ].tolist()
370 |
371 | new_encodings["input_ids"].append(padded_input_ids)
372 | new_encodings["attention_mask"].append(padded_attention_mask)
373 |
374 | for key in new_encodings:
375 | new_encodings[key] = torch.tensor(
376 | new_encodings[key], device=self.device
377 | )
378 |
379 | encoding = new_encodings
380 |
381 | encoding = {key: value.to(self.device) for key, value in encoding.items()}
382 | return encoding
383 |
384 | def _to_embs(self, encoding) -> torch.Tensor:
385 | with torch.inference_mode():
386 | batched_embs = []
387 | for i in range(0, encoding["input_ids"].size(0), self.batch_size):
388 | batch_encoding = {
389 | key: val[i : i + self.batch_size] for key, val in encoding.items()
390 | }
391 | batch_embs = self.model(**batch_encoding)
392 | batched_embs.append(batch_embs)
393 | embs = torch.cat(batched_embs, dim=0)
394 | if self.normalize:
395 | embs = embs / embs.norm(dim=-1, keepdim=True)
396 | return embs
397 |
--------------------------------------------------------------------------------
/rerankers/models/flashrank_ranker.py:
--------------------------------------------------------------------------------
1 | from rerankers.models.ranker import BaseRanker
2 |
3 | from flashrank import Ranker, RerankRequest
4 |
5 |
6 | from typing import Union, List, Optional, Tuple
7 | from rerankers.utils import vprint, prep_docs
8 | from rerankers.results import RankedResults, Result
9 | from rerankers.documents import Document
10 |
11 |
12 | class FlashRankRanker(BaseRanker):
13 | def __init__(
14 | self,
15 | model_name_or_path: str,
16 | verbose: int = 1,
17 | cache_dir: str = "./.flashrank_cache",
18 | **kwargs
19 | ):
20 | self.verbose = verbose
21 | vprint(
22 | f"Loading model FlashRank model {model_name_or_path}...", verbose=verbose
23 | )
24 | self.model = Ranker(model_name=model_name_or_path, cache_dir=cache_dir)
25 | self.ranking_type = "pointwise"
26 |
27 | def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]):
28 | return self.tokenizer(
29 | inputs, return_tensors="pt", padding=True, truncation=True
30 | ).to(self.device)
31 |
32 | def rank(
33 | self,
34 | query: str,
35 | docs: Union[str, List[str], Document, List[Document]],
36 | doc_ids: Optional[Union[List[str], List[int]]] = None,
37 | metadata: Optional[List[dict]] = None,
38 | ) -> RankedResults:
39 | docs = prep_docs(docs, doc_ids, metadata)
40 | passages = [
41 | {"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)
42 | ]
43 |
44 | rerank_request = RerankRequest(query=query, passages=passages)
45 | flashrank_results = self.model.rerank(rerank_request)
46 |
47 | ranked_results = [
48 | Result(
49 | document=docs[result["id"]], # Returns reranked documents.
50 | score=result["score"],
51 | rank=idx + 1,
52 | )
53 | for idx, result in enumerate(flashrank_results)
54 | ]
55 |
56 | return RankedResults(results=ranked_results, query=query, has_scores=True)
57 |
58 | def score(self, query: str, doc: str) -> float:
59 | rerank_request = RerankRequest(
60 | query=query, passages=[{"id": "temp_id", "text": doc}]
61 | )
62 | flashrank_result = self.model.rerank(rerank_request)
63 | score = flashrank_result[0]["score"]
64 | return score
65 |
--------------------------------------------------------------------------------
/rerankers/models/llm_layerwise_ranker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 | from rerankers.models.ranker import BaseRanker
4 | from rerankers.documents import Document
5 | from typing import Union, List, Optional
6 | from rerankers.utils import vprint, get_device, get_dtype, prep_docs
7 | from rerankers.results import RankedResults, Result
8 |
9 |
10 | PROMPTS = {
11 | "BAAI/bge-reranker-v2.5-gemma2-lightweight": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.",
12 | "default": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.",
13 | }
14 |
15 | DEFAULT_PARAMS = {
16 | "default": {},
17 | "BAAI/bge-multilingual-gemma2": {},
18 | "BAAI/bge-reranker-v2-gemma": {},
19 | "BAAI/bge-reranker-v2-minicpm-layerwise": {"cutoff_layers": [28]},
20 | "BAAI/bge-reranker-v2.5-gemma2-lightweight": {
21 | "cutoff_layers": [28],
22 | "compress_ratio": 2,
23 | "compress_layer": [24, 40],
24 | },
25 | }
26 |
27 |
28 | class LLMLayerWiseRanker(BaseRanker):
29 | def __init__(
30 | self,
31 | model_name_or_path: str = "BAAI/bge-reranker-v2.5-gemma2-lightweight",
32 | max_sequence_length: int = 512,
33 | dtype: Optional[Union[str, torch.dtype]] = None,
34 | device: Optional[Union[str, torch.device]] = None,
35 | batch_size: int = 16,
36 | verbose: int = 1,
37 | prompt: Optional[str] = None,
38 | cutoff_layers: Optional[List[int]] = None,
39 | compress_ratio: Optional[int] = None,
40 | compress_layer: Optional[List[int]] = None,
41 | **kwargs,
42 | ):
43 | self.verbose = verbose
44 | self.device = get_device(device, verbose=self.verbose)
45 | self.dtype = get_dtype(dtype, self.device, self.verbose)
46 | self.batch_size = batch_size
47 |
48 | vprint(
49 | f"Loading model {model_name_or_path}, this might take a while...",
50 | self.verbose,
51 | )
52 | vprint(f"Using device {self.device}.", self.verbose)
53 | vprint(f"Using dtype {self.dtype}.", self.verbose)
54 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
55 | tokenizer_trust_remote_code = tokenizer_kwargs.pop("trust_remote_code", True)
56 | self.tokenizer = AutoTokenizer.from_pretrained(
57 | model_name_or_path,
58 | trust_remote_code=tokenizer_trust_remote_code,
59 | **tokenizer_kwargs,
60 | )
61 | self.max_sequence_length = max_sequence_length
62 | self.tokenizer.model_max_length = self.max_sequence_length
63 | self.tokenizer.padding_side = "right"
64 | model_kwargs = kwargs.get("model_kwargs", {})
65 | model_trust_remote_code = model_kwargs.pop("trust_remote_code", True)
66 |
67 | self.model = AutoModelForCausalLM.from_pretrained(
68 | model_name_or_path,
69 | trust_remote_code=model_trust_remote_code,
70 | torch_dtype=self.dtype,
71 | **model_kwargs,
72 | ).to(self.device)
73 | self.model.eval()
74 |
75 | # Create params dict based on specified values or defaults
76 | params = {}
77 | if cutoff_layers is not None:
78 | params["cutoff_layers"] = cutoff_layers
79 | if compress_ratio is not None:
80 | params["compress_ratio"] = compress_ratio
81 | if compress_layer is not None:
82 | params["compress_layer"] = compress_layer
83 | if not params:
84 | params = DEFAULT_PARAMS.get(model_name_or_path, DEFAULT_PARAMS["default"])
85 | self.params = params
86 |
87 | self.prompt = prompt
88 | if self.prompt is None:
89 | self.prompt = PROMPTS.get(model_name_or_path, PROMPTS["default"])
90 |
91 | def _get_inputs(self, pairs, max_sequence_length: int):
92 | prompt = self.prompt
93 | sep = "\n"
94 | prompt_inputs = self.tokenizer(
95 | prompt, return_tensors=None, add_special_tokens=False
96 | )["input_ids"]
97 | sep_inputs = self.tokenizer(sep, return_tensors=None, add_special_tokens=False)[
98 | "input_ids"
99 | ]
100 | inputs = []
101 | for query, passage in pairs:
102 | query_inputs = self.tokenizer(
103 | f"A: {query}",
104 | return_tensors=None,
105 | add_special_tokens=False,
106 | max_length=max_sequence_length * 3 // 4,
107 | truncation=True,
108 | )
109 | passage_inputs = self.tokenizer(
110 | f"B: {passage}",
111 | return_tensors=None,
112 | add_special_tokens=False,
113 | max_length=max_sequence_length,
114 | truncation=True,
115 | )
116 | item = self.tokenizer.prepare_for_model(
117 | [self.tokenizer.bos_token_id] + query_inputs["input_ids"],
118 | sep_inputs + passage_inputs["input_ids"],
119 | truncation="only_second",
120 | max_length=max_sequence_length,
121 | padding=False,
122 | return_attention_mask=False,
123 | return_token_type_ids=False,
124 | add_special_tokens=False,
125 | )
126 | item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs
127 | item["attention_mask"] = [1] * len(item["input_ids"])
128 | inputs.append(item)
129 |
130 | return self.tokenizer.pad(
131 | inputs,
132 | padding=True,
133 | max_length=max_sequence_length + len(sep_inputs) + len(prompt_inputs),
134 | pad_to_multiple_of=8,
135 | return_tensors="pt",
136 | )
137 |
138 | @torch.inference_mode()
139 | def rank(
140 | self,
141 | query: str,
142 | docs: Union[str, List[str], Document, List[Document]],
143 | doc_ids: Optional[Union[List[str], List[int]]] = None,
144 | metadata: Optional[List[dict]] = None,
145 | batch_size: Optional[int] = None,
146 | max_sequence_length: Optional[int] = None,
147 | ) -> RankedResults:
148 | docs = prep_docs(docs, doc_ids, metadata)
149 | pairs = [(query, doc.text) for doc in docs]
150 |
151 | # Override self.batch_size if explicitly set
152 | if batch_size is None:
153 | batch_size = self.batch_size
154 |
155 | # Same for max_sequence_length
156 | if max_sequence_length is None:
157 | max_sequence_length = self.max_sequence_length
158 |
159 | batched_pairs = [
160 | pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size)
161 | ]
162 | scores = []
163 |
164 | for batch in batched_pairs:
165 | inputs = self._get_inputs(batch, max_sequence_length=max_sequence_length)
166 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
167 |
168 | outputs = self.model(**inputs, return_dict=True, **self.params)
169 | all_scores = [
170 | scores[:, -1]
171 | .view(
172 | -1,
173 | )
174 | .float()
175 | for scores in outputs[0]
176 | ]
177 | batch_scores = all_scores[-1].cpu().numpy().tolist()
178 |
179 | scores.extend(batch_scores)
180 |
181 | ranked_results = [
182 | Result(document=doc, score=score, rank=idx + 1)
183 | for idx, (doc, score) in enumerate(
184 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
185 | )
186 | ]
187 | return RankedResults(results=ranked_results, query=query, has_scores=True)
188 |
189 | @torch.inference_mode()
190 | def score(self, query: str, doc: str) -> float:
191 | inputs = self._get_inputs(
192 | [(query, doc)], max_sequence_length=self.max_sequence_length
193 | )
194 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
195 |
196 | outputs = self.model(**inputs, return_dict=True, **self.params)
197 | all_scores = [
198 | scores[:, -1]
199 | .view(
200 | -1,
201 | )
202 | .float()
203 | for scores in outputs[0]
204 | ]
205 | score = all_scores[-1].item()
206 |
207 | return score
208 |
--------------------------------------------------------------------------------
/rerankers/models/llm_relevance_filter.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union, Dict
2 | import warnings
3 | import re
4 |
5 | try:
6 | from litellm import completion
7 | except ImportError:
8 | pass
9 |
10 | from rerankers.models.ranker import BaseRanker
11 | from rerankers.documents import Document
12 | from rerankers.results import RankedResults, Result
13 | from rerankers.utils import prep_docs, vprint
14 |
15 | SUPPORTED_BACKENDS = ["litellm"]
16 | UNSUPPORTED_BACKENDS = ["gaspard", "claudette", "cosette"]
17 |
18 | SYSTEM = (
19 | "You are a friendly AI assistant, working on document relevance filtering. Your task is "
20 | "to determine if a document is relevant to answering a given query. You must assign a binary "
21 | "RELEVANT or NOT_RELEVANT label to each document by carefully analysing them and the query."
22 | )
23 | DEFAULT_PROMPT_TEMPLATE = """
24 | Think carefully about whether the following documents would be useful to answer the query.
25 | For each document, explain your reasoning and then provide a binary decision (RELEVANT or NOT_RELEVANT). If a document is partially relevant, you will assign the RELEVANT label.
26 |
27 | The documents will be given to you in the following format:
28 |
29 |
30 |
31 | Text of the query.
32 |
33 |
34 |
35 |
36 | Text of the first document.
37 |
38 |
39 | Text of the second document.
40 |
41 |
42 |
43 | And you will respond in the following format:
44 |
45 |
46 |
47 | Your reasoning regarding the document's relevance.
48 |
49 |
50 | RELEVANT or NOT_RELEVANT
51 |
52 |
53 |
54 |
55 | Here is the query and documents:
56 |
57 |
58 |
59 | {query}
60 |
61 |
62 |
63 | {docu_inputs}
64 |
65 |
66 |
67 | Analyse the above documents and provide your responses using the provided format. You must assign either the RELEVANT or NOT_RELEVANT label, no other option is permitted."""
68 |
69 | class LLMRelevanceFilter(BaseRanker):
70 | def __init__(
71 | self,
72 | model_name_or_path: str,
73 | backend: str = "litellm",
74 | prompt_template: Optional[str] = None,
75 | temperature: float = 0.0,
76 | verbose: int = 1,
77 | default_label: str = "RELEVANT",
78 | **kwargs
79 | ):
80 | """Initialize the LLM Relevance Filter.
81 |
82 | Args:
83 | model_name_or_path: Name of the model to use (e.g. "gpt-4")
84 | backend: One of "litellm", "gaspard", "claudette", "cosette"
85 | prompt_template: Optional custom prompt template. Must include {query} and {docu_inputs} placeholders.
86 | temperature: Temperature for LLM sampling (default 0.0 for deterministic outputs)
87 | verbose: Verbosity level
88 | **kwargs: Additional kwargs passed to the backend
89 | """
90 | super().__init__(model_name_or_path, verbose)
91 | if backend not in SUPPORTED_BACKENDS:
92 | raise ValueError(f"Backend must be one of {SUPPORTED_BACKENDS}")
93 |
94 | if backend != "litellm":
95 | warnings.warn(f"Backend {backend} is experimental and may not work as expected")
96 |
97 | self.backend = backend
98 | self.model_name = model_name_or_path
99 | self.temperature = temperature
100 | self.prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE
101 | self.verbose = verbose
102 | self.additional_kwargs = kwargs
103 | self.default_label = default_label
104 |
105 | vprint(f"Initialized LLMRelevanceFilter with {backend} backend using model {model_name_or_path}", verbose)
106 |
107 | # Verify backend is available
108 | if backend == "litellm" and "completion" not in globals():
109 | raise ImportError("litellm is required for the litellm backend. Install with pip install litellm")
110 |
111 | def _get_completion(self, prompt: str) -> str:
112 | """Get completion from the selected backend."""
113 | if self.backend == "litellm":
114 | response = completion(
115 | model=self.model_name,
116 | messages=[{"role": "user", "content": prompt}],
117 | temperature=self.temperature,
118 | **self.additional_kwargs
119 | )
120 | return response.choices[0].message.content.strip()
121 | else:
122 | raise NotImplementedError(f"Backend {self.backend} not yet implemented")
123 |
124 | def _parse_response(self, response: str) -> str:
125 | """
126 | Parse an XML response to extract the answer from within the tags.
127 | If no answer is found, defaults to "NOT_RELEVANT".
128 | """
129 | match = re.search(r'\s*(RELEVANT|NOT_RELEVANT)\s*', response, re.IGNORECASE)
130 | if match:
131 | return match.group(1).upper()
132 | print(response)
133 | print("MALFORMATTED RESPONSE!")
134 | return self.default_label
135 |
136 | def _format_doc_inputs(self, docs: List[str]) -> str:
137 | """
138 | Format a list of document texts into an XML string with enumerated document IDs.
139 | Each document is wrapped in a ... block.
140 | """
141 | formatted_docs = []
142 | for i, text in enumerate(docs):
143 | formatted_docs.append(f"\n{text}\n")
144 | return "\n".join(formatted_docs)
145 |
146 | def score(self, query: str, doc: str) -> float:
147 | """Score a single document."""
148 | # Format the single document as an XML input.
149 | doc_xml = self._format_doc_inputs([doc])
150 | prompt = self.prompt_template.format(query=query, docu_inputs=doc_xml)
151 | response = self._get_completion(prompt)
152 | print(response)
153 | answer = self._parse_response(response)
154 | print(answer)
155 | return 1.0 if answer == "RELEVANT" else 0.0
156 |
157 | def rank(
158 | self,
159 | query: str,
160 | docs: Union[str, List[str], Document, List[Document]],
161 | doc_ids: Optional[Union[List[str], List[int]]] = None,
162 | metadata: Optional[List[dict]] = None,
163 | ) -> RankedResults:
164 | """Rank a list of documents based on relevance to query."""
165 | docs = prep_docs(docs, doc_ids, metadata)
166 | doc_texts = [doc.text for doc in docs]
167 | # Format all document texts into one XML block.
168 | docs_xml = self._format_doc_inputs(doc_texts)
169 | prompt = self.prompt_template.format(query=query, docu_inputs=docs_xml)
170 | response = self._get_completion(prompt)
171 | print(response)
172 |
173 | pattern = re.compile(r'(.*?)', re.DOTALL)
174 | matches = pattern.findall(response)
175 | doc_scores = {}
176 | for doc_id, content in matches:
177 | ans = self._parse_response(content)
178 | doc_scores[int(doc_id)] = 1.0 if ans == "RELEVANT" else 0.0
179 |
180 | # Preserve original order while sorting by score descending.
181 | scores_with_index = []
182 | for i, doc in enumerate(docs):
183 | score = doc_scores.get(i, 0.0)
184 | scores_with_index.append((score, i, doc))
185 |
186 | scores_with_index.sort(key=lambda x: (-x[0], x[1]))
187 |
188 | ranked_results = [
189 | Result(document=doc, score=score, rank=idx + 1)
190 | for idx, (score, _, doc) in enumerate(scores_with_index)
191 | ]
192 |
193 | return RankedResults(results=ranked_results, query=query, has_scores=True)
194 |
--------------------------------------------------------------------------------
/rerankers/models/monovlm_ranker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import base64
4 | import io
5 | # TODO: Support more than Qwen
6 | from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
7 | from rerankers.models.ranker import BaseRanker
8 | from rerankers.documents import Document
9 | from typing import Union, List, Optional
10 | from rerankers.utils import vprint, get_device, get_dtype, prep_image_docs
11 | from rerankers.results import RankedResults, Result
12 |
13 | PREDICTION_TOKENS = {
14 | "default": ["False", "True"],
15 | "lightonai/MonoQwen2-VL-v0.1": ["False", "True"]
16 | }
17 |
18 | def _get_output_tokens(model_name_or_path, token_false: str, token_true: str):
19 | if token_false == "auto":
20 | if model_name_or_path in PREDICTION_TOKENS:
21 | token_false = PREDICTION_TOKENS[model_name_or_path][0]
22 | else:
23 | token_false = PREDICTION_TOKENS["default"][0]
24 | print(
25 | f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_false to `{token_false}`."
26 | )
27 | if token_true == "auto":
28 | if model_name_or_path in PREDICTION_TOKENS:
29 | token_true = PREDICTION_TOKENS[model_name_or_path][1]
30 | else:
31 | token_true = PREDICTION_TOKENS["default"][1]
32 | print(
33 | f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_true to `{token_true}`."
34 | )
35 |
36 | return token_false, token_true
37 |
38 | class MonoVLMRanker(BaseRanker):
39 | def __init__(
40 | self,
41 | model_name_or_path: str,
42 | processor_name: Optional[str] = None,
43 | dtype: Optional[Union[str, torch.dtype]] = 'bf16',
44 | device: Optional[Union[str, torch.device]] = None,
45 | batch_size: int = 1,
46 | verbose: int = 1,
47 | token_false: str = "auto",
48 | token_true: str = "auto",
49 | return_logits: bool = False,
50 | prompt_template: str = "Assert the relevance of the previous image document to the following query, answer True or False. The query is: {query}",
51 | **kwargs
52 | ):
53 | self.verbose = verbose
54 | self.device = get_device(device, verbose=self.verbose)
55 | if self.device == 'mps':
56 | print("WARNING: MPS is not supported by MonoVLMRanker due to PyTorch limitations. Falling back to CPU.")
57 | self.device = 'cpu'
58 | print(dtype)
59 | self.dtype = get_dtype(dtype, self.device, self.verbose)
60 | self.batch_size = batch_size
61 | self.return_logits = return_logits
62 | self.prompt_template = prompt_template
63 |
64 | vprint(f"Loading model {model_name_or_path}, this might take a while...", self.verbose)
65 | vprint(f"Using device {self.device}.", self.verbose)
66 | vprint(f"Using dtype {self.dtype}.", self.verbose)
67 |
68 | processor_name = processor_name or "Qwen/Qwen2-VL-2B-Instruct"
69 | processor_kwargs = kwargs.get("processor_kwargs", {})
70 | model_kwargs = kwargs.get("model_kwargs", {})
71 | attention_implementation = kwargs.get("attention_implementation", "flash_attention_2")
72 | self.processor = AutoProcessor.from_pretrained(processor_name, **processor_kwargs)
73 | self.model = Qwen2VLForConditionalGeneration.from_pretrained(
74 | model_name_or_path,
75 | device_map=self.device,
76 | torch_dtype=self.dtype,
77 | attn_implementation=attention_implementation,
78 | **model_kwargs
79 | )
80 | self.model.eval()
81 |
82 | token_false, token_true = _get_output_tokens(
83 | model_name_or_path=model_name_or_path,
84 | token_false=token_false,
85 | token_true=token_true,
86 | )
87 | self.token_false_id = self.processor.tokenizer.convert_tokens_to_ids(token_false)
88 | self.token_true_id = self.processor.tokenizer.convert_tokens_to_ids(token_true)
89 |
90 | vprint(f"VLM true token set to {token_true}", self.verbose)
91 | vprint(f"VLM false token set to {token_false}", self.verbose)
92 |
93 | @torch.inference_mode()
94 | def _get_scores(self, query: str, docs: List[Document]) -> List[float]:
95 | scores = []
96 | for doc in docs:
97 | if doc.document_type != "image" or not doc.base64:
98 | raise ValueError("MonoVLMRanker requires image documents with base64 data")
99 |
100 | # Convert base64 to PIL Image
101 | image_io = io.BytesIO(base64.b64decode(doc.base64))
102 | image_io.seek(0) # Reset file pointer to start
103 | image = Image.open(image_io).convert('RGB')
104 |
105 | # Prepare prompt
106 | prompt = self.prompt_template.format(query=query)
107 | messages = [
108 | {
109 | "role": "user",
110 | "content": [
111 | {"type": "image", "image": image},
112 | {"type": "text", "text": prompt},
113 | ],
114 | }
115 | ]
116 |
117 | # Process inputs
118 | text = self.processor.apply_chat_template(
119 | messages,
120 | tokenize=False,
121 | add_generation_prompt=True
122 | )
123 | inputs = self.processor(
124 | text=text,
125 | images=image,
126 | return_tensors="pt"
127 | ).to(self.device).to(self.dtype)
128 |
129 | # Get model outputs
130 | outputs = self.model(**inputs)
131 | logits = outputs.logits[:, -1, :]
132 |
133 | # Calculate scores
134 | relevant_logits = logits[:, [self.token_false_id, self.token_true_id]]
135 | if self.return_logits:
136 | score = relevant_logits[0, 1].cpu().item() # True logit
137 | else:
138 | probs = torch.softmax(relevant_logits, dim=-1)
139 | score = probs[0, 1].cpu().item() # True probability
140 |
141 | scores.append(score)
142 |
143 | return scores
144 |
145 | def rank(
146 | self,
147 | query: str,
148 | docs: Union[str, List[str], Document, List[Document]],
149 | doc_ids: Optional[Union[List[str], List[int]]] = None,
150 | metadata: Optional[List[dict]] = None,
151 | ) -> RankedResults:
152 | docs = prep_image_docs(docs, doc_ids, metadata)
153 | scores = self._get_scores(query, docs)
154 | ranked_results = [
155 | Result(document=doc, score=score, rank=idx + 1)
156 | for idx, (doc, score) in enumerate(
157 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
158 | )
159 | ]
160 | return RankedResults(results=ranked_results, query=query, has_scores=True)
161 |
162 | def score(self, query: str, doc: Union[str, Document]) -> float:
163 | scores = self._get_scores(query, [doc])
164 | return scores[0]
165 |
--------------------------------------------------------------------------------
/rerankers/models/mxbai_v2.py:
--------------------------------------------------------------------------------
1 | """
2 | MXBai V2 Reranker implementation
3 |
4 | Parts of the code borrowed/adapted from the Apache 2.0 licensed original codebase: https://github.com/mixedbread-ai/mxbai-rerank
5 | """
6 |
7 | from __future__ import annotations
8 |
9 | import torch
10 | from typing import Union, List, Optional
11 |
12 | # Rerankers base imports
13 | from rerankers.models.ranker import BaseRanker
14 | from rerankers.documents import Document
15 | from rerankers.results import RankedResults, Result
16 | from rerankers.utils import vprint, prep_docs, get_device, get_dtype
17 |
18 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
19 |
20 |
21 | # Default prompt templates and tokens
22 | SEPS = {
23 | "mixedbread-ai/mxbai-rerank-large-v2": "\n",
24 | "mixedbread-ai/mxbai-rerank-base-v2": "\n",
25 | "default": "\n"
26 | }
27 |
28 | INSTRUCTION_PROMPT = {
29 | "mixedbread-ai/mxbai-rerank-large-v2": "instruction: {instruction}",
30 | "mixedbread-ai/mxbai-rerank-base-v2": "instruction: {instruction}",
31 | "default": "instruction: {instruction}"
32 | }
33 |
34 | QUERY_PROMPT = {
35 | "mixedbread-ai/mxbai-rerank-large-v2": "query: {query}",
36 | "mixedbread-ai/mxbai-rerank-base-v2": "query: {query}",
37 | "default": "query: {query}"
38 | }
39 |
40 | DOC_PROMPT = {
41 | "mixedbread-ai/mxbai-rerank-large-v2": "document: {document}",
42 | "mixedbread-ai/mxbai-rerank-base-v2": "document: {document}",
43 | "default": "document: {document}"
44 | }
45 |
46 | TASK_PROMPT = {
47 | "mixedbread-ai/mxbai-rerank-large-v2": """You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).
48 | Relevance:""",
49 | "mixedbread-ai/mxbai-rerank-base-v2": """You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).
50 | Relevance:""",
51 | "default": """You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).
52 | Relevance:"""
53 | }
54 |
55 | CHAT_TEMPLATE = {
56 | "mixedbread-ai/mxbai-rerank-large-v2": {
57 | "prefix": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n",
58 | "suffix": "<|im_end|>\n<|im_start|>assistant\n",
59 | },
60 | "mixedbread-ai/mxbai-rerank-base-v2": {
61 | "prefix": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n",
62 | "suffix": "<|im_end|>\n<|im_start|>assistant\n",
63 | },
64 | "default": {
65 | "prefix": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n",
66 | "suffix": "<|im_end|>\n<|im_start|>assistant\n",
67 | }
68 | }
69 |
70 | POS_TOKEN = {
71 | "mixedbread-ai/mxbai-rerank-large-v2": "1",
72 | "mixedbread-ai/mxbai-rerank-base-v2": "1",
73 | "default": "1"
74 | }
75 |
76 | NEG_TOKEN = {
77 | "mixedbread-ai/mxbai-rerank-large-v2": "0",
78 | "mixedbread-ai/mxbai-rerank-base-v2": "0",
79 | "default": "0"
80 | }
81 |
82 |
83 | def _ensure_multiple_of_8(x: int, max_value: Optional[int] = None) -> int:
84 | """Make x a multiple of 8, respecting optional max_value"""
85 | if max_value is not None:
86 | max_value = max_value - max_value % 8
87 | x = min(x, max_value)
88 | return x - x % 8
89 |
90 |
91 | class MxBaiV2Ranker(BaseRanker):
92 | """
93 | A reranker that uses MxBai models for yes/no-based relevance classification.
94 |
95 | This ranker uses causal language models from the MxBai family to determine
96 | document relevance by predicting binary relevance scores (0/1).
97 | """
98 |
99 | def __init__(
100 | self,
101 | model_name_or_path: str = "mixedbread-ai/mxbai-rerank-base-v2",
102 | device: Optional[Union[str, torch.device]] = None,
103 | dtype: Optional[Union[str, torch.dtype]] = None,
104 | batch_size: int = 16,
105 | verbose: int = 1,
106 | max_length: int = 8192,
107 | **kwargs
108 | ):
109 | """
110 | Initialize the MxBai reranker.
111 |
112 | Args:
113 | model_name_or_path: Path or name of the MxBai model.
114 | device: Device to use (e.g. 'cpu', 'cuda:0', or 'auto').
115 | dtype: Torch dtype or 'auto'.
116 | batch_size: Batch size for processing multiple documents.
117 | verbose: Verbosity level.
118 | max_length: Maximum token length for inputs.
119 | **kwargs: Additional kwargs for model and tokenizer.
120 | """
121 | super().__init__(model_name_or_path=model_name_or_path, verbose=verbose)
122 | self.verbose = verbose
123 | self.device = get_device(device, verbose=self.verbose)
124 | self.dtype = get_dtype(dtype, self.device, self.verbose)
125 | self.batch_size = batch_size
126 | self.max_length = max_length
127 | self.cfg = AutoConfig.from_pretrained(model_name_or_path)
128 |
129 | vprint(f"Loading MxBai model from {model_name_or_path}", self.verbose)
130 | vprint(f"Device: {self.device}, Dtype: {self.dtype}", self.verbose)
131 |
132 | # Extract model kwargs
133 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
134 | model_kwargs = kwargs.get("model_kwargs", {})
135 |
136 | # Try to use flash attention if available
137 | try:
138 | import flash_attn
139 | attn_impl = "flash_attention_2"
140 | self.dtype = 'bfloat16'
141 | vprint("Flash attention is available. Setting dtype to bfloat16.", self.verbose)
142 | except ImportError:
143 | attn_impl = None
144 |
145 | # Load model and tokenizer
146 | self.model = AutoModelForCausalLM.from_pretrained(
147 | model_name_or_path,
148 | attn_implementation=attn_impl,
149 | torch_dtype=self.dtype if isinstance(self.dtype, torch.dtype) else "auto",
150 | device_map=str(self.device),
151 | **model_kwargs,
152 | )
153 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs)
154 | self.tokenizer.padding_side = "left"
155 |
156 | # Get model-specific templates and tokens
157 | model_key = model_name_or_path.split("/")[-1]
158 | self._setup_templates(model_key)
159 |
160 | # Pre-tokenize static prompts for efficiency
161 | self._prepare_tokenized_templates()
162 |
163 | # Switch to eval mode
164 | self.model.eval()
165 | vprint("MxBaiV2Ranker ready.", self.verbose)
166 |
167 | def _setup_templates(self, model_key: str):
168 | """Set up the templates and tokens for the specific model"""
169 | # Helper function to get template with fallback to default
170 | def get_template(template_dict, key):
171 | return template_dict.get(key, template_dict["default"])
172 |
173 | # Set up all templates
174 | self.task_prompt = get_template(TASK_PROMPT, model_key)
175 | self.chat_template = get_template(CHAT_TEMPLATE, model_key)
176 | self.query_prompt = get_template(QUERY_PROMPT, model_key)
177 | self.doc_prompt = get_template(DOC_PROMPT, model_key)
178 | self.instruction_prompt = get_template(INSTRUCTION_PROMPT, model_key)
179 | self.pos_token = get_template(POS_TOKEN, model_key)
180 | self.neg_token = get_template(NEG_TOKEN, model_key)
181 | self.sep = get_template(SEPS, model_key)
182 |
183 | if not any(model_key in template_dict for template_dict in [TASK_PROMPT, CHAT_TEMPLATE, QUERY_PROMPT, DOC_PROMPT, INSTRUCTION_PROMPT, POS_TOKEN, NEG_TOKEN, SEPS]):
184 | vprint("Model name did not have all necessary instructions. Using default prompt formats which might not be suitable for this model!", self.verbose)
185 |
186 | def _prepare_tokenized_templates(self):
187 | """Pre-tokenize templates for efficiency"""
188 | # Get token IDs for positive and negative tokens
189 | self.pos_id = self.tokenizer(self.pos_token, return_tensors=None, add_special_tokens=False)["input_ids"][0]
190 | self.neg_id = self.tokenizer(self.neg_token, return_tensors=None, add_special_tokens=False)["input_ids"][0]
191 |
192 | # Pre-tokenize chat template parts
193 | self.prefix_ids = self.tokenizer(self.chat_template["prefix"], return_tensors=None, add_special_tokens=False)["input_ids"]
194 | self.suffix_ids = self.tokenizer(self.chat_template["suffix"], return_tensors=None, add_special_tokens=False)["input_ids"]
195 |
196 | # Pre-tokenize task prompt and separator
197 | self.task_prompt_ids = self.tokenizer(self.task_prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
198 | self.sep_ids = self.tokenizer(self.sep, return_tensors=None, add_special_tokens=False)["input_ids"]
199 |
200 | # Calculate total length of static tokens
201 | self.static_tokens_length = (
202 | len(self.prefix_ids) +
203 | len(self.task_prompt_ids) +
204 | len(self.suffix_ids) +
205 | len(self.sep_ids)
206 | )
207 |
208 | # Set model max length
209 | self.model_max_length = self.cfg.max_position_embeddings
210 |
211 | # Adjust max_length to account for static tokens
212 | if self.max_length + self.static_tokens_length > self.model_max_length:
213 | self.max_length = self.model_max_length - self.static_tokens_length
214 |
215 | # Ensure padding length is a multiple of 8 for efficiency
216 | self.padding_length = _ensure_multiple_of_8(
217 | max(self.model_max_length, self.max_length + self.static_tokens_length),
218 | max_value=self.model_max_length
219 | )
220 |
221 | def _create_full_input_ids(self, content_ids: List[int]) -> List[int]:
222 | """
223 | Create the full input by combining content with template parts.
224 |
225 | Args:
226 | content_ids: Token IDs for the query-document content
227 |
228 | Returns:
229 | List of token IDs for the complete input
230 | """
231 | return (
232 | self.prefix_ids +
233 | content_ids +
234 | self.sep_ids +
235 | self.task_prompt_ids +
236 | self.suffix_ids
237 | )
238 |
239 | def _prepare_batch(
240 | self,
241 | queries: List[str],
242 | documents: List[str],
243 | instruction: Optional[str] = None,
244 | ) -> dict:
245 | """
246 | Prepare a batch of query-document pairs for the model.
247 |
248 | Args:
249 | queries: List of query strings
250 | documents: List of document strings
251 | instruction: Optional instruction to prepend
252 |
253 | Returns:
254 | Dictionary with input_ids and attention_mask tensors
255 | """
256 | batch_inputs = []
257 |
258 | for query, document in zip(queries, documents):
259 | # Format query with template
260 | query_text = self.query_prompt.format(query=query)
261 |
262 | # Add instruction if provided
263 | if instruction:
264 | instruction_text = self.instruction_prompt.format(instruction=instruction)
265 | query_text = instruction_text + self.sep + query_text
266 |
267 | # Tokenize query with length limit
268 | query_ids = self.tokenizer(
269 | query_text,
270 | return_tensors=None,
271 | add_special_tokens=False,
272 | max_length=self.max_length * 3 // 4, # Use 3/4 of tokens for query
273 | truncation=True,
274 | )["input_ids"]
275 |
276 | # Calculate remaining tokens for document
277 | available_tokens = self.model_max_length - len(query_ids) - self.static_tokens_length
278 | doc_max_length = min(available_tokens, self.max_length // 4) # Use 1/4 of tokens for document
279 |
280 | # Tokenize document
281 | doc_text = self.doc_prompt.format(document=document)
282 | doc_ids = self.tokenizer(
283 | doc_text,
284 | return_tensors=None,
285 | add_special_tokens=False,
286 | max_length=doc_max_length,
287 | truncation=True,
288 | )["input_ids"]
289 |
290 | # Combine query and document
291 | combined = self.tokenizer.prepare_for_model(
292 | query_ids,
293 | self.sep_ids + doc_ids,
294 | truncation="only_second",
295 | max_length=self.max_length,
296 | padding=False,
297 | return_attention_mask=False,
298 | return_token_type_ids=False,
299 | add_special_tokens=False,
300 | )
301 |
302 | # Create full input with template
303 | full_input_ids = self._create_full_input_ids(combined["input_ids"])
304 |
305 | # Add to batch
306 | batch_inputs.append({
307 | "input_ids": full_input_ids,
308 | "attention_mask": [1] * len(full_input_ids),
309 | })
310 |
311 | # Pad all inputs to the same length
312 | padded_batch = self.tokenizer.pad(
313 | batch_inputs,
314 | padding="longest",
315 | max_length=self.padding_length,
316 | pad_to_multiple_of=8,
317 | return_tensors="pt",
318 | )
319 |
320 | return padded_batch
321 |
322 | @torch.inference_mode()
323 | def _predict(
324 | self,
325 | queries: List[str],
326 | documents: List[str],
327 | instruction: Optional[str] = None,
328 | ) -> torch.Tensor:
329 | """
330 | Get relevance scores for query-document pairs.
331 |
332 | Args:
333 | queries: List of query strings
334 | documents: List of document strings
335 | instruction: Optional instruction to prepend
336 |
337 | Returns:
338 | Tensor of relevance scores
339 | """
340 | # Prepare inputs
341 | inputs = self._prepare_batch(queries, documents, instruction=instruction)
342 |
343 | # Move to device
344 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
345 |
346 | # Run model
347 | outputs = self.model(
348 | input_ids=inputs["input_ids"],
349 | attention_mask=inputs["attention_mask"],
350 | output_hidden_states=True,
351 | )
352 |
353 | score = outputs.logits[:, -1, self.pos_id] - outputs.logits[:, -1, self.neg_id]
354 |
355 | # Return scores as CPU tensor
356 | return score.detach().cpu().float()
357 |
358 | @torch.inference_mode()
359 | def rank(
360 | self,
361 | query: str,
362 | docs: Union[str, List[str], Document, List[Document]],
363 | doc_ids: Optional[Union[List[str], List[int]]] = None,
364 | metadata: Optional[List[dict]] = None,
365 | batch_size: Optional[int] = None,
366 | instruction: Optional[str] = None,
367 | ) -> RankedResults:
368 | """
369 | Rank documents by relevance to the query.
370 |
371 | Args:
372 | query: Query string
373 | docs: Documents to rank
374 | doc_ids: Optional document IDs
375 | metadata: Optional document metadata
376 | batch_size: Optional batch size override
377 | instruction: Optional instruction to prepend
378 |
379 | Returns:
380 | RankedResults with documents sorted by relevance
381 | """
382 | # Prepare documents
383 | docs = prep_docs(docs, doc_ids, metadata)
384 |
385 | # Use default batch size if not specified
386 | if batch_size is None:
387 | batch_size = self.batch_size
388 |
389 | all_docs = []
390 | all_scores = []
391 |
392 | # Process in batches
393 | for i in range(0, len(docs), batch_size):
394 | batch_docs = docs[i:i + batch_size]
395 | batch_scores = self._predict(
396 | queries=[query] * len(batch_docs),
397 | documents=[d.text for d in batch_docs],
398 | instruction=instruction,
399 | )
400 |
401 | all_docs.extend(batch_docs)
402 | all_scores.extend(batch_scores.tolist())
403 |
404 | # Sort by descending score
405 | scored_docs = sorted(zip(all_docs, all_scores), key=lambda x: x[1], reverse=True)
406 |
407 | # Create ranked results
408 | results = [
409 | Result(document=doc, score=score, rank=idx + 1)
410 | for idx, (doc, score) in enumerate(scored_docs)
411 | ]
412 |
413 | return RankedResults(results=results, query=query, has_scores=True)
414 |
415 | @torch.inference_mode()
416 | def score(
417 | self,
418 | query: str,
419 | doc: Union[str, Document],
420 | instruction: Optional[str] = None,
421 | ) -> float:
422 | """
423 | Score a single query-document pair.
424 |
425 | Args:
426 | query: Query string
427 | doc: Document to score
428 | instruction: Optional instruction to prepend
429 |
430 | Returns:
431 | Relevance score as float
432 | """
433 | # Extract text if document is a Document object
434 | doc_text = doc.text if isinstance(doc, Document) else doc
435 |
436 | # Get score
437 | scores = self._predict(
438 | queries=[query],
439 | documents=[doc_text],
440 | instruction=instruction,
441 | )
442 |
443 | # Return as float
444 | return float(scores[0])
--------------------------------------------------------------------------------
/rerankers/models/pylate_ranker.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import torch
4 |
5 | from pylate import models, rank
6 | from rerankers.documents import Document
7 | from rerankers.models.ranker import BaseRanker
8 | from rerankers.results import RankedResults, Result
9 | from rerankers.utils import get_device, get_dtype, prep_docs, vprint
10 |
11 |
12 | class PyLateRanker(BaseRanker):
13 | def __init__(
14 | self,
15 | model_name: str,
16 | batch_size: int = 32,
17 | dtype: Optional[Union[str, torch.dtype]] = None,
18 | device: Optional[Union[str, torch.device]] = None,
19 | verbose: int = 1,
20 | **kwargs,
21 | ):
22 | self.verbose = verbose
23 | self.device = get_device(device, self.verbose)
24 | self.dtype = get_dtype(dtype, self.device, self.verbose)
25 | self.batch_size = batch_size
26 | vprint(
27 | f"Loading model {model_name}, this might take a while...",
28 | self.verbose,
29 | )
30 | kwargs = kwargs.get("kwargs", {})
31 | kwargs["device"] = self.device
32 | model_kwargs = kwargs.get("model_kwargs", {})
33 | model_kwargs["torch_dtype"] = self.dtype
34 | self.model = models.ColBERT(
35 | model_name_or_path=model_name,
36 | model_kwargs=model_kwargs,
37 | **kwargs,
38 | )
39 |
40 | def rank(
41 | self,
42 | query: str,
43 | docs: Union[Document, str, List[Document], List[str]],
44 | doc_ids: Optional[Union[List[str], List[int]]] = None,
45 | metadata: Optional[List[dict]] = None,
46 | ) -> RankedResults:
47 | docs = prep_docs(docs, doc_ids, metadata)
48 | documents_embeddings = self.model.encode(
49 | [[d.text for d in docs]],
50 | is_query=False,
51 | )
52 |
53 | query_embeddings = self.model.encode(
54 | [query],
55 | is_query=True,
56 | )
57 | scores = rank.rerank(
58 | documents_ids=[doc_ids],
59 | queries_embeddings=query_embeddings,
60 | documents_embeddings=documents_embeddings,
61 | )
62 |
63 | ranked_results = [
64 | Result(
65 | document=doc,
66 | score=score["score"] / len(query_embeddings[0]),
67 | rank=idx + 1,
68 | )
69 | for idx, (doc, score) in enumerate(zip(docs, scores[0]))
70 | ]
71 | return RankedResults(results=ranked_results, query=query, has_scores=True)
72 |
73 | def score(self, query: str, doc: str) -> float:
74 | document_embeddings = self.model.encode(
75 | doc,
76 | is_query=False,
77 | )
78 |
79 | query_embeddings = self.model.encode(
80 | query,
81 | is_query=True,
82 | )
83 | # This is shamefull, I really need to provide a scoring method with padding inside
84 | scores = rank.rerank(
85 | documents_ids=["0"],
86 | queries_embeddings=query_embeddings,
87 | documents_embeddings=document_embeddings,
88 | )
89 | return scores[0][0]["score"] if scores else 0.0
90 |
--------------------------------------------------------------------------------
/rerankers/models/ranker.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from asyncio import get_event_loop
3 | from functools import partial
4 | from typing import List, Optional, Union
5 | from rerankers.results import RankedResults
6 | from rerankers.documents import Document
7 |
8 |
9 | class BaseRanker(ABC):
10 | @abstractmethod
11 | def __init__(self, model_name_or_path: str, verbose: int):
12 | pass
13 |
14 | @abstractmethod
15 | def score(self, query: str, doc: str) -> float:
16 | pass
17 |
18 | @abstractmethod
19 | def rank(
20 | self,
21 | query: str,
22 | docs: Union[str, List[str], Document, List[Document]],
23 | doc_ids: Optional[Union[List[str], List[int]]] = None,
24 | ) -> RankedResults:
25 | """
26 | End-to-end reranking of documents.
27 | """
28 | pass
29 |
30 | async def rank_async(
31 | self,
32 | query: str,
33 | docs: List[str],
34 | doc_ids: Optional[Union[List[str], str]] = None,
35 | ) -> RankedResults:
36 |
37 |
38 | loop = get_event_loop()
39 | return await loop.run_in_executor(None, partial(self.rank, query, docs, doc_ids))
40 |
41 | def as_langchain_compressor(self, k: int = 10):
42 | try:
43 | from rerankers.integrations.langchain import RerankerLangChainCompressor
44 |
45 | return RerankerLangChainCompressor(model=self, k=k)
46 | except ImportError:
47 | print(
48 | "You need to install langchain and langchain_core to export a reranker as a LangChainCompressor!"
49 | )
50 | print(
51 | 'Please run `pip install "rerankers[langchain]"` to get all the required dependencies.'
52 | )
53 |
--------------------------------------------------------------------------------
/rerankers/models/rankgpt_rankers.py:
--------------------------------------------------------------------------------
1 | """Full implementation is from the original RankGPT repository https://github.com/sunnweiwei/RankGPT under its Apache 2.0 License
2 |
3 | Changes made are:
4 | - Truncating the file to only the relevant functions
5 | - Using only LiteLLM
6 | - make_item() added
7 | - Packaging it onto RankGPTRanker"""
8 |
9 | import copy
10 | from typing import Optional, Union, List, Dict
11 | from litellm import completion
12 | from rerankers.models.ranker import BaseRanker
13 | from rerankers.documents import Document
14 | from rerankers.results import RankedResults, Result
15 | from rerankers.utils import vprint, prep_docs
16 |
17 |
18 | def get_prefix_prompt(query, num):
19 | return [
20 | {
21 | "role": "system",
22 | "content": "You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.",
23 | },
24 | {
25 | "role": "user",
26 | "content": f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}.",
27 | },
28 | {"role": "assistant", "content": "Okay, please provide the passages."},
29 | ]
30 |
31 |
32 | def get_post_prompt(query, num):
33 | return f"Search Query: {query}. \nRank the {num} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only response the ranking results, do not say any word or explain."
34 |
35 |
36 | def create_permutation_instruction(
37 | item=None,
38 | rank_start=0,
39 | rank_end=100,
40 | lang: str = "en",
41 | ):
42 | query = item["query"]
43 | num = len(item["hits"][rank_start:rank_end])
44 |
45 | max_length = 300
46 |
47 | messages = get_prefix_prompt(query, num)
48 | rank = 0
49 | for hit in item["hits"][rank_start:rank_end]:
50 | rank += 1
51 | content = hit["content"]
52 | content = content.replace("Title: Content: ", "")
53 | content = content.strip()
54 | # For Japanese should cut by character: content = content[:int(max_length)]
55 | if lang in ["zh", "ja"]:
56 | content = content[: int(max_length)]
57 | else:
58 | content = " ".join(content.split()[: int(max_length)])
59 | messages.append({"role": "user", "content": f"[{rank}] {content}"})
60 | messages.append({"role": "assistant", "content": f"Received passage [{rank}]."})
61 | messages.append({"role": "user", "content": get_post_prompt(query, num)})
62 |
63 | return messages
64 |
65 |
66 | def clean_response(response: str):
67 | new_response = ""
68 | for c in response:
69 | if not c.isdigit():
70 | new_response += " "
71 | else:
72 | new_response += c
73 | new_response = new_response.strip()
74 | return new_response
75 |
76 |
77 | def remove_duplicate(response):
78 | new_response = []
79 | for c in response:
80 | if c not in new_response:
81 | new_response.append(c)
82 | return new_response
83 |
84 |
85 | def receive_permutation(item, permutation, rank_start=0, rank_end=100):
86 | response = clean_response(permutation)
87 | response = [int(x) - 1 for x in response.split()]
88 | response = remove_duplicate(response)
89 | cut_range = copy.deepcopy(item["hits"][rank_start:rank_end])
90 | original_rank = [tt for tt in range(len(cut_range))]
91 | response = [ss for ss in response if ss in original_rank]
92 | response = response + [tt for tt in original_rank if tt not in response]
93 | for j, x in enumerate(response):
94 | item["hits"][j + rank_start] = copy.deepcopy(cut_range[x])
95 | if "rank" in item["hits"][j + rank_start]:
96 | item["hits"][j + rank_start]["rank"] = cut_range[j]["rank"]
97 | if "score" in item["hits"][j + rank_start]:
98 | item["hits"][j + rank_start]["score"] = cut_range[j]["score"]
99 | return item
100 |
101 |
102 | def make_item(
103 | query: str, docs: List[str]
104 | ) -> Dict[str, Union[List[Dict[str, str]], str]]:
105 | return {
106 | "query": query,
107 | "hits": [{"content": doc} for doc in docs],
108 | }
109 |
110 |
111 | class RankGPTRanker(BaseRanker):
112 | def __init__(
113 | self, model: str, api_key: str, lang: str = "en", verbose: int = 1
114 | ) -> "RankGPTRanker":
115 | self.api_key = api_key
116 | self.model = model
117 | self.verbose = verbose
118 | self.lang = lang
119 |
120 | def _query_llm(self, messages: List[Dict[str, str]]) -> str:
121 | response = completion(
122 | api_key=self.api_key, model=self.model, messages=messages, temperature=0
123 | )
124 | return response.choices[0].message.content
125 |
126 | def rank(
127 | self,
128 | query: str,
129 | docs: Union[str, List[str], Document, List[Document]],
130 | doc_ids: Optional[Union[List[str], List[int]]] = None,
131 | metadata: Optional[List[dict]] = None,
132 | rank_start: int = 0,
133 | rank_end: int = 0,
134 | ) -> RankedResults:
135 | docs = prep_docs(docs, doc_ids, metadata)
136 |
137 | item = make_item(query, [d.text for d in docs])
138 | messages = create_permutation_instruction(
139 | item=item,
140 | rank_start=rank_start,
141 | rank_end=rank_end,
142 | lang=self.lang,
143 | )
144 | vprint(f"Querying model {self.model} with via LiteLLM...", self.verbose)
145 | permutation = self._query_llm(messages)
146 | item = receive_permutation(
147 | item, permutation, rank_start=rank_start, rank_end=rank_end
148 | )
149 | ranked_docs = []
150 | for idx, doc in enumerate(item["hits"]):
151 | ranked_docs.append(
152 | Result(
153 | document=list(filter(lambda x: x.text == doc["content"], docs))[0],
154 | rank=idx + 1,
155 | )
156 | )
157 | ranked_results = RankedResults(
158 | results=ranked_docs, query=query, has_scores=False
159 | )
160 | return ranked_results
161 |
162 | def score(self):
163 | print("Listwise ranking models like RankGPT-4 cannot output scores!")
164 | return None
165 |
--------------------------------------------------------------------------------
/rerankers/models/rankllm_ranker.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, List
2 | from rerankers.models.ranker import BaseRanker
3 | from rerankers.documents import Document
4 | from rerankers.results import RankedResults, Result
5 | from rerankers.utils import prep_docs
6 |
7 | # from rerankers import Reranker
8 |
9 | from rank_llm.rerank.reranker import Reranker as rankllm_Reranker
10 | from rank_llm.rerank import PromptMode, get_azure_openai_args, get_genai_api_key, get_openai_api_key
11 | from rank_llm.data import Candidate, Query, Request
12 |
13 |
14 | class RankLLMRanker(BaseRanker):
15 | def __init__(
16 | self,
17 | model: str = "rank_zephyr",
18 | api_key: Optional[str] = None,
19 | lang: str = "en",
20 | verbose: int = 1,
21 | # RankLLM specific arguments
22 | window_size: int = 20,
23 | context_size: int = 4096,
24 | prompt_mode: PromptMode = PromptMode.RANK_GPT,
25 | num_few_shot_examples: int = 0,
26 | few_shot_file: Optional[str] = None,
27 | num_gpus: int = 1,
28 | variable_passages: bool = False,
29 | use_logits: bool = False,
30 | use_alpha: bool = False,
31 | stride: int = 10,
32 | use_azure_openai: bool = False,
33 | ) -> "RankLLMRanker":
34 | self.api_key = api_key
35 | self.model = model
36 | self.verbose = verbose
37 | self.lang = lang
38 |
39 | # RankLLM-specific parameters
40 | self.window_size = window_size
41 | self.context_size = context_size
42 | self.prompt_mode = prompt_mode
43 | self.num_few_shot_examples = num_few_shot_examples
44 | self.few_shot_file = few_shot_file
45 | self.num_gpus = num_gpus
46 | self.variable_passages = variable_passages
47 | self.use_logits = use_logits
48 | self.use_alpha = use_alpha
49 | self.stride = stride
50 | self.use_azure_openai = use_azure_openai
51 |
52 | kwargs = {
53 | "model_path": self.model,
54 | "default_model_coordinator": None,
55 | "context_size": self.context_size,
56 | "prompt_mode": self.prompt_mode,
57 | "num_gpus": self.num_gpus,
58 | "use_logits": self.use_logits,
59 | "use_alpha": self.use_alpha,
60 | "num_few_shot_examples": self.num_few_shot_examples,
61 | "few_shot_file": self.few_shot_file,
62 | "variable_passages": self.variable_passages,
63 | "interactive": False,
64 | "window_size": self.window_size,
65 | "stride": self.stride,
66 | "use_azure_openai": self.use_azure_openai,
67 | }
68 | model_coordinator = rankllm_Reranker.create_model_coordinator(**kwargs)
69 | self.reranker = rankllm_Reranker(model_coordinator)
70 |
71 | def rank(
72 | self,
73 | query: str,
74 | docs: Union[str, List[str], Document, List[Document]],
75 | doc_ids: Optional[Union[List[str], List[int]]] = None,
76 | metadata: Optional[List[dict]] = None,
77 | rank_start: int = 0,
78 | rank_end: int = 0,
79 | ) -> RankedResults:
80 | docs = prep_docs(docs, doc_ids, metadata)
81 |
82 | request = Request(
83 | query=Query(text=query, qid=1),
84 | candidates=[
85 | Candidate(doc={"text": doc.text}, docid=doc_idx, score=1)
86 | for doc_idx, doc in enumerate(docs)
87 | ],
88 | )
89 |
90 | rankllm_results = self.reranker.rerank(
91 | request,
92 | rank_end=len(docs) if rank_end == 0 else rank_end,
93 | window_size=min(20, len(docs)),
94 | step=10,
95 | )
96 |
97 | ranked_docs = []
98 |
99 | for rank, result in enumerate(rankllm_results.candidates, start=rank_start):
100 | ranked_docs.append(
101 | Result(
102 | document=docs[result.docid],
103 | rank=rank,
104 | )
105 | )
106 |
107 | return RankedResults(results=ranked_docs, query=query, has_scores=False)
108 |
109 | def score(self):
110 | print("Listwise ranking models like RankLLM cannot output scores!")
111 | return None
112 |
--------------------------------------------------------------------------------
/rerankers/models/t5ranker.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for InRanker is taken from the excellent InRanker repo https://github.com/unicamp-dl/InRanker under its Apache 2.0 license.
3 | The only change to the original implementation is the removal of InRanker's BaseRanker, replacing it with our own to support the unified API better.
4 | The main purpose for adapting this code here rather than installing the InRanker library is to ensure greater version compatibility (InRanker requires Python >=3.10)
5 | """
6 |
7 | from typing import List, Optional, Union
8 | from math import ceil
9 |
10 | import torch
11 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12 | from rerankers.models.ranker import BaseRanker
13 | from rerankers.documents import Document
14 |
15 |
16 | from rerankers.results import RankedResults, Result
17 | from rerankers.utils import (
18 | vprint,
19 | get_device,
20 | get_dtype,
21 | prep_docs,
22 | get_chunks,
23 | )
24 | try:
25 | from tqdm.auto import tqdm
26 | except ImportError:
27 | def tqdm(iterable, *args, **kwargs):
28 | return iterable
29 |
30 | PREDICTION_TOKENS = {
31 | "default": ["▁false", "▁true"],
32 | "castorini/monot5-base-msmarco": ["▁false", "▁true"],
33 | "castorini/monot5-base-msmarco-10k": ["▁false", "▁true"],
34 | "castorini/monot5-large-msmarco": ["▁false", "▁true"],
35 | "castorini/monot5-large-msmarco-10k": ["▁false", "▁true"],
36 | "castorini/monot5-base-med-msmarco": ["▁false", "▁true"],
37 | "castorini/monot5-3b-med-msmarco": ["▁false", "▁true"],
38 | "castorini/monot5-3b-msmarco-10k": ["▁false", "▁true"],
39 | "unicamp-dl/InRanker-small": ["▁false", "▁true"],
40 | "unicamp-dl/InRanker-base": ["▁false", "▁true"],
41 | "unicamp-dl/InRanker-3B": ["▁false", "▁true"],
42 | "unicamp-dl/mt5-base-en-msmarco": ["▁no", "▁yes"],
43 | "unicamp-dl/ptt5-base-pt-msmarco-10k-v2": ["▁não", "▁sim"],
44 | "unicamp-dl/ptt5-base-pt-msmarco-100k-v2": ["▁não", "▁sim"],
45 | "unicamp-dl/ptt5-base-en-pt-msmarco-100k-v2": ["▁não", "▁sim"],
46 | "unicamp-dl/mt5-base-en-pt-msmarco-v2": ["▁no", "▁yes"],
47 | "unicamp-dl/mt5-base-mmarco-v2": ["▁no", "▁yes"],
48 | "unicamp-dl/mt5-base-en-pt-msmarco-v1": ["▁no", "▁yes"],
49 | "unicamp-dl/mt5-base-mmarco-v1": ["▁no", "▁yes"],
50 | "unicamp-dl/ptt5-base-pt-msmarco-10k-v1": ["▁não", "▁sim"],
51 | "unicamp-dl/ptt5-base-pt-msmarco-100k-v1": ["▁não", "▁sim"],
52 | "unicamp-dl/ptt5-base-en-pt-msmarco-10k-v1": ["▁não", "▁sim"],
53 | "unicamp-dl/mt5-3B-mmarco-en-pt": ["▁", "▁true"],
54 | "unicamp-dl/mt5-13b-mmarco-100k": ["▁", "▁true"],
55 | "unicamp-dl/monoptt5-small": ["▁Não", "▁Sim"],
56 | "unicamp-dl/monoptt5-base": ["▁Não", "▁Sim"],
57 | "unicamp-dl/monoptt5-large": ["▁Não", "▁Sim"],
58 | "unicamp-dl/monoptt5-3b": ["▁Não", "▁Sim"],
59 | "Dundalia/TWOLAR-large": [6136, 1176],
60 | "Dundalia/TWOLAR-xl": [6136, 1176],
61 | }
62 |
63 |
64 | def _get_output_tokens(model_name_or_path, token_false: str, token_true: str):
65 | if token_false == "auto":
66 | if model_name_or_path in PREDICTION_TOKENS:
67 | token_false = PREDICTION_TOKENS[model_name_or_path][0]
68 | else:
69 | token_false = PREDICTION_TOKENS["default"][0]
70 | print(
71 | f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_false to `{token_false}`."
72 | )
73 | if token_true == "auto":
74 | if model_name_or_path in PREDICTION_TOKENS:
75 | token_true = PREDICTION_TOKENS[model_name_or_path][1]
76 | else:
77 | token_true = PREDICTION_TOKENS["default"][1]
78 | print(
79 | f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_true to `{token_true}`."
80 | )
81 |
82 | return token_false, token_true
83 |
84 |
85 | class T5Ranker(BaseRanker):
86 | def __init__(
87 | self,
88 | model_name_or_path: str,
89 | batch_size: int = 32,
90 | dtype: Optional[Union[str, torch.dtype]] = None,
91 | device: Optional[Union[str, torch.device]] = None,
92 | verbose: int = 1,
93 | token_false: str = "auto",
94 | token_true: str = "auto",
95 | return_logits: bool = False,
96 | inputs_template: str = "Query: {query} Document: {text} Relevant:",
97 | **kwargs,
98 | ):
99 | """
100 | Implementation of the key functions from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py
101 | Changes are detailed in the docstring for each relevant function.
102 |
103 | T5Ranker is a wrapper for using Seq2Seq models for ranking.
104 | Args:
105 | batch_size: The batch size to use when encoding.
106 | dtype: Data type for model weights.
107 | device: The device to use for inference ("cpu", "cuda", or "mps").
108 | verbose: Verbosity level.
109 | silent: Whether to show progress bars.
110 | """
111 | self.verbose = verbose
112 | self.device = get_device(device, self.verbose, no_mps=True)
113 | self.dtype = get_dtype(dtype, self.device, self.verbose)
114 | self.batch_size = batch_size
115 | vprint(
116 | f"Loading model {model_name_or_path}, this might take a while...",
117 | self.verbose,
118 | )
119 | vprint(f"Using device {self.device}.", self.verbose)
120 | vprint(f"Using dtype {self.dtype}.", self.verbose)
121 | model_kwargs = kwargs.get("model_kwargs", {})
122 | self.model = AutoModelForSeq2SeqLM.from_pretrained(
123 | model_name_or_path,
124 | torch_dtype=self.dtype,
125 | **model_kwargs,
126 | ).to(self.device)
127 | self.model.eval()
128 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
129 | self.tokenizer = AutoTokenizer.from_pretrained(
130 | model_name_or_path,
131 | **tokenizer_kwargs,
132 | )
133 |
134 | token_false, token_true = _get_output_tokens(
135 | model_name_or_path=model_name_or_path,
136 | token_false=token_false,
137 | token_true=token_true,
138 | )
139 | if isinstance(token_false, int):
140 | self.token_false_id = token_false
141 | else:
142 | self.token_false_id = self.tokenizer.convert_tokens_to_ids(token_false)
143 | if isinstance(token_true, int):
144 | self.token_true_id = token_true
145 | else:
146 | self.token_true_id = self.tokenizer.convert_tokens_to_ids(token_true)
147 | vprint(f"T5 true token set to {token_true}", self.verbose)
148 | vprint(f"T5 false token set to {token_false}", self.verbose)
149 |
150 | self.return_logits = return_logits
151 | if self.return_logits:
152 | vprint(
153 | f"Returning raw logits for `{token_true}` as scores...", self.verbose
154 | )
155 | else:
156 | vprint("Returning normalised scores...", self.verbose)
157 | self.inputs_template = inputs_template
158 | vprint(f"Inputs template set to {inputs_template}", self.verbose)
159 |
160 | def rank(
161 | self,
162 | query: str,
163 | docs: Union[str, List[str], Document, List[Document]],
164 | doc_ids: Optional[Union[List[str], List[int]]] = None,
165 | metadata: Optional[List[dict]] = None,
166 | ) -> RankedResults:
167 | """
168 | Ranks a list of documents based on their relevance to the query.
169 | """
170 | docs = prep_docs(docs, doc_ids, metadata)
171 | scores = self._get_scores(query, [d.text for d in docs])
172 | ranked_results = [
173 | Result(document=doc, score=score, rank=idx + 1)
174 | for idx, (doc, score) in enumerate(
175 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
176 | )
177 | ]
178 | return RankedResults(results=ranked_results, query=query, has_scores=True)
179 |
180 | def score(self, query: str, doc: str) -> float:
181 | """
182 | Scores a single document's relevance to a query.
183 | """
184 | scores = self._get_scores(query, [doc])
185 | return scores[0] if scores else 0.0
186 |
187 | @torch.inference_mode()
188 | def _get_scores(
189 | self,
190 | query: str,
191 | docs: List[str],
192 | max_length: int = 512,
193 | batch_size: Optional[int] = None,
194 | ) -> List[float]:
195 | """
196 | Implementation from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py.
197 | Lightly modified so only the positive logits are returned and renamed the chunking function.
198 |
199 | Given a query and a list of documents, return a list of scores.
200 | Args:
201 | query: The query string.
202 | docs: A list of document strings.
203 | max_length: The maximum length of the input sequence.
204 | """
205 | if self.return_logits:
206 | logits = []
207 | else:
208 | scores = []
209 | if batch_size is None:
210 | batch_size = self.batch_size
211 | for batch in tqdm(
212 | get_chunks(docs, batch_size),
213 | disable=not self.verbose,
214 | desc="Scoring...",
215 | total=ceil(len(docs) / batch_size),
216 | ):
217 | queries_documents = [
218 | self.inputs_template.format(query=query, text=text)
219 | for text in batch
220 | ]
221 | tokenized = self.tokenizer(
222 | queries_documents,
223 | padding=True,
224 | truncation="longest_first",
225 | return_tensors="pt",
226 | max_length=max_length,
227 | ).to(self.device)
228 | input_ids = tokenized["input_ids"]
229 | attention_mask = tokenized["attention_mask"]
230 |
231 | _, batch_scores = self._greedy_decode(
232 | model=self.model,
233 | input_ids=input_ids,
234 | length=1,
235 | attention_mask=attention_mask,
236 | return_last_logits=True,
237 | )
238 | batch_scores = batch_scores[
239 | :, [self.token_false_id, self.token_true_id]
240 | ].cpu()
241 | if self.return_logits:
242 | logits.extend(batch_scores[:, 1].tolist())
243 | else:
244 | batch_scores = torch.log_softmax(batch_scores, dim=-1)
245 | batch_scores = torch.exp(batch_scores[:, 1])
246 | scores.extend(batch_scores.tolist())
247 |
248 | if self.return_logits:
249 | return logits
250 | return scores
251 |
252 | @torch.inference_mode()
253 | def _greedy_decode(
254 | self,
255 | model,
256 | input_ids: torch.Tensor,
257 | length: int,
258 | attention_mask: torch.Tensor = None,
259 | return_last_logits: bool = True,
260 | ):
261 | """Implementation from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py"""
262 | decode_ids = torch.full(
263 | (input_ids.size(0), 1),
264 | model.config.decoder_start_token_id,
265 | dtype=torch.long,
266 | ).to(input_ids.device)
267 | encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask)
268 | next_token_logits = None
269 | for _ in range(length):
270 | try:
271 | model_inputs = model.prepare_inputs_for_generation(
272 | decode_ids,
273 | encoder_outputs=encoder_outputs,
274 | past=None,
275 | attention_mask=attention_mask,
276 | use_cache=True,
277 | )
278 | outputs = model(**model_inputs)
279 | except TypeError:
280 | # Newer transformers versions have deprecated `past`
281 | # Our aim is to maintain pipeline compatibility for as many people as possible
282 | # So currently, we maintain a forking path with this error. Might need to do it more elegantly later on (TODO).
283 | model_inputs = model.prepare_inputs_for_generation(
284 | decode_ids,
285 | encoder_outputs=encoder_outputs,
286 | attention_mask=attention_mask,
287 | use_cache=True,
288 | )
289 | outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
290 | next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size)
291 | decode_ids = torch.cat(
292 | [decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1
293 | )
294 | if return_last_logits:
295 | return decode_ids, next_token_logits
296 | return decode_ids
297 |
--------------------------------------------------------------------------------
/rerankers/models/transformer_ranker.py:
--------------------------------------------------------------------------------
1 | from rerankers.models.ranker import BaseRanker
2 | from rerankers.documents import Document
3 |
4 |
5 | import torch
6 | from typing import Union, List, Optional, Tuple
7 | from transformers import (
8 | AutoModelForSequenceClassification,
9 | AutoTokenizer,
10 | )
11 | from rerankers.utils import (
12 | vprint,
13 | get_device,
14 | get_dtype,
15 | prep_docs,
16 | )
17 | from rerankers.results import RankedResults, Result
18 |
19 |
20 | class TransformerRanker(BaseRanker):
21 | def __init__(
22 | self,
23 | model_name_or_path: str,
24 | dtype: Optional[Union[str, torch.dtype]] = None,
25 | device: Optional[Union[str, torch.device]] = None,
26 | batch_size: int = 16,
27 | verbose: int = 1,
28 | **kwargs,
29 | ):
30 | self.verbose = verbose
31 | self.device = get_device(device, verbose=self.verbose)
32 | self.dtype = get_dtype(dtype, self.device, self.verbose)
33 | self.is_monobert = "monobert" in model_name_or_path.lower()
34 | model_kwargs = kwargs.get("model_kwargs", {})
35 | self.model = AutoModelForSequenceClassification.from_pretrained(
36 | model_name_or_path,
37 | torch_dtype=self.dtype,
38 | **model_kwargs,
39 | ).to(self.device)
40 | vprint(f"Loaded model {model_name_or_path}", self.verbose)
41 | vprint(f"Using device {self.device}.", self.verbose)
42 | vprint(f"Using dtype {self.dtype}.", self.verbose)
43 | self.model.eval()
44 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
45 | self.tokenizer = AutoTokenizer.from_pretrained(
46 | model_name_or_path,
47 | **tokenizer_kwargs,
48 | )
49 | self.ranking_type = "pointwise"
50 | self.batch_size = batch_size
51 |
52 | def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]):
53 | return self.tokenizer(
54 | inputs, return_tensors="pt", padding=True, truncation=True
55 | ).to(self.device)
56 |
57 | @torch.inference_mode()
58 | def rank(
59 | self,
60 | query: str,
61 | docs: Union[str, List[str], Document, List[Document]],
62 | doc_ids: Optional[Union[List[str], List[int]]] = None,
63 | metadata: Optional[List[dict]] = None,
64 | batch_size: Optional[int] = None,
65 | ) -> RankedResults:
66 | docs = prep_docs(docs, doc_ids, metadata)
67 | inputs = [(query, doc.text) for doc in docs]
68 |
69 | # Override self.batch_size if explicitely set
70 | if batch_size is None:
71 | batch_size = self.batch_size
72 | batched_inputs = [
73 | inputs[i : i + batch_size] for i in range(0, len(inputs), batch_size)
74 | ]
75 | scores = []
76 | for batch in batched_inputs:
77 | tokenized_inputs = self.tokenize(batch)
78 | batch_scores = self.model(**tokenized_inputs).logits.squeeze()
79 | if self.dtype != torch.float32:
80 | batch_scores = batch_scores.float()
81 | batch_scores = batch_scores.detach().cpu().numpy().tolist()
82 | if isinstance(batch_scores, float): # Handling the case of single score
83 | scores.append(batch_scores)
84 | else:
85 | scores.extend(batch_scores)
86 | if self.is_monobert: scores = [x[1] - x[0] for x in scores]
87 | if len(scores) == 1:
88 | return RankedResults(results=[Result(document=docs[0], score=scores[0])], query=query, has_scores=True)
89 | else:
90 | ranked_results = [
91 | Result(document=doc, score=score, rank=idx + 1)
92 | for idx, (doc, score) in enumerate(
93 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
94 | )
95 | ]
96 | return RankedResults(results=ranked_results, query=query, has_scores=True)
97 |
98 | @torch.inference_mode()
99 | def score(self, query: str, doc: str) -> float:
100 | inputs = self.tokenize((query, doc))
101 | outputs = self.model(**inputs)
102 | score = outputs.logits.squeeze().detach().cpu().numpy().astype(float)
103 | return score
104 |
--------------------------------------------------------------------------------
/rerankers/models/upr.py:
--------------------------------------------------------------------------------
1 | # models/upr.py
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from transformers import T5ForConditionalGeneration, T5Tokenizer
6 | from math import ceil
7 | from typing import List, Optional, Union
8 |
9 | from rerankers.models.ranker import BaseRanker
10 | from rerankers.documents import Document
11 | from rerankers.results import RankedResults, Result
12 | from rerankers.utils import (
13 | vprint,
14 | get_device,
15 | get_dtype,
16 | prep_docs,
17 | get_chunks,
18 | )
19 |
20 |
21 | class UPRRanker(BaseRanker):
22 | """
23 | UPR (Unsupervised Passage Reranker) replicates the negative log-likelihood
24 | approach from the authors' code. The doc is passed as the encoder input,
25 | and the query is the decoder label.
26 | """
27 |
28 | def __init__(
29 | self,
30 | model_name_or_path: str,
31 | verbose: int = 1,
32 | device: Optional[Union[str, torch.device]] = None,
33 | dtype: Optional[Union[str, torch.dtype]] = None,
34 | batch_size: int = 16,
35 | verbalizer_head: str = "Passage:",
36 | verbalizer: str = "Please write a question based on this passage.",
37 | max_input_length: int = 512,
38 | max_query_length: int = 128,
39 | **kwargs
40 | ):
41 | """
42 | Args:
43 | model_name_or_path: A T5 checkpoint name or path (e.g., 't5-large', 'google/t5-xxl-lm-adapt', etc.)
44 | verbose: Verbosity level.
45 | device: "cuda", "cpu", or None for auto.
46 | dtype: e.g. "float32", "float16", "bf16", or a torch.dtype.
47 | batch_size: How many documents to process at once.
48 | verbalizer_head: Prefixed to the doc text to mimic the 'Passage: ' from the original code.
49 | verbalizer: A short instruction appended to the doc text. The original UPR default is
50 | "Please write a question based on this passage."
51 | max_input_length: Maximum tokens for the encoder side (document).
52 | max_query_length: Maximum tokens for the decoder side (query).
53 | """
54 | self.verbose = verbose
55 | self.device = get_device(device, self.verbose)
56 | self.dtype = get_dtype(dtype, self.device, self.verbose)
57 | self.batch_size = batch_size
58 | self.verbalizer_head = verbalizer_head
59 | self.verbalizer = verbalizer
60 | self.max_input_length = max_input_length
61 | self.max_query_length = max_query_length
62 |
63 | vprint(f"[UPR] Loading T5 model: {model_name_or_path}", self.verbose)
64 | vprint(f"[UPR] device={self.device}, dtype={self.dtype}, batch_size={batch_size}", self.verbose)
65 |
66 | # Load T5
67 | model_kwargs = kwargs.get("model_kwargs", {})
68 | self.model = T5ForConditionalGeneration.from_pretrained(
69 | model_name_or_path, torch_dtype=self.dtype, **model_kwargs
70 | ).to(self.device)
71 | self.model.eval()
72 |
73 | tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
74 | self.tokenizer = T5Tokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs)
75 |
76 | def score(self, query: str, doc: str) -> float:
77 | """
78 | Score a single document. Negative log-likelihood of 'query' given 'doc'.
79 | Higher means more relevant (score = -NLL).
80 | """
81 | scores = self._get_scores(query, [doc])
82 | return scores[0] if scores else 0.0
83 |
84 | def rank(
85 | self,
86 | query: str,
87 | docs: Union[str, List[str], Document, List[Document]],
88 | doc_ids: Optional[Union[List[str], List[int]]] = None,
89 | metadata: Optional[List[dict]] = None,
90 | ) -> RankedResults:
91 | """
92 | Ranks a list of documents by the negative log-likelihood of the query given the doc.
93 | """
94 | # Convert user inputs into a list of Document objects
95 | docs = prep_docs(docs, doc_ids, metadata)
96 |
97 | # Score them
98 | doc_texts = [d.text for d in docs]
99 | scores = self._get_scores(query, doc_texts)
100 |
101 | # Sort in descending order of score
102 | ranked_results = [
103 | Result(document=doc, score=score, rank=idx + 1)
104 | for idx, (doc, score) in enumerate(
105 | sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
106 | )
107 | ]
108 | return RankedResults(results=ranked_results, query=query, has_scores=True)
109 |
110 | @torch.inference_mode()
111 | def _get_scores(self, query: str, docs: List[str]) -> List[float]:
112 | """
113 | Batched negative log-likelihood scoring:
114 | score = - sum_{tokens in query} [ log P(token | doc) ].
115 | """
116 | all_scores = []
117 | # Create mini-batches of docs
118 | for batch in get_chunks(docs, self.batch_size):
119 | # 1) Build the T5 encoder inputs for the doc
120 | # (mimicking "Passage: {doc_text}. Please write a question..." from the original code)
121 | encoder_texts = [
122 | f"{self.verbalizer_head} {doc_text}. {self.verbalizer}"
123 | for doc_text in batch
124 | ]
125 |
126 | encoder_enc = self.tokenizer(
127 | encoder_texts,
128 | padding=True,
129 | truncation=True,
130 | max_length=self.max_input_length,
131 | return_tensors="pt",
132 | ).to(self.device)
133 |
134 | # 2) Build the T5 decoder labels for the query
135 | # (the question is now the *label*, exactly as in original UPR).
136 | decoder_enc = self.tokenizer(
137 | [query] * len(batch),
138 | padding=True,
139 | truncation=True,
140 | max_length=self.max_query_length,
141 | return_tensors="pt",
142 | ).to(self.device)
143 |
144 | # 3) forward pass with `labels=...` so that T5 returns cross-entropy
145 | # but we want the per-token log-likelihood to replicate the approach exactly.
146 | logits = self.model(
147 | input_ids=encoder_enc.input_ids,
148 | attention_mask=encoder_enc.attention_mask,
149 | labels=decoder_enc.input_ids,
150 | ).logits # shape: [batch, seq_len, vocab_size]
151 |
152 | # 4) Compute log-softmax for each token
153 | log_probs = F.log_softmax(logits, dim=-1) # [batch, seq_len, vocab_size]
154 |
155 | # 5) Gather the probabilities at each gold label token => negative log-likelihood
156 | # next_token = decoder_enc.input_ids[..., 1:]
157 | # but T5 shifts internally. We'll simply do gather on
158 | # the label tokens and sum up, replicating the original.
159 | labels = decoder_enc.input_ids.unsqueeze(-1) # [batch, seq_len, 1]
160 | token_log_probs = log_probs.gather(-1, labels).squeeze(-1) # [batch, seq_len]
161 |
162 | # T5 shifts internally, so the first token is the "start token." The original UPR code
163 | # just sums everything. We'll do the same.
164 | nll = -token_log_probs # [batch, seq_len]
165 | sum_nll = torch.sum(nll, dim=1) # sum over query tokens
166 |
167 | # final score = - sum_nll (which is +ve if the NLL is large)
168 | # we want "best doc" to have the largest score => doc that yields the *lowest* NLL
169 | batch_scores = (-sum_nll).tolist()
170 |
171 | all_scores.extend(batch_scores)
172 |
173 | return all_scores
174 |
--------------------------------------------------------------------------------
/rerankers/reranker.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import warnings
3 | from rerankers.models import AVAILABLE_RANKERS
4 | from rerankers.models.ranker import BaseRanker
5 | from rerankers.utils import vprint
6 |
7 | DEFAULTS = {
8 | "jina": {"en": "jina-reranker-v1-base-en"},
9 | "isaacus": {"en": "kanon-universal-classifier"},
10 | "pinecone": {"en": "pinecone-rerank-v0"},
11 | "cohere": {"en": "rerank-english-v3.0", "other": "rerank-multilingual-v3.0"},
12 | "voyage": {"en": "rerank-lite-1"},
13 | "mixedbread.ai": {"en": "mixedbread-ai/mxbai-rerank-large-v1"},
14 | "cross-encoder": {
15 | "en": "mixedbread-ai/mxbai-rerank-base-v1",
16 | "fr": "antoinelouis/crossencoder-camembert-base-mmarcoFR",
17 | "zh": "BAAI/bge-reranker-base",
18 | "other": "corrius/cross-encoder-mmarco-mMiniLMv2-L12-H384-v1",
19 | },
20 | "t5": {"en": "unicamp-dl/InRanker-base", "other": "unicamp-dl/mt5-base-mmarco-v2"},
21 | "lit5": {"en": "castorini/LiT5-Distill-base"},
22 | "rankgpt": {"en": "gpt-4-turbo-preview", "other": "gpt-4-turbo-preview"},
23 | "rankgpt3": {"en": "gpt-3.5-turbo", "other": "gpt-3.5-turbo"},
24 | "rankgpt4": {"en": "gpt-4", "other": "gpt-4"},
25 | "rankllm": {"en": "rank_zephyr", "other": "rank_zephyr"},
26 | "colbert": {
27 | "en": "colbert-ir/colbertv2.0",
28 | "fr": "bclavie/FraColBERTv2",
29 | "ja": "bclavie/JaColBERTv2",
30 | "es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES",
31 | },
32 | "flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"},
33 | "text-embeddings-inference": {"other": "BAAI/bge-reranker-base"},
34 | "llm-layerwise": {
35 | "en": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
36 | "other": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
37 | },
38 | "monovlm": {
39 | "en": "lightonai/MonoQwen2-VL-v0.1",
40 | "other": "lightonai/MonoQwen2-VL-v0.1"
41 | },
42 | "llm-relevance-filter": {
43 | "en": "gpt-4-turbo-preview",
44 | "other": "gpt-4-turbo-preview"
45 | },
46 | "upr": {"en": "google/t5-large-lm-adapt"},
47 | "mxbaiv2": {"en": "mixedbread-ai/mxbai-rerank-base-v2"},
48 | "pylate": {
49 | "en": "lightonai/GTE-ModernColBERT-v1",
50 | "other": "lightonai/GTE-ModernColBERT-v1",
51 | },
52 | }
53 |
54 | DEPS_MAPPING = {
55 | "TransformerRanker": "transformers",
56 | "T5Ranker": "transformers",
57 | "LiT5Ranker": "lit5",
58 | "RankGPTRanker": "gpt",
59 | "APIRanker": "api",
60 | "ColBERTRanker": "transformers",
61 | "FlashRankRanker": "flashrank",
62 | "RankLLMRanker": "rankllm",
63 | "LLMLayerWiseRanker": "transformers",
64 | "MonoVLMRanker": "transformers",
65 | "LLMRelevanceFilter": "litellm",
66 | "UPRRanker": "transformers",
67 | "MxBaiV2Ranker": "transformers",
68 | "PyLateRanker": "pylate",
69 | }
70 |
71 | PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "pinecone", "isaacus", "text-embeddings-inference"]
72 |
73 | def _get_api_provider(model_name: str, model_type: Optional[str] = None) -> Optional[str]:
74 | # If an explicit model_type is provided and it isn't one of the known API providers,
75 | # then we skip auto-detection of an API provider.
76 | if model_type is not None and model_type not in PROVIDERS:
77 | return None
78 | if (model_type in PROVIDERS) or any(provider in model_name for provider in PROVIDERS):
79 | return model_type if model_type in PROVIDERS else next(
80 | (provider for provider in PROVIDERS if provider in model_name), None
81 | )
82 | # Check if the model_name is a key in DEFAULTS to set the provider correctly
83 | return next(
84 | (
85 | provider
86 | for provider in PROVIDERS
87 | if model_name in DEFAULTS and any(provider in values for values in DEFAULTS[model_name].values())
88 | ),
89 | None,
90 | )
91 |
92 | def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) -> str:
93 | if explicit_model_type:
94 | model_mapping = {
95 | "cohere": "APIRanker",
96 | "pinecone": "APIRanker",
97 | "jina": "APIRanker",
98 | "isaacus": "APIRanker",
99 | "voyage": "APIRanker",
100 | "text-embeddings-inference": "APIRanker",
101 | "rankgpt": "RankGPTRanker",
102 | "lit5": "LiT5Ranker",
103 | "t5": "T5Ranker",
104 | "colbert": "ColBERTRanker",
105 | "cross-encoder": "TransformerRanker",
106 | "flashrank": "FlashRankRanker",
107 | "rankllm": "RankLLMRanker",
108 | "llm-layerwise": "LLMLayerWiseRanker",
109 | "monovlm": "MonoVLMRanker",
110 | "llm-relevance-filter": "LLMRelevanceFilter",
111 | "upr": "UPRRanker",
112 | "mxbaiv2": "MxBaiV2Ranker",
113 | "pylate": "PyLateRanker",
114 | }
115 | return model_mapping.get(explicit_model_type, explicit_model_type)
116 | else:
117 | model_name = model_name.lower()
118 | model_mapping = {
119 | "lit5": "LiT5Ranker",
120 | "t5": "T5Ranker",
121 | "inranker": "T5Ranker",
122 | "rankllm": "RankLLMRanker",
123 | "rankgpt": "RankGPTRanker",
124 | "gpt": "RankGPTRanker",
125 | "colbert": "ColBERTRanker",
126 | "cohere": "APIRanker",
127 | "pinecone": "APIRanker",
128 | "jina": "APIRanker",
129 | "isaacus": "APIRanker",
130 | "voyage": "APIRanker",
131 | "text-embeddings-inference": "APIRanker",
132 | "ms-marco-minilm-l-12-v2": "FlashRankRanker",
133 | "ms-marco-multibert-l-12": "FlashRankRanker",
134 | "vicuna": "RankLLMRanker",
135 | "zephyr": "RankLLMRanker",
136 | "bge-reranker-v2.5-gemma2-lightweight": "LLMLayerWiseRanker",
137 | "monovlm": "MonoVLMRanker",
138 | "monoqwen2-vl": "MonoVLMRanker",
139 | "llm-relevance-filter": "LLMRelevanceFilter",
140 | "upr": "UPRRanker",
141 | "mxbaiv2": "MxBaiV2Ranker",
142 | "mxbai-rerank-base-v2": "MxBaiV2Ranker",
143 | "mxbai-rerank-large-v2": "MxBaiV2Ranker",
144 | "pylate": "PyLateRanker",
145 | }
146 | for key, value in model_mapping.items():
147 | if key in model_name:
148 | if key == "gpt":
149 | warnings.warn(
150 | "The key 'gpt' currently defaults to the rough rankGPT implementation. From version 0.0.5 onwards, 'gpt' will default to RankLLM instead. Please specify the 'rankgpt' `model_type` if you want to keep the current behaviour",
151 | DeprecationWarning,
152 | )
153 | return value
154 | if (
155 | any(
156 | keyword in model_name
157 | for keyword in ["minilm", "bert", "cross-encoders/"]
158 | )
159 | and "/" in model_name
160 | ):
161 | return "TransformerRanker"
162 | print(
163 | "Warning: Model type could not be auto-mapped with the defaults list. Defaulting to TransformerRanker."
164 | )
165 | print(
166 | "If your model is NOT intended to be ran as a one-label cross-encoder, please reload it and specify the model_type!",
167 | "Otherwise, you may ignore this warning. You may specify `model_type='cross-encoder'` to suppress this warning in the future.",
168 | )
169 | return "TransformerRanker"
170 |
171 | def _get_defaults(
172 | model_name: str,
173 | model_type: Optional[str] = None,
174 | lang: str = "en",
175 | verbose: int = 1,
176 | ) -> str:
177 | if model_name in DEFAULTS.keys():
178 | print(f"Loading default {model_name} model for language {lang}")
179 | try:
180 | model_name = DEFAULTS[model_name][lang]
181 | except KeyError:
182 | if "other" not in DEFAULTS[model_name]:
183 | print(
184 | f"Model family {model_name} does not have a default for language {lang}"
185 | )
186 | print(
187 | "Aborting now... Please retry with another model family or by specifying a model"
188 | )
189 | return None, None
190 | model_name = DEFAULTS[model_name]["other"]
191 | model_type = _get_model_type(model_name, model_type)
192 | vprint(f"Default Model: {model_name}", verbose)
193 |
194 | return model_name, model_type
195 |
196 | def Reranker(
197 | model_name: str,
198 | lang: str = "en",
199 | model_type: Optional[str] = None,
200 | verbose: int = 1,
201 | **kwargs,
202 | ) -> Optional[BaseRanker]:
203 | original_model_name = model_name
204 | api_provider = _get_api_provider(model_name, model_type)
205 | if api_provider or model_name.lower() in PROVIDERS:
206 | if model_name.lower() in PROVIDERS:
207 | api_provider = model_name.lower()
208 | # Only override model_type to APIRanker if it hasn't been explicitly set
209 | if model_type is None:
210 | model_type = "APIRanker"
211 | model_name = (
212 | DEFAULTS[api_provider][lang]
213 | if lang in DEFAULTS[api_provider]
214 | else DEFAULTS[api_provider]["other"]
215 | )
216 | print(
217 | f"Auto-updated model_name to {model_name} for API provider {api_provider}"
218 | )
219 | else:
220 | if model_type is None:
221 | model_type = "APIRanker"
222 | else:
223 | if original_model_name in DEFAULTS.keys():
224 | model_name, model_type = _get_defaults(original_model_name, model_type, lang, verbose)
225 | if model_name is None:
226 | return None
227 | api_provider = _get_api_provider(model_name, model_type)
228 | if api_provider and model_type is None:
229 | model_type = "APIRanker"
230 |
231 | if api_provider:
232 | kwargs["api_provider"] = api_provider
233 |
234 | model_type = _get_model_type(model_name, model_type)
235 |
236 | try:
237 | vprint(f"Loading {model_type} model {model_name} (this message can be suppressed by setting verbose=0)", verbose)
238 | return AVAILABLE_RANKERS[model_type](model_name, verbose=verbose, **kwargs)
239 | except KeyError:
240 | print(
241 | f"You don't have the necessary dependencies installed to use {model_type}."
242 | )
243 | print(
244 | f'Please install the necessary dependencies for {model_type} by running `pip install "rerankers[{DEPS_MAPPING[model_type]}]"`',
245 | 'or `pip install "rerankers[all]"` to install the dependencies for all reranker types.',
246 | )
247 | return None
248 |
--------------------------------------------------------------------------------
/rerankers/results.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | from rerankers.documents import Document
4 |
5 |
6 | class Result:
7 | def __init__(self, document: Document, score: Optional[float] = None, rank: Optional[int] = None):
8 | self.document = document
9 | self.score = score
10 | self.rank = rank
11 |
12 | if rank is None and score is None:
13 | raise ValueError("Either score or rank must be provided.")
14 |
15 | def __getattr__(self, item):
16 | if hasattr(self.document, item):
17 | return getattr(self.document, item)
18 | elif item in ["document", "score", "rank"]:
19 | return getattr(self, item)
20 | elif item in self.document.attributes:
21 | return getattr(self.document, item)
22 | elif item in self.document.metadata:
23 | return self.document.metadata[item]
24 | raise AttributeError(
25 | f"'{self.__class__.__name__}' object has no attribute '{item}'"
26 | )
27 |
28 | def __repr__(self) -> str:
29 | fields = {
30 | "document": self.document,
31 | "score": self.score,
32 | "rank": self.rank,
33 | }
34 | field_str = ", ".join(f"{k}={v!r}" for k, v in fields.items())
35 | return f"{self.__class__.__name__}({field_str})"
36 |
37 |
38 | class RankedResults:
39 | def __init__(self, results: List[Result], query: str, has_scores: bool = False):
40 | self.results = results
41 | self.query = query
42 | self.has_scores = has_scores
43 |
44 | def __iter__(self):
45 | """Allows iteration over the results list."""
46 | return iter(self.results)
47 |
48 | def __getitem__(self, index):
49 | """Allows indexing to access results directly."""
50 | return self.results[index]
51 |
52 | def results_count(self) -> int:
53 | """Returns the total number of results."""
54 | return len(self.results)
55 |
56 | def top_k(self, k: int) -> List[Result]:
57 | """Returns the top k results based on the score, if available, or rank."""
58 | if self.has_scores:
59 | return sorted(
60 | self.results,
61 | key=lambda x: x.score if x.score is not None else float("-inf"),
62 | reverse=True,
63 | )[:k]
64 | else:
65 | return sorted(
66 | self.results,
67 | key=lambda x: x.rank if x.rank is not None else float("inf"),
68 | )[:k]
69 |
70 | def get_score_by_docid(self, doc_id: Union[int, str]) -> Optional[float]:
71 | """Fetches the score of a result by its doc_id using a more efficient approach."""
72 | result = next((r for r in self.results if r.document.doc_id == doc_id), None)
73 | return result.score if result else None
74 |
75 | def get_result_by_docid(self, doc_id: Union[int, str]) -> Optional[Result]:
76 | """Fetches a result by its doc_id using a more efficient approach."""
77 | result = next((r for r in self.results if r.document.doc_id == doc_id), None)
78 | return result if result else None
79 |
80 | def __repr__(self) -> str:
81 | fields = {
82 | "results": self.results,
83 | "query": self.query,
84 | "has_scores": self.has_scores,
85 | }
86 | field_str = ", ".join(f"{k}={v!r}" for k, v in fields.items())
87 | return f"{self.__class__.__name__}({field_str})"
88 |
--------------------------------------------------------------------------------
/rerankers/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import binascii
3 | from typing import Union, Optional, List, Iterable
4 | try:
5 | import io
6 | from PIL import Image
7 | except ImportError:
8 | pass
9 | from rerankers.documents import Document
10 |
11 | def vprint(txt: str, verbose: int) -> None:
12 | if verbose > 0:
13 | print(txt)
14 |
15 |
16 | try:
17 | import torch
18 |
19 | def get_dtype(
20 | dtype: Optional[Union[str, torch.dtype]],
21 | device: Optional[Union[str, torch.device]],
22 | verbose: int = 1,
23 | ) -> torch.dtype:
24 | if dtype is None:
25 | vprint("No dtype set", verbose)
26 | # if device == "cpu":
27 | # vprint("Device set to `cpu`, setting dtype to `float32`", verbose)
28 | # dtype = torch.float32
29 | if not isinstance(dtype, torch.dtype):
30 | if dtype == "fp16" or dtype == "float16":
31 | dtype = torch.float16
32 | elif dtype == "bf16" or dtype == "bfloat16":
33 | dtype = torch.bfloat16
34 | else:
35 | dtype = torch.float32
36 | vprint(f"Using dtype {dtype}", verbose)
37 | return dtype
38 |
39 | def get_device(
40 | device: Optional[Union[str, torch.device]],
41 | verbose: int = 1,
42 | no_mps: bool = False,
43 | ) -> Union[str, torch.device]:
44 | if not device:
45 | vprint("No device set", verbose)
46 | if torch.cuda.is_available():
47 | device = "cuda"
48 | elif torch.backends.mps.is_available() and not no_mps:
49 | device = "mps"
50 | else:
51 | device = "cpu"
52 | vprint(f"Using device {device}", verbose)
53 | return device
54 |
55 | except ImportError:
56 | pass
57 |
58 |
59 | def make_documents(
60 | docs: List[str],
61 | doc_ids: Optional[Union[List[str], List[int]]] = None,
62 | ):
63 | if doc_ids is None:
64 | doc_ids = list(range(len(docs)))
65 | return [Document(doc, doc_id=doc_ids[i]) for i, doc in enumerate(docs)]
66 |
67 |
68 | def prep_docs(
69 | docs: Union[str, List[str], Document, List[Document]],
70 | doc_ids: Optional[Union[List[str], List[int]]] = None,
71 | metadata: Optional[List[dict]] = None,
72 | ):
73 | if isinstance(docs, Document) or (
74 | isinstance(docs, List) and isinstance(docs[0], Document)
75 | ):
76 | if isinstance(docs, Document):
77 | docs = [docs]
78 | if doc_ids is not None:
79 | if docs[0].doc_id is not None:
80 | print(
81 | "Overriding doc_ids passed within the Document objects with explicitly passed doc_ids!"
82 | )
83 | print(
84 | "This is not the preferred way of doing so, please double-check your code."
85 | )
86 | for i, doc in enumerate(docs):
87 | doc.doc_id = doc_ids[i]
88 |
89 | elif doc_ids is None:
90 | doc_ids = [doc.doc_id for doc in docs]
91 | if doc_ids[0] is None:
92 | print(
93 | "'None' doc_ids detected, reverting to auto-generated integer ids..."
94 | )
95 | doc_ids = list(range(len(docs)))
96 |
97 | if metadata is not None:
98 | if docs[0].meatadata is not None:
99 | print(
100 | "Overriding doc_ids passed within the Document objects with explicitly passed doc_ids!"
101 | )
102 | print(
103 | "This is not the preferred way of doing so, please double-check your code."
104 | )
105 | for i, doc in enumerate(docs):
106 | doc.metadata = metadata[i]
107 |
108 | return docs
109 |
110 | if isinstance(docs, str):
111 | docs = [docs]
112 | if doc_ids is None:
113 | doc_ids = list(range(len(docs)))
114 | if metadata is None:
115 | metadata = [{} for _ in docs]
116 |
117 | return [
118 | Document(doc, doc_id=doc_ids[i], metadata=metadata[i])
119 | for i, doc in enumerate(docs)
120 | ]
121 |
122 |
123 | def prep_image_docs(
124 | docs: Union[str, List[str], Document, List[Document]],
125 | doc_ids: Optional[Union[List[str], List[int]]] = None,
126 | metadata: Optional[List[dict]] = None,
127 | ) -> List[Document]:
128 | """
129 | Prepare image documents for processing. Can handle base64 encoded images or file paths.
130 | Similar to prep_docs but specialized for image documents.
131 | """
132 | # If already Document objects, handle similarly to prep_docs
133 | if isinstance(docs, Document) or (
134 | isinstance(docs, List) and isinstance(docs[0], Document)
135 | ):
136 | if isinstance(docs, Document):
137 | docs = [docs]
138 | # Validate all docs are image type
139 | for doc in docs:
140 | if doc.document_type != "image":
141 | raise ValueError("All documents must be of type 'image'")
142 | return prep_docs(docs, doc_ids, metadata)
143 |
144 | # Handle string inputs (paths or base64)
145 | if isinstance(docs, str):
146 | docs = [docs]
147 |
148 | processed_docs = []
149 | for doc in docs:
150 | # Check if input is base64 by attempting to decode
151 | try:
152 | # Try to decode and verify it's an image
153 | decoded = base64.b64decode(doc)
154 | try:
155 | Image.open(io.BytesIO(decoded)).verify()
156 | b64 = doc
157 | image_path = None
158 | except:
159 | raise binascii.Error("Invalid image data")
160 | except binascii.Error:
161 | # If decode fails, treat as file path
162 | try:
163 | image_path = doc
164 | with open(doc, 'rb') as img_file:
165 | b64 = base64.b64encode(img_file.read()).decode('utf-8')
166 | except Exception as e:
167 | raise ValueError(f"Could not process image input {doc}: {str(e)}")
168 |
169 | processed_docs.append(
170 | Document(
171 | document_type="image",
172 | base64=b64,
173 | image_path=image_path
174 | )
175 | )
176 |
177 | # Handle doc_ids and metadata
178 | if doc_ids is None:
179 | doc_ids = list(range(len(processed_docs)))
180 | if metadata is None:
181 | metadata = [{} for _ in processed_docs]
182 |
183 | # Set doc_ids and metadata
184 | for i, doc in enumerate(processed_docs):
185 | doc.doc_id = doc_ids[i]
186 | doc.metadata = metadata[i]
187 |
188 |
189 | return processed_docs
190 |
191 |
192 |
193 |
194 | def get_chunks(iterable: Iterable, chunk_size: int): # noqa: E741
195 | """
196 | Implementation from https://github.com/unicamp-dl/InRanker/blob/main/inranker/base.py with extra typing and more descriptive names.
197 | This method is used to split a list l into chunks of batch size n.
198 | """
199 | for i in range(0, len(iterable), chunk_size):
200 | yield iterable[i : i + chunk_size]
201 |
--------------------------------------------------------------------------------
/tests/consistency_notebooks/test_colbert.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from ranx import Qrels, Run"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "qrels = Qrels.from_ir_datasets(\"beir/scifact/test\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 3,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stderr",
28 | "output_type": "stream",
29 | "text": [
30 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
31 | " from .autonotebook import tqdm as notebook_tqdm\n"
32 | ]
33 | }
34 | ],
35 | "source": [
36 | "from rerankers import Reranker"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 4,
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "data": {
46 | "text/plain": [
47 | "{'_id': '4983',\n",
48 | " 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',\n",
49 | " 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and preterm infants at term showed marked differences in white matter fiber organization. The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural development in cerebral white matter in living infants.',\n",
50 | " 'metadata': {}}"
51 | ]
52 | },
53 | "execution_count": 4,
54 | "metadata": {},
55 | "output_type": "execute_result"
56 | }
57 | ],
58 | "source": [
59 | "import srsly\n",
60 | "\n",
61 | "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n",
62 | "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n",
63 | "\n",
64 | "corpus[0]"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 5,
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "name": "stdout",
74 | "output_type": "stream",
75 | "text": [
76 | "Loading default colbert model for language en\n",
77 | "Loading ColBERTRanker model colbert-ir/colbertv2.0\n"
78 | ]
79 | }
80 | ],
81 | "source": [
82 | "ranker = Reranker('colbert', device='cuda', verbose=0)"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 6,
88 | "metadata": {},
89 | "outputs": [],
90 | "source": [
91 | "top100 = srsly.read_json('./data/scifact/scifact_top_100.json')\n",
92 | "\n"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 7,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 8,
107 | "metadata": {},
108 | "outputs": [
109 | {
110 | "name": "stderr",
111 | "output_type": "stream",
112 | "text": [
113 | " 75%|███████▌ | 226/300 [03:36<01:14, 1.01s/it]"
114 | ]
115 | }
116 | ],
117 | "source": [
118 | "qrels_dict = dict(qrels)\n",
119 | "queries = [q for q in queries if q['_id'] in qrels_dict]\n",
120 | "from tqdm import tqdm\n",
121 | "\n",
122 | "scores = {}\n",
123 | "for q in tqdm(queries):\n",
124 | " doc_ids = top100[q['_id']]\n",
125 | " docs = [corpus_map[x] for x in doc_ids]\n",
126 | " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "scores_dict = {}\n",
136 | "for q_id, ranked_results in scores.items():\n",
137 | " top_10_results = ranked_results.top_k(10)\n",
138 | " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n",
139 | "run = Run(scores_dict)"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "metadata": {},
146 | "outputs": [
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "Score is within 0.01 NDCG@10 of the reported score!\n"
152 | ]
153 | }
154 | ],
155 | "source": [
156 | "from ranx import evaluate\n",
157 | "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n",
158 | "litterature_result = 0.693 # From ColBERTv2 Paper https://arxiv.org/abs/2112.01488\n",
159 | "if abs(evaluation_score - litterature_result) > 0.01:\n",
160 | " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n",
161 | "else:\n",
162 | " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n"
163 | ]
164 | }
165 | ],
166 | "metadata": {
167 | "kernelspec": {
168 | "display_name": "rerankers",
169 | "language": "python",
170 | "name": "python3"
171 | },
172 | "language_info": {
173 | "codemirror_mode": {
174 | "name": "ipython",
175 | "version": 3
176 | },
177 | "file_extension": ".py",
178 | "mimetype": "text/x-python",
179 | "name": "python",
180 | "nbconvert_exporter": "python",
181 | "pygments_lexer": "ipython3",
182 | "version": "3.10.13"
183 | }
184 | },
185 | "nbformat": 4,
186 | "nbformat_minor": 2
187 | }
188 |
--------------------------------------------------------------------------------
/tests/consistency_notebooks/test_crossenc.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from ranx import Qrels, Run\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "qrels = Qrels.from_ir_datasets(\"beir/scifact/test\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 3,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stderr",
28 | "output_type": "stream",
29 | "text": [
30 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
31 | " from .autonotebook import tqdm as notebook_tqdm\n"
32 | ]
33 | },
34 | {
35 | "name": "stdout",
36 | "output_type": "stream",
37 | "text": [
38 | "T5Ranker\n",
39 | "{'TransformerRanker': , 'APIRanker': , 'RankGPTRanker': , 'T5Ranker': }\n",
40 | "No dtype set\n",
41 | "Using dtype torch.float16\n",
42 | "Loading model castorini/monot5-base-msmarco-10k, this might take a while...\n",
43 | "Using device cuda.\n",
44 | "Using dtype torch.float16.\n"
45 | ]
46 | },
47 | {
48 | "name": "stderr",
49 | "output_type": "stream",
50 | "text": [
51 | "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
52 | ]
53 | },
54 | {
55 | "name": "stdout",
56 | "output_type": "stream",
57 | "text": [
58 | "T5 true token set to ▁true\n",
59 | "T5 false token set to ▁false\n",
60 | "Returning normalised scores...\n"
61 | ]
62 | }
63 | ],
64 | "source": [
65 | "from rerankers import Reranker\n",
66 | "ranker = Reranker('castorini/monot5-base-msmarco-10k', device='cuda', batch_size=128)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 4,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "data": {
76 | "text/plain": [
77 | "{'_id': '4983',\n",
78 | " 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',\n",
79 | " 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and preterm infants at term showed marked differences in white matter fiber organization. The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural development in cerebral white matter in living infants.',\n",
80 | " 'metadata': {}}"
81 | ]
82 | },
83 | "execution_count": 4,
84 | "metadata": {},
85 | "output_type": "execute_result"
86 | }
87 | ],
88 | "source": [
89 | "import srsly\n",
90 | "\n",
91 | "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n",
92 | "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n",
93 | "\n",
94 | "corpus[0]"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 5,
100 | "metadata": {},
101 | "outputs": [
102 | {
103 | "name": "stdout",
104 | "output_type": "stream",
105 | "text": [
106 | "Warning: Model type could not be auto-mapped. Defaulting to TransformerRanker.\n",
107 | "If your model is NOT intended to be ran as a one-label cross-encoder, please reload it and specify the model_type!\n",
108 | "TransformerRanker\n",
109 | "{'TransformerRanker': , 'APIRanker': , 'RankGPTRanker': , 'T5Ranker': }\n",
110 | "No dtype set\n",
111 | "Using dtype torch.float16\n"
112 | ]
113 | },
114 | {
115 | "name": "stdout",
116 | "output_type": "stream",
117 | "text": [
118 | "Loaded model mixedbread-ai/mxbai-rerank-base-v1\n",
119 | "Using device cuda.\n",
120 | "Using dtype torch.float16.\n"
121 | ]
122 | }
123 | ],
124 | "source": [
125 | "ranker = Reranker('mixedbread-ai/mxbai-rerank-base-v1', device='cuda')"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 6,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 7,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 8,
149 | "metadata": {},
150 | "outputs": [
151 | {
152 | "name": "stderr",
153 | "output_type": "stream",
154 | "text": [
155 | "100%|██████████| 300/300 [02:55<00:00, 1.71it/s]\n"
156 | ]
157 | }
158 | ],
159 | "source": [
160 | "qrels_dict = dict(qrels)\n",
161 | "queries = [q for q in queries if q['_id'] in qrels_dict]\n",
162 | "from tqdm import tqdm\n",
163 | "\n",
164 | "scores = {}\n",
165 | "for q in tqdm(queries):\n",
166 | " doc_ids = top100[q['_id']]\n",
167 | " docs = [corpus_map[x] for x in doc_ids]\n",
168 | " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 9,
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "scores_dict = {}\n",
178 | "for q_id, ranked_results in scores.items():\n",
179 | " top_10_results = ranked_results.top_k(10)\n",
180 | " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 10,
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "run = Run(scores_dict)"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 11,
195 | "metadata": {},
196 | "outputs": [
197 | {
198 | "data": {
199 | "text/plain": [
200 | "300"
201 | ]
202 | },
203 | "execution_count": 11,
204 | "metadata": {},
205 | "output_type": "execute_result"
206 | }
207 | ],
208 | "source": [
209 | "len(run)"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 12,
215 | "metadata": {},
216 | "outputs": [
217 | {
218 | "name": "stdout",
219 | "output_type": "stream",
220 | "text": [
221 | "Score is within 0.01NDCG@10 of the reported score!\n"
222 | ]
223 | }
224 | ],
225 | "source": [
226 | "from ranx import evaluate\n",
227 | "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n",
228 | "litterature_result = 0.724 # from MXBAI https://docs.google.com/spreadsheets/d/15ELkSMFv-oHa5TRiIjDvhIstH9dlc3pnZeO-iGz4Ld4/edit#gid=0\n",
229 | "if abs(evaluation_score - litterature_result) > 0.01:\n",
230 | " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n",
231 | "else:\n",
232 | " print(f\"Score is within 0.01NDCG@10 of the reported score!\")\n"
233 | ]
234 | }
235 | ],
236 | "metadata": {
237 | "kernelspec": {
238 | "display_name": "rerankers",
239 | "language": "python",
240 | "name": "python3"
241 | },
242 | "language_info": {
243 | "codemirror_mode": {
244 | "name": "ipython",
245 | "version": 3
246 | },
247 | "file_extension": ".py",
248 | "mimetype": "text/x-python",
249 | "name": "python",
250 | "nbconvert_exporter": "python",
251 | "pygments_lexer": "ipython3",
252 | "version": "3.10.13"
253 | }
254 | },
255 | "nbformat": 4,
256 | "nbformat_minor": 2
257 | }
258 |
--------------------------------------------------------------------------------
/tests/consistency_notebooks/test_inranker.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from ranx import Qrels, Run\n",
10 | "\n"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "qrels = Qrels.from_ir_datasets(\"beir/scifact/test\")"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 3,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stderr",
29 | "output_type": "stream",
30 | "text": [
31 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
32 | " from .autonotebook import tqdm as notebook_tqdm\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "from rerankers import Reranker"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 4,
43 | "metadata": {},
44 | "outputs": [
45 | {
46 | "data": {
47 | "text/plain": [
48 | "{'_id': '4983',\n",
49 | " 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',\n",
50 | " 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and preterm infants at term showed marked differences in white matter fiber organization. The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural development in cerebral white matter in living infants.',\n",
51 | " 'metadata': {}}"
52 | ]
53 | },
54 | "execution_count": 4,
55 | "metadata": {},
56 | "output_type": "execute_result"
57 | }
58 | ],
59 | "source": [
60 | "import srsly\n",
61 | "\n",
62 | "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n",
63 | "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n",
64 | "\n",
65 | "corpus[0]"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 5,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "T5Ranker\n",
78 | "{'TransformerRanker': , 'APIRanker': , 'RankGPTRanker': , 'T5Ranker': }\n"
79 | ]
80 | },
81 | {
82 | "name": "stderr",
83 | "output_type": "stream",
84 | "text": [
85 | "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
86 | ]
87 | }
88 | ],
89 | "source": [
90 | "ranker = Reranker('unicamp-dl/InRanker-base', device='cuda', batch_size=32, verbose=0)"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 6,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n",
100 | "\n"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 7,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 8,
115 | "metadata": {},
116 | "outputs": [
117 | {
118 | "name": "stderr",
119 | "output_type": "stream",
120 | "text": [
121 | "100%|██████████| 300/300 [05:18<00:00, 1.06s/it]\n"
122 | ]
123 | }
124 | ],
125 | "source": [
126 | "qrels_dict = dict(qrels)\n",
127 | "queries = [q for q in queries if q['_id'] in qrels_dict]\n",
128 | "from tqdm import tqdm\n",
129 | "\n",
130 | "scores = {}\n",
131 | "for q in tqdm(queries):\n",
132 | " doc_ids = top100[q['_id']]\n",
133 | " docs = [corpus_map[x] for x in doc_ids]\n",
134 | " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 9,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "scores_dict = {}\n",
144 | "for q_id, ranked_results in scores.items():\n",
145 | " top_10_results = ranked_results.top_k(10)\n",
146 | " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n",
147 | "run = Run(scores_dict)"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 12,
153 | "metadata": {},
154 | "outputs": [
155 | {
156 | "name": "stdout",
157 | "output_type": "stream",
158 | "text": [
159 | "Score is within 0.01 NDCG@10 of the reported score!\n"
160 | ]
161 | }
162 | ],
163 | "source": [
164 | "from ranx import evaluate\n",
165 | "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n",
166 | "litterature_result = 0.7618 # From InRanker Paper https://arxiv.org/pdf/2401.06910.pdf\n",
167 | "if abs(evaluation_score - litterature_result) > 0.01:\n",
168 | " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n",
169 | "else:\n",
170 | " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n"
171 | ]
172 | }
173 | ],
174 | "metadata": {
175 | "kernelspec": {
176 | "display_name": "rerankers",
177 | "language": "python",
178 | "name": "python3"
179 | },
180 | "language_info": {
181 | "codemirror_mode": {
182 | "name": "ipython",
183 | "version": 3
184 | },
185 | "file_extension": ".py",
186 | "mimetype": "text/x-python",
187 | "name": "python",
188 | "nbconvert_exporter": "python",
189 | "pygments_lexer": "ipython3",
190 | "version": "3.10.13"
191 | }
192 | },
193 | "nbformat": 4,
194 | "nbformat_minor": 2
195 | }
196 |
--------------------------------------------------------------------------------
/tests/consistency_notebooks/test_t5.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from ranx import Qrels, Run"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "qrels = Qrels.from_ir_datasets(\"beir/scifact/test\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 3,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stderr",
28 | "output_type": "stream",
29 | "text": [
30 | "/usr/local/share/miniconda/envs/mcol/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
31 | " from .autonotebook import tqdm as notebook_tqdm\n"
32 | ]
33 | }
34 | ],
35 | "source": [
36 | "from rerankers import Reranker"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 4,
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "data": {
46 | "text/plain": [
47 | "{'_id': '4983',\n",
48 | " 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',\n",
49 | " 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and preterm infants at term showed marked differences in white matter fiber organization. The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural development in cerebral white matter in living infants.',\n",
50 | " 'metadata': {}}"
51 | ]
52 | },
53 | "execution_count": 4,
54 | "metadata": {},
55 | "output_type": "execute_result"
56 | }
57 | ],
58 | "source": [
59 | "import srsly\n",
60 | "\n",
61 | "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n",
62 | "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n",
63 | "\n",
64 | "corpus[0]"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 5,
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "name": "stdout",
74 | "output_type": "stream",
75 | "text": [
76 | "T5Ranker\n",
77 | "{'TransformerRanker': , 'APIRanker': , 'T5Ranker': }\n"
78 | ]
79 | },
80 | {
81 | "name": "stderr",
82 | "output_type": "stream",
83 | "text": [
84 | "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "ranker = Reranker('castorini/monot5-base-msmarco-10k', device='cuda', batch_size=128, verbose=0)"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 6,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n",
99 | "\n"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 7,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 8,
114 | "metadata": {},
115 | "outputs": [
116 | {
117 | "name": "stderr",
118 | "output_type": "stream",
119 | "text": [
120 | "100%|██████████| 300/300 [04:44<00:00, 1.06it/s]\n"
121 | ]
122 | }
123 | ],
124 | "source": [
125 | "qrels_dict = dict(qrels)\n",
126 | "queries = [q for q in queries if q['_id'] in qrels_dict]\n",
127 | "from tqdm import tqdm\n",
128 | "\n",
129 | "scores = {}\n",
130 | "for q in tqdm(queries):\n",
131 | " doc_ids = top100[q['_id']]\n",
132 | " docs = [corpus_map[x] for x in doc_ids]\n",
133 | " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 17,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "scores_dict = {}\n",
143 | "for q_id, ranked_results in scores.items():\n",
144 | " top_10_results = ranked_results.top_k(10)\n",
145 | " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n",
146 | "run = Run(scores_dict)"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 18,
152 | "metadata": {},
153 | "outputs": [
154 | {
155 | "name": "stdout",
156 | "output_type": "stream",
157 | "text": [
158 | "Score 0.731 is within 0.005 NDCG@10 of the reported score!\n"
159 | ]
160 | }
161 | ],
162 | "source": [
163 | "from ranx import evaluate\n",
164 | "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n",
165 | "litterature_result = 0.734 # From RankGPT Paper https://arxiv.org/pdf/2304.09542.pdf\n",
166 | "if abs(evaluation_score - litterature_result) > 0.01:\n",
167 | " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the reported score.\")\n",
168 | "else:\n",
169 | " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n"
170 | ]
171 | }
172 | ],
173 | "metadata": {
174 | "kernelspec": {
175 | "display_name": "rerankers",
176 | "language": "python",
177 | "name": "python3"
178 | },
179 | "language_info": {
180 | "codemirror_mode": {
181 | "name": "ipython",
182 | "version": 3
183 | },
184 | "file_extension": ".py",
185 | "mimetype": "text/x-python",
186 | "name": "python",
187 | "nbconvert_exporter": "python",
188 | "pygments_lexer": "ipython3",
189 | "version": "3.9.18"
190 | }
191 | },
192 | "nbformat": 4,
193 | "nbformat_minor": 2
194 | }
195 |
--------------------------------------------------------------------------------
/tests/test_crossenc.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import patch
2 | import torch
3 | from rerankers import Reranker
4 | from rerankers.models.transformer_ranker import TransformerRanker
5 | from rerankers.results import Result, RankedResults
6 | from rerankers.documents import Document
7 |
8 | @patch("rerankers.models.transformer_ranker.TransformerRanker.rank")
9 | def test_transformer_ranker_rank(mock_rank):
10 | query = "Gone with the wind is an absolute masterpiece"
11 | docs = [
12 | "Gone with the wind is a masterclass in bad storytelling.",
13 | "Gone with the wind is an all-time classic",
14 | ]
15 | expected_results = RankedResults(
16 | results=[
17 | Result(
18 | document=Document(
19 | doc_id=1, text="Gone with the wind is an all-time classic"
20 | ),
21 | score=1.6181640625,
22 | rank=1,
23 | ),
24 | Result(
25 | document=Document(
26 | doc_id=0,
27 | text="Gone with the wind is a masterclass in bad storytelling.",
28 | ),
29 | score=0.88427734375,
30 | rank=2,
31 | ),
32 | ],
33 | query=query,
34 | has_scores=True,
35 | )
36 | mock_rank.return_value = expected_results
37 | ranker = TransformerRanker("mixedbread-ai/mxbai-rerank-xsmall-v1")
38 | results = ranker.rank(query=query, docs=docs)
39 | assert results == expected_results
40 |
--------------------------------------------------------------------------------
/tests/test_results.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rerankers.results import Result, RankedResults
3 | from rerankers.documents import Document
4 |
5 |
6 | def test_ranked_results_functions():
7 | results = RankedResults(
8 | results=[
9 | Result(document=Document(doc_id=0, text="Doc 0"), score=0.9, rank=2),
10 | Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1),
11 | ],
12 | query="Test Query",
13 | has_scores=True,
14 | )
15 | assert results.results_count() == 2
16 | top_k = results.top_k(1)
17 | assert len(top_k) == 1
18 | assert top_k[0].doc_id == 1
19 | assert results.get_score_by_docid(0) == 0.9
20 |
21 |
22 | def test_result_attributes():
23 | result = Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1)
24 | assert result.doc_id == 1
25 | assert result.text == "Doc 1"
26 | assert result.score == 0.95
27 | assert result.rank == 1
28 |
29 |
30 | def test_result_validation_error():
31 | with pytest.raises(ValueError) as excinfo:
32 | Result(document=Document(doc_id=2, text="Doc 2"))
33 | assert "Either score or rank must be provided." in str(excinfo.value)
34 |
--------------------------------------------------------------------------------