├── .github
└── workflows
│ └── ci.yaml
├── LICENSE
├── README.md
├── chunked_pooling
├── __init__.py
├── chunked_eval_tasks.py
├── chunking.py
├── mteb_chunked_eval.py
└── wrappers.py
├── examples.ipynb
├── explanatory_contextual_retrieval.py
├── img
├── context-problem.png
├── method.png
└── rag.png
├── pyproject.toml
├── run_chunked_eval.py
└── tests
├── __init__.py
├── conftest.py
├── test_api.py
├── test_chunking_methods.py
└── test_v3.py
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: Run Tests
2 |
3 | on:
4 | pull_request:
5 | types: [opened, synchronize, reopened]
6 | push:
7 | branches:
8 | - main
9 |
10 | env:
11 | JINA_API_TOKEN: ${{ secrets.JINA_API_TOKEN }}
12 |
13 | jobs:
14 | test:
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v3
19 |
20 | - name: Set up Python
21 | uses: actions/setup-python@v4
22 | with:
23 | python-version: '3.11'
24 |
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install .[dev]
29 |
30 | - name: Run tests
31 | run: pytest tests
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Late Chunking of Short Chunks in Long-Context Embedding Models
2 |
3 | [**Blog part 1**](https://jina.ai/news/late-chunking-in-long-context-embedding-models) | [**Blog part 2**](https://jina.ai/news/what-late-chunking-really-is-and-what-its-not-part-ii/) | [**ArXiv paper**](https://arxiv.org/abs/2409.04701)
4 |
5 | For many applications, encoding a whole text document into a single embedding representation is not useful. Many applications require retrieving smaller parts of the text and dense vector-based information retrieval systems often perform better with smaller text segments because of the limited information capacity of embedding vectors.
6 |
7 | 
8 |
9 |
10 | RAG (Retrieval Augmented Generations) is one of the best known applications to require splitting document collections into smaller text chunks. These chunks are typically stored in a vector database with vector representations created by a text embedding model.
11 | At runtime, the same embedding model encodes a query text into a vector representation, which is used to identify relevant stored text chunks. These are them passed to a large language model (LLM) which synthesizes a response to the query based on the retrieved texts.
12 |
13 | ## Context Problem
14 |
15 |
16 | This simple RAG approach is not without challenges. Long distance contextual dependencies, i.e. when the relevant information is spread over multiple chunks and taking text segments out of context makes them useless, are particularly poorly handled by this approach.
17 | 
18 | In the image above one can see an Wikipedia article that is split into chunks of sentences.
19 | One can see that phrases like "its" and "the city" referencing "Berlin" which is mentioned only in the first sentence, e.g., it is harder for the embedding model to link it to the respective entity to produce a high-quality embedding representation.
20 |
21 |
22 | For example, if we split a Wikipedia article into sentence-length segments, as in the example above, a RAG system might not be able to answer a query like "What is the population of Berlin?" The city name and the population never appear together in a single segment, and lacking any larger document context.
23 | An LLM to which one of the segments is presented cannot resolve the anaphoric references like "it" or "the city".
24 |
25 | ## Context-Sensitive Chunking
26 |
27 | To overcome this problem, we take advantage of the long input sequences that recent embedding models like [`jina-embeddings-v2-base-en`](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) can process.
28 | These models support much longer input texts, for example, 8192 tokens for `jina-embeddings-v2-base-en` or roughly ten standard pages of text. Text segments of this size are much less likely to have contextual dependencies that can only be resolved with a larger context.
29 | However, we still need vector representations of much smaller chunks of text, in part because of the limited input sizes of LLMs but primarily because of the limited information capacity of short embedding vectors.
30 |
31 | 
32 |
33 |
34 | The simple encoding approach (as seen on the left side of the image above) chunks texts before processing them, using sentences, paragraphs, and maximum length limits to split text _a priori_, and then applying an embedding model to the resulting chunks.
35 | Late Chunking, instead, first applies the transformer part from the embedding model to the entire text, or the largest part of it possible. This generates a sequence of vector representations for each token that encompass textual information from the entire text.
36 | To generate a single embedding for a text, many embedding models apply _mean pooling_ to these token representations to output a single vector. Late Chunking instead applies mean pooling to smaller segments of this sequence of token vectors, producing embeddings for each chunk that take into account the entire text.
37 |
38 | ## The Effect of Context-Sensitive Chunking
39 |
40 | This has immediately measurable concrete effects on retrieval. As an example, in case of "the city" and "Berlin" in a Wikipedia article, the vectors representing "the city" contain information connecting it to the previous mention of "Berlin", making it a much better match for queries involving that city name.
41 |
42 | You can see that in numerical results below, which compares the embedding of the string "Berlin" to various sentences from the article about Berlin. The column "Traditional Similarity" is the similarity values using _a priori_ chunking, and "Late Chunking Similarity" is with context-sensitive chunking.
43 |
44 | | Text | Similarity Traditional | Similarity Late Chunking |
45 | |---------------------------------------------------------------------------------------------------------------------------------------|------------------------|-------------------------------|
46 | | Berlin is the capital and largest city of Germany, both by area and by population." | 0.84862185 | 0.849546 |
47 | | Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. | 0.7084338 | 0.82489026 |
48 | | The city is also one of the states of Germany, and is the third smallest state in the country in terms of area. | 0.7534553 | 0.84980094 |
49 |
50 | As you can see the similarity scores for the first chunk that contains "Berlin" are very close to each other.
51 | For the other two chunks they siginificantly differ, as the late chunking dramatically improves matching on sentences that do not explicitly use the word "Berlin" but have anaphoric references to it.
52 |
53 | ## Evaluation on Retrieval Tasks
54 |
55 |
56 | To verify the effectiveness of this approach beyond a few toy examples, we tested it with some of the retrieval benchmarks from [BeIR](https://github.com/beir-cellar/beir).
57 | Those retrieval tasks consist of a query set, a corpus of text documents, and a QRels file that stores information about the IDs of documents that are relevant for each query.
58 | To identify the relevant documents of a query, one can chunk the documents, encode them into an embedding index, and determine for each query embedding the most similar chunks (kNN).
59 | As each chunk corresponds to a document, one can convert the kNN ranking of chunks into a kNN ranking of documents (for documents occurring multiple times in the ranking, only the first occurrence is retained).
60 | After that, one can compare the resulting ranking with the ranking corresponding to the ground-truth QRels file and calculate retrieval metrics like nDCG@10.
61 | We run this evaluation for various BeIR datasets with traditional chunking and our novel late chunking method.
62 | To split texts into chunks, we choose a straightforward method, which chunks the tests into strings of 256 tokens.
63 | Both the traditional and late chunking tests used the [jina-embeddings-v2-small-en](https://huggingface.co/jinaai/jina-embeddings-v2-small-en) model.
64 |
65 | | Dataset | AVG Document Length (characters) | Traditional Chunking (nDCG@10) | Late Chunking (nDCG@10) | No Chunking (nDCG@10) |
66 | |-----------|----------------------------------|--------------------------------|--------------------------------------|-----------------------|
67 | | SciFact | 1498.4 | 64.20% | **66.10%** | 63.89% |
68 | | TRECCOVID | 1116.7 | 63.36% | 64.70% | **65.18%** |
69 | | FiQA2018 | 767.2 | 33.25% | **33.84%** | 33.43% |
70 | | NFCorpus | 1589.8 | 23.46% | 29.98% | **30.40%** |
71 | | Quora | 62.2 | 87.19% | 87.19% | 87.19% |
72 |
73 | In all cases, late chunking improved the score. In some cases, it also outperforms encoding the whole document into a single embedding, while for other datasets, no chunking performs best. However, this only makes sense if one does not need to rank chunks. One can also see that the average length of the documents correlates with greater improvement in the nDCG scores through late chunking.
74 |
75 | To reporoduce the evaluation, you can install the dependencies with `pip install .` and run the following script for the tasks "SciFactChunked", "TRECCOVIDChunked", "FiQA2018Chunked", "NFCorpusChunked", and "QuoraChunked":
76 |
77 | ```bash
78 | python3 run_chunked_eval.py --task-name {TASK_NAME}
79 | ```
80 |
81 | ## Acknowledgement and References
82 |
83 | Thanks to Isabelle Mohr([@violenil](https://github.com/violenil)) for contributing some code and Scott Martens ([@scott-martens](https://github.com/scott-martens)) for reviewing the README.
84 |
85 | More about the Evaluation tasks can be found in the [MTEB Repository](https://github.com/embeddings-benchmark/mteb) and details about the training of the models for long input text in our paper: ["Jina embeddings 2: 8192-token general-purpose text embeddings for long documents."](https://arxiv.org/abs/2310.19923)
86 |
87 | If you find Late Chunking useful in your research, use can cite the paper [Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models](https://arxiv.org/abs/2409.04701):
88 |
89 | ```
90 | @article{gunther2024late,
91 | title={Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models},
92 | author={G{\"u}nther, Michael and Mohr, Isabelle and Williams, Daniel J and Wang, Bo and Xiao, Han},
93 | journal={arXiv preprint arXiv:2409.04701},
94 | year={2024}
95 | }
96 | ```
97 |
--------------------------------------------------------------------------------
/chunked_pooling/__init__.py:
--------------------------------------------------------------------------------
1 | def chunk_by_sentences(input_text: str, tokenizer: callable):
2 | """
3 | Split the input text into sentences using the tokenizer
4 | :param input_text: The text snippet to split into sentences
5 | :param tokenizer: The tokenizer to use
6 | :return: A tuple containing the list of text chunks and their corresponding token spans
7 | """
8 | inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
9 | punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
10 | sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
11 | token_offsets = inputs['offset_mapping'][0]
12 | token_ids = inputs['input_ids'][0]
13 | chunk_positions = [
14 | (i, int(start + 1))
15 | for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
16 | if token_id == punctuation_mark_id
17 | and (
18 | token_offsets[i + 1][0] - token_offsets[i][1] > 0
19 | or token_ids[i + 1] == sep_id
20 | )
21 | ]
22 | chunks = [
23 | input_text[x[1] : y[1]]
24 | for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
25 | ]
26 | span_annotations = [
27 | (x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
28 | ]
29 | return chunks, span_annotations
30 |
31 |
32 | def chunked_pooling(
33 | model_output: 'BatchEncoding', span_annotation: list, max_length=None
34 | ):
35 | token_embeddings = model_output[0]
36 | outputs = []
37 | for embeddings, annotations in zip(token_embeddings, span_annotation):
38 | if (
39 | max_length is not None
40 | ): # remove annotations which go bejond the max-length of the model
41 | annotations = [
42 | (start, min(end, max_length - 1))
43 | for (start, end) in annotations
44 | if start < (max_length - 1)
45 | ]
46 | pooled_embeddings = [
47 | embeddings[start:end].sum(dim=0) / (end - start)
48 | for start, end in annotations
49 | if (end - start) >= 1
50 | ]
51 | pooled_embeddings = [
52 | embedding.float().detach().cpu().numpy() for embedding in pooled_embeddings
53 | ]
54 | outputs.append(pooled_embeddings)
55 |
56 | return outputs
57 |
--------------------------------------------------------------------------------
/chunked_pooling/chunked_eval_tasks.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | from mteb.abstasks.TaskMetadata import TaskMetadata
3 |
4 | from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval
5 |
6 |
7 | class SciFactChunked(AbsTaskChunkedRetrieval):
8 | metadata = TaskMetadata(
9 | name='SciFactChunked',
10 | dataset={
11 | 'path': 'mteb/scifact',
12 | 'revision': '0228b52cf27578f30900b9e5271d331663a030d7',
13 | 'name': 'SciFact',
14 | },
15 | description=(
16 | 'SciFact verifies scientific claims using evidence from the '
17 | 'research literature containing scientific paper abstracts.'
18 | ),
19 | reference='https://github.com/allenai/scifact',
20 | type='Retrieval',
21 | category='s2p',
22 | eval_splits=['test'],
23 | eval_langs=['eng-Latn'],
24 | main_score='ndcg_at_10',
25 | date=None,
26 | form=None,
27 | domains=None,
28 | task_subtypes=None,
29 | license=None,
30 | socioeconomic_status=None,
31 | annotations_creators=None,
32 | dialect=None,
33 | text_creation=None,
34 | bibtex_citation=None,
35 | n_samples=None,
36 | avg_character_length=None,
37 | )
38 |
39 | def __init__(self, **kwargs):
40 | super().__init__(**kwargs)
41 |
42 |
43 | class NarrativeQAChunked(AbsTaskChunkedRetrieval):
44 | metadata = TaskMetadata(
45 | name='NarrativeQAChunked',
46 | dataset={
47 | 'path': 'narrativeqa',
48 | 'revision': '2e643e7363944af1c33a652d1c87320d0871c4e4',
49 | 'name': 'NarrativeQARetrieval',
50 | },
51 | reference='https://metatext.io/datasets/narrativeqa',
52 | description=(
53 | 'NarrativeQA is a dataset for the task of question answering '
54 | 'on long narratives. It consists of realistic QA instances '
55 | 'collected from literature (fiction and non-fiction) '
56 | 'and movie scripts. '
57 | ),
58 | type='Retrieval',
59 | category='s2p',
60 | eval_splits=['test'],
61 | eval_langs=['eng-Latn'],
62 | main_score='ndcg_at_10',
63 | date=None,
64 | form=None,
65 | domains=None,
66 | task_subtypes=None,
67 | license=None,
68 | socioeconomic_status=None,
69 | annotations_creators=None,
70 | dialect=None,
71 | text_creation=None,
72 | bibtex_citation=None,
73 | n_samples=None,
74 | avg_character_length=None,
75 | )
76 |
77 | def __init__(self, **kwargs):
78 | super().__init__(**kwargs)
79 |
80 |
81 | class NFCorpusChunked(AbsTaskChunkedRetrieval):
82 | metadata = TaskMetadata(
83 | name="NFCorpusChunked",
84 | dataset={
85 | "path": "mteb/nfcorpus",
86 | "revision": "ec0fa4fe99da2ff19ca1214b7966684033a58814",
87 | 'name': 'NFCorpus',
88 | },
89 | description="NFCorpus: A Full-Text Learning to Rank Dataset for Medical Information Retrieval",
90 | reference="https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/",
91 | type="Retrieval",
92 | category="s2p",
93 | eval_splits=["test"],
94 | eval_langs=["eng-Latn"],
95 | main_score="ndcg_at_10",
96 | date=None,
97 | form=None,
98 | domains=None,
99 | task_subtypes=None,
100 | license=None,
101 | socioeconomic_status=None,
102 | annotations_creators=None,
103 | dialect=None,
104 | text_creation=None,
105 | bibtex_citation=None,
106 | n_samples=None,
107 | avg_character_length=None,
108 | )
109 |
110 | def __init__(self, **kwargs):
111 | super().__init__(**kwargs)
112 |
113 |
114 | class QuoraChunked(AbsTaskChunkedRetrieval):
115 | metadata = TaskMetadata(
116 | name="QuoraChunked",
117 | dataset={
118 | "path": "mteb/quora",
119 | "revision": "e4e08e0b7dbe3c8700f0daef558ff32256715259",
120 | "name": "QuoraRetrieval",
121 | },
122 | description=(
123 | "QuoraRetrieval is based on questions that are marked as duplicates on the Quora platform. Given a"
124 | " question, find other (duplicate) questions."
125 | ),
126 | reference="https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs",
127 | type="Retrieval",
128 | category="s2s",
129 | eval_splits=["dev", "test"],
130 | eval_langs=["eng-Latn"],
131 | main_score="ndcg_at_10",
132 | date=None,
133 | form=None,
134 | domains=None,
135 | task_subtypes=None,
136 | license=None,
137 | socioeconomic_status=None,
138 | annotations_creators=None,
139 | dialect=None,
140 | text_creation=None,
141 | bibtex_citation=None,
142 | n_samples=None,
143 | avg_character_length=None,
144 | )
145 |
146 | def __init__(self, **kwargs):
147 | super().__init__(**kwargs)
148 |
149 |
150 | class FiQA2018Chunked(AbsTaskChunkedRetrieval):
151 | metadata = TaskMetadata(
152 | name="FiQA2018Chunked",
153 | description="Financial Opinion Mining and Question Answering",
154 | reference="https://sites.google.com/view/fiqa/",
155 | dataset={
156 | "path": "mteb/fiqa",
157 | "revision": "27a168819829fe9bcd655c2df245fb19452e8e06",
158 | 'name': 'FiQA2018',
159 | },
160 | type="Retrieval",
161 | category="s2p",
162 | eval_splits=["train", "dev", "test"],
163 | eval_langs=["eng-Latn"],
164 | main_score="ndcg_at_10",
165 | date=None,
166 | form=None,
167 | domains=None,
168 | task_subtypes=None,
169 | license=None,
170 | socioeconomic_status=None,
171 | annotations_creators=None,
172 | dialect=None,
173 | text_creation=None,
174 | bibtex_citation=None,
175 | n_samples=None,
176 | avg_character_length=None,
177 | )
178 |
179 | def __init__(self, **kwargs):
180 | super().__init__(**kwargs)
181 |
182 |
183 | class TRECCOVIDChunked(AbsTaskChunkedRetrieval):
184 | metadata = TaskMetadata(
185 | name='TRECCOVIDChunked',
186 | description=(
187 | 'TRECCOVID is an ad-hoc search challenge based on the '
188 | 'COVID-19 dataset containing scientific articles '
189 | 'related to the COVID-19 pandemic.'
190 | ),
191 | reference='https://ir.nist.gov/covidSubmit/index.html',
192 | dataset={
193 | 'path': 'mteb/trec-covid',
194 | 'revision': 'bb9466bac8153a0349341eb1b22e06409e78ef4e',
195 | 'name': 'TRECCOVID',
196 | },
197 | type='Retrieval',
198 | category='s2p',
199 | eval_splits=['test'],
200 | eval_langs=['eng-Latn'],
201 | main_score='ndcg_at_10',
202 | date=None,
203 | form=None,
204 | domains=None,
205 | task_subtypes=None,
206 | license=None,
207 | socioeconomic_status=None,
208 | annotations_creators=None,
209 | dialect=None,
210 | text_creation=None,
211 | bibtex_citation=None,
212 | n_samples=None,
213 | avg_character_length=None,
214 | )
215 |
216 | def __init__(self, **kwargs):
217 | super().__init__(**kwargs)
218 |
219 |
220 | class LEMBWikimQARetrievalChunked(AbsTaskChunkedRetrieval):
221 | """
222 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py
223 | """
224 |
225 | _EVAL_SPLIT = "test"
226 |
227 | metadata = TaskMetadata(
228 | name="LEMBWikimQARetrievalChunked",
229 | dataset={
230 | "path": "dwzhu/LongEmbed",
231 | "revision": "10039a580487dacecf79db69166e17ace3ede392",
232 | "name": "LEMBWikimQARetrieval",
233 | },
234 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
235 | description=("2wikimqa subset of dwzhu/LongEmbed dataset."),
236 | type="Retrieval",
237 | category="s2p",
238 | modalities=["text"],
239 | eval_splits=[_EVAL_SPLIT],
240 | eval_langs=["eng-Latn"],
241 | main_score="ndcg_at_10",
242 | date=("1950-01-01", "2019-12-31"),
243 | domains=None,
244 | socioeconomic_status=None,
245 | n_samples=None,
246 | avg_character_length=None,
247 | form=None,
248 | text_creation=None,
249 | task_subtypes=["Article retrieval"],
250 | license="not specified",
251 | annotations_creators="derived",
252 | dialect=[],
253 | sample_creation="found",
254 | bibtex_citation="""
255 | @inproceedings{ho2020constructing,
256 | title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps},
257 | author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko},
258 | booktitle={Proceedings of the 28th International Conference on Computational Linguistics},
259 | pages={6609--6625},
260 | year={2020}
261 | }
262 | """,
263 | descriptive_stats={
264 | "n_samples": {_EVAL_SPLIT: 500},
265 | "avg_character_length": {
266 | "test": {
267 | "average_document_length": 37445.60333333333,
268 | "average_query_length": 67.57,
269 | "num_documents": 300,
270 | "num_queries": 300,
271 | "average_relevant_docs_per_query": 1.0,
272 | }
273 | },
274 | },
275 | )
276 |
277 | def load_data(self, **kwargs):
278 | if self.data_loaded:
279 | return
280 |
281 | dataset_dict = {**self.metadata.dataset}
282 | dataset_dict['name'] = '2wikimqa'
283 |
284 | query_list = datasets.load_dataset(**dataset_dict)["queries"]
285 | queries = {row["qid"]: row["text"] for row in query_list}
286 |
287 | corpus_list = datasets.load_dataset(**dataset_dict)["corpus"]
288 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}
289 |
290 | qrels_list = datasets.load_dataset(**dataset_dict)["qrels"]
291 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}
292 |
293 | self.corpus = {self._EVAL_SPLIT: corpus}
294 | self.queries = {self._EVAL_SPLIT: queries}
295 | self.relevant_docs = {self._EVAL_SPLIT: qrels}
296 |
297 | self.data_loaded = True
298 |
299 |
300 | class LEMBSummScreenFDRetrievalChunked(AbsTaskChunkedRetrieval):
301 | """
302 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py
303 | """
304 |
305 | _EVAL_SPLIT = "test"
306 |
307 | metadata = TaskMetadata(
308 | name="LEMBSummScreenFDRetrievalChunked",
309 | dataset={
310 | "path": "dwzhu/LongEmbed",
311 | "revision": "10039a580487dacecf79db69166e17ace3ede392",
312 | "name": "LEMBSummScreenFDRetrieval",
313 | },
314 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
315 | description=("summ_screen_fd subset of dwzhu/LongEmbed dataset."),
316 | type="Retrieval",
317 | category="s2p",
318 | modalities=["text"],
319 | eval_splits=[_EVAL_SPLIT],
320 | eval_langs=["eng-Latn"],
321 | main_score="ndcg_at_10",
322 | date=("1950-01-01", "2019-12-31"),
323 | domains=None,
324 | socioeconomic_status=None,
325 | n_samples=None,
326 | avg_character_length=None,
327 | form=None,
328 | text_creation=None,
329 | task_subtypes=["Article retrieval"],
330 | license="not specified",
331 | annotations_creators="derived",
332 | dialect=[],
333 | sample_creation="found",
334 | bibtex_citation="""
335 | @inproceedings{ho2020constructing,
336 | title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps},
337 | author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko},
338 | booktitle={Proceedings of the 28th International Conference on Computational Linguistics},
339 | pages={6609--6625},
340 | year={2020}
341 | }
342 | """,
343 | descriptive_stats={
344 | "n_samples": {_EVAL_SPLIT: 500},
345 | "avg_character_length": {
346 | "test": {
347 | "average_document_length": 30854.327,
348 | "average_query_length": 591.49,
349 | "num_documents": 300,
350 | "num_queries": 300,
351 | "average_relevant_docs_per_query": 1.0,
352 | }
353 | },
354 | },
355 | )
356 |
357 | def load_data(self, **kwargs):
358 | if self.data_loaded:
359 | return
360 |
361 | dataset_dict = {**self.metadata.dataset}
362 | dataset_dict['name'] = 'summ_screen_fd'
363 |
364 | query_list = datasets.load_dataset(**dataset_dict)["queries"]
365 | queries = {row["qid"]: row["text"] for row in query_list}
366 |
367 | corpus_list = datasets.load_dataset(**dataset_dict)["corpus"]
368 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}
369 |
370 | qrels_list = datasets.load_dataset(**dataset_dict)["qrels"]
371 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}
372 |
373 | self.corpus = {self._EVAL_SPLIT: corpus}
374 | self.queries = {self._EVAL_SPLIT: queries}
375 | self.relevant_docs = {self._EVAL_SPLIT: qrels}
376 |
377 | self.data_loaded = True
378 |
379 |
380 | class LEMBQMSumRetrievalChunked(AbsTaskChunkedRetrieval):
381 | """
382 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py
383 | """
384 |
385 | _EVAL_SPLIT = "test"
386 |
387 | metadata = TaskMetadata(
388 | name="LEMBQMSumRetrievalChunked",
389 | dataset={
390 | "path": "dwzhu/LongEmbed",
391 | "revision": "10039a580487dacecf79db69166e17ace3ede392",
392 | "name": "LEMBQMSumRetrieval",
393 | },
394 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
395 | description=("qmsum subset of dwzhu/LongEmbed dataset."),
396 | type="Retrieval",
397 | category="s2p",
398 | modalities=["text"],
399 | eval_splits=[_EVAL_SPLIT],
400 | eval_langs=["eng-Latn"],
401 | main_score="ndcg_at_10",
402 | date=("1950-01-01", "2019-12-31"),
403 | domains=None,
404 | socioeconomic_status=None,
405 | n_samples=None,
406 | avg_character_length=None,
407 | form=None,
408 | text_creation=None,
409 | task_subtypes=["Article retrieval"],
410 | license="not specified",
411 | annotations_creators="derived",
412 | dialect=[],
413 | sample_creation="found",
414 | bibtex_citation="""
415 | @inproceedings{ho2020constructing,
416 | title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps},
417 | author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko},
418 | booktitle={Proceedings of the 28th International Conference on Computational Linguistics},
419 | pages={6609--6625},
420 | year={2020}
421 | }
422 | """,
423 | descriptive_stats={
424 | "n_samples": {_EVAL_SPLIT: 500},
425 | "avg_character_length": {
426 | "test": {
427 | "average_document_length": 53335.817,
428 | "average_query_length": 433.50,
429 | "num_documents": 300,
430 | "num_queries": 300,
431 | "average_relevant_docs_per_query": 1.0,
432 | }
433 | },
434 | },
435 | )
436 |
437 | def load_data(self, **kwargs):
438 | if self.data_loaded:
439 | return
440 |
441 | dataset_dict = {**self.metadata.dataset}
442 | dataset_dict['name'] = 'qmsum'
443 |
444 | query_list = datasets.load_dataset(**dataset_dict)["queries"]
445 | queries = {row["qid"]: row["text"] for row in query_list}
446 |
447 | corpus_list = datasets.load_dataset(**dataset_dict)["corpus"]
448 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}
449 |
450 | qrels_list = datasets.load_dataset(**dataset_dict)["qrels"]
451 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}
452 |
453 | self.corpus = {self._EVAL_SPLIT: corpus}
454 | self.queries = {self._EVAL_SPLIT: queries}
455 | self.relevant_docs = {self._EVAL_SPLIT: qrels}
456 |
457 | self.data_loaded = True
458 |
459 |
460 | class LEMBNeedleRetrievalChunked(AbsTaskChunkedRetrieval):
461 | """
462 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBNeedleRetrieval.py
463 | """
464 |
465 | _EVAL_SPLIT = [
466 | "test_256",
467 | "test_512",
468 | "test_1024",
469 | "test_2048",
470 | "test_4096",
471 | "test_8192",
472 | "test_16384",
473 | "test_32768",
474 | ]
475 |
476 | metadata = TaskMetadata(
477 | name="LEMBNeedleRetrievalChunked",
478 | dataset={
479 | "path": "dwzhu/LongEmbed",
480 | "revision": "6e346642246bfb4928c560ee08640dc84d074e8c",
481 | "name": "needle",
482 | },
483 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
484 | description=("needle subset of dwzhu/LongEmbed dataset."),
485 | type="Retrieval",
486 | category="s2p",
487 | modalities=["text"],
488 | eval_splits=_EVAL_SPLIT,
489 | eval_langs=["eng-Latn"],
490 | main_score="ndcg_at_1",
491 | date=("2000-01-01", "2023-12-31"),
492 | domains=["Academic", "Blog", "Written"],
493 | task_subtypes=["Article retrieval"],
494 | license="not specified",
495 | annotations_creators="derived",
496 | dialect=[],
497 | sample_creation="found",
498 | bibtex_citation="""
499 | @article{zhu2024longembed,
500 | title={LongEmbed: Extending Embedding Models for Long Context Retrieval},
501 | author={Zhu, Dawei and Wang, Liang and Yang, Nan and Song, Yifan and Wu, Wenhao and Wei, Furu and Li, Sujian},
502 | journal={arXiv preprint arXiv:2404.12096},
503 | year={2024}
504 | }
505 | """,
506 | descriptive_stats={
507 | "n_samples": {
508 | "test_256": 150,
509 | "test_512": 150,
510 | "test_1024": 150,
511 | "test_2048": 150,
512 | "test_4096": 150,
513 | "test_8192": 150,
514 | "test_16384": 150,
515 | "test_32768": 150,
516 | },
517 | "avg_character_length": {
518 | "test_256": {
519 | "average_document_length": 1013.22,
520 | "average_query_length": 60.48,
521 | "num_documents": 100,
522 | "num_queries": 50,
523 | "average_relevant_docs_per_query": 1.0,
524 | },
525 | "test_512": {
526 | "average_document_length": 2009.96,
527 | "average_query_length": 57.3,
528 | "num_documents": 100,
529 | "num_queries": 50,
530 | "average_relevant_docs_per_query": 1.0,
531 | },
532 | "test_1024": {
533 | "average_document_length": 4069.9,
534 | "average_query_length": 58.28,
535 | "num_documents": 100,
536 | "num_queries": 50,
537 | "average_relevant_docs_per_query": 1.0,
538 | },
539 | "test_2048": {
540 | "average_document_length": 8453.82,
541 | "average_query_length": 59.92,
542 | "num_documents": 100,
543 | "num_queries": 50,
544 | "average_relevant_docs_per_query": 1.0,
545 | },
546 | "test_4096": {
547 | "average_document_length": 17395.8,
548 | "average_query_length": 55.86,
549 | "num_documents": 100,
550 | "num_queries": 50,
551 | "average_relevant_docs_per_query": 1.0,
552 | },
553 | "test_8192": {
554 | "average_document_length": 35203.82,
555 | "average_query_length": 59.6,
556 | "num_documents": 100,
557 | "num_queries": 50,
558 | "average_relevant_docs_per_query": 1.0,
559 | },
560 | "test_16384": {
561 | "average_document_length": 72054.8,
562 | "average_query_length": 59.12,
563 | "num_documents": 100,
564 | "num_queries": 50,
565 | "average_relevant_docs_per_query": 1.0,
566 | },
567 | "test_32768": {
568 | "average_document_length": 141769.8,
569 | "average_query_length": 58.34,
570 | "num_documents": 100,
571 | "num_queries": 50,
572 | "average_relevant_docs_per_query": 1.0,
573 | },
574 | },
575 | },
576 | )
577 |
578 | def load_data(self, **kwargs):
579 | if self.data_loaded:
580 | return
581 |
582 | self.corpus = {}
583 | self.queries = {}
584 | self.relevant_docs = {}
585 |
586 | for split in self._EVAL_SPLIT:
587 | context_length = int(split.split("_")[1])
588 | query_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
589 | "queries"
590 | ] # dict_keys(['qid', 'text'])
591 | query_list = query_list.filter(
592 | lambda x: x["context_length"] == context_length
593 | )
594 | queries = {row["qid"]: row["text"] for row in query_list}
595 |
596 | corpus_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
597 | "corpus"
598 | ] # dict_keys(['doc_id', 'text'])
599 | corpus_list = corpus_list.filter(
600 | lambda x: x["context_length"] == context_length
601 | )
602 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}
603 |
604 | qrels_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
605 | "qrels"
606 | ] # dict_keys(['qid', 'doc_id'])
607 | qrels_list = qrels_list.filter(
608 | lambda x: x["context_length"] == context_length
609 | )
610 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}
611 |
612 | self.corpus[split] = corpus
613 | self.queries[split] = queries
614 | self.relevant_docs[split] = qrels
615 |
616 | self.data_loaded = True
617 |
618 |
619 | class LEMBPasskeyRetrievalChunked(AbsTaskChunkedRetrieval):
620 | """
621 | modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBPasskeyRetrieval.py
622 | """
623 |
624 | _EVAL_SPLIT = [
625 | "test_256",
626 | "test_512",
627 | "test_1024",
628 | "test_2048",
629 | "test_4096",
630 | "test_8192",
631 | "test_16384",
632 | "test_32768",
633 | ]
634 |
635 | metadata = TaskMetadata(
636 | name="LEMBPasskeyRetrievalChunked",
637 | dataset={
638 | "path": "dwzhu/LongEmbed",
639 | "revision": "6e346642246bfb4928c560ee08640dc84d074e8c",
640 | "name": "passkey",
641 | },
642 | reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
643 | description=("passkey subset of dwzhu/LongEmbed dataset."),
644 | type="Retrieval",
645 | category="s2p",
646 | modalities=["text"],
647 | eval_splits=_EVAL_SPLIT,
648 | eval_langs=["eng-Latn"],
649 | main_score="ndcg_at_1",
650 | date=("2000-01-01", "2023-12-31"),
651 | domains=["Fiction", "Written"],
652 | task_subtypes=["Article retrieval"],
653 | license="not specified",
654 | annotations_creators="derived",
655 | dialect=[],
656 | sample_creation="found",
657 | bibtex_citation="""
658 | @article{zhu2024longembed,
659 | title={LongEmbed: Extending Embedding Models for Long Context Retrieval},
660 | author={Zhu, Dawei and Wang, Liang and Yang, Nan and Song, Yifan and Wu, Wenhao and Wei, Furu and Li, Sujian},
661 | journal={arXiv preprint arXiv:2404.12096},
662 | year={2024}
663 | }
664 | """,
665 | descriptive_stats={
666 | "n_samples": {
667 | "test_256": 150,
668 | "test_512": 150,
669 | "test_1024": 150,
670 | "test_2048": 150,
671 | "test_4096": 150,
672 | "test_8192": 150,
673 | "test_16384": 150,
674 | "test_32768": 150,
675 | },
676 | "avg_character_length": {
677 | "test_256": {
678 | "average_document_length": 876.24,
679 | "average_query_length": 38.1,
680 | "num_documents": 100,
681 | "num_queries": 50,
682 | "average_relevant_docs_per_query": 1.0,
683 | },
684 | "test_512": {
685 | "average_document_length": 1785.2,
686 | "average_query_length": 37.76,
687 | "num_documents": 100,
688 | "num_queries": 50,
689 | "average_relevant_docs_per_query": 1.0,
690 | },
691 | "test_1024": {
692 | "average_document_length": 3607.18,
693 | "average_query_length": 37.68,
694 | "num_documents": 100,
695 | "num_queries": 50,
696 | "average_relevant_docs_per_query": 1.0,
697 | },
698 | "test_2048": {
699 | "average_document_length": 7242.2,
700 | "average_query_length": 37.8,
701 | "num_documents": 100,
702 | "num_queries": 50,
703 | "average_relevant_docs_per_query": 1.0,
704 | },
705 | "test_4096": {
706 | "average_document_length": 14518.16,
707 | "average_query_length": 37.64,
708 | "num_documents": 100,
709 | "num_queries": 50,
710 | "average_relevant_docs_per_query": 1.0,
711 | },
712 | "test_8192": {
713 | "average_document_length": 29071.16,
714 | "average_query_length": 37.54,
715 | "num_documents": 100,
716 | "num_queries": 50,
717 | "average_relevant_docs_per_query": 1.0,
718 | },
719 | "test_16384": {
720 | "average_document_length": 58175.16,
721 | "average_query_length": 38.12,
722 | "num_documents": 100,
723 | "num_queries": 50,
724 | "average_relevant_docs_per_query": 1.0,
725 | },
726 | "test_32768": {
727 | "average_document_length": 116380.16,
728 | "average_query_length": 37.74,
729 | "num_documents": 100,
730 | "num_queries": 50,
731 | "average_relevant_docs_per_query": 1.0,
732 | },
733 | },
734 | },
735 | )
736 |
737 | def load_data(self, **kwargs):
738 | if self.data_loaded:
739 | return
740 |
741 | self.corpus = {}
742 | self.queries = {}
743 | self.relevant_docs = {}
744 |
745 | for split in self._EVAL_SPLIT:
746 | context_length = int(split.split("_")[1])
747 | query_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
748 | "queries"
749 | ] # dict_keys(['qid', 'text'])
750 | query_list = query_list.filter(
751 | lambda x: x["context_length"] == context_length
752 | )
753 | queries = {row["qid"]: row["text"] for row in query_list}
754 |
755 | corpus_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
756 | "corpus"
757 | ] # dict_keys(['doc_id', 'text'])
758 | corpus_list = corpus_list.filter(
759 | lambda x: x["context_length"] == context_length
760 | )
761 | corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}
762 |
763 | qrels_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
764 | "qrels"
765 | ] # dict_keys(['qid', 'doc_id'])
766 | qrels_list = qrels_list.filter(
767 | lambda x: x["context_length"] == context_length
768 | )
769 | qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}
770 |
771 | self.corpus[split] = corpus
772 | self.queries[split] = queries
773 | self.relevant_docs[split] = qrels
774 |
775 | self.data_loaded = True
776 |
--------------------------------------------------------------------------------
/chunked_pooling/chunking.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import logging
3 | from typing import Dict, List, Optional, Tuple, Union
4 |
5 | from llama_index.core.node_parser import SemanticSplitterNodeParser
6 | from llama_index.core.schema import Document
7 | from llama_index.embeddings.huggingface import HuggingFaceEmbedding
8 | from transformers import AutoTokenizer
9 |
10 | # Set the logging level to WARNING to suppress INFO and DEBUG messages
11 | logging.getLogger('sentence_transformers').setLevel(logging.WARNING)
12 |
13 | CHUNKING_STRATEGIES = ['semantic', 'fixed', 'sentences']
14 |
15 |
16 | class Chunker:
17 | def __init__(
18 | self,
19 | chunking_strategy: str,
20 | ):
21 | if chunking_strategy not in CHUNKING_STRATEGIES:
22 | raise ValueError("Unsupported chunking strategy: ", chunking_strategy)
23 | self.chunking_strategy = chunking_strategy
24 | self.embed_model = None
25 | self.embedding_model_name = None
26 |
27 | def _setup_semantic_chunking(self, embedding_model_name):
28 | if embedding_model_name:
29 | self.embedding_model_name = embedding_model_name
30 |
31 | self.embed_model = HuggingFaceEmbedding(
32 | model_name=self.embedding_model_name,
33 | trust_remote_code=True,
34 | embed_batch_size=1,
35 | )
36 | self.splitter = SemanticSplitterNodeParser(
37 | embed_model=self.embed_model,
38 | show_progress=False,
39 | )
40 |
41 | def chunk_semantically(
42 | self,
43 | text: str,
44 | tokenizer: 'AutoTokenizer',
45 | embedding_model_name: Optional[str] = None,
46 | ) -> List[Tuple[int, int]]:
47 | if self.embed_model is None:
48 | self._setup_semantic_chunking(embedding_model_name)
49 |
50 | # Get semantic nodes
51 | nodes = [
52 | (node.start_char_idx, node.end_char_idx)
53 | for node in self.splitter.get_nodes_from_documents(
54 | [Document(text=text)], show_progress=False
55 | )
56 | ]
57 |
58 | # Tokenize the entire text
59 | tokens = tokenizer.encode_plus(
60 | text,
61 | return_offsets_mapping=True,
62 | add_special_tokens=False,
63 | padding=True,
64 | truncation=True,
65 | )
66 | token_offsets = tokens.offset_mapping
67 |
68 | chunk_spans = []
69 |
70 | for char_start, char_end in nodes:
71 | # Convert char indices to token indices
72 | start_chunk_index = bisect.bisect_left(
73 | [offset[0] for offset in token_offsets], char_start
74 | )
75 | end_chunk_index = bisect.bisect_right(
76 | [offset[1] for offset in token_offsets], char_end
77 | )
78 |
79 | # Add the chunk span if it's within the tokenized text
80 | if start_chunk_index < len(token_offsets) and end_chunk_index <= len(
81 | token_offsets
82 | ):
83 | chunk_spans.append((start_chunk_index, end_chunk_index))
84 | else:
85 | break
86 |
87 | return chunk_spans
88 |
89 | def chunk_by_tokens(
90 | self,
91 | text: str,
92 | chunk_size: int,
93 | tokenizer: 'AutoTokenizer',
94 | ) -> List[Tuple[int, int, int]]:
95 | tokens = tokenizer.encode_plus(
96 | text, return_offsets_mapping=True, add_special_tokens=False
97 | )
98 | token_offsets = tokens.offset_mapping
99 |
100 | chunk_spans = []
101 | for i in range(0, len(token_offsets), chunk_size):
102 | chunk_end = min(i + chunk_size, len(token_offsets))
103 | if chunk_end - i > 0:
104 | chunk_spans.append((i, chunk_end))
105 |
106 | return chunk_spans
107 |
108 | def chunk_by_sentences(
109 | self,
110 | text: str,
111 | n_sentences: int,
112 | tokenizer: 'AutoTokenizer',
113 | ) -> List[Tuple[int, int, int]]:
114 | tokens = tokenizer.encode_plus(
115 | text, return_offsets_mapping=True, add_special_tokens=False
116 | )
117 | token_offsets = tokens.offset_mapping
118 |
119 | chunk_spans = []
120 | chunk_start = 0
121 | count_chunks = 0
122 | for i in range(0, len(token_offsets)):
123 | if tokens.tokens(0)[i] in ('.', '!', '?') and (
124 | (len(tokens.tokens(0)) == i + 1)
125 | or (tokens.token_to_chars(i).end != tokens.token_to_chars(i + 1).start)
126 | ):
127 | count_chunks += 1
128 | if count_chunks == n_sentences:
129 | chunk_spans.append((chunk_start, i + 1))
130 | chunk_start = i + 1
131 | count_chunks = 0
132 | if len(tokens.tokens(0)) - chunk_start > 1:
133 | chunk_spans.append((chunk_start, len(tokens.tokens(0))))
134 | return chunk_spans
135 |
136 | def chunk(
137 | self,
138 | text: str,
139 | tokenizer: 'AutoTokenizer',
140 | chunking_strategy: str = None,
141 | chunk_size: Optional[int] = None,
142 | n_sentences: Optional[int] = None,
143 | embedding_model_name: Optional[str] = None,
144 | ):
145 | chunking_strategy = chunking_strategy or self.chunking_strategy
146 | if chunking_strategy == "semantic":
147 | return self.chunk_semantically(
148 | text,
149 | embedding_model_name=embedding_model_name,
150 | tokenizer=tokenizer,
151 | )
152 | elif chunking_strategy == "fixed":
153 | if chunk_size < 4:
154 | raise ValueError("Chunk size must be >= 4.")
155 | return self.chunk_by_tokens(text, chunk_size, tokenizer)
156 | elif chunking_strategy == "sentences":
157 | return self.chunk_by_sentences(text, n_sentences, tokenizer)
158 | else:
159 | raise ValueError("Unsupported chunking strategy")
160 |
--------------------------------------------------------------------------------
/chunked_pooling/mteb_chunked_eval.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any, Optional
3 |
4 | import numpy as np
5 | import torch
6 | from mteb.abstasks import AbsTask
7 | from mteb.evaluation.evaluators import RetrievalEvaluator
8 | from mteb.load_results.mteb_results import ScoresDict
9 | from mteb.tasks import Retrieval
10 | from tqdm import tqdm
11 |
12 | from chunked_pooling import chunked_pooling
13 | from chunked_pooling.chunking import Chunker
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class AbsTaskChunkedRetrieval(AbsTask):
19 | def __init__(
20 | self,
21 | chunking_strategy: str = None,
22 | chunked_pooling_enabled: bool = False,
23 | tokenizer: Optional[Any] = None,
24 | prune_size: Optional[int] = None,
25 | chunk_size: Optional[int] = None,
26 | n_sentences: Optional[int] = None,
27 | model_has_instructions: bool = False,
28 | embedding_model_name: Optional[str] = None, # for semantic chunking
29 | truncate_max_length: Optional[int] = 8192,
30 | long_late_chunking_embed_size: Optional[int] = 0,
31 | long_late_chunking_overlap_size: Optional[int] = 512,
32 | **kwargs,
33 | ):
34 | super().__init__(**kwargs)
35 | try:
36 | self.retrieval_task = getattr(
37 | Retrieval,
38 | self.metadata_dict['dataset'].get('name', None)
39 | or self.metadata_dict.get('name'),
40 | )()
41 | except:
42 | logger.warning('Could not initialize retrieval_task')
43 | self.chunking_strategy = chunking_strategy
44 | self.chunker = Chunker(self.chunking_strategy)
45 | self.chunked_pooling_enabled = chunked_pooling_enabled
46 | self.tokenizer = tokenizer
47 | self.prune_size = prune_size
48 | self.model_has_instructions = model_has_instructions
49 | self.chunking_args = {
50 | 'chunk_size': chunk_size,
51 | 'n_sentences': n_sentences,
52 | 'embedding_model_name': embedding_model_name,
53 | }
54 | self.truncate_max_length = (
55 | truncate_max_length if truncate_max_length is not None and truncate_max_length > 0 else None
56 | )
57 |
58 | self.long_late_chunking_embed_size = long_late_chunking_embed_size
59 | self.long_late_chunking_overlap_size = long_late_chunking_overlap_size
60 |
61 | def load_data(self, **kwargs):
62 | self.retrieval_task.load_data(**kwargs)
63 | self.corpus = self.retrieval_task.corpus
64 | self.queries = self.retrieval_task.queries
65 | self.relevant_docs = self.retrieval_task.relevant_docs
66 | # prune dataset
67 | if self.prune_size:
68 | self.queries, self.corpus, self.relevant_docs = self._prune(
69 | self.queries, self.corpus, self.relevant_docs, self.prune_size
70 | )
71 |
72 | def calculate_metadata_metrics(self):
73 | self.retrieval_task.calculate_metadata_metrics()
74 |
75 | def evaluate(
76 | self, model, split: str = "test", encode_kwargs: dict[str, Any] = {}, **kwargs
77 | ) -> dict[str, ScoresDict]:
78 | scores: dict[str, ScoresDict] = {}
79 | hf_subsets = list(self.hf_subsets) if self.is_multilingual else ["default"]
80 |
81 | for hf_subset in hf_subsets:
82 | logger.info(f"Subset: {hf_subset}")
83 |
84 | if hf_subset == "default":
85 | corpus, queries, relevant_docs = (
86 | self.corpus[split],
87 | self.queries[split],
88 | self.relevant_docs[split],
89 | )
90 | else:
91 | corpus, queries, relevant_docs = (
92 | self.corpus[hf_subset][split],
93 | self.queries[hf_subset][split],
94 | self.relevant_docs[hf_subset][split],
95 | )
96 |
97 | scores[hf_subset] = self._evaluate_monolingual(
98 | model,
99 | corpus,
100 | queries,
101 | relevant_docs,
102 | hf_subset,
103 | encode_kwargs=encode_kwargs,
104 | **kwargs,
105 | )
106 |
107 | return scores
108 |
109 | def _truncate_documents(self, corpus):
110 | for k, v in corpus.items():
111 | title_tokens = 0
112 | if 'title' in v:
113 | tokens = self.tokenizer(
114 | v['title'] + ' ',
115 | return_offsets_mapping=True,
116 | max_length=self.truncate_max_length,
117 | )
118 | title_tokens = len(tokens.input_ids)
119 | tokens = self.tokenizer(
120 | v['text'],
121 | return_offsets_mapping=True,
122 | max_length=self.truncate_max_length - title_tokens,
123 | )
124 | last_token_span = tokens.offset_mapping[-2]
125 | v['text'] = v['text'][: last_token_span[1]]
126 | return corpus
127 |
128 | def _embed_with_overlap(self, model, model_inputs):
129 | len_tokens = len(model_inputs["input_ids"][0])
130 |
131 | if len_tokens > self.long_late_chunking_embed_size:
132 | indices = []
133 | for i in range(
134 | 0,
135 | len_tokens,
136 | self.long_late_chunking_embed_size
137 | - self.long_late_chunking_overlap_size,
138 | ):
139 | start = i
140 | end = min(i + self.long_late_chunking_embed_size, len_tokens)
141 | indices.append((start, end))
142 | else:
143 | indices = [(0, len_tokens)]
144 |
145 | outputs = []
146 | for start, end in indices:
147 | batch_inputs = {k: v[:, start:end] for k, v in model_inputs.items()}
148 |
149 | with torch.no_grad():
150 | model_output = model(**batch_inputs)
151 |
152 | if start > 0:
153 | outputs.append(
154 | model_output[0][:, self.long_late_chunking_overlap_size :]
155 | )
156 | else:
157 | outputs.append(model_output[0])
158 |
159 | return torch.cat(outputs, dim=1).to(model.device)
160 |
161 | def _evaluate_monolingual(
162 | self,
163 | model,
164 | corpus,
165 | queries,
166 | relevant_docs,
167 | lang=None,
168 | batch_size=1,
169 | encode_kwargs=None,
170 | **kwargs,
171 | ):
172 | if self.truncate_max_length:
173 | corpus = self._truncate_documents(corpus)
174 | # split corpus into chunks
175 | if not self.chunked_pooling_enabled:
176 | corpus = self._apply_chunking(corpus, self.tokenizer)
177 | max_chunks = max([len(x) for x in corpus.values()])
178 | corpus = self._flatten_chunks(corpus)
179 | k_values = self._calculate_k_values(max_chunks)
180 | # determine the maximum number of documents to consider in a ranking
181 | max_k = int(max(k_values) / max_chunks)
182 | retriever = RetrievalEvaluator(
183 | model,
184 | k_values=k_values,
185 | encode_kwargs=(encode_kwargs or dict()),
186 | **kwargs,
187 | )
188 | results = retriever(corpus, queries)
189 | else:
190 | query_ids = list(queries.keys())
191 | query_texts = [queries[k] for k in query_ids]
192 | if hasattr(model, 'encode_queries'):
193 | query_embs = model.encode_queries(query_texts)
194 | else:
195 | query_embs = model.encode(query_texts)
196 |
197 | corpus_ids = list(corpus.keys())
198 | corpus_texts = [
199 | (
200 | f"{corpus[k]['title']} {corpus[k]['text']}"
201 | if 'title' in corpus[k]
202 | else corpus[k]['text']
203 | )
204 | for k in corpus_ids
205 | ]
206 |
207 | chunk_annotations = self._calculate_annotations(model, corpus_texts)
208 |
209 | corpus_embs = []
210 | with torch.no_grad():
211 | for inputs in tqdm(
212 | self._batch_inputs(
213 | list(zip(corpus_texts, chunk_annotations)),
214 | batch_size=batch_size,
215 | ),
216 | total=(len(corpus_texts) // batch_size),
217 | ):
218 | if self.model_has_instructions:
219 | instr = model.get_instructions()[1]
220 | else:
221 | instr = ''
222 | text_inputs = [instr + x[0] for x in inputs]
223 | annotations = [x[1] for x in inputs]
224 | model_inputs = self.tokenizer(
225 | text_inputs,
226 | return_tensors='pt',
227 | padding=True,
228 | truncation=self.truncate_max_length is not None,
229 | max_length=self.truncate_max_length,
230 | )
231 | if model.device.type == 'cuda':
232 | model_inputs = {
233 | k: v.to(model.device) for k, v in model_inputs.items()
234 | }
235 |
236 | if self.long_late_chunking_embed_size > 0:
237 | model_outputs = self._embed_with_overlap(model, model_inputs)
238 | output_embs = chunked_pooling(
239 | [model_outputs], annotations, max_length=None
240 | )
241 | else: # truncation
242 | model_outputs = model(**model_inputs)
243 | output_embs = chunked_pooling(
244 | model_outputs,
245 | annotations,
246 | max_length=self.truncate_max_length,
247 | )
248 | corpus_embs.extend(output_embs)
249 |
250 | max_chunks = max([len(x) for x in corpus_embs])
251 | k_values = self._calculate_k_values(max_chunks)
252 | # determine the maximum number of documents to consider in a ranking
253 | max_k = int(max(k_values) / max_chunks)
254 | (
255 | chunk_id_list,
256 | doc_to_chunk,
257 | flattened_corpus_embs,
258 | ) = self.flatten_corpus_embs(corpus_embs, corpus_ids)
259 | similarity_matrix = np.dot(query_embs, flattened_corpus_embs.T)
260 | results = self.get_results(
261 | chunk_id_list, k_values, query_ids, similarity_matrix
262 | )
263 |
264 | doc_results = self.get_doc_results(results)
265 |
266 | ndcg, _map, recall, precision, _ = RetrievalEvaluator.evaluate(
267 | relevant_docs,
268 | doc_results,
269 | [k for k in k_values if k <= max_k],
270 | ignore_identical_ids=kwargs.get('ignore_identical_ids', True),
271 | )
272 | mrr, _ = RetrievalEvaluator.evaluate_custom(
273 | relevant_docs,
274 | doc_results,
275 | [k for k in k_values if k <= max_k],
276 | 'mrr',
277 | )
278 | scores = {
279 | **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
280 | **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
281 | **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
282 | **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
283 | **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
284 | }
285 | self._add_main_score(scores)
286 | return scores
287 |
288 | def _add_main_score(self, scores: ScoresDict) -> None:
289 | scores["main_score"] = scores[self.metadata.main_score]
290 |
291 | def get_results(self, chunk_id_list, k_values, query_ids, similarity_matrix):
292 | results = {}
293 | for i, query_id in enumerate(query_ids):
294 | query_results = {}
295 | for idx, score in enumerate(similarity_matrix[i]):
296 | chunk_id = chunk_id_list[idx]
297 | query_results[chunk_id] = score
298 | # Sort results by score and only keep the top k scores
299 | sorted_query_results = dict(
300 | sorted(query_results.items(), key=lambda item: item[1], reverse=True)[
301 | : max(k_values)
302 | ]
303 | )
304 | results[query_id] = sorted_query_results
305 | return results
306 |
307 | def flatten_corpus_embs(self, corpus_embs, corpus_ids):
308 | doc_to_chunk = {}
309 | flattened_corpus_embs = []
310 | chunk_id_list = []
311 | for doc_id, emb in zip(corpus_ids, corpus_embs):
312 | for i, chunk in enumerate(emb):
313 | flattened_corpus_embs.append(chunk)
314 | doc_to_chunk[f"{doc_id}~{i}"] = doc_id
315 | chunk_id_list.append(f"{doc_id}~{i}")
316 | flattened_corpus_embs = np.vstack(flattened_corpus_embs)
317 | flattened_corpus_embs = self._normalize(flattened_corpus_embs)
318 | return chunk_id_list, doc_to_chunk, flattened_corpus_embs
319 |
320 | @staticmethod
321 | def get_doc_results(results):
322 | doc_results = dict()
323 | for q, result_chunks in results.items():
324 | docs = dict()
325 | for c_id, score in result_chunks.items():
326 | d_id = '~'.join(c_id.split('~')[:-1])
327 | if (d_id not in docs) or (score > docs[d_id]):
328 | docs[d_id] = float(score)
329 | doc_results[q] = docs
330 | return doc_results
331 |
332 | def _calculate_k_values(self, max_chunks):
333 | k_values = [1, 3, 5, 10, 20]
334 | n = 2
335 | while 10**n < 100 * max_chunks:
336 | k_values.append(10**n)
337 | n += 1
338 | return k_values
339 |
340 | def _apply_chunking(self, corpus, tokenizer):
341 | chunked_corpus = dict()
342 | for k, v in corpus.items():
343 | text = f"{v['title']} {v['text']}" if 'title' in v else v['text']
344 | current_doc = []
345 | chunk_annotations = self.chunker.chunk(
346 | text,
347 | tokenizer,
348 | chunking_strategy=self.chunking_strategy,
349 | **self.chunking_args,
350 | )
351 | tokens = tokenizer.encode_plus(text, add_special_tokens=False)
352 | for start_token_idx, end_token_idx in chunk_annotations:
353 | text_chunk = tokenizer.decode(
354 | tokens.encodings[0].ids[start_token_idx:end_token_idx]
355 | )
356 | current_doc.append({'text': text_chunk})
357 | chunked_corpus[k] = current_doc
358 | return chunked_corpus
359 |
360 | def _calculate_annotations(self, model, corpus_texts):
361 | if self.model_has_instructions:
362 | instr = model.get_instructions()[1]
363 | instr_tokens = self.tokenizer(instr, add_special_tokens=False)
364 | n_instruction_tokens = len(instr_tokens[0])
365 | else:
366 | n_instruction_tokens = 0
367 | chunk_annotations = [
368 | self._extend_special_tokens(
369 | self.chunker.chunk(
370 | text,
371 | self.tokenizer,
372 | chunking_strategy=self.chunking_strategy,
373 | **self.chunking_args,
374 | ),
375 | n_instruction_tokens=n_instruction_tokens,
376 | )
377 | for text in corpus_texts
378 | ]
379 | return chunk_annotations
380 |
381 | @staticmethod
382 | def _flatten_chunks(chunked_corpus):
383 | flattened_corpus = dict()
384 | for k, li in chunked_corpus.items():
385 | for i, c in enumerate(li):
386 | flattened_corpus[f'{k}~{i}'] = c
387 |
388 | return flattened_corpus
389 |
390 | @staticmethod
391 | def _normalize(x):
392 | return x / np.linalg.norm(x, axis=1)[:, None]
393 |
394 | @staticmethod
395 | def _batch_inputs(li, batch_size):
396 | for i in range(0, len(li), batch_size):
397 | yield li[i : i + batch_size]
398 |
399 | @staticmethod
400 | def _extend_special_tokens(
401 | annotations, n_instruction_tokens=0, include_prefix=True, include_sep=True
402 | ):
403 | """Extends the spans because of additional special tokens, e.g. the CLS token
404 | which are not considered by the chunker.
405 | """
406 | new_annotations = []
407 | for i in range(len(annotations)):
408 | add_left_offset = 1 if (not include_prefix) or int(i > 0) else 0
409 | left_offset = 1 + n_instruction_tokens
410 | left = (
411 | annotations[i][0] + add_left_offset * left_offset
412 | ) # move everything by one for [CLS]
413 |
414 | add_sep = 1 if include_sep and ((i + 1) == len(annotations)) else 0
415 | right_offset = left_offset + add_sep
416 | right = (
417 | annotations[i][1] + right_offset
418 | ) # move everything by one for [CLS] and the last one for [SEP]
419 |
420 | new_annotations.append((left, right))
421 | return new_annotations
422 |
423 | @staticmethod
424 | def _prune(queries, corpus, relevant_docs, prune_size):
425 | new_queries = {'test': {}}
426 | new_corpus = {'test': {}}
427 | new_relevant_docs = {'test': {}}
428 | for i, key in enumerate(relevant_docs['test']):
429 | if i >= prune_size:
430 | break
431 | new_relevant_docs['test'][key] = relevant_docs['test'][key]
432 | for x in relevant_docs['test'][key]:
433 | new_corpus['test'][x] = corpus['test'][x]
434 | new_queries['test'][key] = queries['test'][key]
435 | return new_queries, new_corpus, new_relevant_docs
436 |
437 | def _calculate_metrics_from_split(*args, **kwargs):
438 | pass
439 |
440 | def _evaluate_subset(*args, **kwargs):
441 | pass
442 |
--------------------------------------------------------------------------------
/chunked_pooling/wrappers.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Optional, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | from sentence_transformers import SentenceTransformer
7 | from transformers import AutoModel
8 | from transformers.modeling_outputs import BaseModelOutputWithPooling
9 |
10 |
11 | def construct_document(doc):
12 | if isinstance(doc, str):
13 | return doc
14 | elif 'title' in doc:
15 | return f'{doc["title"]} {doc["text"].strip()}'
16 | else:
17 | return doc['text'].strip()
18 |
19 |
20 | class JinaEmbeddingsV3Wrapper(nn.Module):
21 | def __init__(
22 | self, model_name, tasks=['retrieval.query', 'retrieval.passage'], **model_kwargs
23 | ):
24 | super().__init__()
25 | self._model = AutoModel.from_pretrained(
26 | model_name, trust_remote_code=True, **model_kwargs
27 | )
28 | self.tasks = tasks
29 |
30 | def encode_queries(
31 | self,
32 | sentences: Union[str, List[str]],
33 | *args,
34 | task: Optional[str] = None,
35 | **kwargs,
36 | ):
37 | return self._model.encode(sentences, *args, task=self.tasks[0], **kwargs)
38 |
39 | def encode_corpus(
40 | self,
41 | sentences: Union[str, List[str]],
42 | *args,
43 | **kwargs,
44 | ):
45 | _sentences = [construct_document(sentence) for sentence in sentences]
46 | return self._model.encode(_sentences, *args, task=self.tasks[1], **kwargs)
47 |
48 | def get_instructions(self):
49 | return [self._model._task_instructions[x] for x in self.tasks]
50 |
51 | def forward(self, *args, **kwargs):
52 | task_id = self._model._adaptation_map[self.tasks[1]]
53 | num_examples = kwargs['input_ids'].shape[0]
54 | adapter_mask = torch.full(
55 | (num_examples,), task_id, dtype=torch.int32, device=self._model.device
56 | )
57 | return self._model.forward(*args, adapter_mask=adapter_mask, **kwargs)
58 |
59 | @property
60 | def device(self):
61 | return self._model.device
62 |
63 | @staticmethod
64 | def has_instructions():
65 | return True
66 |
67 |
68 | class NomicAIWrapper(nn.Module):
69 | def __init__(self, model_name, **model_kwargs):
70 | super().__init__()
71 | self._model = SentenceTransformer(
72 | model_name, trust_remote_code=True, **model_kwargs
73 | )
74 | self.instructions = ['search_query: ', 'search_document: ']
75 |
76 | def get_instructions(self):
77 | return self.instructions
78 |
79 | def forward(self, *args, **kwargs):
80 | model_output = self._model.forward(kwargs)
81 | base_model_output = BaseModelOutputWithPooling(
82 | last_hidden_state=model_output['token_embeddings'],
83 | pooler_output=model_output['sentence_embedding'],
84 | attentions=model_output['attention_mask'],
85 | )
86 | return base_model_output
87 |
88 | def encode_queries(
89 | self,
90 | sentences: Union[str, List[str]],
91 | *args,
92 | **kwargs,
93 | ):
94 | return self._model.encode(
95 | [self.instructions[0] + s for s in sentences], *args, **kwargs
96 | )
97 |
98 | def encode_corpus(
99 | self,
100 | sentences: Union[str, List[str]],
101 | *args,
102 | **kwargs,
103 | ):
104 | return self._model.encode(
105 | [self.instructions[1] + construct_document(s) for s in sentences],
106 | *args,
107 | **kwargs,
108 | )
109 |
110 | @property
111 | def device(self):
112 | return self._model.device
113 |
114 | @staticmethod
115 | def has_instructions():
116 | return True
117 |
118 |
119 | MODEL_WRAPPERS = {
120 | 'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper,
121 | 'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer,
122 | 'nomic-ai/nomic-embed-text-v1': NomicAIWrapper,
123 | }
124 |
125 | MODELS_WITHOUT_PROMPT_NAME_ARG = [
126 | 'jinaai/jina-embeddings-v2-small-en',
127 | 'jinaai/jina-embeddings-v2-base-en',
128 | 'jinaai/jina-embeddings-v3',
129 | ]
130 |
131 |
132 | def remove_unsupported_kwargs(original_encode):
133 | def wrapper(self, *args, **kwargs):
134 | # Remove 'prompt_name' from kwargs if present
135 | kwargs.pop('prompt_name', None)
136 | kwargs.pop('request_qid', None)
137 | return original_encode(self, *args, **kwargs)
138 |
139 | return wrapper
140 |
141 |
142 | def load_model(model_name, model_weights=None, **model_kwargs):
143 | if model_name in MODEL_WRAPPERS:
144 | model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs)
145 | if hasattr(MODEL_WRAPPERS[model_name], 'has_instructions'):
146 | has_instructions = MODEL_WRAPPERS[model_name].has_instructions()
147 | else:
148 | has_instructions = False
149 | else:
150 | model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
151 | has_instructions = False
152 |
153 | if model_weights and os.path.exists(model_weights):
154 | model._model.load_state_dict(torch.load(model_weights, device=model.device))
155 |
156 | # encode functions of various models do not support all sentence transformers kwargs parameter
157 | if model_name in MODELS_WITHOUT_PROMPT_NAME_ARG:
158 | ENCODE_FUNC_NAMES = ['encode', 'encode_queries', 'encode_corpus']
159 | for func_name in ENCODE_FUNC_NAMES:
160 | if hasattr(model, func_name):
161 | setattr(
162 | model,
163 | func_name,
164 | remove_unsupported_kwargs(getattr(model, func_name)),
165 | )
166 |
167 | return model, has_instructions
168 |
--------------------------------------------------------------------------------
/examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "e1173893c4f0ea56",
6 | "metadata": {
7 | "collapsed": false,
8 | "jupyter": {
9 | "outputs_hidden": false
10 | }
11 | },
12 | "source": [
13 | "# Chunked Pooling\n",
14 | "This notebooks explains how the chunked pooling can be implemented. First you need to install the requirements: "
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "id": "d02a920f-cde0-4035-9834-49b087aab5cc",
21 | "metadata": {
22 | "is_executing": true
23 | },
24 | "outputs": [],
25 | "source": [
26 | "!pip install -r requirements.txt"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "id": "58a8fbc1e477db48",
32 | "metadata": {
33 | "collapsed": false,
34 | "jupyter": {
35 | "outputs_hidden": false
36 | }
37 | },
38 | "source": [
39 | "Then we load a model which we want to use for the embedding. We choose `jinaai/jina-embeddings-v2-base-en` but any other model which supports mean pooling is possible. However, models with a large maximum context-length are preferred."
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 1,
45 | "id": "1380abf7acde9517",
46 | "metadata": {
47 | "collapsed": false,
48 | "jupyter": {
49 | "outputs_hidden": false
50 | }
51 | },
52 | "outputs": [
53 | {
54 | "name": "stderr",
55 | "output_type": "stream",
56 | "text": [
57 | "/home/michael/workspace/chunked-pooling/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
58 | " from .autonotebook import tqdm as notebook_tqdm\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "from transformers import AutoModel\n",
64 | "from transformers import AutoTokenizer\n",
65 | "\n",
66 | "from chunked_pooling import chunked_pooling, chunk_by_sentences\n",
67 | "\n",
68 | "# load model and tokenizer\n",
69 | "tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)\n",
70 | "model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "id": "2cc0c1162797ffb0",
76 | "metadata": {
77 | "collapsed": false,
78 | "jupyter": {
79 | "outputs_hidden": false
80 | }
81 | },
82 | "source": [
83 | "Now we define the text which we want to encode and split it into chunks. The `chunk_by_sentences` function also returns the span annotations. Those specify the number of tokens per chunk which is needed for the chunked pooling."
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 2,
89 | "id": "8ef392f3437ef82e",
90 | "metadata": {
91 | "collapsed": false,
92 | "jupyter": {
93 | "outputs_hidden": false
94 | }
95 | },
96 | "outputs": [
97 | {
98 | "name": "stdout",
99 | "output_type": "stream",
100 | "text": [
101 | "Chunks:\n",
102 | "- \"Berlin is the capital and largest city of Germany, both by area and by population.\"\n",
103 | "- \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"\n",
104 | "- \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"\n"
105 | ]
106 | }
107 | ],
108 | "source": [
109 | "input_text = \"Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"\n",
110 | "\n",
111 | "# determine chunks\n",
112 | "chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)\n",
113 | "print('Chunks:\\n- \"' + '\"\\n- \"'.join(chunks) + '\"')\n"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "id": "9ac41fd1f0560da7",
119 | "metadata": {
120 | "collapsed": false,
121 | "jupyter": {
122 | "outputs_hidden": false
123 | }
124 | },
125 | "source": [
126 | "Now we encode the chunks with the traditional and the context-sensitive chunked pooling method:"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 3,
132 | "id": "abe3d93b9e6609b9",
133 | "metadata": {
134 | "collapsed": false,
135 | "jupyter": {
136 | "outputs_hidden": false
137 | }
138 | },
139 | "outputs": [],
140 | "source": [
141 | "# chunk before\n",
142 | "embeddings_traditional_chunking = model.encode(chunks)\n",
143 | "\n",
144 | "# chunk afterwards (context-sensitive chunked pooling)\n",
145 | "inputs = tokenizer(input_text, return_tensors='pt')\n",
146 | "model_output = model(**inputs)\n",
147 | "embeddings = chunked_pooling(model_output, [span_annotations])[0]"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "id": "e84b1b9d48cb6367",
153 | "metadata": {
154 | "collapsed": false,
155 | "jupyter": {
156 | "outputs_hidden": false
157 | }
158 | },
159 | "source": [
160 | "Finally, we compare the similarity of the word \"Berlin\" with the chunks. The similarity should be higher for the context-sensitive chunked pooling method:"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 4,
166 | "id": "da0cec59a3ece76",
167 | "metadata": {
168 | "collapsed": false,
169 | "jupyter": {
170 | "outputs_hidden": false
171 | }
172 | },
173 | "outputs": [
174 | {
175 | "name": "stdout",
176 | "output_type": "stream",
177 | "text": [
178 | "similarity_new(\"Berlin\", \"Berlin is the capital and largest city of Germany, both by area and by population.\"): 0.849546\n",
179 | "similarity_trad(\"Berlin\", \"Berlin is the capital and largest city of Germany, both by area and by population.\"): 0.84862185\n",
180 | "similarity_new(\"Berlin\", \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"): 0.82489026\n",
181 | "similarity_trad(\"Berlin\", \" Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\"): 0.7084338\n",
182 | "similarity_new(\"Berlin\", \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"): 0.84980094\n",
183 | "similarity_trad(\"Berlin\", \" The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.\"): 0.7534553\n"
184 | ]
185 | }
186 | ],
187 | "source": [
188 | "import numpy as np\n",
189 | "\n",
190 | "cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))\n",
191 | "\n",
192 | "berlin_embedding = model.encode('Berlin')\n",
193 | "\n",
194 | "for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):\n",
195 | " print(f'similarity_new(\"Berlin\", \"{chunk}\"):', cos_sim(berlin_embedding, new_embedding))\n",
196 | " print(f'similarity_trad(\"Berlin\", \"{chunk}\"):', cos_sim(berlin_embedding, trad_embeddings))"
197 | ]
198 | }
199 | ],
200 | "metadata": {
201 | "kernelspec": {
202 | "display_name": "Python 3 (ipykernel)",
203 | "language": "python",
204 | "name": "python3"
205 | },
206 | "language_info": {
207 | "codemirror_mode": {
208 | "name": "ipython",
209 | "version": 3
210 | },
211 | "file_extension": ".py",
212 | "mimetype": "text/x-python",
213 | "name": "python",
214 | "nbconvert_exporter": "python",
215 | "pygments_lexer": "ipython3",
216 | "version": "3.10.12"
217 | }
218 | },
219 | "nbformat": 4,
220 | "nbformat_minor": 5
221 | }
222 |
--------------------------------------------------------------------------------
/explanatory_contextual_retrieval.py:
--------------------------------------------------------------------------------
1 | # experiments/explanatory_contextual_retrieval.py
2 | #
3 | # a simple example with a trivial piece of text to showcase the late chunking method against
4 | # contextual retrieval method. contextual retrieval manually inserts context to each
5 | # chunk, i.e. forces context to be around each chunk. so works as a good comparison
6 | # to late chunking to see if the similarities are similar (which they appear to be)
7 |
8 | from chunked_pooling.wrappers import load_model
9 | from transformers import AutoModel, AutoTokenizer, pipeline, AutoModelForCausalLM
10 | import torch
11 | import numpy as np
12 |
13 | import chunked_pooling
14 | from chunked_pooling import chunked_pooling
15 | from chunked_pooling.chunking import Chunker
16 |
17 | from typing import List, Tuple
18 | from transformers import AutoModel, AutoTokenizer, pipeline
19 |
20 | import requests
21 | import os
22 |
23 | def request_anthropic_api(prompt: str):
24 | url = "https://api.anthropic.com/v1/messages"
25 | headers = {
26 | "x-api-key": os.getenv("ANTHROPIC_API_KEY"),
27 | "anthropic-version": "2023-06-01",
28 | "content-type": "application/json"
29 | }
30 | data = {
31 | "model": "claude-3-haiku-20240307",
32 | "max_tokens": 2048,
33 | "messages": [
34 | {"role": "user", "content": prompt}
35 | ]
36 | }
37 | response = requests.post(url, headers=headers, json=data)
38 | return response.json()["content"][0]["text"]
39 |
40 | def setup_local_llm(llm_name):
41 |
42 | model = AutoModelForCausalLM.from_pretrained(llm_name, trust_remote_code=True)
43 | tokenizer = AutoTokenizer.from_pretrained(llm_name, trust_remote_code=True)
44 |
45 | def llm(prompt):
46 | messages = [{"role": "user", "content": prompt}]
47 | inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
48 | inputs = inputs.to(model.device)
49 | outputs = model.generate(inputs, max_new_tokens=512)
50 | text_output = tokenizer.batch_decode(outputs)[0]
51 | if "<|assistant|>" in text_output:
52 | text_output = text_output.split("<|assistant|>")[1].strip()
53 | return text_output
54 |
55 | return llm
56 |
57 | def cosine_similarity(vector1, vector2):
58 | vector1_norm = vector1 / np.linalg.norm(vector1)
59 | vector2_norm = vector2 / np.linalg.norm(vector2)
60 | return np.dot(vector1_norm, vector2_norm)
61 |
62 | class LateChunkingEmbedder:
63 |
64 | def __init__(self,
65 | model: AutoModel,
66 | tokenizer: AutoTokenizer,
67 | chunking_strategy: str = "sentences",
68 | n_sentences: int = 1
69 | ):
70 |
71 | self.model = model
72 | self.tokenizer = tokenizer
73 |
74 | self.chunker = Chunker(chunking_strategy = chunking_strategy)
75 | self.n_sentences = n_sentences
76 |
77 |
78 | def run(self, document: str):
79 | annotations = [self.chunker.chunk(text=document, tokenizer=self.tokenizer, n_sentences=self.n_sentences)]
80 | model_inputs = self.tokenizer(
81 | document,
82 | return_tensors='pt',
83 | padding=True,
84 | truncation=True,
85 | max_length=8192,
86 | )
87 | model_outputs = self.model(**model_inputs)
88 | self.output_embs = chunked_pooling(
89 | model_outputs, annotations, max_length=8192,
90 | )[0]
91 | return self.output_embs
92 |
93 | def query(self, query: str):
94 | if "output_embs" not in dir(self):
95 | raise ValueError("no embeddings calculated, use .run(document) to create chunk embeddings")
96 | query_embedding = self.model.encode(query)
97 | similarities = []
98 | for emb in self.output_embs:
99 | similarities.append(cosine_similarity(query_embedding, emb))
100 |
101 | return similarities
102 |
103 |
104 | class ContextualRetrievalEmbedder():
105 | def __init__(self,
106 | model: AutoModel,
107 | tokenizer: AutoTokenizer,
108 | llm_name: str = "microsoft/Phi-3.5-mini-instruct",
109 | chunking_strategy: str = "fixed"
110 | ):
111 |
112 | self.llm = setup_local_llm(llm_name)
113 | # self.llm = request_anthropic_api
114 |
115 | self.prompt = """
116 |
117 | {{WHOLE_DOCUMENT}}
118 |
119 | Here is the chunk we want to situate within the whole document
120 |
121 | {{CHUNK_CONTENT}}
122 |
123 | Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else.
124 | """.strip()
125 |
126 | self.model = model
127 | self.tokenizer = tokenizer
128 |
129 | self.chunker = Chunker(chunking_strategy = chunking_strategy)
130 |
131 |
132 | def _add_context(self, chunk: str, document: str):
133 | prompt = self.prompt.replace("{{WHOLE_DOCUMENT}}", document).replace("{{CHUNK_CONTENT}}", chunk)
134 | extra_context = self.llm(prompt)
135 | return extra_context + " " + chunk
136 |
137 | def _tokens_to_text(self, text: str, annotations: List[Tuple[int, int]]):
138 | tokens = self.tokenizer.encode_plus(
139 | text, return_offsets_mapping=True, add_special_tokens=False
140 | )
141 | token_offsets = tokens.offset_mapping
142 | chunks = []
143 | for start, end in annotations:
144 | chunk = text[token_offsets[start][0]:token_offsets[end-1][1]]
145 | chunks.append(chunk)
146 | return chunks
147 |
148 | def run(self, document: str):
149 | annotations = [self.chunker.chunk(text=document, tokenizer=self.tokenizer, n_sentences=1)]
150 | self.chunks = self._tokens_to_text(text=document, annotations=annotations[0])
151 | self.chunks = [self._add_context(chunk, document) for chunk in self.chunks]
152 |
153 | model_outputs = self.model.encode(self.chunks)
154 | self.output_embs = [model_outputs[i, :] for i in range(len(self.chunks))]
155 | return self.output_embs
156 |
157 | def query(self, query: str):
158 | if "output_embs" not in dir(self):
159 | raise ValueError("no embeddings calculated, use .run(document) to create chunk embeddings")
160 | query_embedding = self.model.encode(query)
161 | similarities = []
162 | for emb in self.output_embs:
163 | similarities.append(cosine_similarity(query_embedding, emb))
164 |
165 | return similarities
166 |
167 |
168 |
169 | if __name__ == "__main__":
170 |
171 | text = """
172 | The recent SEC filing provided insights into ACME Corp's performance for Q2 2023.
173 | It highlighted a 3% revenue growth over the previous quarter.
174 | The company, which had a revenue of $314 million in the prior quarter, showed steady progress.
175 | They attributed this growth to strategic initiatives and operational efficiencies.
176 | The report emphasized the company's resilience and ability to navigate market challenges, reflecting positively on their financial health and future prospects.
177 | """.strip().replace("\n", "")
178 |
179 | llm_model_name = "microsoft/Phi-3.5-mini-instruct"
180 | embedding_model_name = "jinaai/jina-embeddings-v2-small-en"
181 |
182 | embedding_model, has_instructions = load_model(embedding_model_name)
183 | embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True)
184 |
185 | cr = ContextualRetrievalEmbedder(embedding_model, embedding_tokenizer, llm_model_name, chunking_strategy="sentences")
186 | cr.run(text);
187 | cr_cosine_similarities = cr.query("What is ACME Corp's revenue growth for Q2 2023?")
188 |
189 | lc = LateChunkingEmbedder(embedding_model, embedding_tokenizer)
190 | lc.run(text)
191 | lc_cosine_similarities = lc.query("What is ACME Corp's revenue growth for Q2 2023?")
192 |
193 | # import pandas as pd
194 | for i, (cr_similarity, lc_similarity) in enumerate(zip(cr_cosine_similarities, lc_cosine_similarities)):
195 | print(f"{text.split('.')[:-1][i].strip()}")
196 | print(f"Similarities: Contextual Retrieval: {cr_similarity:.4f} | Late Chunking: {lc_similarity:.4f}")
197 | print("")
--------------------------------------------------------------------------------
/img/context-problem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jina-ai/late-chunking/1d3bb02bf091becd0771455e4e7959463935e26c/img/context-problem.png
--------------------------------------------------------------------------------
/img/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jina-ai/late-chunking/1d3bb02bf091becd0771455e4e7959463935e26c/img/method.png
--------------------------------------------------------------------------------
/img/rag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jina-ai/late-chunking/1d3bb02bf091becd0771455e4e7959463935e26c/img/rag.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "late_chunking"
3 | requires-python = "~=3.8"
4 | dependencies = [
5 | "jupyterlab==4.2.5",
6 | "transformers==4.43.4",
7 | "torch==2.4.0",
8 | "mteb==1.14.20",
9 | "datasets==2.19.1",
10 | "llama-index-embeddings-huggingface==0.3.1",
11 | "llama-index==0.11.10",
12 | "click==8.1.7",
13 | "einops==0.6.1",
14 | ]
15 | version = "0.0.0"
16 |
17 | [project.optional-dependencies]
18 | dev = [
19 | "pytest~=7.3.2",
20 | "black==23.3.0",
21 | "isort==5.12.0",
22 | "ruff==0.0.265",
23 | ]
24 |
25 | [tool.setuptools.packages.find]
26 | include = ["chunked_pooling"]
27 |
--------------------------------------------------------------------------------
/run_chunked_eval.py:
--------------------------------------------------------------------------------
1 | import click
2 | import torch.cuda
3 | from mteb import MTEB
4 | from transformers import AutoModel, AutoTokenizer
5 |
6 | from chunked_pooling.chunked_eval_tasks import *
7 | from chunked_pooling.wrappers import load_model
8 |
9 | DEFAULT_CHUNKING_STRATEGY = 'fixed'
10 | DEFAULT_CHUNK_SIZE = 256
11 | DEFAULT_N_SENTENCES = 5
12 | BATCH_SIZE = 1
13 | DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE = 256
14 | DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE = 0 # set to 0 to disable long late chunking
15 | DEFAULT_TRUNCATE_MAX_LENGTH = None
16 |
17 |
18 | @click.command()
19 | @click.option(
20 | '--model-name',
21 | default='jinaai/jina-embeddings-v2-small-en',
22 | help='The name of the model to use.',
23 | )
24 | @click.option(
25 | '--model-weights',
26 | default=None,
27 | help='The path to the model weights to use, e.g. in case of finetuning.',
28 | )
29 | @click.option(
30 | '--strategy',
31 | default=DEFAULT_CHUNKING_STRATEGY,
32 | help='The chunking strategy to be applied.',
33 | )
34 | @click.option(
35 | '--task-name', default='SciFactChunked', help='The evaluation task to perform.'
36 | )
37 | @click.option(
38 | '--eval-split', default='test', help='The name of the evaluation split in the task.'
39 | )
40 | @click.option(
41 | '--chunking-model',
42 | default=None,
43 | required=False,
44 | help='The name of the model used for semantic chunking.',
45 | )
46 | @click.option(
47 | '--truncate-max-length',
48 | default=DEFAULT_TRUNCATE_MAX_LENGTH,
49 | type=int,
50 | help='Maximum number of tokens; by default, truncation to 8192 tokens. If None, Long Late Chunking algorithm should be enabled.',
51 | )
52 | @click.option(
53 | '--chunk-size',
54 | default=DEFAULT_CHUNK_SIZE,
55 | type=int,
56 | help='Number of tokens per chunk for fixed strategy.',
57 | )
58 | @click.option(
59 | '--n-sentences',
60 | default=DEFAULT_N_SENTENCES,
61 | type=int,
62 | help='Number of sentences per chunk for sentence strategy.',
63 | )
64 | @click.option(
65 | '--long-late-chunking-embed-size',
66 | default=DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE,
67 | type=int,
68 | help='Number of tokens per macro chunk used for long late chunking.',
69 | )
70 | @click.option(
71 | '--long-late-chunking-overlap-size',
72 | default=DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE,
73 | type=int,
74 | help='Token length of the embeddings that come before/after soft boundaries (i.e. overlapping embeddings). Above zero, overlap is used between neighbouring embeddings.',
75 | )
76 | def main(
77 | model_name,
78 | model_weights,
79 | strategy,
80 | task_name,
81 | eval_split,
82 | chunking_model,
83 | truncate_max_length,
84 | chunk_size,
85 | n_sentences,
86 | long_late_chunking_embed_size,
87 | long_late_chunking_overlap_size,
88 | ):
89 | try:
90 | task_cls = globals()[task_name]
91 | except:
92 | raise ValueError(f'Unknown task name: {task_name}')
93 |
94 | if truncate_max_length is not None and (long_late_chunking_embed_size > 0):
95 | truncate_max_length = None
96 | print(
97 | f'Truncation is disabled because Long Late Chunking algorithm is enabled.'
98 | )
99 |
100 | model, has_instructions = load_model(model_name, model_weights)
101 |
102 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
103 |
104 | chunking_args = {
105 | 'chunk_size': chunk_size,
106 | 'n_sentences': n_sentences,
107 | 'chunking_strategy': strategy,
108 | 'model_has_instructions': has_instructions,
109 | 'embedding_model_name': chunking_model if chunking_model else model_name,
110 | }
111 |
112 | if torch.cuda.is_available():
113 | model = model.cuda()
114 |
115 | model.eval()
116 |
117 | # Evaluate with late chunking
118 | tasks = [
119 | task_cls(
120 | chunked_pooling_enabled=True,
121 | tokenizer=tokenizer,
122 | prune_size=None,
123 | truncate_max_length=truncate_max_length,
124 | long_late_chunking_embed_size=long_late_chunking_embed_size,
125 | long_late_chunking_overlap_size=long_late_chunking_overlap_size,
126 | **chunking_args,
127 | )
128 | ]
129 |
130 | evaluation = MTEB(
131 | tasks=tasks,
132 | chunked_pooling_enabled=True,
133 | tokenizer=tokenizer,
134 | prune_size=None,
135 | **chunking_args,
136 | )
137 | evaluation.run(
138 | model,
139 | output_folder='results-chunked-pooling',
140 | eval_splits=[eval_split],
141 | overwrite_results=True,
142 | batch_size=BATCH_SIZE,
143 | encode_kwargs={'batch_size': BATCH_SIZE},
144 | )
145 |
146 | # Encode without late chunking
147 | tasks = [
148 | task_cls(
149 | chunked_pooling_enabled=False,
150 | tokenizer=tokenizer,
151 | prune_size=None,
152 | truncate_max_length=truncate_max_length,
153 | **chunking_args,
154 | )
155 | ]
156 |
157 | evaluation = MTEB(
158 | tasks=tasks,
159 | chunked_pooling_enabled=False,
160 | tokenizer=tokenizer,
161 | prune_size=None,
162 | **chunking_args,
163 | )
164 | evaluation.run(
165 | model,
166 | output_folder='results-normal-pooling',
167 | eval_splits=[eval_split],
168 | overwrite_results=True,
169 | batch_size=BATCH_SIZE,
170 | encode_kwargs={'batch_size': BATCH_SIZE},
171 | )
172 |
173 |
174 | if __name__ == '__main__':
175 | main()
176 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jina-ai/late-chunking/1d3bb02bf091becd0771455e4e7959463935e26c/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from mteb.abstasks.TaskMetadata import TaskMetadata
3 |
4 | from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval
5 |
6 |
7 | class DummyTask(AbsTaskChunkedRetrieval):
8 | metadata = TaskMetadata(
9 | dataset={
10 | 'path': '~',
11 | 'revision': '',
12 | },
13 | name='dummy',
14 | description='',
15 | type='Retrieval',
16 | category='s2p',
17 | reference=None,
18 | eval_splits=[],
19 | eval_langs=[],
20 | main_score='ndcg_at_10',
21 | date=None,
22 | form=None,
23 | domains=None,
24 | task_subtypes=None,
25 | license=None,
26 | socioeconomic_status=None,
27 | annotations_creators=None,
28 | dialect=None,
29 | text_creation=None,
30 | bibtex_citation=None,
31 | n_samples=None,
32 | avg_character_length=None,
33 | )
34 |
35 | def load_data():
36 | pass
37 |
38 | def __init__(self, **kwargs):
39 | super().__init__(**kwargs)
40 |
41 |
42 | @pytest.fixture()
43 | def dummy_task_factory():
44 | def _create_dummy_task(*args, **kwargs):
45 | return DummyTask(*args, **kwargs)
46 |
47 | return _create_dummy_task
48 |
--------------------------------------------------------------------------------
/tests/test_api.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from transformers import AutoModel, AutoTokenizer
4 |
5 | from chunked_pooling import chunked_pooling
6 | from chunked_pooling.wrappers import load_model
7 | from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval
8 |
9 | MODEL_NAME = 'jinaai/jina-embeddings-v3'
10 |
11 | # Define Text and Chunk
12 | CHUNKS = ["Organic skincare", "for sensitive skin", "with aloe vera and chamomile"]
13 | FULL_TEXT = ' '.join(CHUNKS)
14 |
15 |
16 | def load_api_results():
17 | import requests
18 |
19 | url = 'https://api.jina.ai/v1/embeddings'
20 | headers = {
21 | 'Content-Type': 'application/json',
22 | 'Authorization': f'Bearer {os.environ["JINA_API_TOKEN"]}',
23 | }
24 | data = {
25 | "model": "jina-embeddings-v3",
26 | "task": "retrieval.passage",
27 | "dimensions": 1024,
28 | "late_chunking": True,
29 | "embedding_type": "float",
30 | "input": CHUNKS,
31 | }
32 | response = requests.post(url, headers=headers, json=data)
33 | data = response.json()
34 | return [np.array(x['embedding']) for x in data['data']]
35 |
36 |
37 | def calculate_annotations(model, boundary_cues, model_has_instructions, tokenizer):
38 | if model_has_instructions:
39 | instr = model.get_instructions()[1]
40 | instr_tokens = tokenizer(instr, add_special_tokens=False)
41 | n_instruction_tokens = len(instr_tokens[0])
42 | else:
43 | n_instruction_tokens = 0
44 | chunk_annotations = [
45 | AbsTaskChunkedRetrieval._extend_special_tokens(
46 | annotations,
47 | n_instruction_tokens=n_instruction_tokens,
48 | include_prefix=True,
49 | include_sep=True,
50 | )
51 | for annotations in boundary_cues
52 | ]
53 | return chunk_annotations
54 |
55 |
56 | def test_compare_v3_api_embeddings():
57 | # Load Model
58 | model, has_instr = load_model(MODEL_NAME, use_flash_attn=False)
59 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
60 |
61 | # Determine Boundary Cues
62 | tokenization = tokenizer(
63 | FULL_TEXT, return_offsets_mapping=True, add_special_tokens=False
64 | )
65 | boundary_cues = []
66 | chunk_i = 0
67 | last_cue = 0
68 | last_end = 0
69 | for i, (start, end) in enumerate(tokenization.offset_mapping):
70 | if end >= (last_end + len(CHUNKS[chunk_i])):
71 | boundary_cues.append((last_cue, i + 1))
72 | chunk_i += 1
73 | last_cue = i + 1
74 | last_end = end
75 | extended_boundary_cues = calculate_annotations(
76 | model, [boundary_cues], has_instr, tokenizer
77 | )
78 |
79 | # Append Instruction for Retrieval Task
80 | instr = model.get_instructions()[1]
81 | text_inputs = [instr + FULL_TEXT]
82 | model_inputs = tokenizer(
83 | text_inputs,
84 | return_tensors='pt',
85 | padding=True,
86 | truncation=True,
87 | max_length=8192,
88 | )
89 | model_outputs = model(**model_inputs)
90 |
91 | # Apply Late Chunking
92 | output_embs = chunked_pooling(
93 | model_outputs, extended_boundary_cues, max_length=8192
94 | )[0]
95 | api_embs = load_api_results()
96 | for local_emb, api_emb in zip(output_embs, api_embs):
97 | local_emb_norm = local_emb / np.linalg.norm(local_emb)
98 | api_emb_norm = api_emb / np.linalg.norm(api_emb)
99 | assert np.allclose(local_emb_norm, api_emb_norm, rtol=1e-02, atol=1e-02)
100 | assert 1.0 - np.dot(local_emb_norm, api_emb_norm) < 1e-3
101 |
--------------------------------------------------------------------------------
/tests/test_chunking_methods.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from transformers import AutoTokenizer
3 |
4 | from chunked_pooling.chunking import CHUNKING_STRATEGIES, Chunker
5 | from chunked_pooling.mteb_chunked_eval import AbsTaskChunkedRetrieval
6 |
7 | EXAMPLE_TEXT_1 = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."
8 | PUNCTATIONS = ('.', '!', '?')
9 |
10 |
11 | @pytest.mark.parametrize("n_sentences", [1, 2, 3, 4])
12 | def test_chunk_by_sentences(n_sentences):
13 | strategy = 'sentences'
14 | model_name = 'jinaai/jina-embeddings-v2-small-en'
15 | chunker = Chunker(chunking_strategy=strategy)
16 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17 | boundary_cues = chunker.chunk(
18 | text=EXAMPLE_TEXT_1,
19 | tokenizer=tokenizer,
20 | chunking_strategy=strategy,
21 | n_sentences=n_sentences,
22 | )
23 | extended_boundary_cues = AbsTaskChunkedRetrieval._extend_special_tokens(
24 | boundary_cues
25 | )
26 | model_inputs = tokenizer(
27 | EXAMPLE_TEXT_1,
28 | return_tensors='pt',
29 | padding=True,
30 | truncation=True,
31 | max_length=8192,
32 | )
33 |
34 | # check that the cues start with 0 and end with the last token
35 | assert extended_boundary_cues[0][0] == 0
36 | assert len(model_inputs.tokens()) == extended_boundary_cues[-1][1]
37 |
38 | # check that all chunks but the last one end with a punctuation
39 | assert all(
40 | model_inputs.tokens()[x:y][-1] in PUNCTATIONS
41 | for (x, y) in extended_boundary_cues[:-1]
42 | )
43 |
44 | # check that the last chunk ends with a "[SEP]" token
45 | last_cue = extended_boundary_cues[-1]
46 | assert model_inputs.tokens()[last_cue[0] : last_cue[1]][-1] == "[SEP]"
47 |
48 | # check that the boundary cues are continuous (no token is missing)
49 | assert all(
50 | [
51 | extended_boundary_cues[i][1] == extended_boundary_cues[i + 1][0]
52 | for i in range(len(extended_boundary_cues) - 1)
53 | ]
54 | )
55 |
56 |
57 | @pytest.mark.parametrize(
58 | "boundary_cues", [[(0, 17), (17, 44), (44, 69)], [(0, 44), (44, 69)]]
59 | )
60 | def test_token_equivalence(boundary_cues):
61 | model_name = 'jinaai/jina-embeddings-v2-small-en'
62 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
63 | tokens = tokenizer.encode_plus(
64 | EXAMPLE_TEXT_1, add_special_tokens=False, return_offsets_mapping=True
65 | )
66 | for start_token_idx, end_token_idx in boundary_cues:
67 | decoded_text_chunk = tokenizer.decode(
68 | tokens.input_ids[start_token_idx:end_token_idx]
69 | )
70 |
71 | original_text_chunk = EXAMPLE_TEXT_1[
72 | tokens.offset_mapping[start_token_idx][0] : tokens.offset_mapping[
73 | end_token_idx - 1
74 | ][1]
75 | ]
76 | chunk_tokens_original = tokenizer.encode_plus(original_text_chunk)
77 | chunk_tokens_decoded = tokenizer.encode_plus(decoded_text_chunk)
78 | assert chunk_tokens_original == chunk_tokens_decoded
79 |
80 |
81 | def test_chunker_initialization():
82 | for strategy in CHUNKING_STRATEGIES:
83 | chunker = Chunker(chunking_strategy=strategy)
84 | assert chunker.chunking_strategy == strategy
85 |
86 |
87 | def test_invalid_chunking_strategy():
88 | with pytest.raises(ValueError):
89 | Chunker(chunking_strategy="invalid")
90 |
91 |
92 | def test_chunk_by_tokens():
93 | chunker = Chunker(chunking_strategy="fixed")
94 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
95 | chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=10)
96 | assert len(chunks) > 1
97 | for start, end in chunks:
98 | assert end - start <= 10
99 |
100 |
101 | @pytest.mark.parametrize(
102 | 'model_name',
103 | ['jinaai/jina-embeddings-v2-small-en', 'sentence-transformers/all-MiniLM-L6-v2'],
104 | )
105 | def test_chunk_semantically(model_name):
106 | chunker = Chunker(chunking_strategy="semantic")
107 | tokenizer = AutoTokenizer.from_pretrained(model_name)
108 | tokens = tokenizer.encode_plus(
109 | EXAMPLE_TEXT_1, add_special_tokens=False, return_offsets_mapping=True
110 | )
111 | boundary_cues = chunker.chunk(
112 | EXAMPLE_TEXT_1,
113 | tokenizer=tokenizer,
114 | chunking_strategy='semantic',
115 | embedding_model_name=model_name,
116 | )
117 |
118 | # check if it returns boundary cues
119 | assert len(boundary_cues) > 0
120 |
121 | # test if bounaries are at the end of sentences
122 | for start_token_idx, end_token_idx in boundary_cues:
123 | assert (
124 | EXAMPLE_TEXT_1[tokens.offset_mapping[end_token_idx - 1][0]] in PUNCTATIONS
125 | )
126 | decoded_text_chunk = tokenizer.decode(
127 | tokens.input_ids[start_token_idx:end_token_idx]
128 | )
129 |
130 | # check that the boundary cues are continuous (no token is missing)
131 | assert all(
132 | [
133 | boundary_cues[i][1] == boundary_cues[i + 1][0]
134 | for i in range(len(boundary_cues) - 1)
135 | ]
136 | )
137 |
138 |
139 | def test_empty_input():
140 | chunker = Chunker(chunking_strategy="fixed")
141 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
142 | chunks = chunker.chunk("", tokenizer=tokenizer, chunk_size=10)
143 | assert len(chunks) == 0
144 |
145 |
146 | def test_input_shorter_than_chunk_size():
147 | short_text = "Short text."
148 | chunker = Chunker(chunking_strategy="fixed")
149 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
150 | chunks = chunker.chunk(short_text, tokenizer=tokenizer, chunk_size=20)
151 | assert len(chunks) == 1
152 |
153 |
154 | @pytest.mark.parametrize("chunk_size", [10, 20, 50])
155 | def test_various_chunk_sizes(chunk_size):
156 | chunker = Chunker(chunking_strategy="fixed")
157 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
158 | chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=chunk_size)
159 | assert len(chunks) > 0
160 | for start, end in chunks:
161 | assert end - start <= chunk_size
162 |
163 |
164 | def test_chunk_method_with_different_strategies():
165 | chunker = Chunker(chunking_strategy="fixed")
166 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
167 | fixed_chunks = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, chunk_size=10)
168 | semantic_chunks = chunker.chunk(
169 | EXAMPLE_TEXT_1,
170 | tokenizer=tokenizer,
171 | chunking_strategy="semantic",
172 | embedding_model_name='jinaai/jina-embeddings-v2-small-en',
173 | )
174 | assert fixed_chunks != semantic_chunks
175 |
176 |
177 | def test_chunk_by_sentences_different_n():
178 | chunker = Chunker(chunking_strategy="sentences")
179 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
180 | chunks_1 = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, n_sentences=1)
181 | chunks_2 = chunker.chunk(EXAMPLE_TEXT_1, tokenizer=tokenizer, n_sentences=2)
182 | assert len(chunks_1) > len(chunks_2)
183 |
--------------------------------------------------------------------------------
/tests/test_v3.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 |
3 | from run_chunked_eval import DEFAULT_CHUNK_SIZE, load_model
4 |
5 | MODEL_NAME = 'jinaai/jina-embeddings-v3'
6 |
7 |
8 | def test_instruction_handling(dummy_task_factory):
9 | model, has_instructions = load_model(MODEL_NAME)
10 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
11 | task = dummy_task_factory(
12 | chunking_strategy='fixed',
13 | chunk_size=DEFAULT_CHUNK_SIZE,
14 | tokenizer=tokenizer,
15 | model_has_instructions=has_instructions,
16 | )
17 | n_instruction_tokens = len(
18 | tokenizer(model.get_instructions()[1], add_special_tokens=False)['input_ids']
19 | )
20 | annotations_one_token = task._calculate_annotations(model, ['A'])[0]
21 | assert len(annotations_one_token) == 1
22 | assert annotations_one_token[0] == (0, n_instruction_tokens + 3)
23 |
--------------------------------------------------------------------------------