├── .github └── workflows │ └── ci.yaml ├── LICENSE ├── README.md ├── chunked_pooling ├── __init__.py ├── chunked_eval_tasks.py ├── chunking.py ├── mteb_chunked_eval.py └── wrappers.py ├── examples.ipynb ├── explanatory_contextual_retrieval.py ├── img ├── context-problem.png ├── method.png └── rag.png ├── pyproject.toml ├── run_chunked_eval.py └── tests ├── __init__.py ├── conftest.py ├── test_api.py ├── test_chunking_methods.py └── test_v3.py /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | pull_request: 5 | types: [opened, synchronize, reopened] 6 | push: 7 | branches: 8 | - main 9 | 10 | env: 11 | JINA_API_TOKEN: ${{ secrets.JINA_API_TOKEN }} 12 | 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: '3.11' 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install .[dev] 29 | 30 | - name: Run tests 31 | run: pytest tests 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Late Chunking of Short Chunks in Long-Context Embedding Models 2 | 3 | [**Blog part 1**](https://jina.ai/news/late-chunking-in-long-context-embedding-models) | [**Blog part 2**](https://jina.ai/news/what-late-chunking-really-is-and-what-its-not-part-ii/) | [**ArXiv paper**](https://arxiv.org/abs/2409.04701) 4 | 5 | For many applications, encoding a whole text document into a single embedding representation is not useful. Many applications require retrieving smaller parts of the text and dense vector-based information retrieval systems often perform better with smaller text segments because of the limited information capacity of embedding vectors. 6 | 7 | ![img.png](img/rag.png) 8 | 9 | 10 | RAG (Retrieval Augmented Generations) is one of the best known applications to require splitting document collections into smaller text chunks. These chunks are typically stored in a vector database with vector representations created by a text embedding model. 11 | At runtime, the same embedding model encodes a query text into a vector representation, which is used to identify relevant stored text chunks. These are them passed to a large language model (LLM) which synthesizes a response to the query based on the retrieved texts. 12 | 13 | ## Context Problem 14 | 15 | 16 | This simple RAG approach is not without challenges. Long distance contextual dependencies, i.e. when the relevant information is spread over multiple chunks and taking text segments out of context makes them useless, are particularly poorly handled by this approach. 17 | ![img.png](img/context-problem.png) 18 | In the image above one can see an Wikipedia article that is split into chunks of sentences. 19 | One can see that phrases like "its" and "the city" referencing "Berlin" which is mentioned only in the first sentence, e.g., it is harder for the embedding model to link it to the respective entity to produce a high-quality embedding representation. 20 | 21 | 22 | For example, if we split a Wikipedia article into sentence-length segments, as in the example above, a RAG system might not be able to answer a query like "What is the population of Berlin?" The city name and the population never appear together in a single segment, and lacking any larger document context. 23 | An LLM to which one of the segments is presented cannot resolve the anaphoric references like "it" or "the city". 24 | 25 | ## Context-Sensitive Chunking 26 | 27 | To overcome this problem, we take advantage of the long input sequences that recent embedding models like [`jina-embeddings-v2-base-en`](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) can process. 28 | These models support much longer input texts, for example, 8192 tokens for `jina-embeddings-v2-base-en` or roughly ten standard pages of text. Text segments of this size are much less likely to have contextual dependencies that can only be resolved with a larger context. 29 | However, we still need vector representations of much smaller chunks of text, in part because of the limited input sizes of LLMs but primarily because of the limited information capacity of short embedding vectors. 30 | 31 | ![img.png](img/method.png) 32 | 33 | 34 | The simple encoding approach (as seen on the left side of the image above) chunks texts before processing them, using sentences, paragraphs, and maximum length limits to split text _a priori_, and then applying an embedding model to the resulting chunks. 35 | Late Chunking, instead, first applies the transformer part from the embedding model to the entire text, or the largest part of it possible. This generates a sequence of vector representations for each token that encompass textual information from the entire text. 36 | To generate a single embedding for a text, many embedding models apply _mean pooling_ to these token representations to output a single vector. Late Chunking instead applies mean pooling to smaller segments of this sequence of token vectors, producing embeddings for each chunk that take into account the entire text. 37 | 38 | ## The Effect of Context-Sensitive Chunking 39 | 40 | This has immediately measurable concrete effects on retrieval. As an example, in case of "the city" and "Berlin" in a Wikipedia article, the vectors representing "the city" contain information connecting it to the previous mention of "Berlin", making it a much better match for queries involving that city name. 41 | 42 | You can see that in numerical results below, which compares the embedding of the string "Berlin" to various sentences from the article about Berlin. The column "Traditional Similarity" is the similarity values using _a priori_ chunking, and "Late Chunking Similarity" is with context-sensitive chunking. 43 | 44 | | Text | Similarity Traditional | Similarity Late Chunking | 45 | |---------------------------------------------------------------------------------------------------------------------------------------|------------------------|-------------------------------| 46 | | Berlin is the capital and largest city of Germany, both by area and by population." | 0.84862185 | 0.849546 | 47 | | Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. | 0.7084338 | 0.82489026 | 48 | | The city is also one of the states of Germany, and is the third smallest state in the country in terms of area. | 0.7534553 | 0.84980094 | 49 | 50 | As you can see the similarity scores for the first chunk that contains "Berlin" are very close to each other. 51 | For the other two chunks they siginificantly differ, as the late chunking dramatically improves matching on sentences that do not explicitly use the word "Berlin" but have anaphoric references to it. 52 | 53 | ## Evaluation on Retrieval Tasks 54 | 55 | 56 | To verify the effectiveness of this approach beyond a few toy examples, we tested it with some of the retrieval benchmarks from [BeIR](https://github.com/beir-cellar/beir). 57 | Those retrieval tasks consist of a query set, a corpus of text documents, and a QRels file that stores information about the IDs of documents that are relevant for each query. 58 | To identify the relevant documents of a query, one can chunk the documents, encode them into an embedding index, and determine for each query embedding the most similar chunks (kNN). 59 | As each chunk corresponds to a document, one can convert the kNN ranking of chunks into a kNN ranking of documents (for documents occurring multiple times in the ranking, only the first occurrence is retained). 60 | After that, one can compare the resulting ranking with the ranking corresponding to the ground-truth QRels file and calculate retrieval metrics like nDCG@10. 61 | We run this evaluation for various BeIR datasets with traditional chunking and our novel late chunking method. 62 | To split texts into chunks, we choose a straightforward method, which chunks the tests into strings of 256 tokens. 63 | Both the traditional and late chunking tests used the [jina-embeddings-v2-small-en](https://huggingface.co/jinaai/jina-embeddings-v2-small-en) model. 64 | 65 | | Dataset | AVG Document Length (characters) | Traditional Chunking (nDCG@10) | Late Chunking (nDCG@10) | No Chunking (nDCG@10) | 66 | |-----------|----------------------------------|--------------------------------|--------------------------------------|-----------------------| 67 | | SciFact | 1498.4 | 64.20% | **66.10%** | 63.89% | 68 | | TRECCOVID | 1116.7 | 63.36% | 64.70% | **65.18%** | 69 | | FiQA2018 | 767.2 | 33.25% | **33.84%** | 33.43% | 70 | | NFCorpus | 1589.8 | 23.46% | 29.98% | **30.40%** | 71 | | Quora | 62.2 | 87.19% | 87.19% | 87.19% | 72 | 73 | In all cases, late chunking improved the score. In some cases, it also outperforms encoding the whole document into a single embedding, while for other datasets, no chunking performs best. However, this only makes sense if one does not need to rank chunks. One can also see that the average length of the documents correlates with greater improvement in the nDCG scores through late chunking. 74 | 75 | To reporoduce the evaluation, you can install the dependencies with `pip install .` and run the following script for the tasks "SciFactChunked", "TRECCOVIDChunked", "FiQA2018Chunked", "NFCorpusChunked", and "QuoraChunked": 76 | 77 | ```bash 78 | python3 run_chunked_eval.py --task-name {TASK_NAME} 79 | ``` 80 | 81 | ## Acknowledgement and References 82 | 83 | Thanks to Isabelle Mohr([@violenil](https://github.com/violenil)) for contributing some code and Scott Martens ([@scott-martens](https://github.com/scott-martens)) for reviewing the README. 84 | 85 | More about the Evaluation tasks can be found in the [MTEB Repository](https://github.com/embeddings-benchmark/mteb) and details about the training of the models for long input text in our paper: ["Jina embeddings 2: 8192-token general-purpose text embeddings for long documents."](https://arxiv.org/abs/2310.19923) 86 | 87 | If you find Late Chunking useful in your research, use can cite the paper [Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models](https://arxiv.org/abs/2409.04701): 88 | 89 | ``` 90 | @article{gunther2024late, 91 | title={Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models}, 92 | author={G{\"u}nther, Michael and Mohr, Isabelle and Williams, Daniel J and Wang, Bo and Xiao, Han}, 93 | journal={arXiv preprint arXiv:2409.04701}, 94 | year={2024} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /chunked_pooling/__init__.py: -------------------------------------------------------------------------------- 1 | def chunk_by_sentences(input_text: str, tokenizer: callable): 2 | """ 3 | Split the input text into sentences using the tokenizer 4 | :param input_text: The text snippet to split into sentences 5 | :param tokenizer: The tokenizer to use 6 | :return: A tuple containing the list of text chunks and their corresponding token spans 7 | """ 8 | inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True) 9 | punctuation_mark_id = tokenizer.convert_tokens_to_ids('.') 10 | sep_id = tokenizer.convert_tokens_to_ids('[SEP]') 11 | token_offsets = inputs['offset_mapping'][0] 12 | token_ids = inputs['input_ids'][0] 13 | chunk_positions = [ 14 | (i, int(start + 1)) 15 | for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets)) 16 | if token_id == punctuation_mark_id 17 | and ( 18 | token_offsets[i + 1][0] - token_offsets[i][1] > 0 19 | or token_ids[i + 1] == sep_id 20 | ) 21 | ] 22 | chunks = [ 23 | input_text[x[1] : y[1]] 24 | for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions) 25 | ] 26 | span_annotations = [ 27 | (x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions) 28 | ] 29 | return chunks, span_annotations 30 | 31 | 32 | def chunked_pooling( 33 | model_output: 'BatchEncoding', span_annotation: list, max_length=None 34 | ): 35 | token_embeddings = model_output[0] 36 | outputs = [] 37 | for embeddings, annotations in zip(token_embeddings, span_annotation): 38 | if ( 39 | max_length is not None 40 | ): # remove annotations which go bejond the max-length of the model 41 | annotations = [ 42 | (start, min(end, max_length - 1)) 43 | for (start, end) in annotations 44 | if start < (max_length - 1) 45 | ] 46 | pooled_embeddings = [ 47 | embeddings[start:end].sum(dim=0) / (end - start) 48 | for start, end in annotations 49 | if (end - start) >= 1 50 | ] 51 | pooled_embeddings = [ 52 | embedding.float().detach().cpu().numpy() for embedding in pooled_embeddings 53 | ] 54 | outputs.append(pooled_embeddings) 55 | 56 | return outputs 57 | -------------------------------------------------------------------------------- /chunked_pooling/chunked_eval_tasks.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from mteb.abstasks.TaskMetadata import TaskMetadata 3 | 4 | from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval 5 | 6 | 7 | class SciFactChunked(AbsTaskChunkedRetrieval): 8 | metadata = TaskMetadata( 9 | name='SciFactChunked', 10 | dataset={ 11 | 'path': 'mteb/scifact', 12 | 'revision': '0228b52cf27578f30900b9e5271d331663a030d7', 13 | 'name': 'SciFact', 14 | }, 15 | description=( 16 | 'SciFact verifies scientific claims using evidence from the ' 17 | 'research literature containing scientific paper abstracts.' 18 | ), 19 | reference='https://github.com/allenai/scifact', 20 | type='Retrieval', 21 | category='s2p', 22 | eval_splits=['test'], 23 | eval_langs=['eng-Latn'], 24 | main_score='ndcg_at_10', 25 | date=None, 26 | form=None, 27 | domains=None, 28 | task_subtypes=None, 29 | license=None, 30 | socioeconomic_status=None, 31 | annotations_creators=None, 32 | dialect=None, 33 | text_creation=None, 34 | bibtex_citation=None, 35 | n_samples=None, 36 | avg_character_length=None, 37 | ) 38 | 39 | def __init__(self, **kwargs): 40 | super().__init__(**kwargs) 41 | 42 | 43 | class NarrativeQAChunked(AbsTaskChunkedRetrieval): 44 | metadata = TaskMetadata( 45 | name='NarrativeQAChunked', 46 | dataset={ 47 | 'path': 'narrativeqa', 48 | 'revision': '2e643e7363944af1c33a652d1c87320d0871c4e4', 49 | 'name': 'NarrativeQARetrieval', 50 | }, 51 | reference='https://metatext.io/datasets/narrativeqa', 52 | description=( 53 | 'NarrativeQA is a dataset for the task of question answering ' 54 | 'on long narratives. It consists of realistic QA instances ' 55 | 'collected from literature (fiction and non-fiction) ' 56 | 'and movie scripts. ' 57 | ), 58 | type='Retrieval', 59 | category='s2p', 60 | eval_splits=['test'], 61 | eval_langs=['eng-Latn'], 62 | main_score='ndcg_at_10', 63 | date=None, 64 | form=None, 65 | domains=None, 66 | task_subtypes=None, 67 | license=None, 68 | socioeconomic_status=None, 69 | annotations_creators=None, 70 | dialect=None, 71 | text_creation=None, 72 | bibtex_citation=None, 73 | n_samples=None, 74 | avg_character_length=None, 75 | ) 76 | 77 | def __init__(self, **kwargs): 78 | super().__init__(**kwargs) 79 | 80 | 81 | class NFCorpusChunked(AbsTaskChunkedRetrieval): 82 | metadata = TaskMetadata( 83 | name="NFCorpusChunked", 84 | dataset={ 85 | "path": "mteb/nfcorpus", 86 | "revision": "ec0fa4fe99da2ff19ca1214b7966684033a58814", 87 | 'name': 'NFCorpus', 88 | }, 89 | description="NFCorpus: A Full-Text Learning to Rank Dataset for Medical Information Retrieval", 90 | reference="https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/", 91 | type="Retrieval", 92 | category="s2p", 93 | eval_splits=["test"], 94 | eval_langs=["eng-Latn"], 95 | main_score="ndcg_at_10", 96 | date=None, 97 | form=None, 98 | domains=None, 99 | task_subtypes=None, 100 | license=None, 101 | socioeconomic_status=None, 102 | annotations_creators=None, 103 | dialect=None, 104 | text_creation=None, 105 | bibtex_citation=None, 106 | n_samples=None, 107 | avg_character_length=None, 108 | ) 109 | 110 | def __init__(self, **kwargs): 111 | super().__init__(**kwargs) 112 | 113 | 114 | class QuoraChunked(AbsTaskChunkedRetrieval): 115 | metadata = TaskMetadata( 116 | name="QuoraChunked", 117 | dataset={ 118 | "path": "mteb/quora", 119 | "revision": "e4e08e0b7dbe3c8700f0daef558ff32256715259", 120 | "name": "QuoraRetrieval", 121 | }, 122 | description=( 123 | "QuoraRetrieval is based on questions that are marked as duplicates on the Quora platform. Given a" 124 | " question, find other (duplicate) questions." 125 | ), 126 | reference="https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs", 127 | type="Retrieval", 128 | category="s2s", 129 | eval_splits=["dev", "test"], 130 | eval_langs=["eng-Latn"], 131 | main_score="ndcg_at_10", 132 | date=None, 133 | form=None, 134 | domains=None, 135 | task_subtypes=None, 136 | license=None, 137 | socioeconomic_status=None, 138 | annotations_creators=None, 139 | dialect=None, 140 | text_creation=None, 141 | bibtex_citation=None, 142 | n_samples=None, 143 | avg_character_length=None, 144 | ) 145 | 146 | def __init__(self, **kwargs): 147 | super().__init__(**kwargs) 148 | 149 | 150 | class FiQA2018Chunked(AbsTaskChunkedRetrieval): 151 | metadata = TaskMetadata( 152 | name="FiQA2018Chunked", 153 | description="Financial Opinion Mining and Question Answering", 154 | reference="https://sites.google.com/view/fiqa/", 155 | dataset={ 156 | "path": "mteb/fiqa", 157 | "revision": "27a168819829fe9bcd655c2df245fb19452e8e06", 158 | 'name': 'FiQA2018', 159 | }, 160 | type="Retrieval", 161 | category="s2p", 162 | eval_splits=["train", "dev", "test"], 163 | eval_langs=["eng-Latn"], 164 | main_score="ndcg_at_10", 165 | date=None, 166 | form=None, 167 | domains=None, 168 | task_subtypes=None, 169 | license=None, 170 | socioeconomic_status=None, 171 | annotations_creators=None, 172 | dialect=None, 173 | text_creation=None, 174 | bibtex_citation=None, 175 | n_samples=None, 176 | avg_character_length=None, 177 | ) 178 | 179 | def __init__(self, **kwargs): 180 | super().__init__(**kwargs) 181 | 182 | 183 | class TRECCOVIDChunked(AbsTaskChunkedRetrieval): 184 | metadata = TaskMetadata( 185 | name='TRECCOVIDChunked', 186 | description=( 187 | 'TRECCOVID is an ad-hoc search challenge based on the ' 188 | 'COVID-19 dataset containing scientific articles ' 189 | 'related to the COVID-19 pandemic.' 190 | ), 191 | reference='https://ir.nist.gov/covidSubmit/index.html', 192 | dataset={ 193 | 'path': 'mteb/trec-covid', 194 | 'revision': 'bb9466bac8153a0349341eb1b22e06409e78ef4e', 195 | 'name': 'TRECCOVID', 196 | }, 197 | type='Retrieval', 198 | category='s2p', 199 | eval_splits=['test'], 200 | eval_langs=['eng-Latn'], 201 | main_score='ndcg_at_10', 202 | date=None, 203 | form=None, 204 | domains=None, 205 | task_subtypes=None, 206 | license=None, 207 | socioeconomic_status=None, 208 | annotations_creators=None, 209 | dialect=None, 210 | text_creation=None, 211 | bibtex_citation=None, 212 | n_samples=None, 213 | avg_character_length=None, 214 | ) 215 | 216 | def __init__(self, **kwargs): 217 | super().__init__(**kwargs) 218 | 219 | 220 | class LEMBWikimQARetrievalChunked(AbsTaskChunkedRetrieval): 221 | """ 222 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py 223 | """ 224 | 225 | _EVAL_SPLIT = "test" 226 | 227 | metadata = TaskMetadata( 228 | name="LEMBWikimQARetrievalChunked", 229 | dataset={ 230 | "path": "dwzhu/LongEmbed", 231 | "revision": "10039a580487dacecf79db69166e17ace3ede392", 232 | "name": "LEMBWikimQARetrieval", 233 | }, 234 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed", 235 | description=("2wikimqa subset of dwzhu/LongEmbed dataset."), 236 | type="Retrieval", 237 | category="s2p", 238 | modalities=["text"], 239 | eval_splits=[_EVAL_SPLIT], 240 | eval_langs=["eng-Latn"], 241 | main_score="ndcg_at_10", 242 | date=("1950-01-01", "2019-12-31"), 243 | domains=None, 244 | socioeconomic_status=None, 245 | n_samples=None, 246 | avg_character_length=None, 247 | form=None, 248 | text_creation=None, 249 | task_subtypes=["Article retrieval"], 250 | license="not specified", 251 | annotations_creators="derived", 252 | dialect=[], 253 | sample_creation="found", 254 | bibtex_citation=""" 255 | @inproceedings{ho2020constructing, 256 | title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps}, 257 | author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko}, 258 | booktitle={Proceedings of the 28th International Conference on Computational Linguistics}, 259 | pages={6609--6625}, 260 | year={2020} 261 | } 262 | """, 263 | descriptive_stats={ 264 | "n_samples": {_EVAL_SPLIT: 500}, 265 | "avg_character_length": { 266 | "test": { 267 | "average_document_length": 37445.60333333333, 268 | "average_query_length": 67.57, 269 | "num_documents": 300, 270 | "num_queries": 300, 271 | "average_relevant_docs_per_query": 1.0, 272 | } 273 | }, 274 | }, 275 | ) 276 | 277 | def load_data(self, **kwargs): 278 | if self.data_loaded: 279 | return 280 | 281 | dataset_dict = {**self.metadata.dataset} 282 | dataset_dict['name'] = '2wikimqa' 283 | 284 | query_list = datasets.load_dataset(**dataset_dict)["queries"] 285 | queries = {row["qid"]: row["text"] for row in query_list} 286 | 287 | corpus_list = datasets.load_dataset(**dataset_dict)["corpus"] 288 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list} 289 | 290 | qrels_list = datasets.load_dataset(**dataset_dict)["qrels"] 291 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list} 292 | 293 | self.corpus = {self._EVAL_SPLIT: corpus} 294 | self.queries = {self._EVAL_SPLIT: queries} 295 | self.relevant_docs = {self._EVAL_SPLIT: qrels} 296 | 297 | self.data_loaded = True 298 | 299 | 300 | class LEMBSummScreenFDRetrievalChunked(AbsTaskChunkedRetrieval): 301 | """ 302 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py 303 | """ 304 | 305 | _EVAL_SPLIT = "test" 306 | 307 | metadata = TaskMetadata( 308 | name="LEMBSummScreenFDRetrievalChunked", 309 | dataset={ 310 | "path": "dwzhu/LongEmbed", 311 | "revision": "10039a580487dacecf79db69166e17ace3ede392", 312 | "name": "LEMBSummScreenFDRetrieval", 313 | }, 314 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed", 315 | description=("summ_screen_fd subset of dwzhu/LongEmbed dataset."), 316 | type="Retrieval", 317 | category="s2p", 318 | modalities=["text"], 319 | eval_splits=[_EVAL_SPLIT], 320 | eval_langs=["eng-Latn"], 321 | main_score="ndcg_at_10", 322 | date=("1950-01-01", "2019-12-31"), 323 | domains=None, 324 | socioeconomic_status=None, 325 | n_samples=None, 326 | avg_character_length=None, 327 | form=None, 328 | text_creation=None, 329 | task_subtypes=["Article retrieval"], 330 | license="not specified", 331 | annotations_creators="derived", 332 | dialect=[], 333 | sample_creation="found", 334 | bibtex_citation=""" 335 | @inproceedings{ho2020constructing, 336 | title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps}, 337 | author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko}, 338 | booktitle={Proceedings of the 28th International Conference on Computational Linguistics}, 339 | pages={6609--6625}, 340 | year={2020} 341 | } 342 | """, 343 | descriptive_stats={ 344 | "n_samples": {_EVAL_SPLIT: 500}, 345 | "avg_character_length": { 346 | "test": { 347 | "average_document_length": 30854.327, 348 | "average_query_length": 591.49, 349 | "num_documents": 300, 350 | "num_queries": 300, 351 | "average_relevant_docs_per_query": 1.0, 352 | } 353 | }, 354 | }, 355 | ) 356 | 357 | def load_data(self, **kwargs): 358 | if self.data_loaded: 359 | return 360 | 361 | dataset_dict = {**self.metadata.dataset} 362 | dataset_dict['name'] = 'summ_screen_fd' 363 | 364 | query_list = datasets.load_dataset(**dataset_dict)["queries"] 365 | queries = {row["qid"]: row["text"] for row in query_list} 366 | 367 | corpus_list = datasets.load_dataset(**dataset_dict)["corpus"] 368 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list} 369 | 370 | qrels_list = datasets.load_dataset(**dataset_dict)["qrels"] 371 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list} 372 | 373 | self.corpus = {self._EVAL_SPLIT: corpus} 374 | self.queries = {self._EVAL_SPLIT: queries} 375 | self.relevant_docs = {self._EVAL_SPLIT: qrels} 376 | 377 | self.data_loaded = True 378 | 379 | 380 | class LEMBQMSumRetrievalChunked(AbsTaskChunkedRetrieval): 381 | """ 382 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py 383 | """ 384 | 385 | _EVAL_SPLIT = "test" 386 | 387 | metadata = TaskMetadata( 388 | name="LEMBQMSumRetrievalChunked", 389 | dataset={ 390 | "path": "dwzhu/LongEmbed", 391 | "revision": "10039a580487dacecf79db69166e17ace3ede392", 392 | "name": "LEMBQMSumRetrieval", 393 | }, 394 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed", 395 | description=("qmsum subset of dwzhu/LongEmbed dataset."), 396 | type="Retrieval", 397 | category="s2p", 398 | modalities=["text"], 399 | eval_splits=[_EVAL_SPLIT], 400 | eval_langs=["eng-Latn"], 401 | main_score="ndcg_at_10", 402 | date=("1950-01-01", "2019-12-31"), 403 | domains=None, 404 | socioeconomic_status=None, 405 | n_samples=None, 406 | avg_character_length=None, 407 | form=None, 408 | text_creation=None, 409 | task_subtypes=["Article retrieval"], 410 | license="not specified", 411 | annotations_creators="derived", 412 | dialect=[], 413 | sample_creation="found", 414 | bibtex_citation=""" 415 | @inproceedings{ho2020constructing, 416 | title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps}, 417 | author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko}, 418 | booktitle={Proceedings of the 28th International Conference on Computational Linguistics}, 419 | pages={6609--6625}, 420 | year={2020} 421 | } 422 | """, 423 | descriptive_stats={ 424 | "n_samples": {_EVAL_SPLIT: 500}, 425 | "avg_character_length": { 426 | "test": { 427 | "average_document_length": 53335.817, 428 | "average_query_length": 433.50, 429 | "num_documents": 300, 430 | "num_queries": 300, 431 | "average_relevant_docs_per_query": 1.0, 432 | } 433 | }, 434 | }, 435 | ) 436 | 437 | def load_data(self, **kwargs): 438 | if self.data_loaded: 439 | return 440 | 441 | dataset_dict = {**self.metadata.dataset} 442 | dataset_dict['name'] = 'qmsum' 443 | 444 | query_list = datasets.load_dataset(**dataset_dict)["queries"] 445 | queries = {row["qid"]: row["text"] for row in query_list} 446 | 447 | corpus_list = datasets.load_dataset(**dataset_dict)["corpus"] 448 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list} 449 | 450 | qrels_list = datasets.load_dataset(**dataset_dict)["qrels"] 451 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list} 452 | 453 | self.corpus = {self._EVAL_SPLIT: corpus} 454 | self.queries = {self._EVAL_SPLIT: queries} 455 | self.relevant_docs = {self._EVAL_SPLIT: qrels} 456 | 457 | self.data_loaded = True 458 | 459 | 460 | class LEMBNeedleRetrievalChunked(AbsTaskChunkedRetrieval): 461 | """ 462 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBNeedleRetrieval.py 463 | """ 464 | 465 | _EVAL_SPLIT = [ 466 | "test_256", 467 | "test_512", 468 | "test_1024", 469 | "test_2048", 470 | "test_4096", 471 | "test_8192", 472 | "test_16384", 473 | "test_32768", 474 | ] 475 | 476 | metadata = TaskMetadata( 477 | name="LEMBNeedleRetrievalChunked", 478 | dataset={ 479 | "path": "dwzhu/LongEmbed", 480 | "revision": "6e346642246bfb4928c560ee08640dc84d074e8c", 481 | "name": "needle", 482 | }, 483 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed", 484 | description=("needle subset of dwzhu/LongEmbed dataset."), 485 | type="Retrieval", 486 | category="s2p", 487 | modalities=["text"], 488 | eval_splits=_EVAL_SPLIT, 489 | eval_langs=["eng-Latn"], 490 | main_score="ndcg_at_1", 491 | date=("2000-01-01", "2023-12-31"), 492 | domains=["Academic", "Blog", "Written"], 493 | task_subtypes=["Article retrieval"], 494 | license="not specified", 495 | annotations_creators="derived", 496 | dialect=[], 497 | sample_creation="found", 498 | bibtex_citation=""" 499 | @article{zhu2024longembed, 500 | title={LongEmbed: Extending Embedding Models for Long Context Retrieval}, 501 | author={Zhu, Dawei and Wang, Liang and Yang, Nan and Song, Yifan and Wu, Wenhao and Wei, Furu and Li, Sujian}, 502 | journal={arXiv preprint arXiv:2404.12096}, 503 | year={2024} 504 | } 505 | """, 506 | descriptive_stats={ 507 | "n_samples": { 508 | "test_256": 150, 509 | "test_512": 150, 510 | "test_1024": 150, 511 | "test_2048": 150, 512 | "test_4096": 150, 513 | "test_8192": 150, 514 | "test_16384": 150, 515 | "test_32768": 150, 516 | }, 517 | "avg_character_length": { 518 | "test_256": { 519 | "average_document_length": 1013.22, 520 | "average_query_length": 60.48, 521 | "num_documents": 100, 522 | "num_queries": 50, 523 | "average_relevant_docs_per_query": 1.0, 524 | }, 525 | "test_512": { 526 | "average_document_length": 2009.96, 527 | "average_query_length": 57.3, 528 | "num_documents": 100, 529 | "num_queries": 50, 530 | "average_relevant_docs_per_query": 1.0, 531 | }, 532 | "test_1024": { 533 | "average_document_length": 4069.9, 534 | "average_query_length": 58.28, 535 | "num_documents": 100, 536 | "num_queries": 50, 537 | "average_relevant_docs_per_query": 1.0, 538 | }, 539 | "test_2048": { 540 | "average_document_length": 8453.82, 541 | "average_query_length": 59.92, 542 | "num_documents": 100, 543 | "num_queries": 50, 544 | "average_relevant_docs_per_query": 1.0, 545 | }, 546 | "test_4096": { 547 | "average_document_length": 17395.8, 548 | "average_query_length": 55.86, 549 | "num_documents": 100, 550 | "num_queries": 50, 551 | "average_relevant_docs_per_query": 1.0, 552 | }, 553 | "test_8192": { 554 | "average_document_length": 35203.82, 555 | "average_query_length": 59.6, 556 | "num_documents": 100, 557 | "num_queries": 50, 558 | "average_relevant_docs_per_query": 1.0, 559 | }, 560 | "test_16384": { 561 | "average_document_length": 72054.8, 562 | "average_query_length": 59.12, 563 | "num_documents": 100, 564 | "num_queries": 50, 565 | "average_relevant_docs_per_query": 1.0, 566 | }, 567 | "test_32768": { 568 | "average_document_length": 141769.8, 569 | "average_query_length": 58.34, 570 | "num_documents": 100, 571 | "num_queries": 50, 572 | "average_relevant_docs_per_query": 1.0, 573 | }, 574 | }, 575 | }, 576 | ) 577 | 578 | def load_data(self, **kwargs): 579 | if self.data_loaded: 580 | return 581 | 582 | self.corpus = {} 583 | self.queries = {} 584 | self.relevant_docs = {} 585 | 586 | for split in self._EVAL_SPLIT: 587 | context_length = int(split.split("_")[1]) 588 | query_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ 589 | "queries" 590 | ] # dict_keys(['qid', 'text']) 591 | query_list = query_list.filter( 592 | lambda x: x["context_length"] == context_length 593 | ) 594 | queries = {row["qid"]: row["text"] for row in query_list} 595 | 596 | corpus_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ 597 | "corpus" 598 | ] # dict_keys(['doc_id', 'text']) 599 | corpus_list = corpus_list.filter( 600 | lambda x: x["context_length"] == context_length 601 | ) 602 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list} 603 | 604 | qrels_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ 605 | "qrels" 606 | ] # dict_keys(['qid', 'doc_id']) 607 | qrels_list = qrels_list.filter( 608 | lambda x: x["context_length"] == context_length 609 | ) 610 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list} 611 | 612 | self.corpus[split] = corpus 613 | self.queries[split] = queries 614 | self.relevant_docs[split] = qrels 615 | 616 | self.data_loaded = True 617 | 618 | 619 | class LEMBPasskeyRetrievalChunked(AbsTaskChunkedRetrieval): 620 | """ 621 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBPasskeyRetrieval.py 622 | """ 623 | 624 | _EVAL_SPLIT = [ 625 | "test_256", 626 | "test_512", 627 | "test_1024", 628 | "test_2048", 629 | "test_4096", 630 | "test_8192", 631 | "test_16384", 632 | "test_32768", 633 | ] 634 | 635 | metadata = TaskMetadata( 636 | name="LEMBPasskeyRetrievalChunked", 637 | dataset={ 638 | "path": "dwzhu/LongEmbed", 639 | "revision": "6e346642246bfb4928c560ee08640dc84d074e8c", 640 | "name": "passkey", 641 | }, 642 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed", 643 | description=("passkey subset of dwzhu/LongEmbed dataset."), 644 | type="Retrieval", 645 | category="s2p", 646 | modalities=["text"], 647 | eval_splits=_EVAL_SPLIT, 648 | eval_langs=["eng-Latn"], 649 | main_score="ndcg_at_1", 650 | date=("2000-01-01", "2023-12-31"), 651 | domains=["Fiction", "Written"], 652 | task_subtypes=["Article retrieval"], 653 | license="not specified", 654 | annotations_creators="derived", 655 | dialect=[], 656 | sample_creation="found", 657 | bibtex_citation=""" 658 | @article{zhu2024longembed, 659 | title={LongEmbed: Extending Embedding Models for Long Context Retrieval}, 660 | author={Zhu, Dawei and Wang, Liang and Yang, Nan and Song, Yifan and Wu, Wenhao and Wei, Furu and Li, Sujian}, 661 | journal={arXiv preprint arXiv:2404.12096}, 662 | year={2024} 663 | } 664 | """, 665 | descriptive_stats={ 666 | "n_samples": { 667 | "test_256": 150, 668 | "test_512": 150, 669 | "test_1024": 150, 670 | "test_2048": 150, 671 | "test_4096": 150, 672 | "test_8192": 150, 673 | "test_16384": 150, 674 | "test_32768": 150, 675 | }, 676 | "avg_character_length": { 677 | "test_256": { 678 | "average_document_length": 876.24, 679 | "average_query_length": 38.1, 680 | "num_documents": 100, 681 | "num_queries": 50, 682 | "average_relevant_docs_per_query": 1.0, 683 | }, 684 | "test_512": { 685 | "average_document_length": 1785.2, 686 | "average_query_length": 37.76, 687 | "num_documents": 100, 688 | "num_queries": 50, 689 | "average_relevant_docs_per_query": 1.0, 690 | }, 691 | "test_1024": { 692 | "average_document_length": 3607.18, 693 | "average_query_length": 37.68, 694 | "num_documents": 100, 695 | "num_queries": 50, 696 | "average_relevant_docs_per_query": 1.0, 697 | }, 698 | "test_2048": { 699 | "average_document_length": 7242.2, 700 | "average_query_length": 37.8, 701 | "num_documents": 100, 702 | "num_queries": 50, 703 | "average_relevant_docs_per_query": 1.0, 704 | }, 705 | "test_4096": { 706 | "average_document_length": 14518.16, 707 | "average_query_length": 37.64, 708 | "num_documents": 100, 709 | "num_queries": 50, 710 | "average_relevant_docs_per_query": 1.0, 711 | }, 712 | "test_8192": { 713 | "average_document_length": 29071.16, 714 | "average_query_length": 37.54, 715 | "num_documents": 100, 716 | "num_queries": 50, 717 | "average_relevant_docs_per_query": 1.0, 718 | }, 719 | "test_16384": { 720 | "average_document_length": 58175.16, 721 | "average_query_length": 38.12, 722 | "num_documents": 100, 723 | "num_queries": 50, 724 | "average_relevant_docs_per_query": 1.0, 725 | }, 726 | "test_32768": { 727 | "average_document_length": 116380.16, 728 | "average_query_length": 37.74, 729 | "num_documents": 100, 730 | "num_queries": 50, 731 | "average_relevant_docs_per_query": 1.0, 732 | }, 733 | }, 734 | }, 735 | ) 736 | 737 | def load_data(self, **kwargs): 738 | if self.data_loaded: 739 | return 740 | 741 | self.corpus = {} 742 | self.queries = {} 743 | self.relevant_docs = {} 744 | 745 | for split in self._EVAL_SPLIT: 746 | context_length = int(split.split("_")[1]) 747 | query_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ 748 | "queries" 749 | ] # dict_keys(['qid', 'text']) 750 | query_list = query_list.filter( 751 | lambda x: x["context_length"] == context_length 752 | ) 753 | queries = {row["qid"]: row["text"] for row in query_list} 754 | 755 | corpus_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ 756 | "corpus" 757 | ] # dict_keys(['doc_id', 'text']) 758 | corpus_list = corpus_list.filter( 759 | lambda x: x["context_length"] == context_length 760 | ) 761 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list} 762 | 763 | qrels_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ 764 | "qrels" 765 | ] # dict_keys(['qid', 'doc_id']) 766 | qrels_list = qrels_list.filter( 767 | lambda x: x["context_length"] == context_length 768 | ) 769 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list} 770 | 771 | self.corpus[split] = corpus 772 | self.queries[split] = queries 773 | self.relevant_docs[split] = qrels 774 | 775 | self.data_loaded = True 776 | -------------------------------------------------------------------------------- /chunked_pooling/chunking.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import logging 3 | from typing import Dict, List, Optional, Tuple, Union 4 | 5 | from llama_index.core.node_parser import SemanticSplitterNodeParser 6 | from llama_index.core.schema import Document 7 | from llama_index.embeddings.huggingface import HuggingFaceEmbedding 8 | from transformers import AutoTokenizer 9 | 10 | # Set the logging level to WARNING to suppress INFO and DEBUG messages 11 | logging.getLogger('sentence_transformers').setLevel(logging.WARNING) 12 | 13 | CHUNKING_STRATEGIES = ['semantic', 'fixed', 'sentences'] 14 | 15 | 16 | class Chunker: 17 | def __init__( 18 | self, 19 | chunking_strategy: str, 20 | ): 21 | if chunking_strategy not in CHUNKING_STRATEGIES: 22 | raise ValueError("Unsupported chunking strategy: ", chunking_strategy) 23 | self.chunking_strategy = chunking_strategy 24 | self.embed_model = None 25 | self.embedding_model_name = None 26 | 27 | def _setup_semantic_chunking(self, embedding_model_name): 28 | if embedding_model_name: 29 | self.embedding_model_name = embedding_model_name 30 | 31 | self.embed_model = HuggingFaceEmbedding( 32 | model_name=self.embedding_model_name, 33 | trust_remote_code=True, 34 | embed_batch_size=1, 35 | ) 36 | self.splitter = SemanticSplitterNodeParser( 37 | embed_model=self.embed_model, 38 | show_progress=False, 39 | ) 40 | 41 | def chunk_semantically( 42 | self, 43 | text: str, 44 | tokenizer: 'AutoTokenizer', 45 | embedding_model_name: Optional[str] = None, 46 | ) -> List[Tuple[int, int]]: 47 | if self.embed_model is None: 48 | self._setup_semantic_chunking(embedding_model_name) 49 | 50 | # Get semantic nodes 51 | nodes = [ 52 | (node.start_char_idx, node.end_char_idx) 53 | for node in self.splitter.get_nodes_from_documents( 54 | [Document(text=text)], show_progress=False 55 | ) 56 | ] 57 | 58 | # Tokenize the entire text 59 | tokens = tokenizer.encode_plus( 60 | text, 61 | return_offsets_mapping=True, 62 | add_special_tokens=False, 63 | padding=True, 64 | truncation=True, 65 | ) 66 | token_offsets = tokens.offset_mapping 67 | 68 | chunk_spans = [] 69 | 70 | for char_start, char_end in nodes: 71 | # Convert char indices to token indices 72 | start_chunk_index = bisect.bisect_left( 73 | [offset[0] for offset in token_offsets], char_start 74 | ) 75 | end_chunk_index = bisect.bisect_right( 76 | [offset[1] for offset in token_offsets], char_end 77 | ) 78 | 79 | # Add the chunk span if it's within the tokenized text 80 | if start_chunk_index < len(token_offsets) and end_chunk_index <= len( 81 | token_offsets 82 | ): 83 | chunk_spans.append((start_chunk_index, end_chunk_index)) 84 | else: 85 | break 86 | 87 | return chunk_spans 88 | 89 | def chunk_by_tokens( 90 | self, 91 | text: str, 92 | chunk_size: int, 93 | tokenizer: 'AutoTokenizer', 94 | ) -> List[Tuple[int, int, int]]: 95 | tokens = tokenizer.encode_plus( 96 | text, return_offsets_mapping=True, add_special_tokens=False 97 | ) 98 | token_offsets = tokens.offset_mapping 99 | 100 | chunk_spans = [] 101 | for i in range(0, len(token_offsets), chunk_size): 102 | chunk_end = min(i + chunk_size, len(token_offsets)) 103 | if chunk_end - i > 0: 104 | chunk_spans.append((i, chunk_end)) 105 | 106 | return chunk_spans 107 | 108 | def chunk_by_sentences( 109 | self, 110 | text: str, 111 | n_sentences: int, 112 | tokenizer: 'AutoTokenizer', 113 | ) -> List[Tuple[int, int, int]]: 114 | tokens = tokenizer.encode_plus( 115 | text, return_offsets_mapping=True, add_special_tokens=False 116 | ) 117 | token_offsets = tokens.offset_mapping 118 | 119 | chunk_spans = [] 120 | chunk_start = 0 121 | count_chunks = 0 122 | for i in range(0, len(token_offsets)): 123 | if tokens.tokens(0)[i] in ('.', '!', '?') and ( 124 | (len(tokens.tokens(0)) == i + 1) 125 | or (tokens.token_to_chars(i).end != tokens.token_to_chars(i + 1).start) 126 | ): 127 | count_chunks += 1 128 | if count_chunks == n_sentences: 129 | chunk_spans.append((chunk_start, i + 1)) 130 | chunk_start = i + 1 131 | count_chunks = 0 132 | if len(tokens.tokens(0)) - chunk_start > 1: 133 | chunk_spans.append((chunk_start, len(tokens.tokens(0)))) 134 | return chunk_spans 135 | 136 | def chunk( 137 | self, 138 | text: str, 139 | tokenizer: 'AutoTokenizer', 140 | chunking_strategy: str = None, 141 | chunk_size: Optional[int] = None, 142 | n_sentences: Optional[int] = None, 143 | embedding_model_name: Optional[str] = None, 144 | ): 145 | chunking_strategy = chunking_strategy or self.chunking_strategy 146 | if chunking_strategy == "semantic": 147 | return self.chunk_semantically( 148 | text, 149 | embedding_model_name=embedding_model_name, 150 | tokenizer=tokenizer, 151 | ) 152 | elif chunking_strategy == "fixed": 153 | if chunk_size < 4: 154 | raise ValueError("Chunk size must be >= 4.") 155 | return self.chunk_by_tokens(text, chunk_size, tokenizer) 156 | elif chunking_strategy == "sentences": 157 | return self.chunk_by_sentences(text, n_sentences, tokenizer) 158 | else: 159 | raise ValueError("Unsupported chunking strategy") 160 | -------------------------------------------------------------------------------- /chunked_pooling/mteb_chunked_eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Optional 3 | 4 | import numpy as np 5 | import torch 6 | from mteb.abstasks import AbsTask 7 | from mteb.evaluation.evaluators import RetrievalEvaluator 8 | from mteb.load_results.mteb_results import ScoresDict 9 | from mteb.tasks import Retrieval 10 | from tqdm import tqdm 11 | 12 | from chunked_pooling import chunked_pooling 13 | from chunked_pooling.chunking import Chunker 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class AbsTaskChunkedRetrieval(AbsTask): 19 | def __init__( 20 | self, 21 | chunking_strategy: str = None, 22 | chunked_pooling_enabled: bool = False, 23 | tokenizer: Optional[Any] = None, 24 | prune_size: Optional[int] = None, 25 | chunk_size: Optional[int] = None, 26 | n_sentences: Optional[int] = None, 27 | model_has_instructions: bool = False, 28 | embedding_model_name: Optional[str] = None, # for semantic chunking 29 | truncate_max_length: Optional[int] = 8192, 30 | long_late_chunking_embed_size: Optional[int] = 0, 31 | long_late_chunking_overlap_size: Optional[int] = 512, 32 | **kwargs, 33 | ): 34 | super().__init__(**kwargs) 35 | try: 36 | self.retrieval_task = getattr( 37 | Retrieval, 38 | self.metadata_dict['dataset'].get('name', None) 39 | or self.metadata_dict.get('name'), 40 | )() 41 | except: 42 | logger.warning('Could not initialize retrieval_task') 43 | self.chunking_strategy = chunking_strategy 44 | self.chunker = Chunker(self.chunking_strategy) 45 | self.chunked_pooling_enabled = chunked_pooling_enabled 46 | self.tokenizer = tokenizer 47 | self.prune_size = prune_size 48 | self.model_has_instructions = model_has_instructions 49 | self.chunking_args = { 50 | 'chunk_size': chunk_size, 51 | 'n_sentences': n_sentences, 52 | 'embedding_model_name': embedding_model_name, 53 | } 54 | self.truncate_max_length = ( 55 | truncate_max_length if truncate_max_length is not None and truncate_max_length > 0 else None 56 | ) 57 | 58 | self.long_late_chunking_embed_size = long_late_chunking_embed_size 59 | self.long_late_chunking_overlap_size = long_late_chunking_overlap_size 60 | 61 | def load_data(self, **kwargs): 62 | self.retrieval_task.load_data(**kwargs) 63 | self.corpus = self.retrieval_task.corpus 64 | self.queries = self.retrieval_task.queries 65 | self.relevant_docs = self.retrieval_task.relevant_docs 66 | # prune dataset 67 | if self.prune_size: 68 | self.queries, self.corpus, self.relevant_docs = self._prune( 69 | self.queries, self.corpus, self.relevant_docs, self.prune_size 70 | ) 71 | 72 | def calculate_metadata_metrics(self): 73 | self.retrieval_task.calculate_metadata_metrics() 74 | 75 | def evaluate( 76 | self, model, split: str = "test", encode_kwargs: dict[str, Any] = {}, **kwargs 77 | ) -> dict[str, ScoresDict]: 78 | scores: dict[str, ScoresDict] = {} 79 | hf_subsets = list(self.hf_subsets) if self.is_multilingual else ["default"] 80 | 81 | for hf_subset in hf_subsets: 82 | logger.info(f"Subset: {hf_subset}") 83 | 84 | if hf_subset == "default": 85 | corpus, queries, relevant_docs = ( 86 | self.corpus[split], 87 | self.queries[split], 88 | self.relevant_docs[split], 89 | ) 90 | else: 91 | corpus, queries, relevant_docs = ( 92 | self.corpus[hf_subset][split], 93 | self.queries[hf_subset][split], 94 | self.relevant_docs[hf_subset][split], 95 | ) 96 | 97 | scores[hf_subset] = self._evaluate_monolingual( 98 | model, 99 | corpus, 100 | queries, 101 | relevant_docs, 102 | hf_subset, 103 | encode_kwargs=encode_kwargs, 104 | **kwargs, 105 | ) 106 | 107 | return scores 108 | 109 | def _truncate_documents(self, corpus): 110 | for k, v in corpus.items(): 111 | title_tokens = 0 112 | if 'title' in v: 113 | tokens = self.tokenizer( 114 | v['title'] + ' ', 115 | return_offsets_mapping=True, 116 | max_length=self.truncate_max_length, 117 | ) 118 | title_tokens = len(tokens.input_ids) 119 | tokens = self.tokenizer( 120 | v['text'], 121 | return_offsets_mapping=True, 122 | max_length=self.truncate_max_length - title_tokens, 123 | ) 124 | last_token_span = tokens.offset_mapping[-2] 125 | v['text'] = v['text'][: last_token_span[1]] 126 | return corpus 127 | 128 | def _embed_with_overlap(self, model, model_inputs): 129 | len_tokens = len(model_inputs["input_ids"][0]) 130 | 131 | if len_tokens > self.long_late_chunking_embed_size: 132 | indices = [] 133 | for i in range( 134 | 0, 135 | len_tokens, 136 | self.long_late_chunking_embed_size 137 | - self.long_late_chunking_overlap_size, 138 | ): 139 | start = i 140 | end = min(i + self.long_late_chunking_embed_size, len_tokens) 141 | indices.append((start, end)) 142 | else: 143 | indices = [(0, len_tokens)] 144 | 145 | outputs = [] 146 | for start, end in indices: 147 | batch_inputs = {k: v[:, start:end] for k, v in model_inputs.items()} 148 | 149 | with torch.no_grad(): 150 | model_output = model(**batch_inputs) 151 | 152 | if start > 0: 153 | outputs.append( 154 | model_output[0][:, self.long_late_chunking_overlap_size :] 155 | ) 156 | else: 157 | outputs.append(model_output[0]) 158 | 159 | return torch.cat(outputs, dim=1).to(model.device) 160 | 161 | def _evaluate_monolingual( 162 | self, 163 | model, 164 | corpus, 165 | queries, 166 | relevant_docs, 167 | lang=None, 168 | batch_size=1, 169 | encode_kwargs=None, 170 | **kwargs, 171 | ): 172 | if self.truncate_max_length: 173 | corpus = self._truncate_documents(corpus) 174 | # split corpus into chunks 175 | if not self.chunked_pooling_enabled: 176 | corpus = self._apply_chunking(corpus, self.tokenizer) 177 | max_chunks = max([len(x) for x in corpus.values()]) 178 | corpus = self._flatten_chunks(corpus) 179 | k_values = self._calculate_k_values(max_chunks) 180 | # determine the maximum number of documents to consider in a ranking 181 | max_k = int(max(k_values) / max_chunks) 182 | retriever = RetrievalEvaluator( 183 | model, 184 | k_values=k_values, 185 | encode_kwargs=(encode_kwargs or dict()), 186 | **kwargs, 187 | ) 188 | results = retriever(corpus, queries) 189 | else: 190 | query_ids = list(queries.keys()) 191 | query_texts = [queries[k] for k in query_ids] 192 | if hasattr(model, 'encode_queries'): 193 | query_embs = model.encode_queries(query_texts) 194 | else: 195 | query_embs = model.encode(query_texts) 196 | 197 | corpus_ids = list(corpus.keys()) 198 | corpus_texts = [ 199 | ( 200 | f"{corpus[k]['title']} {corpus[k]['text']}" 201 | if 'title' in corpus[k] 202 | else corpus[k]['text'] 203 | ) 204 | for k in corpus_ids 205 | ] 206 | 207 | chunk_annotations = self._calculate_annotations(model, corpus_texts) 208 | 209 | corpus_embs = [] 210 | with torch.no_grad(): 211 | for inputs in tqdm( 212 | self._batch_inputs( 213 | list(zip(corpus_texts, chunk_annotations)), 214 | batch_size=batch_size, 215 | ), 216 | total=(len(corpus_texts) // batch_size), 217 | ): 218 | if self.model_has_instructions: 219 | instr = model.get_instructions()[1] 220 | else: 221 | instr = '' 222 | text_inputs = [instr + x[0] for x in inputs] 223 | annotations = [x[1] for x in inputs] 224 | model_inputs = self.tokenizer( 225 | text_inputs, 226 | return_tensors='pt', 227 | padding=True, 228 | truncation=self.truncate_max_length is not None, 229 | max_length=self.truncate_max_length, 230 | ) 231 | if model.device.type == 'cuda': 232 | model_inputs = { 233 | k: v.to(model.device) for k, v in model_inputs.items() 234 | } 235 | 236 | if self.long_late_chunking_embed_size > 0: 237 | model_outputs = self._embed_with_overlap(model, model_inputs) 238 | output_embs = chunked_pooling( 239 | [model_outputs], annotations, max_length=None 240 | ) 241 | else: # truncation 242 | model_outputs = model(**model_inputs) 243 | output_embs = chunked_pooling( 244 | model_outputs, 245 | annotations, 246 | max_length=self.truncate_max_length, 247 | ) 248 | corpus_embs.extend(output_embs) 249 | 250 | max_chunks = max([len(x) for x in corpus_embs]) 251 | k_values = self._calculate_k_values(max_chunks) 252 | # determine the maximum number of documents to consider in a ranking 253 | max_k = int(max(k_values) / max_chunks) 254 | ( 255 | chunk_id_list, 256 | doc_to_chunk, 257 | flattened_corpus_embs, 258 | ) = self.flatten_corpus_embs(corpus_embs, corpus_ids) 259 | similarity_matrix = np.dot(query_embs, flattened_corpus_embs.T) 260 | results = self.get_results( 261 | chunk_id_list, k_values, query_ids, similarity_matrix 262 | ) 263 | 264 | doc_results = self.get_doc_results(results) 265 | 266 | ndcg, _map, recall, precision, _ = RetrievalEvaluator.evaluate( 267 | relevant_docs, 268 | doc_results, 269 | [k for k in k_values if k <= max_k], 270 | ignore_identical_ids=kwargs.get('ignore_identical_ids', True), 271 | ) 272 | mrr, _ = RetrievalEvaluator.evaluate_custom( 273 | relevant_docs, 274 | doc_results, 275 | [k for k in k_values if k <= max_k], 276 | 'mrr', 277 | ) 278 | scores = { 279 | **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, 280 | **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, 281 | **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, 282 | **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, 283 | **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()}, 284 | } 285 | self._add_main_score(scores) 286 | return scores 287 | 288 | def _add_main_score(self, scores: ScoresDict) -> None: 289 | scores["main_score"] = scores[self.metadata.main_score] 290 | 291 | def get_results(self, chunk_id_list, k_values, query_ids, similarity_matrix): 292 | results = {} 293 | for i, query_id in enumerate(query_ids): 294 | query_results = {} 295 | for idx, score in enumerate(similarity_matrix[i]): 296 | chunk_id = chunk_id_list[idx] 297 | query_results[chunk_id] = score 298 | # Sort results by score and only keep the top k scores 299 | sorted_query_results = dict( 300 | sorted(query_results.items(), key=lambda item: item[1], reverse=True)[ 301 | : max(k_values) 302 | ] 303 | ) 304 | results[query_id] = sorted_query_results 305 | return results 306 | 307 | def flatten_corpus_embs(self, corpus_embs, corpus_ids): 308 | doc_to_chunk = {} 309 | flattened_corpus_embs = [] 310 | chunk_id_list = [] 311 | for doc_id, emb in zip(corpus_ids, corpus_embs): 312 | for i, chunk in enumerate(emb): 313 | flattened_corpus_embs.append(chunk) 314 | doc_to_chunk[f"{doc_id}~{i}"] = doc_id 315 | chunk_id_list.append(f"{doc_id}~{i}") 316 | flattened_corpus_embs = np.vstack(flattened_corpus_embs) 317 | flattened_corpus_embs = self._normalize(flattened_corpus_embs) 318 | return chunk_id_list, doc_to_chunk, flattened_corpus_embs 319 | 320 | @staticmethod 321 | def get_doc_results(results): 322 | doc_results = dict() 323 | for q, result_chunks in results.items(): 324 | docs = dict() 325 | for c_id, score in result_chunks.items(): 326 | d_id = '~'.join(c_id.split('~')[:-1]) 327 | if (d_id not in docs) or (score > docs[d_id]): 328 | docs[d_id] = float(score) 329 | doc_results[q] = docs 330 | return doc_results 331 | 332 | def _calculate_k_values(self, max_chunks): 333 | k_values = [1, 3, 5, 10, 20] 334 | n = 2 335 | while 10**n < 100 * max_chunks: 336 | k_values.append(10**n) 337 | n += 1 338 | return k_values 339 | 340 | def _apply_chunking(self, corpus, tokenizer): 341 | chunked_corpus = dict() 342 | for k, v in corpus.items(): 343 | text = f"{v['title']} {v['text']}" if 'title' in v else v['text'] 344 | current_doc = [] 345 | chunk_annotations = self.chunker.chunk( 346 | text, 347 | tokenizer, 348 | chunking_strategy=self.chunking_strategy, 349 | **self.chunking_args, 350 | ) 351 | tokens = tokenizer.encode_plus(text, add_special_tokens=False) 352 | for start_token_idx, end_token_idx in chunk_annotations: 353 | text_chunk = tokenizer.decode( 354 | tokens.encodings[0].ids[start_token_idx:end_token_idx] 355 | ) 356 | current_doc.append({'text': text_chunk}) 357 | chunked_corpus[k] = current_doc 358 | return chunked_corpus 359 | 360 | def _calculate_annotations(self, model, corpus_texts): 361 | if self.model_has_instructions: 362 | instr = model.get_instructions()[1] 363 | instr_tokens = self.tokenizer(instr, add_special_tokens=False) 364 | n_instruction_tokens = len(instr_tokens[0]) 365 | else: 366 | n_instruction_tokens = 0 367 | chunk_annotations = [ 368 | self._extend_special_tokens( 369 | self.chunker.chunk( 370 | text, 371 | self.tokenizer, 372 | chunking_strategy=self.chunking_strategy, 373 | **self.chunking_args, 374 | ), 375 | n_instruction_tokens=n_instruction_tokens, 376 | ) 377 | for text in corpus_texts 378 | ] 379 | return chunk_annotations 380 | 381 | @staticmethod 382 | def _flatten_chunks(chunked_corpus): 383 | flattened_corpus = dict() 384 | for k, li in chunked_corpus.items(): 385 | for i, c in enumerate(li): 386 | flattened_corpus[f'{k}~{i}'] = c 387 | 388 | return flattened_corpus 389 | 390 | @staticmethod 391 | def _normalize(x): 392 | return x / np.linalg.norm(x, axis=1)[:, None] 393 | 394 | @staticmethod 395 | def _batch_inputs(li, batch_size): 396 | for i in range(0, len(li), batch_size): 397 | yield li[i : i + batch_size] 398 | 399 | @staticmethod 400 | def _extend_special_tokens( 401 | annotations, n_instruction_tokens=0, include_prefix=True, include_sep=True 402 | ): 403 | """Extends the spans because of additional special tokens, e.g. the CLS token 404 | which are not considered by the chunker. 405 | """ 406 | new_annotations = [] 407 | for i in range(len(annotations)): 408 | add_left_offset = 1 if (not include_prefix) or int(i > 0) else 0 409 | left_offset = 1 + n_instruction_tokens 410 | left = ( 411 | annotations[i][0] + add_left_offset * left_offset 412 | ) # move everything by one for [CLS] 413 | 414 | add_sep = 1 if include_sep and ((i + 1) == len(annotations)) else 0 415 | right_offset = left_offset + add_sep 416 | right = ( 417 | annotations[i][1] + right_offset 418 | ) # move everything by one for [CLS] and the last one for [SEP] 419 | 420 | new_annotations.append((left, right)) 421 | return new_annotations 422 | 423 | @staticmethod 424 | def _prune(queries, corpus, relevant_docs, prune_size): 425 | new_queries = {'test': {}} 426 | new_corpus = {'test': {}} 427 | new_relevant_docs = {'test': {}} 428 | for i, key in enumerate(relevant_docs['test']): 429 | if i >= prune_size: 430 | break 431 | new_relevant_docs['test'][key] = relevant_docs['test'][key] 432 | for x in relevant_docs['test'][key]: 433 | new_corpus['test'][x] = corpus['test'][x] 434 | new_queries['test'][key] = queries['test'][key] 435 | return new_queries, new_corpus, new_relevant_docs 436 | 437 | def _calculate_metrics_from_split(*args, **kwargs): 438 | pass 439 | 440 | def _evaluate_subset(*args, **kwargs): 441 | pass 442 | -------------------------------------------------------------------------------- /chunked_pooling/wrappers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from sentence_transformers import SentenceTransformer 7 | from transformers import AutoModel 8 | from transformers.modeling_outputs import BaseModelOutputWithPooling 9 | 10 | 11 | def construct_document(doc): 12 | if isinstance(doc, str): 13 | return doc 14 | elif 'title' in doc: 15 | return f'{doc["title"]} {doc["text"].strip()}' 16 | else: 17 | return doc['text'].strip() 18 | 19 | 20 | class JinaEmbeddingsV3Wrapper(nn.Module): 21 | def __init__( 22 | self, model_name, tasks=['retrieval.query', 'retrieval.passage'], **model_kwargs 23 | ): 24 | super().__init__() 25 | self._model = AutoModel.from_pretrained( 26 | model_name, trust_remote_code=True, **model_kwargs 27 | ) 28 | self.tasks = tasks 29 | 30 | def encode_queries( 31 | self, 32 | sentences: Union[str, List[str]], 33 | *args, 34 | task: Optional[str] = None, 35 | **kwargs, 36 | ): 37 | return self._model.encode(sentences, *args, task=self.tasks[0], **kwargs) 38 | 39 | def encode_corpus( 40 | self, 41 | sentences: Union[str, List[str]], 42 | *args, 43 | **kwargs, 44 | ): 45 | _sentences = [construct_document(sentence) for sentence in sentences] 46 | return self._model.encode(_sentences, *args, task=self.tasks[1], **kwargs) 47 | 48 | def get_instructions(self): 49 | return [self._model._task_instructions[x] for x in self.tasks] 50 | 51 | def forward(self, *args, **kwargs): 52 | task_id = self._model._adaptation_map[self.tasks[1]] 53 | num_examples = kwargs['input_ids'].shape[0] 54 | adapter_mask = torch.full( 55 | (num_examples,), task_id, dtype=torch.int32, device=self._model.device 56 | ) 57 | return self._model.forward(*args, adapter_mask=adapter_mask, **kwargs) 58 | 59 | @property 60 | def device(self): 61 | return self._model.device 62 | 63 | @staticmethod 64 | def has_instructions(): 65 | return True 66 | 67 | 68 | class NomicAIWrapper(nn.Module): 69 | def __init__(self, model_name, **model_kwargs): 70 | super().__init__() 71 | self._model = SentenceTransformer( 72 | model_name, trust_remote_code=True, **model_kwargs 73 | ) 74 | self.instructions = ['search_query: ', 'search_document: '] 75 | 76 | def get_instructions(self): 77 | return self.instructions 78 | 79 | def forward(self, *args, **kwargs): 80 | model_output = self._model.forward(kwargs) 81 | base_model_output = BaseModelOutputWithPooling( 82 | last_hidden_state=model_output['token_embeddings'], 83 | pooler_output=model_output['sentence_embedding'], 84 | attentions=model_output['attention_mask'], 85 | ) 86 | return base_model_output 87 | 88 | def encode_queries( 89 | self, 90 | sentences: Union[str, List[str]], 91 | *args, 92 | **kwargs, 93 | ): 94 | return self._model.encode( 95 | [self.instructions[0] + s for s in sentences], *args, **kwargs 96 | ) 97 | 98 | def encode_corpus( 99 | self, 100 | sentences: Union[str, List[str]], 101 | *args, 102 | **kwargs, 103 | ): 104 | return self._model.encode( 105 | [self.instructions[1] + construct_document(s) for s in sentences], 106 | *args, 107 | **kwargs, 108 | ) 109 | 110 | @property 111 | def device(self): 112 | return self._model.device 113 | 114 | @staticmethod 115 | def has_instructions(): 116 | return True 117 | 118 | 119 | MODEL_WRAPPERS = { 120 | 'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper, 121 | 'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer, 122 | 'nomic-ai/nomic-embed-text-v1': NomicAIWrapper, 123 | } 124 | 125 | MODELS_WITHOUT_PROMPT_NAME_ARG = [ 126 | 'jinaai/jina-embeddings-v2-small-en', 127 | 'jinaai/jina-embeddings-v2-base-en', 128 | 'jinaai/jina-embeddings-v3', 129 | ] 130 | 131 | 132 | def remove_unsupported_kwargs(original_encode): 133 | def wrapper(self, *args, **kwargs): 134 | # Remove 'prompt_name' from kwargs if present 135 | kwargs.pop('prompt_name', None) 136 | kwargs.pop('request_qid', None) 137 | return original_encode(self, *args, **kwargs) 138 | 139 | return wrapper 140 | 141 | 142 | def load_model(model_name, model_weights=None, **model_kwargs): 143 | if model_name in MODEL_WRAPPERS: 144 | model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs) 145 | if hasattr(MODEL_WRAPPERS[model_name], 'has_instructions'): 146 | has_instructions = MODEL_WRAPPERS[model_name].has_instructions() 147 | else: 148 | has_instructions = False 149 | else: 150 | model = AutoModel.from_pretrained(model_name, trust_remote_code=True) 151 | has_instructions = False 152 | 153 | if model_weights and os.path.exists(model_weights): 154 | model._model.load_state_dict(torch.load(model_weights, device=model.device)) 155 | 156 | # encode functions of various models do not support all sentence transformers kwargs parameter 157 | if model_name in MODELS_WITHOUT_PROMPT_NAME_ARG: 158 | ENCODE_FUNC_NAMES = ['encode', 'encode_queries', 'encode_corpus'] 159 | for func_name in ENCODE_FUNC_NAMES: 160 | if hasattr(model, func_name): 161 | setattr( 162 | model, 163 | func_name, 164 | remove_unsupported_kwargs(getattr(model, func_name)), 165 | ) 166 | 167 | return model, has_instructions 168 | -------------------------------------------------------------------------------- /examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e1173893c4f0ea56", 6 | "metadata": { 7 | "collapsed": false, 8 | "jupyter": { 9 | "outputs_hidden": false 10 | } 11 | }, 12 | "source": [ 13 | "# Chunked Pooling\n", 14 | "This notebooks explains how the chunked pooling can be implemented. First you need to install the requirements: " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "d02a920f-cde0-4035-9834-49b087aab5cc", 21 | "metadata": { 22 | "is_executing": true 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "!pip install -r requirements.txt" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "58a8fbc1e477db48", 32 | "metadata": { 33 | "collapsed": false, 34 | "jupyter": { 35 | "outputs_hidden": false 36 | } 37 | }, 38 | "source": [ 39 | "Then we load a model which we want to use for the embedding. We choose `jinaai/jina-embeddings-v2-base-en` but any other model which supports mean pooling is possible. However, models with a large maximum context-length are preferred." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "id": "1380abf7acde9517", 46 | "metadata": { 47 | "collapsed": false, 48 | "jupyter": { 49 | "outputs_hidden": false 50 | } 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stderr", 55 | "output_type": "stream", 56 | "text": [ 57 | "/home/michael/workspace/chunked-pooling/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 58 | " from .autonotebook import tqdm as notebook_tqdm\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "from transformers import AutoModel\n", 64 | "from transformers import AutoTokenizer\n", 65 | "\n", 66 | "from chunked_pooling import chunked_pooling, chunk_by_sentences\n", 67 | "\n", 68 | "# load model and tokenizer\n", 69 | "tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)\n", 70 | "model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "2cc0c1162797ffb0", 76 | "metadata": { 77 | "collapsed": false, 78 | "jupyter": { 79 | "outputs_hidden": false 80 | } 81 | }, 82 | "source": [ 83 | "Now we define the text which we want to encode and split it into chunks. The `chunk_by_sentences` function also returns the span annotations. Those specify the number of tokens per chunk which is needed for the chunked pooling." 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 2, 89 | "id": "8ef392f3437ef82e", 90 | "metadata": { 91 | "collapsed": false, 92 | "jupyter": { 93 | "outputs_hidden": false 94 | } 95 | }, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "Chunks:\n", 102 | "- \"Berlin is the capital and largest city of Germany, both by area and by population.\"\n", 103 | "- \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"\n", 104 | "- \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "input_text = \"Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"\n", 110 | "\n", 111 | "# determine chunks\n", 112 | "chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)\n", 113 | "print('Chunks:\\n- \"' + '\"\\n- \"'.join(chunks) + '\"')\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "id": "9ac41fd1f0560da7", 119 | "metadata": { 120 | "collapsed": false, 121 | "jupyter": { 122 | "outputs_hidden": false 123 | } 124 | }, 125 | "source": [ 126 | "Now we encode the chunks with the traditional and the context-sensitive chunked pooling method:" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 3, 132 | "id": "abe3d93b9e6609b9", 133 | "metadata": { 134 | "collapsed": false, 135 | "jupyter": { 136 | "outputs_hidden": false 137 | } 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "# chunk before\n", 142 | "embeddings_traditional_chunking = model.encode(chunks)\n", 143 | "\n", 144 | "# chunk afterwards (context-sensitive chunked pooling)\n", 145 | "inputs = tokenizer(input_text, return_tensors='pt')\n", 146 | "model_output = model(**inputs)\n", 147 | "embeddings = chunked_pooling(model_output, [span_annotations])[0]" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "e84b1b9d48cb6367", 153 | "metadata": { 154 | "collapsed": false, 155 | "jupyter": { 156 | "outputs_hidden": false 157 | } 158 | }, 159 | "source": [ 160 | "Finally, we compare the similarity of the word \"Berlin\" with the chunks. The similarity should be higher for the context-sensitive chunked pooling method:" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 4, 166 | "id": "da0cec59a3ece76", 167 | "metadata": { 168 | "collapsed": false, 169 | "jupyter": { 170 | "outputs_hidden": false 171 | } 172 | }, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "similarity_new(\"Berlin\", \"Berlin is the capital and largest city of Germany, both by area and by population.\"): 0.849546\n", 179 | "similarity_trad(\"Berlin\", \"Berlin is the capital and largest city of Germany, both by area and by population.\"): 0.84862185\n", 180 | "similarity_new(\"Berlin\", \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"): 0.82489026\n", 181 | "similarity_trad(\"Berlin\", \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"): 0.7084338\n", 182 | "similarity_new(\"Berlin\", \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"): 0.84980094\n", 183 | "similarity_trad(\"Berlin\", \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"): 0.7534553\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "import numpy as np\n", 189 | "\n", 190 | "cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))\n", 191 | "\n", 192 | "berlin_embedding = model.encode('Berlin')\n", 193 | "\n", 194 | "for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):\n", 195 | " print(f'similarity_new(\"Berlin\", \"{chunk}\"):', cos_sim(berlin_embedding, new_embedding))\n", 196 | " print(f'similarity_trad(\"Berlin\", \"{chunk}\"):', cos_sim(berlin_embedding, trad_embeddings))" 197 | ] 198 | } 199 | ], 200 | "metadata": { 201 | "kernelspec": { 202 | "display_name": "Python 3 (ipykernel)", 203 | "language": "python", 204 | "name": "python3" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.10.12" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 5 221 | } 222 | -------------------------------------------------------------------------------- /explanatory_contextual_retrieval.py: -------------------------------------------------------------------------------- 1 | # experiments/explanatory_contextual_retrieval.py 2 | # 3 | # a simple example with a trivial piece of text to showcase the late chunking method against 4 | # contextual retrieval method. contextual retrieval manually inserts context to each 5 | # chunk, i.e. forces context to be around each chunk. so works as a good comparison 6 | # to late chunking to see if the similarities are similar (which they appear to be) 7 | 8 | from chunked_pooling.wrappers import load_model 9 | from transformers import AutoModel, AutoTokenizer, pipeline, AutoModelForCausalLM 10 | import torch 11 | import numpy as np 12 | 13 | import chunked_pooling 14 | from chunked_pooling import chunked_pooling 15 | from chunked_pooling.chunking import Chunker 16 | 17 | from typing import List, Tuple 18 | from transformers import AutoModel, AutoTokenizer, pipeline 19 | 20 | import requests 21 | import os 22 | 23 | def request_anthropic_api(prompt: str): 24 | url = "https://api.anthropic.com/v1/messages" 25 | headers = { 26 | "x-api-key": os.getenv("ANTHROPIC_API_KEY"), 27 | "anthropic-version": "2023-06-01", 28 | "content-type": "application/json" 29 | } 30 | data = { 31 | "model": "claude-3-haiku-20240307", 32 | "max_tokens": 2048, 33 | "messages": [ 34 | {"role": "user", "content": prompt} 35 | ] 36 | } 37 | response = requests.post(url, headers=headers, json=data) 38 | return response.json()["content"][0]["text"] 39 | 40 | def setup_local_llm(llm_name): 41 | 42 | model = AutoModelForCausalLM.from_pretrained(llm_name, trust_remote_code=True) 43 | tokenizer = AutoTokenizer.from_pretrained(llm_name, trust_remote_code=True) 44 | 45 | def llm(prompt): 46 | messages = [{"role": "user", "content": prompt}] 47 | inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") 48 | inputs = inputs.to(model.device) 49 | outputs = model.generate(inputs, max_new_tokens=512) 50 | text_output = tokenizer.batch_decode(outputs)[0] 51 | if "<|assistant|>" in text_output: 52 | text_output = text_output.split("<|assistant|>")[1].strip() 53 | return text_output 54 | 55 | return llm 56 | 57 | def cosine_similarity(vector1, vector2): 58 | vector1_norm = vector1 / np.linalg.norm(vector1) 59 | vector2_norm = vector2 / np.linalg.norm(vector2) 60 | return np.dot(vector1_norm, vector2_norm) 61 | 62 | class LateChunkingEmbedder: 63 | 64 | def __init__(self, 65 | model: AutoModel, 66 | tokenizer: AutoTokenizer, 67 | chunking_strategy: str = "sentences", 68 | n_sentences: int = 1 69 | ): 70 | 71 | self.model = model 72 | self.tokenizer = tokenizer 73 | 74 | self.chunker = Chunker(chunking_strategy = chunking_strategy) 75 | self.n_sentences = n_sentences 76 | 77 | 78 | def run(self, document: str): 79 | annotations = [self.chunker.chunk(text=document, tokenizer=self.tokenizer, n_sentences=self.n_sentences)] 80 | model_inputs = self.tokenizer( 81 | document, 82 | return_tensors='pt', 83 | padding=True, 84 | truncation=True, 85 | max_length=8192, 86 | ) 87 | model_outputs = self.model(**model_inputs) 88 | self.output_embs = chunked_pooling( 89 | model_outputs, annotations, max_length=8192, 90 | )[0] 91 | return self.output_embs 92 | 93 | def query(self, query: str): 94 | if "output_embs" not in dir(self): 95 | raise ValueError("no embeddings calculated, use .run(document) to create chunk embeddings") 96 | query_embedding = self.model.encode(query) 97 | similarities = [] 98 | for emb in self.output_embs: 99 | similarities.append(cosine_similarity(query_embedding, emb)) 100 | 101 | return similarities 102 | 103 | 104 | class ContextualRetrievalEmbedder(): 105 | def __init__(self, 106 | model: AutoModel, 107 | tokenizer: AutoTokenizer, 108 | llm_name: str = "microsoft/Phi-3.5-mini-instruct", 109 | chunking_strategy: str = "fixed" 110 | ): 111 | 112 | self.llm = setup_local_llm(llm_name) 113 | # self.llm = request_anthropic_api 114 | 115 | self.prompt = """ 116 | 117 | {{WHOLE_DOCUMENT}} 118 | 119 | Here is the chunk we want to situate within the whole document 120 | 121 | {{CHUNK_CONTENT}} 122 | 123 | Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else. 124 | """.strip() 125 | 126 | self.model = model 127 | self.tokenizer = tokenizer 128 | 129 | self.chunker = Chunker(chunking_strategy = chunking_strategy) 130 | 131 | 132 | def _add_context(self, chunk: str, document: str): 133 | prompt = self.prompt.replace("{{WHOLE_DOCUMENT}}", document).replace("{{CHUNK_CONTENT}}", chunk) 134 | extra_context = self.llm(prompt) 135 | return extra_context + " " + chunk 136 | 137 | def _tokens_to_text(self, text: str, annotations: List[Tuple[int, int]]): 138 | tokens = self.tokenizer.encode_plus( 139 | text, return_offsets_mapping=True, add_special_tokens=False 140 | ) 141 | token_offsets = tokens.offset_mapping 142 | chunks = [] 143 | for start, end in annotations: 144 | chunk = text[token_offsets[start][0]:token_offsets[end-1][1]] 145 | chunks.append(chunk) 146 | return chunks 147 | 148 | def run(self, document: str): 149 | annotations = [self.chunker.chunk(text=document, tokenizer=self.tokenizer, n_sentences=1)] 150 | self.chunks = self._tokens_to_text(text=document, annotations=annotations[0]) 151 | self.chunks = [self._add_context(chunk, document) for chunk in self.chunks] 152 | 153 | model_outputs = self.model.encode(self.chunks) 154 | self.output_embs = [model_outputs[i, :] for i in range(len(self.chunks))] 155 | return self.output_embs 156 | 157 | def query(self, query: str): 158 | if "output_embs" not in dir(self): 159 | raise ValueError("no embeddings calculated, use .run(document) to create chunk embeddings") 160 | query_embedding = self.model.encode(query) 161 | similarities = [] 162 | for emb in self.output_embs: 163 | similarities.append(cosine_similarity(query_embedding, emb)) 164 | 165 | return similarities 166 | 167 | 168 | 169 | if __name__ == "__main__": 170 | 171 | text = """ 172 | The recent SEC filing provided insights into ACME Corp's performance for Q2 2023. 173 | It highlighted a 3% revenue growth over the previous quarter. 174 | The company, which had a revenue of $314 million in the prior quarter, showed steady progress. 175 | They attributed this growth to strategic initiatives and operational efficiencies. 176 | The report emphasized the company's resilience and ability to navigate market challenges, reflecting positively on their financial health and future prospects. 177 | """.strip().replace("\n", "") 178 | 179 | llm_model_name = "microsoft/Phi-3.5-mini-instruct" 180 | embedding_model_name = "jinaai/jina-embeddings-v2-small-en" 181 | 182 | embedding_model, has_instructions = load_model(embedding_model_name) 183 | embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True) 184 | 185 | cr = ContextualRetrievalEmbedder(embedding_model, embedding_tokenizer, llm_model_name, chunking_strategy="sentences") 186 | cr.run(text); 187 | cr_cosine_similarities = cr.query("What is ACME Corp's revenue growth for Q2 2023?") 188 | 189 | lc = LateChunkingEmbedder(embedding_model, embedding_tokenizer) 190 | lc.run(text) 191 | lc_cosine_similarities = lc.query("What is ACME Corp's revenue growth for Q2 2023?") 192 | 193 | # import pandas as pd 194 | for i, (cr_similarity, lc_similarity) in enumerate(zip(cr_cosine_similarities, lc_cosine_similarities)): 195 | print(f"{text.split('.')[:-1][i].strip()}") 196 | print(f"Similarities: Contextual Retrieval: {cr_similarity:.4f} | Late Chunking: {lc_similarity:.4f}") 197 | print("") -------------------------------------------------------------------------------- /img/context-problem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jina-ai/late-chunking/1d3bb02bf091becd0771455e4e7959463935e26c/img/context-problem.png -------------------------------------------------------------------------------- /img/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jina-ai/late-chunking/1d3bb02bf091becd0771455e4e7959463935e26c/img/method.png -------------------------------------------------------------------------------- /img/rag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jina-ai/late-chunking/1d3bb02bf091becd0771455e4e7959463935e26c/img/rag.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "late_chunking" 3 | requires-python = "~=3.8" 4 | dependencies = [ 5 | "jupyterlab==4.2.5", 6 | "transformers==4.43.4", 7 | "torch==2.4.0", 8 | "mteb==1.14.20", 9 | "datasets==2.19.1", 10 | "llama-index-embeddings-huggingface==0.3.1", 11 | "llama-index==0.11.10", 12 | "click==8.1.7", 13 | "einops==0.6.1", 14 | ] 15 | version = "0.0.0" 16 | 17 | [project.optional-dependencies] 18 | dev = [ 19 | "pytest~=7.3.2", 20 | "black==23.3.0", 21 | "isort==5.12.0", 22 | "ruff==0.0.265", 23 | ] 24 | 25 | [tool.setuptools.packages.find] 26 | include = ["chunked_pooling"] 27 | -------------------------------------------------------------------------------- /run_chunked_eval.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch.cuda 3 | from mteb import MTEB 4 | from transformers import AutoModel, AutoTokenizer 5 | 6 | from chunked_pooling.chunked_eval_tasks import * 7 | from chunked_pooling.wrappers import load_model 8 | 9 | DEFAULT_CHUNKING_STRATEGY = 'fixed' 10 | DEFAULT_CHUNK_SIZE = 256 11 | DEFAULT_N_SENTENCES = 5 12 | BATCH_SIZE = 1 13 | DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE = 256 14 | DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE = 0 # set to 0 to disable long late chunking 15 | DEFAULT_TRUNCATE_MAX_LENGTH = None 16 | 17 | 18 | @click.command() 19 | @click.option( 20 | '--model-name', 21 | default='jinaai/jina-embeddings-v2-small-en', 22 | help='The name of the model to use.', 23 | ) 24 | @click.option( 25 | '--model-weights', 26 | default=None, 27 | help='The path to the model weights to use, e.g. in case of finetuning.', 28 | ) 29 | @click.option( 30 | '--strategy', 31 | default=DEFAULT_CHUNKING_STRATEGY, 32 | help='The chunking strategy to be applied.', 33 | ) 34 | @click.option( 35 | '--task-name', default='SciFactChunked', help='The evaluation task to perform.' 36 | ) 37 | @click.option( 38 | '--eval-split', default='test', help='The name of the evaluation split in the task.' 39 | ) 40 | @click.option( 41 | '--chunking-model', 42 | default=None, 43 | required=False, 44 | help='The name of the model used for semantic chunking.', 45 | ) 46 | @click.option( 47 | '--truncate-max-length', 48 | default=DEFAULT_TRUNCATE_MAX_LENGTH, 49 | type=int, 50 | help='Maximum number of tokens; by default, truncation to 8192 tokens. If None, Long Late Chunking algorithm should be enabled.', 51 | ) 52 | @click.option( 53 | '--chunk-size', 54 | default=DEFAULT_CHUNK_SIZE, 55 | type=int, 56 | help='Number of tokens per chunk for fixed strategy.', 57 | ) 58 | @click.option( 59 | '--n-sentences', 60 | default=DEFAULT_N_SENTENCES, 61 | type=int, 62 | help='Number of sentences per chunk for sentence strategy.', 63 | ) 64 | @click.option( 65 | '--long-late-chunking-embed-size', 66 | default=DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE, 67 | type=int, 68 | help='Number of tokens per macro chunk used for long late chunking.', 69 | ) 70 | @click.option( 71 | '--long-late-chunking-overlap-size', 72 | default=DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE, 73 | type=int, 74 | help='Token length of the embeddings that come before/after soft boundaries (i.e. overlapping embeddings). Above zero, overlap is used between neighbouring embeddings.', 75 | ) 76 | def main( 77 | model_name, 78 | model_weights, 79 | strategy, 80 | task_name, 81 | eval_split, 82 | chunking_model, 83 | truncate_max_length, 84 | chunk_size, 85 | n_sentences, 86 | long_late_chunking_embed_size, 87 | long_late_chunking_overlap_size, 88 | ): 89 | try: 90 | task_cls = globals()[task_name] 91 | except: 92 | raise ValueError(f'Unknown task name: {task_name}') 93 | 94 | if truncate_max_length is not None and (long_late_chunking_embed_size > 0): 95 | truncate_max_length = None 96 | print( 97 | f'Truncation is disabled because Long Late Chunking algorithm is enabled.' 98 | ) 99 | 100 | model, has_instructions = load_model(model_name, model_weights) 101 | 102 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 103 | 104 | chunking_args = { 105 | 'chunk_size': chunk_size, 106 | 'n_sentences': n_sentences, 107 | 'chunking_strategy': strategy, 108 | 'model_has_instructions': has_instructions, 109 | 'embedding_model_name': chunking_model if chunking_model else model_name, 110 | } 111 | 112 | if torch.cuda.is_available(): 113 | model = model.cuda() 114 | 115 | model.eval() 116 | 117 | # Evaluate with late chunking 118 | tasks = [ 119 | task_cls( 120 | chunked_pooling_enabled=True, 121 | tokenizer=tokenizer, 122 | prune_size=None, 123 | truncate_max_length=truncate_max_length, 124 | long_late_chunking_embed_size=long_late_chunking_embed_size, 125 | long_late_chunking_overlap_size=long_late_chunking_overlap_size, 126 | **chunking_args, 127 | ) 128 | ] 129 | 130 | evaluation = MTEB( 131 | tasks=tasks, 132 | chunked_pooling_enabled=True, 133 | tokenizer=tokenizer, 134 | prune_size=None, 135 | **chunking_args, 136 | ) 137 | evaluation.run( 138 | model, 139 | output_folder='results-chunked-pooling', 140 | eval_splits=[eval_split], 141 | overwrite_results=True, 142 | batch_size=BATCH_SIZE, 143 | encode_kwargs={'batch_size': BATCH_SIZE}, 144 | ) 145 | 146 | # Encode without late chunking 147 | tasks = [ 148 | task_cls( 149 | chunked_pooling_enabled=False, 150 | tokenizer=tokenizer, 151 | prune_size=None, 152 | truncate_max_length=truncate_max_length, 153 | **chunking_args, 154 | ) 155 | ] 156 | 157 | evaluation = MTEB( 158 | tasks=tasks, 159 | chunked_pooling_enabled=False, 160 | tokenizer=tokenizer, 161 | prune_size=None, 162 | **chunking_args, 163 | ) 164 | evaluation.run( 165 | model, 166 | output_folder='results-normal-pooling', 167 | eval_splits=[eval_split], 168 | overwrite_results=True, 169 | batch_size=BATCH_SIZE, 170 | encode_kwargs={'batch_size': BATCH_SIZE}, 171 | ) 172 | 173 | 174 | if __name__ == '__main__': 175 | main() 176 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jina-ai/late-chunking/1d3bb02bf091becd0771455e4e7959463935e26c/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mteb.abstasks.TaskMetadata import TaskMetadata 3 | 4 | from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval 5 | 6 | 7 | class DummyTask(AbsTaskChunkedRetrieval): 8 | metadata = TaskMetadata( 9 | dataset={ 10 | 'path': '~', 11 | 'revision': '', 12 | }, 13 | name='dummy', 14 | description='', 15 | type='Retrieval', 16 | category='s2p', 17 | reference=None, 18 | eval_splits=[], 19 | eval_langs=[], 20 | main_score='ndcg_at_10', 21 | date=None, 22 | form=None, 23 | domains=None, 24 | task_subtypes=None, 25 | license=None, 26 | socioeconomic_status=None, 27 | annotations_creators=None, 28 | dialect=None, 29 | text_creation=None, 30 | bibtex_citation=None, 31 | n_samples=None, 32 | avg_character_length=None, 33 | ) 34 | 35 | def load_data(): 36 | pass 37 | 38 | def __init__(self, **kwargs): 39 | super().__init__(**kwargs) 40 | 41 | 42 | @pytest.fixture() 43 | def dummy_task_factory(): 44 | def _create_dummy_task(*args, **kwargs): 45 | return DummyTask(*args, **kwargs) 46 | 47 | return _create_dummy_task 48 | -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from transformers import AutoModel, AutoTokenizer 4 | 5 | from chunked_pooling import chunked_pooling 6 | from chunked_pooling.wrappers import load_model 7 | from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval 8 | 9 | MODEL_NAME = 'jinaai/jina-embeddings-v3' 10 | 11 | # Define Text and Chunk 12 | CHUNKS = ["Organic skincare", "for sensitive skin", "with aloe vera and chamomile"] 13 | FULL_TEXT = ' '.join(CHUNKS) 14 | 15 | 16 | def load_api_results(): 17 | import requests 18 | 19 | url = 'https://api.jina.ai/v1/embeddings' 20 | headers = { 21 | 'Content-Type': 'application/json', 22 | 'Authorization': f'Bearer {os.environ["JINA_API_TOKEN"]}', 23 | } 24 | data = { 25 | "model": "jina-embeddings-v3", 26 | "task": "retrieval.passage", 27 | "dimensions": 1024, 28 | "late_chunking": True, 29 | "embedding_type": "float", 30 | "input": CHUNKS, 31 | } 32 | response = requests.post(url, headers=headers, json=data) 33 | data = response.json() 34 | return [np.array(x['embedding']) for x in data['data']] 35 | 36 | 37 | def calculate_annotations(model, boundary_cues, model_has_instructions, tokenizer): 38 | if model_has_instructions: 39 | instr = model.get_instructions()[1] 40 | instr_tokens = tokenizer(instr, add_special_tokens=False) 41 | n_instruction_tokens = len(instr_tokens[0]) 42 | else: 43 | n_instruction_tokens = 0 44 | chunk_annotations = [ 45 | AbsTaskChunkedRetrieval._extend_special_tokens( 46 | annotations, 47 | n_instruction_tokens=n_instruction_tokens, 48 | include_prefix=True, 49 | include_sep=True, 50 | ) 51 | for annotations in boundary_cues 52 | ] 53 | return chunk_annotations 54 | 55 | 56 | def test_compare_v3_api_embeddings(): 57 | # Load Model 58 | model, has_instr = load_model(MODEL_NAME, use_flash_attn=False) 59 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) 60 | 61 | # Determine Boundary Cues 62 | tokenization = tokenizer( 63 | FULL_TEXT, return_offsets_mapping=True, add_special_tokens=False 64 | ) 65 | boundary_cues = [] 66 | chunk_i = 0 67 | last_cue = 0 68 | last_end = 0 69 | for i, (start, end) in enumerate(tokenization.offset_mapping): 70 | if end >= (last_end + len(CHUNKS[chunk_i])): 71 | boundary_cues.append((last_cue, i + 1)) 72 | chunk_i += 1 73 | last_cue = i + 1 74 | last_end = end 75 | extended_boundary_cues = calculate_annotations( 76 | model, [boundary_cues], has_instr, tokenizer 77 | ) 78 | 79 | # Append Instruction for Retrieval Task 80 | instr = model.get_instructions()[1] 81 | text_inputs = [instr + FULL_TEXT] 82 | model_inputs = tokenizer( 83 | text_inputs, 84 | return_tensors='pt', 85 | padding=True, 86 | truncation=True, 87 | max_length=8192, 88 | ) 89 | model_outputs = model(**model_inputs) 90 | 91 | # Apply Late Chunking 92 | output_embs = chunked_pooling( 93 | model_outputs, extended_boundary_cues, max_length=8192 94 | )[0] 95 | api_embs = load_api_results() 96 | for local_emb, api_emb in zip(output_embs, api_embs): 97 | local_emb_norm = local_emb / np.linalg.norm(local_emb) 98 | api_emb_norm = api_emb / np.linalg.norm(api_emb) 99 | assert np.allclose(local_emb_norm, api_emb_norm, rtol=1e-02, atol=1e-02) 100 | assert 1.0 - np.dot(local_emb_norm, api_emb_norm) < 1e-3 101 | -------------------------------------------------------------------------------- /tests/test_chunking_methods.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import AutoTokenizer 3 | 4 | from chunked_pooling.chunking import CHUNKING_STRATEGIES, Chunker 5 | from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval 6 | 7 | EXAMPLE_TEXT_1 = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area." 8 | PUNCTATIONS = ('.', '!', '?') 9 | 10 | 11 | @pytest.mark.parametrize("n_sentences", [1, 2, 3, 4]) 12 | def test_chunk_by_sentences(n_sentences): 13 | strategy = 'sentences' 14 | model_name = 'jinaai/jina-embeddings-v2-small-en' 15 | chunker = Chunker(chunking_strategy=strategy) 16 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 17 | boundary_cues = chunker.chunk( 18 | text=EXAMPLE_TEXT_1, 19 | tokenizer=tokenizer, 20 | chunking_strategy=strategy, 21 | n_sentences=n_sentences, 22 | ) 23 | extended_boundary_cues = AbsTaskChunkedRetrieval._extend_special_tokens( 24 | boundary_cues 25 | ) 26 | model_inputs = tokenizer( 27 | EXAMPLE_TEXT_1, 28 | return_tensors='pt', 29 | padding=True, 30 | truncation=True, 31 | max_length=8192, 32 | ) 33 | 34 | # check that the cues start with 0 and end with the last token 35 | assert extended_boundary_cues[0][0] == 0 36 | assert len(model_inputs.tokens()) == extended_boundary_cues[-1][1] 37 | 38 | # check that all chunks but the last one end with a punctuation 39 | assert all( 40 | model_inputs.tokens()[x:y][-1] in PUNCTATIONS 41 | for (x, y) in extended_boundary_cues[:-1] 42 | ) 43 | 44 | # check that the last chunk ends with a "[SEP]" token 45 | last_cue = extended_boundary_cues[-1] 46 | assert model_inputs.tokens()[last_cue[0] : last_cue[1]][-1] == "[SEP]" 47 | 48 | # check that the boundary cues are continuous (no token is missing) 49 | assert all( 50 | [ 51 | extended_boundary_cues[i][1] == extended_boundary_cues[i + 1][0] 52 | for i in range(len(extended_boundary_cues) - 1) 53 | ] 54 | ) 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "boundary_cues", [[(0, 17), (17, 44), (44, 69)], [(0, 44), (44, 69)]] 59 | ) 60 | def test_token_equivalence(boundary_cues): 61 | model_name = 'jinaai/jina-embeddings-v2-small-en' 62 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 63 | tokens = tokenizer.encode_plus( 64 | EXAMPLE_TEXT_1, add_special_tokens=False, return_offsets_mapping=True 65 | ) 66 | for start_token_idx, end_token_idx in boundary_cues: 67 | decoded_text_chunk = tokenizer.decode( 68 | tokens.input_ids[start_token_idx:end_token_idx] 69 | ) 70 | 71 | original_text_chunk = EXAMPLE_TEXT_1[ 72 | tokens.offset_mapping[start_token_idx][0] : tokens.offset_mapping[ 73 | end_token_idx - 1 74 | ][1] 75 | ] 76 | chunk_tokens_original = tokenizer.encode_plus(original_text_chunk) 77 | chunk_tokens_decoded = tokenizer.encode_plus(decoded_text_chunk) 78 | assert chunk_tokens_original == chunk_tokens_decoded 79 | 80 | 81 | def test_chunker_initialization(): 82 | for strategy in CHUNKING_STRATEGIES: 83 | chunker = Chunker(chunking_strategy=strategy) 84 | assert chunker.chunking_strategy == strategy 85 | 86 | 87 | def test_invalid_chunking_strategy(): 88 | with pytest.raises(ValueError): 89 | Chunker(chunking_strategy="invalid") 90 | 91 | 92 | def test_chunk_by_tokens(): 93 | chunker = Chunker(chunking_strategy="fixed") 94 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 95 | chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=10) 96 | assert len(chunks) > 1 97 | for start, end in chunks: 98 | assert end - start <= 10 99 | 100 | 101 | @pytest.mark.parametrize( 102 | 'model_name', 103 | ['jinaai/jina-embeddings-v2-small-en', 'sentence-transformers/all-MiniLM-L6-v2'], 104 | ) 105 | def test_chunk_semantically(model_name): 106 | chunker = Chunker(chunking_strategy="semantic") 107 | tokenizer = AutoTokenizer.from_pretrained(model_name) 108 | tokens = tokenizer.encode_plus( 109 | EXAMPLE_TEXT_1, add_special_tokens=False, return_offsets_mapping=True 110 | ) 111 | boundary_cues = chunker.chunk( 112 | EXAMPLE_TEXT_1, 113 | tokenizer=tokenizer, 114 | chunking_strategy='semantic', 115 | embedding_model_name=model_name, 116 | ) 117 | 118 | # check if it returns boundary cues 119 | assert len(boundary_cues) > 0 120 | 121 | # test if bounaries are at the end of sentences 122 | for start_token_idx, end_token_idx in boundary_cues: 123 | assert ( 124 | EXAMPLE_TEXT_1[tokens.offset_mapping[end_token_idx - 1][0]] in PUNCTATIONS 125 | ) 126 | decoded_text_chunk = tokenizer.decode( 127 | tokens.input_ids[start_token_idx:end_token_idx] 128 | ) 129 | 130 | # check that the boundary cues are continuous (no token is missing) 131 | assert all( 132 | [ 133 | boundary_cues[i][1] == boundary_cues[i + 1][0] 134 | for i in range(len(boundary_cues) - 1) 135 | ] 136 | ) 137 | 138 | 139 | def test_empty_input(): 140 | chunker = Chunker(chunking_strategy="fixed") 141 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 142 | chunks = chunker.chunk("", tokenizer=tokenizer, chunk_size=10) 143 | assert len(chunks) == 0 144 | 145 | 146 | def test_input_shorter_than_chunk_size(): 147 | short_text = "Short text." 148 | chunker = Chunker(chunking_strategy="fixed") 149 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 150 | chunks = chunker.chunk(short_text, tokenizer=tokenizer, chunk_size=20) 151 | assert len(chunks) == 1 152 | 153 | 154 | @pytest.mark.parametrize("chunk_size", [10, 20, 50]) 155 | def test_various_chunk_sizes(chunk_size): 156 | chunker = Chunker(chunking_strategy="fixed") 157 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 158 | chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=chunk_size) 159 | assert len(chunks) > 0 160 | for start, end in chunks: 161 | assert end - start <= chunk_size 162 | 163 | 164 | def test_chunk_method_with_different_strategies(): 165 | chunker = Chunker(chunking_strategy="fixed") 166 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 167 | fixed_chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=10) 168 | semantic_chunks = chunker.chunk( 169 | EXAMPLE_TEXT_1, 170 | tokenizer=tokenizer, 171 | chunking_strategy="semantic", 172 | embedding_model_name='jinaai/jina-embeddings-v2-small-en', 173 | ) 174 | assert fixed_chunks != semantic_chunks 175 | 176 | 177 | def test_chunk_by_sentences_different_n(): 178 | chunker = Chunker(chunking_strategy="sentences") 179 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 180 | chunks_1 = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, n_sentences=1) 181 | chunks_2 = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, n_sentences=2) 182 | assert len(chunks_1) > len(chunks_2) 183 | -------------------------------------------------------------------------------- /tests/test_v3.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | from run_chunked_eval import DEFAULT_CHUNK_SIZE, load_model 4 | 5 | MODEL_NAME = 'jinaai/jina-embeddings-v3' 6 | 7 | 8 | def test_instruction_handling(dummy_task_factory): 9 | model, has_instructions = load_model(MODEL_NAME) 10 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) 11 | task = dummy_task_factory( 12 | chunking_strategy='fixed', 13 | chunk_size=DEFAULT_CHUNK_SIZE, 14 | tokenizer=tokenizer, 15 | model_has_instructions=has_instructions, 16 | ) 17 | n_instruction_tokens = len( 18 | tokenizer(model.get_instructions()[1], add_special_tokens=False)['input_ids'] 19 | ) 20 | annotations_one_token = task._calculate_annotations(model, ['A'])[0] 21 | assert len(annotations_one_token) == 1 22 | assert annotations_one_token[0] == (0, n_instruction_tokens + 3) 23 | --------------------------------------------------------------------------------