├── .gitignore ├── README.md ├── contextual_embeddings ├── __init__.py ├── collators │ ├── __init__.py │ └── contextual_collator.py ├── models │ ├── __init__.py │ ├── long_context_model.py │ └── utils.py └── training │ ├── __init__.py │ ├── contextual_trainer.py │ └── contextual_training.py ├── pyproject.toml ├── scripts ├── configs │ └── examples │ │ ├── modernbert.yaml │ │ └── moderncolbert.yaml └── training │ └── training.py └── tests ├── __init__.py └── models ├── __init__.py ├── test_contextual_inference_model.py ├── test_contextual_model.py └── test_li_inference_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConTEB: Context is Gold to find the Gold Passage: Evaluating and Training Contextual Document Embeddings 2 | 3 | This repository contains all training and inference code released with our preprint [*Context is Gold to find the Gold Passage: Evaluating and Training Contextual Document Embeddings*](https://arxiv.org/abs/2505.24782). 4 | 5 | 6 | [![arXiv](https://img.shields.io/badge/arXiv-2505.24782-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/2505.24782) 7 | [![GitHub](https://img.shields.io/badge/Code_Repository-100000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/illuin-tech/contextual-embeddings) 8 | [![Hugging Face](https://img.shields.io/badge/ConTEB_HF_Page-FFD21E?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/illuin-conteb) 9 | 10 | 11 | 12 | 13 | 14 | ## Installation 15 | 16 | ```bash 17 | pip install -e . 18 | ``` 19 | 20 | ## Training 21 | 22 | Example configurations can be found in the `configs` directory. To run a training job, use the following command: 23 | 24 | ```bash 25 | accelerate launch scripts/training/training.py scripts/configs/examples/modernbert.yaml 26 | ``` 27 | 28 | ## Inference 29 | 30 | To run inference with a contextual model, you can use the following examples: 31 | ```python 32 | from contextual_embeddings import LongContextEmbeddingModel 33 | from sentence_transformers import SentenceTransformer 34 | from pylate.models import ColBERT 35 | 36 | documents = [ 37 | [ 38 | "The old lighthouse keeper trimmed his lamp, its beam cutting a lonely path through the fog.", 39 | "He remembered nights of violent storms, when the ocean seemed to swallow the sky whole.", 40 | "Still, he found comfort in his duty, a silent guardian against the treacherous sea." 41 | ], 42 | [ 43 | "A curious fox cub, all rust and wonder, ventured out from its den for the first time.", 44 | "Each rustle of leaves, every chirping bird, was a new symphony to its tiny ears.", 45 | "Under the watchful eye of its mother, it began to learn the secrets of the whispering forest." 46 | ] 47 | ] 48 | 49 | # ============================== AVERAGE POOLING EXAMPLE ============================== 50 | 51 | base_model = SentenceTransformer("illuin-conteb/modernbert-large-insent") 52 | contextual_model = LongContextEmbeddingModel( 53 | base_model=base_model, 54 | add_prefix=True 55 | ) 56 | embeddings = contextual_model.embed_documents(documents) 57 | print("Length of embeddings:", len(embeddings)) # 2 58 | print("Length of first document embedding:", len(embeddings[0])) # 3 59 | print(f"Shape of first chunk embedding: {embeddings[0][0].shape}") # torch.Size([768]) 60 | 61 | # ============================== LATE INTERACTION (COLBERT) EXAMPLE ============================== 62 | 63 | base_model = ColBERT("illuin-conteb/modern-colbert-insent") 64 | contextual_model = LongContextEmbeddingModel( 65 | base_model=base_model, 66 | pooling_mode="tokens" 67 | ) 68 | embeddings = contextual_model.embed_documents(documents) 69 | print("Length of embeddings:", len(embeddings)) # 2 70 | print("Length of first document embedding:", len(embeddings[0])) # 3 71 | print(f"Shape of first chunk embedding: {embeddings[0][0].shape}") # torch.Size([22, 128]) 72 | ``` 73 | 74 | ## Evaluation 75 | 76 | Code for evaluation can be found in the [ConTEB](https://github.com/illuin-tech/conteb) repository. 77 | 78 | 79 | ### Abstract 80 | 81 | A limitation of modern document retrieval embedding methods is that they typically encode passages (chunks) from the same documents independently, often overlooking crucial contextual information from the rest of the document that could greatly improve individual chunk representations. 82 | 83 | In this work, we introduce *ConTEB* (Context-aware Text Embedding Benchmark), a benchmark designed to evaluate retrieval models on their ability to leverage document-wide context. Our results show that state-of-the-art embedding models struggle in retrieval scenarios where context is required. To address this limitation, we propose *InSeNT* (In-sequence Negative Training), a novel contrastive post-training approach which combined with \textit{late chunking} pooling enhances contextual representation learning while preserving computational efficiency. Our method significantly improves retrieval quality on *ConTEB* without sacrificing base model performance. 84 | We further find chunks embedded with our method are more robust to suboptimal chunking strategies and larger retrieval corpus sizes. 85 | We open-source all artifacts here and at https://github.com/illuin-tech/contextual-embeddings. 86 | 87 | ## Ressources 88 | 89 | - [*HuggingFace Project Page*](https://huggingface.co/illuin-conteb): The HF page centralizing everything! 90 | - [*(Model) ModernBERT*](https://huggingface.co/illuin-conteb/modernbert-large-insent): The Contextualized ModernBERT bi-encoder trained with InSENT loss and Late Chunking 91 | - [*(Model) ModernColBERT*](https://huggingface.co/illuin-conteb/modern-colbert-insent): The Contextualized ModernColBERT trained with InSENT loss and Late Chunking 92 | - [*Leaderboard*](TODO): Coming Soon 93 | - [*(Data) ConTEB Benchmark Datasets*]([TODO](https://huggingface.co/collections/illuin-conteb/conteb-evaluation-datasets-6839fffd25f1d3685f3ad604)): Datasets included in ConTEB. 94 | - [*(Code) Contextual Document Engine*](https://github.com/illuin-tech/contextual-embeddings): The code used to train and run inference with our architecture. 95 | - [*(Code) ConTEB Benchmarkk*](https://github.com/illuin-tech/conteb): A Python package/CLI tool to evaluate document retrieval systems on the ConTEB benchmark. 96 | - [*Preprint*](https://arxiv.org/abs/2505.24782): The paper with all details! 97 | - [*Blog*](https://huggingface.co/blog/manu/conteb): A blogpost that covers the paper in a 5 minute read. 98 | 99 | ## Contact of the first authors 100 | 101 | - Manuel Faysse: manuel.faysse@illuin.tech 102 | - Max Conti: max.conti@illuin.tech 103 | 104 | ## Citation 105 | 106 | If you use any datasets or models from this organization in your research, please cite the original dataset as follows: 107 | 108 | ```latex 109 | @misc{conti2025contextgoldgoldpassage, 110 | title={Context is Gold to find the Gold Passage: Evaluating and Training Contextual Document Embeddings}, 111 | author={Max Conti and Manuel Faysse and Gautier Viaud and Antoine Bosselut and Céline Hudelot and Pierre Colombo}, 112 | year={2025}, 113 | eprint={2505.24782}, 114 | archivePrefix={arXiv}, 115 | primaryClass={cs.IR}, 116 | url={https://arxiv.org/abs/2505.24782}, 117 | } 118 | ``` 119 | 120 | ## Acknowledgments 121 | 122 | This work is partially supported by [ILLUIN Technology](https://www.illuin.tech/), and by a grant from ANRT France. 123 | This work was performed using HPC resources from the GENCI Jeanzay supercomputer with grant AD011016393. 124 | 125 | -------------------------------------------------------------------------------- /contextual_embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from .collators import ContextualDataCollator 2 | from .models import LongContextEmbeddingModel 3 | from .training import ContextualTrainer, ContextualTraining, ContextualTrainingConfig 4 | -------------------------------------------------------------------------------- /contextual_embeddings/collators/__init__.py: -------------------------------------------------------------------------------- 1 | from .contextual_collator import ContextualDataCollator 2 | -------------------------------------------------------------------------------- /contextual_embeddings/collators/contextual_collator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Dict, List, Optional 3 | 4 | from sentence_transformers.data_collator import SentenceTransformerDataCollator 5 | from torch import Tensor 6 | 7 | 8 | @dataclass 9 | class ContextualDataCollator(SentenceTransformerDataCollator): 10 | is_multi_ctx_training: bool = field(default=True, kw_only=True) 11 | add_prefixes: bool = field(default=False, kw_only=True) 12 | sep_token: Optional[str] = field(default=None, kw_only=True) 13 | colbert_tokenize: Optional[bool] = field(default=False, kw_only=True) 14 | 15 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Tensor]: 16 | # We should always be able to return a loss, label or not: 17 | batch = {} 18 | 19 | # tokenize the queries 20 | 21 | # merge the lists of queries into one list for the tokenizer 22 | all_queries = [query for row in features for query in row["queries"]] 23 | 24 | # TODO: put prefixes as parameters 25 | if self.add_prefixes: 26 | all_queries = [f"search_query: {query}" for query in all_queries] 27 | 28 | tokenized = self.tokenize_fn(all_queries) 29 | for key, value in tokenized.items(): 30 | batch[f"queries_{key}"] = value 31 | 32 | # tokenize the documents 33 | if self.is_multi_ctx_training: 34 | concatenated_docs = [self.sep_token.join(row["docs_list"]) for row in features] 35 | if self.add_prefixes: 36 | concatenated_docs = [f"search_document: {doc}" for doc in concatenated_docs] 37 | 38 | if self.colbert_tokenize: 39 | tokenized = self.tokenize_fn(concatenated_docs, is_query=False) 40 | else: 41 | tokenized = self.tokenize_fn(concatenated_docs) 42 | else: 43 | single_docs = [row["docs_list"] for row in features] 44 | batch["n_docs_per_sample"] = [len(docs) for docs in single_docs] 45 | flattened_docs = [doc for sublist in single_docs for doc in sublist] 46 | 47 | if self.add_prefixes: 48 | flattened_docs = [f"search_document: {doc}" for doc in flattened_docs] 49 | 50 | # TODO: add assertion to check that max idx in queries_chunk_indices does not exceed n_docs 51 | assert len(single_docs) >= 1, "Single context collation requires at least one document" 52 | 53 | if self.colbert_tokenize: 54 | tokenized = self.tokenize_fn(flattened_docs, is_query=False) 55 | else: 56 | tokenized = self.tokenize_fn(flattened_docs) 57 | 58 | for key, value in tokenized.items(): 59 | batch[f"docs_{key}"] = value 60 | 61 | # merge the queries_chunk_ids into one tensor 62 | batch["queries_chunk_indices"] = [row["queries_chunk_ids"] for row in features] 63 | 64 | batch["add_prefixes"] = self.add_prefixes 65 | 66 | return batch 67 | -------------------------------------------------------------------------------- /contextual_embeddings/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .long_context_model import LongContextEmbeddingModel 2 | -------------------------------------------------------------------------------- /contextual_embeddings/models/long_context_model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from pylate.models import ColBERT 5 | from sentence_transformers import SentenceTransformer 6 | from torch import Tensor, nn 7 | from tqdm import tqdm 8 | 9 | 10 | class LongContextEmbeddingModel(SentenceTransformer): 11 | def __init__( 12 | self, 13 | base_model: SentenceTransformer, 14 | sim_score_scale: float = 20.0, 15 | normalize_embeddings: bool = True, 16 | multi_ctx_training: bool = True, 17 | lambda_seq: float = 0.5, 18 | doc_prefix_str: str = "search_document:", 19 | pooling_mode: str = "average", 20 | add_prefix: bool = False, 21 | show_progress_bar: bool = True, 22 | ): 23 | super().__init__() 24 | 25 | self.add_prefix = add_prefix 26 | self.base_model = base_model 27 | self.cross_entropy_loss = nn.CrossEntropyLoss() 28 | self.sim_score_scale = sim_score_scale 29 | self.normalize_embeddings = normalize_embeddings 30 | self.multi_ctx_training = multi_ctx_training 31 | self.lambda_seq = lambda_seq 32 | self.n_tokens_prefix = ( 33 | len(base_model.tokenizer(doc_prefix_str, add_special_tokens=True)["input_ids"]) - 1 34 | ) # - 1 to remove the [SEP] token (but count the [CLS] token) 35 | self.show_progress_bar = show_progress_bar 36 | 37 | # test fix for models that do not have a sep_token (decoders) 38 | if self.base_model.tokenizer.sep_token is None: 39 | self.base_model.tokenizer.sep_token = self.base_model.tokenizer.pad_token 40 | 41 | # if hasattr(self.base_model, "document_prefix_id"): 42 | # self.join_token_id = self.base_model.document_prefix_id 43 | # else: 44 | # self.join_token_id = self.base_model.tokenizer.sep_token_id 45 | self.join_token_id = self.base_model.tokenizer.sep_token_id 46 | 47 | self.pooling_mode = pooling_mode 48 | 49 | def tokenize( 50 | # self, texts: list[str] | list[dict] | list[tuple[str, str]] 51 | self, 52 | texts: list[str], 53 | **kwargs, 54 | ) -> dict[str, Tensor]: 55 | return self.base_model.tokenize(texts, **kwargs) 56 | 57 | def encode(self, args, **kwargs): 58 | return self.base_model.encode(args, **kwargs) 59 | 60 | def _get_doc_indices( 61 | self, batch_indices: torch.Tensor, sep_indices: torch.Tensor 62 | ) -> Tuple[torch.Tensor, torch.Tensor]: 63 | """Get the start and end indices of each document in a batch of sequences separated by sep_tokens. 64 | 65 | Args: 66 | sep_indices (torch.Tensor): indices of the sep_tokens in the input sequences of the batch\ 67 | (result of torch.where) 68 | batch_indices (torch.Tensor): batch indices corresponding to each sep_token (result of torch.where) 69 | 70 | Returns: 71 | starts (torch.Tensor): tensor of start indices of each document in the batch (batch_size, max_n_docs).\ 72 | Contains -1 for padding values. 73 | ends (torch.Tensor): tensor of end indices of each document in the batch (batch_size, max_n_docs).\ 74 | Contains -1 for padding values. 75 | """ 76 | # Get batch size and max number of docs per sample 77 | unique, counts = torch.unique(batch_indices, return_counts=True) 78 | batch_size = len(unique) 79 | 80 | # if isinstance(self.base_model, ColBERT): 81 | # max_docs = counts.max().item() - 1 # to uncount the last join token 82 | # else: 83 | # 84 | max_docs = counts.max().item() 85 | 86 | # Create output tensor 87 | ends = ( 88 | torch.zeros((batch_size, max_docs), dtype=torch.long, device=batch_indices.device) - 1 89 | ) # -1 is a padding value 90 | starts = ends.clone() 91 | starts[:, 0] = 0 92 | 93 | # Fill values using masked indexing (for each sample of the batch) 94 | for i, group_idx in enumerate(unique): 95 | mask = batch_indices == group_idx 96 | values = sep_indices[mask] 97 | # if isinstance(self.base_model, ColBERT): 98 | # ends[i, : len(values) - 1] = values[1:] 99 | # starts[i, 1 : len(values) - 1] = values[1: -1] # + 1 # shift by 1 to avoid taking the [SEP] token 100 | # else: 101 | # ends[i, : len(values)] = values 102 | # starts[i, 1 : len(values)] = values[:-1] + 1 # shift by 1 to avoid taking the [SEP] token 103 | ends[i, : len(values)] = values 104 | starts[i, 1 : len(values)] = values[:-1] + 1 # shift by 1 to avoid taking the [SEP] token 105 | 106 | return starts, ends 107 | 108 | def _late_chunking_pooling( 109 | self, 110 | input_ids: torch.Tensor, 111 | token_embeddings: torch.Tensor, 112 | ) -> Tuple[torch.Tensor, torch.Tensor]: 113 | """Computes the chunk (or document) embeddings for sequences of multiple documents separated by sep_tokens. 114 | Returns embeddings pooled depending on the pooling mode of the model (tokens or average pooling). 115 | 116 | Args: 117 | token_embeddings (Tensor): The token embeddings of the input sequences, of shape (batch_size, seq_len, dim). 118 | input_ids (Tensor): The input ids of the sequences, of shape (batch_size, seq_len). 119 | 120 | Returns: 121 | batch_embeddings (Tensor): The chunk embeddings of the input sequences, of shape \ 122 | (batch_size, max_n_chunks, dim), where max_n_chunks is the maximum number of chunks in the batch, 123 | in the case of average pooling. 124 | For token pooling, the shape is (batch_size, max_n_chunks, max_chunk_len, dim), 125 | where max_chunk_len is the maximum length of a chunk in the batch. 126 | padding_mask (Tensor): A mask indicating which chunks are padding, of shape (batch_size, max_n_chunks). 127 | """ 128 | # batch_indices, sep_indices = torch.where(input_ids == self.join_token_id) 129 | batch_indices, sep_indices = torch.where(input_ids == self.base_model.tokenizer.sep_token_id) 130 | starts, ends = self._get_doc_indices(batch_indices, sep_indices) # (batch_size, max_n_chunks) 131 | 132 | # Get useful values 133 | batch_size, max_seq_len, dim = token_embeddings.shape 134 | max_n_chunks = starts.shape[1] 135 | max_chunk_len = (ends - starts).max() 136 | 137 | # Create index tensor (to index all tokens of each chunk) 138 | indices = torch.arange(max_chunk_len, device=token_embeddings.device) 139 | # Repeat indices for each chunk 140 | indices = indices.expand(batch_size, max_n_chunks, -1) # (batch_size, max_n_chunks, max_chunk_len) 141 | # Shift indices by start positions 142 | indices = indices + starts.unsqueeze(-1) 143 | 144 | # Create mask for valid indices 145 | mask = indices < ends.unsqueeze(-1) # (batch_size, max_n_chunks, max_chunk_len) 146 | 147 | # Put toy index at invalid positions (for gather) 148 | indices[~mask] = 0 149 | 150 | all_chunk_lengths = ends - starts # (batch_size, max_n_chunks) 151 | all_chunk_lengths = all_chunk_lengths.masked_fill( 152 | all_chunk_lengths == 0, 1 153 | ) # avoid division by zero (for empty chunks) 154 | 155 | # Gather values using advanced indexing for each chunk 156 | batch_embeddings = [] 157 | for i in range(max_n_chunks): 158 | chunk_indices = indices[:, i, :].unsqueeze(-1) # (batch_size, max_chunk_len, 1) 159 | chunk_mask = mask[:, i, :].unsqueeze(-1) 160 | chunk_lengths = all_chunk_lengths[:, i].unsqueeze(-1) # (batch_size, 1) 161 | 162 | # Expand the chunk indices to the embedding dimension and gather the values 163 | gathered = token_embeddings.gather(1, chunk_indices.expand(-1, -1, dim)) # (batch_size, max_chunk_len, dim) 164 | 165 | # Mask invalid positions and compute mean 166 | gathered = gathered.masked_fill(~chunk_mask, 0) 167 | 168 | if self.pooling_mode == "average": 169 | chunk_embeddings = gathered.sum(dim=1) / chunk_lengths # (batch_size, dim) 170 | batch_embeddings.append(chunk_embeddings) 171 | elif self.pooling_mode == "tokens": 172 | # If not average pooling, we keep the token embeddings 173 | batch_embeddings.append(gathered) # (batch_size, max_chunk_len, dim) 174 | else: 175 | raise ValueError(f"Pooling mode {self.pooling_mode} not supported. Use 'average' or 'tokens'.") 176 | 177 | batch_embeddings = torch.stack(batch_embeddings, dim=1) 178 | 179 | padding_mask = starts == -1 180 | return ( 181 | batch_embeddings, 182 | padding_mask, 183 | ) # return the start indices for masking the loss 184 | 185 | def _compute_sim(self, embeddings_a: torch.Tensor, embeddings_b: torch.Tensor) -> torch.Tensor: 186 | if len(embeddings_a.shape) not in [2, 3] or len(embeddings_b.shape) not in [ 187 | 2, 188 | 3, 189 | ]: 190 | raise ValueError("Embeddings should have shape (n, dim) or (b, n, dim)") 191 | 192 | normalized_a = torch.nn.functional.normalize(embeddings_a, p=2, dim=-1) 193 | normalized_b = torch.nn.functional.normalize(embeddings_b, p=2, dim=-1) 194 | return torch.matmul(normalized_a, normalized_b.transpose(-2, -1)) 195 | 196 | def _compute_rowwise_sim(self, embeddings_a: torch.Tensor, embeddings_b: torch.Tensor) -> torch.Tensor: 197 | if len(embeddings_a.shape) != 2 or len(embeddings_b.shape) != 2: 198 | raise ValueError("Embeddings should have shape (n, dim) or (b, n, dim)") 199 | 200 | normalized_a = torch.nn.functional.normalize(embeddings_a, p=2, dim=-1) 201 | normalized_b = torch.nn.functional.normalize(embeddings_b, p=2, dim=-1) 202 | cosine_similarities = (normalized_a * normalized_b).sum(dim=-1) 203 | 204 | return cosine_similarities 205 | 206 | def _compute_loss_inbatch_inseq( 207 | self, 208 | query_embeddings: torch.Tensor, 209 | chunk_embeddings: torch.Tensor, 210 | doc_padding_mask: torch.Tensor, 211 | queries_chunk_indices: List[List[int]], 212 | n_docs_per_sample: Optional[List[int]] = None, 213 | ) -> torch.Tensor: 214 | """Computes a contrastive loss between queries and chunks using in-batch and in-sequence negatives. 215 | The queries are matched with the corresponding document among all negatives. 216 | 217 | Queries are of shape `(n_queries_in_batch, dim)`, and chunk_embeddings of \ 218 | shape `(batch_size, max_n_docs, dim)`, \ 219 | where `max_n_docs` is the maximum number of documents concatenated in one sample of the batch. 220 | This implies that the chunk_embeddings are padded for samples with less than `max_n_docs` documents. 221 | 222 | Since all negatives are considered, we directly compare all `n_queries_in_batch` \ 223 | queries with all `batch_size * max_n_docs` documents. 224 | We then mask out the similarity scores for padding values in the documents (setting them to -inf),\ 225 | so that they do not contribute to the loss. 226 | 227 | Args: 228 | query_embeddings (torch.Tensor): the computed embeddings for the queries, \ 229 | shape: (n_queries_in_batch, dim) 230 | chunk_embeddings (torch.Tensor): the computed embeddings for the documents, \ 231 | shape: (batch_size, max_n_docs, dim) 232 | doc_padding_mask (torch.Tensor): the mask of padding values in the documents, \ 233 | shape: (batch_size, max_n_docs) 234 | queries_chunk_indices (List[List[int]]): the position in the sequence of the chunk that query `i` refers to 235 | n_docs_per_sample (List[int], optional): the number of documents for each sample in the batch, \ 236 | only necesary for single context training 237 | 238 | Returns: 239 | loss (torch.Tensor): the cross-entropy loss between queries and documents 240 | """ 241 | max_n_docs = chunk_embeddings.shape[1] 242 | 243 | # reshape the chunk embeddings and doc mask 244 | chunk_embeddings = chunk_embeddings.view(-1, chunk_embeddings.shape[-1]) # (batch_size * n_docs, dim) 245 | doc_padding_mask = doc_padding_mask.view(-1) # (batch_size * n_docs) 246 | 247 | sim_scores = ( 248 | self._compute_sim(query_embeddings, chunk_embeddings) * self.sim_score_scale 249 | ) # (n_queries_in_batch, batch_size * n_docs) 250 | 251 | # mask out the similarity scores for padding values 252 | padding_indices = doc_padding_mask.expand( 253 | query_embeddings.shape[0], -1 254 | ) # (n_queries_in_batch, batch_size * n_docs) 255 | masked_scores = sim_scores.masked_fill(padding_indices, float("-inf")) 256 | 257 | if self.multi_ctx_training: 258 | labels = torch.cat( 259 | [ 260 | torch.tensor(indices, device=sim_scores.device) + batch_idx * max_n_docs 261 | for batch_idx, indices in enumerate(queries_chunk_indices) 262 | ] 263 | ) 264 | else: 265 | assert len(n_docs_per_sample) == len(queries_chunk_indices), ( 266 | "There should be the same number of original samples in the batch" 267 | ) 268 | 269 | offsets = torch.tensor([0] + n_docs_per_sample[:-1], device=sim_scores.device).cumsum(dim=0) 270 | labels = torch.cat( 271 | [ 272 | torch.tensor(indices, device=sim_scores.device) + offsets[batch_idx] 273 | for batch_idx, indices in enumerate(queries_chunk_indices) 274 | ] 275 | ) 276 | 277 | assert labels.shape[0] == query_embeddings.shape[0], ( 278 | "The number of labels should match the number of queries" 279 | ) 280 | assert torch.all(labels < sim_scores.shape[1]), ( 281 | "All labels should be valid indices (i.e. less than the number of documents)" 282 | ) 283 | assert torch.all(labels >= 0), "All labels should be valid indices (i.e. greater or equal to 0)" 284 | 285 | loss = self.cross_entropy_loss(masked_scores, labels) 286 | return loss 287 | 288 | def _compute_loss_from_embeddings(self, query_embeddings, doc_embeddings, labels): 289 | # compute similarity scores 290 | sim_scores = self._compute_sim(query_embeddings, doc_embeddings) * self.sim_score_scale 291 | # compute the cross-entropy loss 292 | return self.cross_entropy_loss(sim_scores, labels) 293 | 294 | def _compute_loss_batch_negatives( 295 | self, 296 | query_embeddings: torch.Tensor, 297 | batch_negatives: torch.Tensor, 298 | golden_doc_embeddings: torch.Tensor, 299 | ): 300 | assert query_embeddings.shape[0] == golden_doc_embeddings.shape[0], ( 301 | "The number of queries should match the number of golden documents" 302 | ) 303 | # compute similarity scores for the golden documents 304 | golden_sim_scores = self._compute_rowwise_sim(query_embeddings, golden_doc_embeddings) * self.sim_score_scale 305 | 306 | # compute similarity scores for the batch negatives 307 | batch_sim_scores = self._compute_sim(query_embeddings, batch_negatives) * self.sim_score_scale 308 | 309 | # concatenate the scores 310 | all_sim_scores = torch.cat([golden_sim_scores.unsqueeze(1), batch_sim_scores], dim=1) 311 | 312 | # true labels are the first scores for each query 313 | labels = torch.zeros(query_embeddings.shape[0], dtype=torch.long, device=query_embeddings.device) 314 | 315 | # compute the cross-entropy loss 316 | return self.cross_entropy_loss(all_sim_scores, labels) 317 | 318 | def _compute_loss_weighted_inbatch_inseq( 319 | self, 320 | query_embeddings: torch.Tensor, 321 | chunk_embeddings: torch.Tensor, 322 | doc_padding_mask: torch.Tensor, 323 | queries_chunk_indices: List[List[int]], 324 | # loss_type: str = "inbatch_inseq", 325 | ) -> torch.Tensor: 326 | batch_size = chunk_embeddings.shape[0] 327 | 328 | query_offset = 0 329 | in_seq_loss_arr = [] 330 | in_batch_loss_arr = [] 331 | 332 | # iterate over each sample in the batch 333 | for b_idx in range(batch_size): 334 | # compute the loss for the in-sequence negatives 335 | sequence_doc_embeddings = chunk_embeddings[b_idx] # shape (max_n_docs, dim) 336 | sequence_doc_mask = doc_padding_mask[b_idx] # shape (max_n_docs) 337 | sequence_doc_embeddings = sequence_doc_embeddings[~sequence_doc_mask] # shape (n_docs, dim) 338 | 339 | assert len(sequence_doc_embeddings) > max(queries_chunk_indices[b_idx]), ( 340 | "Query-chunk indices should be less than the number of documents" 341 | ) 342 | sample_query_embeddings = query_embeddings[ 343 | query_offset : query_offset + len(queries_chunk_indices[b_idx]) 344 | ] # shape (n_queries, dim) 345 | query_offset += len(queries_chunk_indices[b_idx]) 346 | 347 | in_sequence_loss = self._compute_loss_from_embeddings( 348 | sample_query_embeddings, 349 | sequence_doc_embeddings, 350 | labels=torch.tensor( 351 | queries_chunk_indices[b_idx], 352 | dtype=torch.long, 353 | device=query_embeddings.device, 354 | ), # labels for this sample 355 | ) 356 | in_seq_loss_arr.append(in_sequence_loss) 357 | 358 | # compute the loss for the in-batch negatives (without in-sequence negatives) 359 | batch_negatives = chunk_embeddings[ 360 | torch.arange(batch_size) != b_idx 361 | ] # shape (batch_size - 1, max_n_docs, dim) 362 | batch_doc_padding_mask = doc_padding_mask[ 363 | torch.arange(batch_size) != b_idx 364 | ] # shape (batch_size - 1, max_n_docs) 365 | batch_negatives = batch_negatives[~batch_doc_padding_mask] 366 | # TODO: check if line below is necessary 367 | batch_negatives = batch_negatives.view(-1, batch_negatives.shape[-1]) # shape (n_batch_negs, dim) 368 | golden_doc_embeddings = sequence_doc_embeddings[queries_chunk_indices[b_idx]] 369 | in_batch_loss = self._compute_loss_batch_negatives( 370 | sample_query_embeddings, batch_negatives, golden_doc_embeddings 371 | ) 372 | in_batch_loss_arr.append(in_batch_loss) 373 | 374 | in_seq_mean_loss = torch.stack(in_seq_loss_arr).mean() 375 | in_batch_mean_loss = torch.stack(in_batch_loss_arr).mean() 376 | 377 | loss = self.lambda_seq * in_seq_mean_loss + (1 - self.lambda_seq) * in_batch_mean_loss 378 | return loss 379 | 380 | def _compute_max_sim_scores( 381 | self, 382 | q_token_embeddings: torch.Tensor, 383 | d_token_embeddings: torch.Tensor, 384 | ): 385 | # normalize the embeddings 386 | q_token_embeddings = torch.nn.functional.normalize(q_token_embeddings, p=2, dim=-1) 387 | d_token_embeddings = torch.nn.functional.normalize(d_token_embeddings, p=2, dim=-1) 388 | 389 | # perform dot product between query and document embeddings 390 | sim_scores = torch.einsum( 391 | "qnd,bmd->qbnm", q_token_embeddings, d_token_embeddings 392 | ) # (n_queries, n_docs, max_q_len, max_doc_len) 393 | 394 | # max_sim: take the max over doc embeddings, then sum over query embeddings 395 | max_sim_scores = sim_scores.max(dim=3)[0].sum(dim=2) 396 | return max_sim_scores # (n_queries, n_docs) 397 | 398 | def _compute_max_sim_loss_in_batch( 399 | self, 400 | q_token_embeddings: torch.Tensor, 401 | bn_d_token_embeddings: torch.Tensor, 402 | golden_d_token_embeddings: torch.Tensor, 403 | ): 404 | assert q_token_embeddings.shape[0] == golden_d_token_embeddings.shape[0], ( 405 | "The number of queries should match the number of golden documents" 406 | ) 407 | golden_max_sim_scores = self._compute_max_sim_scores( 408 | q_token_embeddings, golden_d_token_embeddings 409 | ) # (n_queries, n_golden_docs) 410 | 411 | # take only the score of the golden doc for the corresponding query 412 | golden_max_sim_scores = golden_max_sim_scores.diag() # (n_queries,) 413 | 414 | # compute similarity scores for the batch negatives 415 | bn_max_sim_scores = self._compute_max_sim_scores( 416 | q_token_embeddings, bn_d_token_embeddings 417 | ) # (n_queries, n_batch_negs) 418 | 419 | # concatenate the scores 420 | all_sim_scores = torch.cat([golden_max_sim_scores.unsqueeze(1), bn_max_sim_scores], dim=1) 421 | 422 | # true labels are the first scores for each query 423 | labels = torch.zeros(q_token_embeddings.shape[0], dtype=torch.long, device=q_token_embeddings.device) 424 | 425 | # compute the cross-entropy loss 426 | loss = self.cross_entropy_loss(all_sim_scores, labels) 427 | 428 | return loss 429 | 430 | def _compute_late_interaction_loss( 431 | self, 432 | q_token_embeddings: torch.Tensor, 433 | d_token_embeddings: torch.Tensor, 434 | doc_padding_mask: torch.Tensor, 435 | queries_chunk_indices: List[List[int]], 436 | ): 437 | # query_token_embeddings have shape (n_queries_in_batch, max_query_len, dim) 438 | # doc_token_embeddings have shape (batch_size, max_n_docs, max_doc_len, dim) 439 | 440 | batch_size = d_token_embeddings.shape[0] 441 | query_offset = 0 442 | in_seq_loss_arr = [] 443 | in_batch_loss_arr = [] 444 | 445 | for b_idx in range(batch_size): 446 | # in-sequence loss 447 | seq_d_token_embeddings = d_token_embeddings[b_idx] # (max_n_docs, max_doc_len, dim) 448 | sequence_doc_mask = doc_padding_mask[b_idx] # shape (max_n_docs) 449 | seq_d_token_embeddings = seq_d_token_embeddings[~sequence_doc_mask] # shape (n_docs, max_doc_len, dim) 450 | 451 | sample_q_token_embeddings = q_token_embeddings[ 452 | query_offset : query_offset + len(queries_chunk_indices[b_idx]) 453 | ] # shape (n_queries, max_q_len, dim) 454 | query_offset += len(queries_chunk_indices[b_idx]) 455 | 456 | # TODO: make sure q_token_embeddings are padded with 0's 457 | max_sim_scores = self._compute_max_sim_scores( 458 | sample_q_token_embeddings, seq_d_token_embeddings 459 | ) # (n_queries, n_docs) 460 | 461 | labels = torch.tensor( 462 | queries_chunk_indices[b_idx], 463 | dtype=torch.long, 464 | device=q_token_embeddings.device, 465 | ) # labels for this sample 466 | in_seq_loss = self.cross_entropy_loss(max_sim_scores, labels) 467 | in_seq_loss_arr.append(in_seq_loss) 468 | 469 | # in-batch loss 470 | batch_negatives = d_token_embeddings[ 471 | torch.arange(batch_size) != b_idx 472 | ] # shape (batch_size - 1, max_n_docs, max_doc_len, dim) 473 | batch_doc_padding_mask = doc_padding_mask[ 474 | torch.arange(batch_size) != b_idx 475 | ] # shape (batch_size - 1, max_n_docs) 476 | batch_negatives = batch_negatives[~batch_doc_padding_mask] # (n_batch_negs, max_doc_len, dim) 477 | golden_doc_token_embeddings = seq_d_token_embeddings[ 478 | queries_chunk_indices[b_idx] 479 | ] # (n_queries, max_q_len, dim) 480 | 481 | # compute sim and loss 482 | in_batch_loss = self._compute_max_sim_loss_in_batch( 483 | sample_q_token_embeddings, batch_negatives, golden_doc_token_embeddings 484 | ) 485 | in_batch_loss_arr.append(in_batch_loss) 486 | 487 | # compute the weighted mean loss 488 | in_seq_mean_loss = torch.stack(in_seq_loss_arr).mean() 489 | in_batch_mean_loss = torch.stack(in_batch_loss_arr).mean() 490 | 491 | loss = self.lambda_seq * in_seq_mean_loss + (1 - self.lambda_seq) * in_batch_mean_loss 492 | return loss 493 | 494 | def _add_last_join_token( 495 | self, 496 | input_ids: torch.Tensor, 497 | attention_mask: torch.Tensor, 498 | ) -> torch.Tensor: 499 | # find the end of each sequence with the attention mask 500 | seq_ends = attention_mask.sum(dim=1) 501 | # add a column to the input_ids to make sure there is no overflow 502 | input_ids = torch.cat([input_ids, torch.zeros((input_ids.shape[0], 1), device=input_ids.device)], dim=1) 503 | # add a join token at that spot in each row 504 | input_ids[torch.arange(input_ids.shape[0]), seq_ends] = self.join_token_id 505 | return input_ids 506 | 507 | def forward(self, *args, **kwargs): 508 | """ 509 | Forward pass of the model. Computes the loss for the given batch of queries and documents. 510 | Args: 511 | args: A list containing a dictionary with the following keys: 512 | - "query_inputs": The input tensors for the queries. 513 | - "docs_inputs": The input tensors for the documents. 514 | kwargs: A dictionary containing the following keys: 515 | - "queries_chunk_indices": A list of lists containing the indices of the chunks for each query. 516 | - "loss_type": The type of loss to compute. Can be "inbatch_inseq", "weighted", or "late_interaction" 517 | (default: weighted). 518 | - "n_docs_per_sample": The number of documents per sample in the batch 519 | (only needed for single context training). 520 | - "add_prefixes": Whether to add prefixes to the document inputs (default: False). 521 | """ 522 | query_inputs = args[0]["query_inputs"] 523 | query_model_outputs = self.base_model(query_inputs) 524 | 525 | doc_inputs = args[0]["docs_inputs"] 526 | doc_token_embeddings = self.base_model(doc_inputs)["token_embeddings"] 527 | 528 | if kwargs.get("add_prefixes", False): 529 | doc_token_embeddings = doc_token_embeddings[:, self.n_tokens_prefix :, :] 530 | for k, v in doc_inputs.items(): 531 | doc_inputs[k] = v[:, self.n_tokens_prefix :] 532 | 533 | doc_inputs["input_ids"] = self._add_last_join_token( 534 | doc_inputs["input_ids"], 535 | doc_inputs["attention_mask"], 536 | ) # add the join token at the end of each sequence 537 | 538 | # for now, assume multi-docs 539 | chunk_embeddings, doc_padding_mask = self._late_chunking_pooling( 540 | doc_inputs["input_ids"], 541 | doc_token_embeddings, 542 | ) 543 | 544 | if self.pooling_mode == "average": 545 | query_embeddings = query_model_outputs["sentence_embedding"] # (n_queries_in_batch, dim) 546 | elif self.pooling_mode == "tokens": 547 | query_embeddings = query_model_outputs["token_embeddings"] # (n_queries_in_batch, max_query_len, dim) 548 | else: 549 | raise ValueError(f"Pooling mode {self.pooling_mode} not supported. Use 'average' or 'tokens'.") 550 | 551 | if "loss_type" not in kwargs or kwargs["loss_type"] == "weighted": 552 | loss = self._compute_loss_weighted_inbatch_inseq( 553 | query_embeddings, 554 | chunk_embeddings, 555 | doc_padding_mask, 556 | kwargs["queries_chunk_indices"], 557 | ) 558 | elif kwargs["loss_type"] == "inbatch_inseq": 559 | # in-seq and in-batch negatives 560 | loss = self._compute_loss_inbatch_inseq( 561 | query_embeddings, 562 | chunk_embeddings, 563 | doc_padding_mask, 564 | kwargs["queries_chunk_indices"], 565 | kwargs.get("n_docs_per_sample"), 566 | ) 567 | 568 | elif kwargs["loss_type"] == "late_interaction": 569 | loss = self._compute_late_interaction_loss( 570 | query_embeddings, 571 | chunk_embeddings, 572 | doc_padding_mask, 573 | kwargs["queries_chunk_indices"], 574 | ) 575 | else: 576 | raise ValueError( 577 | f"Unknown loss type: {kwargs['loss_type']}.\ 578 | Supported types are: inbatch_inseq, weighted, late_interaction." 579 | ) 580 | 581 | return {"loss": loss} 582 | 583 | def embed_queries(self, queries): 584 | """ 585 | Embeds a batch of queries. 586 | Args: 587 | queries (list[str]): A list of queries to be embedded. 588 | Returns: 589 | torch.Tensor: A tensor of shape (n_queries, dim) containing the embeddings of the queries. 590 | """ 591 | self.base_model.eval() 592 | 593 | kwargs = { 594 | "show_progress_bar": self.show_progress_bar, 595 | "batch_size": self.batch_size, 596 | "normalize_embeddings": self.normalize_embeddings, 597 | "convert_to_tensor": True, 598 | "prompt": "search_query: " if self.add_prefix else None, 599 | } 600 | if not isinstance(self.base_model, ColBERT): 601 | output_value = "sentence_embedding" if self.pooling_mode == "average" else "token_embeddings" 602 | kwargs["output_value"] = output_value 603 | 604 | outputs = self.base_model.encode( 605 | queries, 606 | **kwargs, 607 | ) 608 | 609 | # pad token embeddings 610 | if self.pooling_mode == "tokens": 611 | outputs = torch.nn.utils.rnn.pad_sequence(outputs, batch_first=True) 612 | 613 | if self.add_prefix: 614 | # remove the prefix from the token embeddings 615 | outputs = outputs[:, self.n_tokens_prefix :, :] 616 | 617 | return outputs 618 | 619 | def _tokenize_docs(self, documents): 620 | inputs_list = [] 621 | 622 | # iterate over each document, which is a list of chunks 623 | for docs in documents: 624 | # tokenize all chunks 625 | doc_inputs = self.tokenize(docs, is_query=False) 626 | # remove CLS token from all inputs except the first 627 | doc_inputs = {k: v[:, 1:] for k, v in doc_inputs.items()} 628 | 629 | input_ids = doc_inputs["input_ids"] 630 | attention_mask = doc_inputs["attention_mask"] 631 | # take only the ids of the valid tokens of all chunks 632 | valid_input_ids_list = [input_ids[i, : attention_mask[i].sum()] for i in range(input_ids.shape[0])] 633 | # concat all chunks together 634 | concat_seq = torch.cat(valid_input_ids_list, dim=0) 635 | # add a join_token at the end for the late chunking pooling logic 636 | concat_seq = torch.cat( 637 | [concat_seq, torch.tensor([self.join_token_id], device=concat_seq.device)], 638 | dim=0, 639 | ) 640 | inputs_list.append(concat_seq) 641 | 642 | # pad the sequences 643 | padded_inputs = torch.nn.utils.rnn.pad_sequence(inputs_list, batch_first=True) 644 | # add the attention mask 645 | attention_mask = torch.zeros(padded_inputs.shape[0], padded_inputs.shape[1], device=padded_inputs.device) 646 | for i in range(padded_inputs.shape[0]): 647 | attention_mask[i, : len(inputs_list[i]) - 1] = 1 648 | 649 | return { 650 | "input_ids": padded_inputs, 651 | "attention_mask": attention_mask, 652 | } 653 | 654 | def _partition_long_doc(self, doc_chunks, n_chunks_overlap: Optional[int] = 10): 655 | # tokenize all chunks 656 | if isinstance(self.base_model, ColBERT): 657 | tokenized_chunks = self.tokenize(doc_chunks, is_query=False) 658 | else: 659 | tokenized_chunks = self.tokenize(doc_chunks) 660 | 661 | tokenized_chunks = [ 662 | ( 663 | tokenized_chunks["input_ids"][i, : tokenized_chunks["attention_mask"][i].sum()], 664 | tokenized_chunks["attention_mask"][i].sum().item(), 665 | ) 666 | for i in range(tokenized_chunks["input_ids"].shape[0]) 667 | ] 668 | 669 | chunk_buffer = [] 670 | chunk_lists = [] # stores lists of chunks that will be embedded together 671 | valid_chunks_lists = [] # masks to avoid duplicate chunk embeddings 672 | overlap_in_buffer = False 673 | 674 | for idx, (_, chunk_size) in enumerate(tokenized_chunks): 675 | # if buffer is full, empty it 676 | if chunk_size + sum([line for _, line in chunk_buffer]) >= self.base_model.max_seq_length: 677 | chunk_lists.append([doc_chunks[idx] for idx, _ in chunk_buffer]) 678 | valid_chunks = [1 for _ in range(len(chunk_buffer))] 679 | if overlap_in_buffer: 680 | valid_chunks[:n_chunks_overlap] = [0 for _ in range(n_chunks_overlap)] 681 | valid_chunks_lists.append(valid_chunks) 682 | 683 | # empty buffer, leaving some overlapping chunks 684 | chunk_buffer = chunk_buffer[-n_chunks_overlap:] 685 | overlap_in_buffer = True 686 | 687 | chunk_buffer.append((idx, chunk_size)) 688 | 689 | # add last list 690 | chunk_lists.append([doc_chunks[idx] for idx, _ in chunk_buffer]) 691 | valid_chunks = [1 for _ in range(len(chunk_buffer))] 692 | if overlap_in_buffer: 693 | valid_chunks[:n_chunks_overlap] = [0 for _ in range(n_chunks_overlap)] 694 | valid_chunks_lists.append(valid_chunks) 695 | 696 | assert sum([c for valid_chunks in valid_chunks_lists for c in valid_chunks]) == len(doc_chunks) 697 | 698 | return chunk_lists, valid_chunks_lists 699 | 700 | def _embed_long_doc(self, doc_chunks): 701 | chunk_lists, valid_chunks_lists = self._partition_long_doc(doc_chunks) 702 | chunk_embeddings = self.embed_batch_documents(chunk_lists) 703 | valid_chunk_embeddings = [embeds[: sum(valid)] for embeds, valid in zip(chunk_embeddings, valid_chunks_lists)] 704 | # flatten array of embeddings 705 | valid_chunk_embeddings = [v for valids in valid_chunk_embeddings for v in valids] 706 | 707 | return valid_chunk_embeddings 708 | 709 | def embed_batch_documents(self, documents): 710 | """ 711 | Embeds a batch of documents, where each document is a list of chunks. 712 | Args: 713 | documents (list[list[str]]): A list of documents, where each document is a list of chunks (strings). 714 | Each chunk is a string that will be tokenized and embedded. 715 | Returns: 716 | list[list[torch.Tensor]]: A list of lists of embeddings, where each inner list corresponds to a document, 717 | and contains the embeddings of the chunks in that document. 718 | """ 719 | # documents are lists of chunks 720 | self.base_model.eval() 721 | doc_strings = [self.base_model.tokenizer.sep_token.join(doc) for doc in documents] 722 | 723 | if self.add_prefix: 724 | doc_strings = ["search_document: " + doc for doc in doc_strings] 725 | 726 | if isinstance(self.base_model, ColBERT): 727 | doc_inputs = self.tokenize(doc_strings, is_query=False) 728 | else: 729 | doc_inputs = self.tokenize(doc_strings) 730 | doc_inputs = {k: v.to(self.base_model.device) for k, v in doc_inputs.items()} 731 | 732 | with torch.no_grad(): 733 | doc_embeddings = self.base_model(doc_inputs)["token_embeddings"] 734 | 735 | if self.add_prefix: 736 | doc_embeddings = doc_embeddings[:, self.n_tokens_prefix :, :] 737 | for k, v in doc_inputs.items(): 738 | doc_inputs[k] = v[:, self.n_tokens_prefix :] 739 | 740 | chunk_embeddings, padding_mask = self._late_chunking_pooling(doc_inputs["input_ids"], doc_embeddings) 741 | if self.normalize_embeddings: 742 | chunk_embeddings = torch.nn.functional.normalize(chunk_embeddings, p=2, dim=-1) 743 | 744 | outputs = [] 745 | for i in range(len(documents)): 746 | filtered = chunk_embeddings[i, ~padding_mask[i]] 747 | embedding_list = [filtered[j].cpu() for j in range(filtered.shape[0])] 748 | outputs.append(embedding_list) 749 | 750 | return outputs 751 | 752 | def embed_documents(self, documents, batch_size=64): 753 | """ 754 | 755 | Embeds a list of documents, where each document is a list of chunks. 756 | Args: 757 | documents (list[list[str]]): A list of documents, where each document is a list of chunks (strings). 758 | Each chunk is a string that will be tokenized and embedded. 759 | batch_size (int): The size of the batch to use for embedding the documents. 760 | Returns: 761 | list[list[torch.Tensor]]: A list of lists of embeddings, where each inner list corresponds to a document, 762 | and contains the embeddings of the chunks in that document.""" 763 | all_outputs = [] 764 | for i in tqdm(range(0, len(documents), batch_size), "Embedding documents", disable=not self.show_progress_bar): 765 | outputs = self.embed_batch_documents(documents[i : i + batch_size]) 766 | all_outputs.extend(outputs) 767 | 768 | return all_outputs 769 | 770 | def gradient_checkpointing_enable(self, **kwargs): 771 | """ 772 | Activates gradient checkpointing for the current model (not sure if necessary yet). 773 | """ 774 | self.base_model.gradient_checkpointing_enable(**kwargs) 775 | 776 | def enable_input_require_grads(self, **kwargs): 777 | """ 778 | Enables the gradients for the input embeddings (not sure if necessary yet). 779 | """ 780 | self.base_model.enable_input_require_grads(**kwargs) 781 | 782 | def save( 783 | self, 784 | path: str, 785 | model_name: str | None = None, 786 | create_model_card: bool = True, 787 | train_datasets: list[str] | None = None, 788 | safe_serialization: bool = True, 789 | ) -> None: 790 | self.base_model.save( 791 | path, 792 | model_name=model_name, 793 | create_model_card=create_model_card, 794 | train_datasets=train_datasets, 795 | safe_serialization=safe_serialization, 796 | ) 797 | -------------------------------------------------------------------------------- /contextual_embeddings/models/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datasets import Dataset, concatenate_datasets, load_dataset 3 | from tqdm import tqdm 4 | 5 | N_SPECIAL_TOKENS = 3 # to be safe, we imagine the tokenizer will add 3 special tokens 6 | 7 | 8 | def get_dataset_from_beir_format( 9 | queries: Dataset, 10 | chunks: Dataset, 11 | sep_token: str, 12 | ) -> Dataset: 13 | # Extract ids for both queries and chunks (format is always docID_chunkID) 14 | def extract_ids(sample): 15 | split = sample["chunk_id"].split("_") 16 | sample["doc_id"] = split[0] 17 | sample["chunk_id"] = int(split[1]) 18 | return sample 19 | 20 | chunks = chunks.map(extract_ids) 21 | queries = queries.map(extract_ids) 22 | 23 | df_chunks = pd.DataFrame(chunks) 24 | grouped_chunks = ( 25 | df_chunks.sort_values(["doc_id", "chunk_id"]).groupby("doc_id").agg({"chunk": list, "chunk_id": list}) 26 | ) 27 | grouped_chunks = grouped_chunks.rename(columns={"chunk": "docs_list"}) 28 | 29 | assert grouped_chunks["chunk_id"].apply(lambda x: x == sorted(x)).all() 30 | grouped_chunks = grouped_chunks.drop(columns=["chunk_id"]) 31 | 32 | # concatenate chunks with sep_tokens (keep the list as a separate entry) 33 | grouped_chunks["docs"] = grouped_chunks["docs_list"].apply(lambda x: sep_token.join(x)) 34 | 35 | df_queries = pd.DataFrame(queries) 36 | df_queries = df_queries[df_queries["query"].apply(lambda x: x.strip() != "")] 37 | grouped_queries = ( 38 | df_queries.sort_values(["doc_id", "chunk_id"]) 39 | .groupby("doc_id") 40 | .agg({"query": list, "chunk_id": list, "answer": list}) 41 | ) 42 | assert grouped_queries["chunk_id"].apply(lambda x: x == sorted(x)).all() 43 | grouped_queries = grouped_queries.rename(columns={"query": "queries", "chunk_id": "queries_chunk_ids"}) 44 | 45 | df_dataset = grouped_queries.join(grouped_chunks, how="inner") 46 | 47 | assert len(df_dataset) == len(grouped_queries), "All queries should have a corresponding document." 48 | 49 | dataset = Dataset.from_pandas(df_dataset) 50 | 51 | return dataset 52 | 53 | 54 | def get_chunked_mldr_st(path, base_model, split="train", all_queries=True, filter_long_samples=True): 55 | ds_docs = load_dataset(path, "documents", split=split) 56 | 57 | if filter_long_samples: 58 | 59 | def extract_n_tokens(sample): 60 | sample["n_tokens"] = len(base_model.tokenizer(sample["chunk"]).input_ids) 61 | return sample 62 | 63 | ds_docs = ds_docs.map(extract_n_tokens) 64 | ds_docs = ds_docs.filter(lambda x: x["n_tokens"] < base_model.max_seq_length - N_SPECIAL_TOKENS) 65 | 66 | ds_queries = load_dataset(path, "queries", split=split) 67 | if all_queries: 68 | ds_synthetic = load_dataset(path, "synthetic_queries", split=split) 69 | ds_queries = concatenate_datasets([ds_queries, ds_synthetic]) 70 | 71 | chunk_ids_mapping = {s["chunk_id"]: i for i, s in enumerate(ds_docs)} 72 | 73 | dataset_queries = [] 74 | dataset_docs = [] 75 | for sample in tqdm(ds_queries): 76 | if sample["chunk_id"] in chunk_ids_mapping: 77 | dataset_queries.append(sample["query"]) 78 | dataset_docs.append(ds_docs[chunk_ids_mapping[sample["chunk_id"]]]["chunk"]) 79 | 80 | dataset = Dataset.from_dict({"queries": dataset_queries, "docs": dataset_docs}) 81 | return dataset 82 | 83 | 84 | def create_contextual_dataset(path, base_model, split="train", all_queries=True, filter_long_samples=True): 85 | ds_docs = load_dataset(path, "documents", split=split) 86 | ds_queries = load_dataset(path, "queries", split=split) 87 | if all_queries: 88 | ds_synthetic = load_dataset(path, "synthetic_queries", split=split) 89 | ds_queries = concatenate_datasets([ds_queries, ds_synthetic]) 90 | 91 | print(f"Number of total queries: {len(ds_queries)}") 92 | dataset = get_dataset_from_beir_format(ds_queries, ds_docs, base_model.tokenizer.sep_token) 93 | 94 | if filter_long_samples: 95 | dataset = dataset.filter( 96 | lambda x: len(base_model.tokenizer(x["docs"]).input_ids) < base_model.max_seq_length - N_SPECIAL_TOKENS 97 | ) 98 | 99 | def remove_bad_queries(sample): 100 | mask = [qc_id < len(sample["docs_list"]) for qc_id in sample["queries_chunk_ids"]] 101 | sample["queries_chunk_ids"] = [qc_id for qc_id, m in zip(sample["queries_chunk_ids"], mask) if m] 102 | sample["queries"] = [q for q, m in zip(sample["queries"], mask) if m] 103 | return sample 104 | 105 | print(f"Number of queries after filtering: {sum([len(q) for q in dataset['queries']])}") 106 | dataset = dataset.map(remove_bad_queries) 107 | print(f"Number of queries after removing bad queries: {sum([len(q) for q in dataset['queries']])}") 108 | 109 | return dataset 110 | 111 | 112 | def get_nomic_clusters_dataset(path, base_model, split="train", num_proc=64): 113 | dataset = load_dataset(path, split=split) 114 | 115 | def extract_n_tokens(sample): 116 | sample["n_tokens"] = len(base_model.tokenizer(sample["docs"]).input_ids) 117 | sample["queries_chunk_ids"] = list(range(len(sample["queries"]))) 118 | return sample 119 | 120 | dataset = dataset.map(extract_n_tokens) 121 | 122 | dataset = dataset.filter( 123 | lambda x: x["n_tokens"] < base_model.max_seq_length - N_SPECIAL_TOKENS, 124 | ) 125 | 126 | return dataset 127 | 128 | 129 | def get_long_context_dataset(base_model, base_path="./data_dir/", split="train", return_all=False): 130 | ds_mldr_big = create_contextual_dataset(f"{base_path}/chunked-mldr-big", base_model, split=split) 131 | ds_narrative_qa = create_contextual_dataset(f"{base_path}/narrative_qa", base_model, split=split, all_queries=False) 132 | ds_squad = create_contextual_dataset(f"{base_path}/squad", base_model, split=split, all_queries=False) 133 | full_dataset = concatenate_datasets([ds_mldr_big, ds_narrative_qa, ds_squad]) 134 | 135 | if return_all: 136 | return ds_mldr_big, ds_narrative_qa, ds_squad, full_dataset 137 | 138 | return full_dataset 139 | 140 | 141 | def get_smaller_chunks_dataset(base_model, base_path="./data_dir/", split="train"): 142 | ds_mldr_big = create_contextual_dataset(f"{base_path}/chunked-mldr-big-100", base_model, split=split) 143 | ds_squad = create_contextual_dataset( 144 | f"{base_path}/squad-chunked-par-100", base_model, split=split, all_queries=False 145 | ) 146 | full_dataset = concatenate_datasets([ds_mldr_big, ds_squad]) 147 | return full_dataset 148 | 149 | 150 | def get_mixed_granularity_dataset(base_model, base_path="./data_dir/", split="train"): 151 | ds_mldr_big = create_contextual_dataset(f"{base_path}/chunked-mldr-big", base_model, split=split) 152 | ds_narrative_qa = create_contextual_dataset(f"{base_path}/narrative_qa", base_model, split=split, all_queries=False) 153 | ds_squad = create_contextual_dataset(f"{base_path}/squad", base_model, split=split, all_queries=False) 154 | ds_mldr_big_chunked = create_contextual_dataset(f"{base_path}/chunked-mldr-big-100", base_model, split=split) 155 | ds_squad_chunked = create_contextual_dataset( 156 | f"{base_path}/squad-chunked-par-100", base_model, split=split, all_queries=False 157 | ) 158 | full_dataset = concatenate_datasets([ds_mldr_big, ds_narrative_qa, ds_squad, ds_mldr_big_chunked, ds_squad_chunked]) 159 | return full_dataset 160 | 161 | 162 | def get_nomic_st(path, base_model, split="train"): 163 | # nomic_embed_supervised_clustered 164 | dataset = load_dataset(path, split=split) 165 | 166 | def extract_n_tokens(sample): 167 | sample["n_tokens"] = len(base_model.tokenizer(sample["document"]).input_ids) 168 | return sample 169 | 170 | dataset = dataset.map(extract_n_tokens) 171 | 172 | dataset = dataset.filter( 173 | lambda x: x["n_tokens"] < base_model.max_seq_length - N_SPECIAL_TOKENS, 174 | ) 175 | 176 | # only keep query and document fields 177 | dataset = Dataset.from_dict({"queries": dataset["query"], "docs": dataset["document"]}) 178 | 179 | return dataset 180 | -------------------------------------------------------------------------------- /contextual_embeddings/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .contextual_trainer import ContextualTrainer 2 | from .contextual_training import ContextualTraining, ContextualTrainingConfig 3 | -------------------------------------------------------------------------------- /contextual_embeddings/training/contextual_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer 4 | from torch import Tensor 5 | 6 | 7 | class ContextualTrainer(SentenceTransformerTrainer): 8 | def __init__(self, training_loss_type: str = "inbatch_inseq", **super_kwargs): 9 | super().__init__(**super_kwargs) 10 | self.training_loss_type = training_loss_type 11 | 12 | def compute_loss( 13 | self, 14 | model: SentenceTransformer, 15 | inputs: dict[str, Tensor | Any], 16 | return_outputs: bool = False, 17 | num_items_in_batch=None, 18 | ) -> Tensor | tuple[Tensor, dict[str, Any]]: 19 | features, _ = self.collect_features(inputs) 20 | 21 | # Pass everything to the model and compute the loss in the model directly 22 | features_dict = { 23 | "query_inputs": {k: v for k, v in features[0].items()}, 24 | "docs_inputs": {k: v for k, v in features[1].items()}, 25 | } 26 | 27 | model_outputs = model( 28 | features_dict, 29 | queries_chunk_indices=inputs["queries_chunk_indices"], 30 | n_docs_per_sample=inputs.get("n_docs_per_sample"), 31 | loss_type=self.training_loss_type, 32 | add_prefixes=inputs.get("add_prefixes", False), 33 | ) 34 | loss = model_outputs["loss"] 35 | 36 | if return_outputs: 37 | # During prediction/evaluation, `compute_loss` will be called with `return_outputs=True`. 38 | # However, Sentence Transformer losses do not return outputs, so we return an empty dictionary. 39 | # This does not result in any problems, as the SentenceTransformerTrainingArguments sets 40 | # `prediction_loss_only=True` which means that the output is not used. 41 | return loss, {} 42 | return loss 43 | -------------------------------------------------------------------------------- /contextual_embeddings/training/contextual_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from datetime import datetime 4 | from typing import Optional 5 | 6 | from datasets import Dataset 7 | from sentence_transformers import SentenceTransformerTrainingArguments 8 | from sentence_transformers.evaluation import SentenceEvaluator 9 | 10 | from ..collators.contextual_collator import ContextualDataCollator 11 | from ..models.long_context_model import LongContextEmbeddingModel 12 | from ..training.contextual_trainer import ContextualTrainer 13 | 14 | 15 | @dataclass 16 | class ContextualTrainingConfig: 17 | model: LongContextEmbeddingModel 18 | exp_name: str = "test_exp" 19 | n_gpus: int = 1 20 | training_args: SentenceTransformerTrainingArguments = None 21 | output_dir: str = None 22 | train_dataset: Optional[Dataset] = None 23 | eval_dataset: Optional[Dataset] = None 24 | evaluator: Optional[SentenceEvaluator] = None 25 | run_train: bool = True 26 | multi_ctx_training: bool = True 27 | wandb_project: str = "long-context-model" 28 | wandb_group: str = "main_exp" 29 | loss_type: str = "inbatch_inseq" 30 | add_prefixes: bool = False 31 | colbert_tokenize: bool = False 32 | 33 | def __post_init__(self): 34 | """ 35 | Initialize the model and tokenizer if not provided 36 | """ 37 | self.base_output_dir = self.output_dir 38 | self.output_dir = os.path.join(self.output_dir, self.exp_name) 39 | 40 | if os.path.exists(self.output_dir): 41 | dt = datetime.now().strftime("%Y%m%d%H%M%S") 42 | self.output_dir += f"_{dt}" 43 | 44 | if self.training_args is None: 45 | self.training_args = SentenceTransformerTrainingArguments(output_dir=self.output_dir) 46 | elif self.training_args.output_dir is None: 47 | self.training_args.output_dir = self.output_dir 48 | 49 | if self.training_args.run_name is None: 50 | self.training_args.run_name = self.exp_name 51 | 52 | # cast if string 53 | if isinstance(self.training_args.learning_rate, str): 54 | self.training_args.learning_rate = float(self.training_args.learning_rate) 55 | 56 | self.training_args.remove_unused_columns = False 57 | 58 | def set_exp_name(self, exp_name: str): 59 | self.exp_name = exp_name 60 | self.output_dir = os.path.join(self.base_output_dir, self.exp_name) 61 | self.training_args.output_dir = self.output_dir 62 | self.training_args.run_name = self.exp_name 63 | 64 | 65 | class ContextualTraining: 66 | def __init__(self, config: ContextualTrainingConfig): 67 | self.config = config 68 | self.model = config.model 69 | 70 | def train(self): 71 | """ 72 | Train the model using the provided configuration. 73 | """ 74 | trainer = ContextualTrainer( 75 | training_loss_type=self.config.loss_type, 76 | model=self.model, 77 | tokenizer=self.model.base_model.tokenizer, 78 | args=self.config.training_args, 79 | train_dataset=self.config.train_dataset, 80 | eval_dataset=self.config.eval_dataset, 81 | evaluator=self.config.evaluator, 82 | data_collator=ContextualDataCollator( 83 | tokenize_fn=self.model.tokenize, 84 | is_multi_ctx_training=self.config.multi_ctx_training, 85 | sep_token=self.model.base_model.tokenizer.sep_token, 86 | add_prefixes=self.config.add_prefixes, 87 | colbert_tokenize=self.config.colbert_tokenize, 88 | ), 89 | ) 90 | 91 | trainer.train() 92 | 93 | def save(self, config_file): 94 | """ 95 | Save the trained model and configuration. 96 | Args: 97 | config_file: Path to the configuration file to be copied. 98 | """ 99 | # save model 100 | self.model.save(self.config.output_dir) 101 | 102 | # copy-paste the yml file with os 103 | # ugly but no other way since the file is formatted for the configue library 104 | # so not fully complying to the yaml syntax (hence not supported by e.g. pyyaml) 105 | os.system(f"cp {config_file} {self.config.output_dir}/training_config.yml") 106 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.version] 6 | source = "vcs" 7 | 8 | [tool.hatch.build.targets.wheel] 9 | include = ["contextual_embeddings"] 10 | 11 | [project] 12 | name = "contextual_embeddings" 13 | version = "0.0.1" 14 | description = "This is the package used to model contextualized embeddings." 15 | readme = "README.md" 16 | requires-python = ">=3.11" 17 | dependencies = [ 18 | "accelerate>=1.7.0", 19 | "configue>=5.0.0", 20 | "datasets>=3.2.0", 21 | "flash-attn==2.7.2.post1", 22 | "numpy<2", 23 | "pylate>=1.2.0", 24 | "sentence-transformers>=3.3.1", 25 | "torch==2.2.2", 26 | "transformers>=4.47.1", 27 | "typer>=0.15.1", 28 | ] 29 | 30 | [tool.uv] 31 | dev-dependencies = [ 32 | "pytest>=8.3.4", 33 | ] 34 | 35 | [tool.ruff] 36 | line-length = 120 37 | 38 | [tool.ruff.lint] 39 | select = ["E", "F", "W", "I", "N"] 40 | -------------------------------------------------------------------------------- /scripts/configs/examples/modernbert.yaml: -------------------------------------------------------------------------------- 1 | multi_ctx_training: True 2 | base_model: 3 | (): sentence_transformers.SentenceTransformer 4 | model_name_or_path: "nomic-ai/modernbert-embed-base" 5 | model_kwargs: 6 | attn_implementation: "flash_attention_2" 7 | torch_dtype: !ext torch.bfloat16 8 | 9 | config: 10 | (): contextual_embeddings.ContextualTrainingConfig 11 | model: 12 | (): contextual_embeddings.LongContextEmbeddingModel 13 | base_model: !cfg base_model # points to the variable defined above 14 | multi_ctx_training: !cfg multi_ctx_training 15 | multi_ctx_training: !cfg multi_ctx_training # passed to both model and trainer 16 | loss_type: "weighted" 17 | exp_name: "modernbert-test" 18 | n_gpus: 4 19 | output_dir: "./checkpoints/test" 20 | train_dataset: 21 | (): contextual_embeddings.models.utils.get_long_context_dataset # function returning the dataset 22 | base_model: !cfg base_model 23 | eval_dataset: 24 | mldr: 25 | (): contextual_embeddings.models.utils.create_contextual_dataset 26 | path: "illuin-cde/chunked-mldr-big" 27 | split: "test" 28 | base_model: !cfg base_model 29 | squad: 30 | (): contextual_embeddings.models.utils.create_contextual_dataset 31 | path: "illuin-cde/squad" 32 | split: "validation" 33 | base_model: !cfg base_model 34 | all_queries: False 35 | narrative_qa: 36 | (): contextual_embeddings.models.utils.create_contextual_dataset 37 | path: "illuin-cde/narrative_qa" 38 | split: "test" 39 | base_model: !cfg base_model 40 | all_queries: False 41 | run_train: True 42 | training_args: 43 | (): sentence_transformers.SentenceTransformerTrainingArguments 44 | output_dir: null 45 | overwrite_output_dir: true 46 | num_train_epochs: 2 47 | per_device_train_batch_size: 4 48 | per_device_eval_batch_size: 4 49 | fp16: False # Set to False if you get an error that your GPU can't run on FP16 50 | bf16: True # Set to True if you have a GPU that supports BF16 51 | learning_rate: 5e-5 52 | warmup_steps: 55 53 | lr_scheduler_type: "cosine" 54 | eval_strategy: "steps" 55 | eval_on_start: True 56 | eval_steps: 100 57 | logging_steps: 10 # how often to log to W&B 58 | report_to: "wandb" 59 | -------------------------------------------------------------------------------- /scripts/configs/examples/moderncolbert.yaml: -------------------------------------------------------------------------------- 1 | multi_ctx_training: True 2 | base_model: 3 | (): pylate.models.ColBERT 4 | model_name_or_path: "lightonai/GTE-ModernColBERT-v1" 5 | model_kwargs: 6 | attn_implementation: "flash_attention_2" 7 | torch_dtype: !ext torch.bfloat16 8 | document_length: 8192 9 | 10 | config: 11 | (): contextual_embeddings.ContextualTrainingConfig 12 | model: 13 | (): contextual_embeddings.LongContextEmbeddingModel 14 | base_model: !cfg base_model # points to the variable defined above 15 | multi_ctx_training: !cfg multi_ctx_training 16 | lambda_seq: 0.1 17 | pooling_mode: "tokens" 18 | multi_ctx_training: !cfg multi_ctx_training # passed to both model and trainer 19 | colbert_tokenize: True 20 | loss_type: "late_interaction" 21 | exp_name: "moderncolbert-test" 22 | n_gpus: 1 23 | output_dir: "./checkpoints/test" 24 | train_dataset: 25 | (): contextual_embeddings.models.utils.get_long_context_dataset # function returning the dataset 26 | base_model: !cfg base_model 27 | eval_dataset: 28 | mldr: 29 | (): contextual_embeddings.models.utils.create_contextual_dataset 30 | path: "illuin-cde/chunked-mldr-big" 31 | split: "test" 32 | base_model: !cfg base_model 33 | squad: 34 | (): contextual_embeddings.models.utils.create_contextual_dataset 35 | path: "illuin-cde/squad" 36 | split: "validation" 37 | base_model: !cfg base_model 38 | all_queries: False 39 | narrative_qa: 40 | (): contextual_embeddings.models.utils.create_contextual_dataset 41 | path: "illuin-cde/narrative_qa" 42 | split: "test" 43 | base_model: !cfg base_model 44 | all_queries: False 45 | run_train: True 46 | training_args: 47 | (): sentence_transformers.SentenceTransformerTrainingArguments 48 | output_dir: null 49 | overwrite_output_dir: true 50 | num_train_epochs: 2 51 | per_device_train_batch_size: 1 52 | per_device_eval_batch_size: 1 53 | fp16: False # Set to False if you get an error that your GPU can't run on FP16 54 | bf16: True # Set to True if you have a GPU that supports BF16 55 | learning_rate: 5e-5 56 | warmup_steps: 55 57 | lr_scheduler_type: "cosine" 58 | eval_strategy: "steps" 59 | eval_on_start: True 60 | eval_steps: 100 61 | logging_steps: 10 # how often to log to W&B 62 | report_to: "wandb" 63 | -------------------------------------------------------------------------------- /scripts/training/training.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import configue 4 | import typer 5 | 6 | from contextual_embeddings.training.contextual_training import ( 7 | ContextualTraining, 8 | ContextualTrainingConfig, 9 | ) 10 | 11 | 12 | def main(config_file: Path, lambda_seq: float = -1.0) -> None: 13 | print("Loading config") 14 | config = configue.load(config_file, sub_path="config") 15 | print("Creating Setup") 16 | if isinstance(config, ContextualTrainingConfig): 17 | app = ContextualTraining(config) 18 | if lambda_seq > 0: 19 | app.model.lambda_seq = lambda_seq 20 | config.set_exp_name(f"{config.exp_name}_{'-'.join(str(lambda_seq).split('.'))}") 21 | else: 22 | raise ValueError("Config must be of type ContextualTrainingConfig") 23 | 24 | if config.run_train: 25 | print("Training model") 26 | app.train() 27 | app.save(config_file=config_file) 28 | 29 | """if config.run_eval: 30 | print("Running evaluation") 31 | app.eval() 32 | print("Done!") """ 33 | 34 | 35 | if __name__ == "__main__": 36 | typer.run(main) 37 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illuin-tech/contextual-embeddings/9f4c6ff586067dfde1d4270ccce99ef1f71812e6/tests/__init__.py -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illuin-tech/contextual-embeddings/9f4c6ff586067dfde1d4270ccce99ef1f71812e6/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_contextual_inference_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from sentence_transformers import SentenceTransformer 4 | 5 | from contextual_embeddings import LongContextEmbeddingModel 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def model(): 10 | # Use a small model for testing 11 | base_model = SentenceTransformer("nomic-ai/modernbert-embed-base") 12 | model = LongContextEmbeddingModel(base_model=base_model, normalize_embeddings=True, pooling_mode="average") 13 | model.batch_size = 2 14 | model.show_progress_bar = False 15 | return model 16 | 17 | 18 | @pytest.fixture 19 | def test_documents(): 20 | return [ 21 | ["This is chunk 1 of doc 1", "This is chunk 2 of doc 1"], # 2 chunks 22 | ["This is chunk 1 of doc 2", "This is chunk 2 of doc 2", "This is chunk 3 of doc 2"], # 3 chunks 23 | ["Single chunk document"], # 1 chunk 24 | ["Chunk 1 of doc 4", "Chunk 2 of doc 4", "Chunk 3 of doc 4", "Chunk 4 of doc 4"], # 4 chunks 25 | ] 26 | 27 | 28 | @pytest.fixture 29 | def test_queries(): 30 | return [ 31 | "Short query", 32 | "This is a longer query with more words", 33 | "Query 3", 34 | "Very long query that contains many words to test the embedding function", 35 | ] 36 | 37 | 38 | @pytest.fixture(scope="module") 39 | def model_with_prefix(): 40 | base_model = SentenceTransformer("nomic-ai/modernbert-embed-base") 41 | model = LongContextEmbeddingModel( 42 | base_model=base_model, normalize_embeddings=True, pooling_mode="average", add_prefix=True 43 | ) 44 | model.batch_size = 2 45 | model.show_progress_bar = False 46 | return model 47 | 48 | 49 | def test_embed_documents(model, test_documents): 50 | # Get embeddings 51 | embeddings = model.embed_documents(test_documents, batch_size=2) 52 | 53 | # Test the output structure 54 | assert len(embeddings) == len(test_documents), ( 55 | "Number of document embeddings should match number of input documents" 56 | ) 57 | 58 | # Check that each document has the right number of chunk embeddings 59 | for i, doc in enumerate(test_documents): 60 | assert len(embeddings[i]) == len(doc), ( 61 | f"Document {i} should have {len(doc)} chunk embeddings but has {len(embeddings[i])}" 62 | ) 63 | 64 | # Check embedding dimensions 65 | embedding_dim = model.base_model.get_sentence_embedding_dimension() 66 | for chunk_embedding in embeddings[i]: 67 | assert chunk_embedding.shape == (embedding_dim,), f"Chunk embedding dimension should be {embedding_dim}" 68 | 69 | # Check if embeddings are normalized 70 | if model.normalize_embeddings: 71 | norm = torch.norm(chunk_embedding).item() 72 | assert abs(norm - 1.0) < 1e-5, "Embeddings should be normalized to unit length" 73 | 74 | 75 | def test_embed_queries(model, test_queries): 76 | # Get query embeddings 77 | query_embeddings = model.embed_queries(test_queries) 78 | 79 | # Test output structure 80 | if model.pooling_mode == "average": 81 | # For average pooling, we expect a 2D tensor 82 | assert len(query_embeddings) == len(test_queries), ( 83 | "Number of query embeddings should match number of input queries" 84 | ) 85 | 86 | embedding_dim = model.base_model.get_sentence_embedding_dimension() 87 | assert query_embeddings.shape == (len(test_queries), embedding_dim), ( 88 | f"Query embeddings should have shape ({len(test_queries)}, {embedding_dim})" 89 | ) 90 | 91 | # Check if embeddings are normalized 92 | if model.normalize_embeddings: 93 | norms = torch.norm(query_embeddings, dim=1) 94 | for norm in norms: 95 | assert abs(norm.item() - 1.0) < 1e-5, "Query embeddings should be normalized to unit length" 96 | else: 97 | # For token pooling, we expect padded token embeddings 98 | assert query_embeddings.shape[0] == len(test_queries), ( 99 | "Number of query embeddings should match number of input queries" 100 | ) 101 | 102 | embedding_dim = model.base_model.get_sentence_embedding_dimension() 103 | assert query_embeddings.shape[2] == embedding_dim, ( 104 | f"Query token embeddings should have dimension {embedding_dim}" 105 | ) 106 | 107 | 108 | @pytest.mark.parametrize("batch_size", [1, 2, 5]) 109 | def test_embed_documents_with_batch_size(model, batch_size): 110 | # Create a larger test set to test batch processing 111 | test_documents = [ 112 | [f"Doc {i} Chunk {j}" for j in range(i % 3 + 1)] 113 | for i in range(5) # 5 documents with 1, 2, 3, 1, 2 chunks 114 | ] 115 | 116 | embeddings = model.embed_documents(test_documents, batch_size=batch_size) 117 | 118 | # Verify all documents have embeddings 119 | assert len(embeddings) == len(test_documents) 120 | 121 | # Verify each document has the right number of chunk embeddings 122 | for i, doc in enumerate(test_documents): 123 | assert len(embeddings[i]) == len(doc), ( 124 | f"Document {i} should have {len(doc)} chunk embeddings with batch size {batch_size}" 125 | ) 126 | 127 | 128 | def test_embed_documents_with_prefix(model_with_prefix, test_documents): 129 | # Get embeddings with prefix 130 | embeddings = model_with_prefix.embed_documents(test_documents, batch_size=2) 131 | 132 | # Test the output structure 133 | assert len(embeddings) == len(test_documents), ( 134 | "Number of document embeddings should match number of input documents" 135 | ) 136 | 137 | # Check that each document has the right number of chunk embeddings 138 | for i, doc in enumerate(test_documents): 139 | assert len(embeddings[i]) == len(doc), ( 140 | f"Document {i} should have {len(doc)} chunk embeddings but has {len(embeddings[i])}" 141 | ) 142 | 143 | # Check embedding dimensions 144 | embedding_dim = model_with_prefix.base_model.get_sentence_embedding_dimension() 145 | for chunk_embedding in embeddings[i]: 146 | assert chunk_embedding.shape == (embedding_dim,), f"Chunk embedding dimension should be {embedding_dim}" 147 | 148 | # Check if embeddings are normalized 149 | if model_with_prefix.normalize_embeddings: 150 | norm = torch.norm(chunk_embedding).item() 151 | assert abs(norm - 1.0) < 1e-5, "Embeddings should be normalized to unit length" 152 | 153 | 154 | def test_embeddings_are_different_with_prefix(model, model_with_prefix, test_documents, test_queries): 155 | """Test that embeddings with prefix are different from embeddings without prefix.""" 156 | # Get embeddings without prefix 157 | doc_embeddings_no_prefix = model.embed_documents(test_documents, batch_size=1) 158 | query_embeddings_no_prefix = model.embed_queries(test_queries) 159 | 160 | # Get embeddings with prefix 161 | doc_embeddings_with_prefix = model_with_prefix.embed_documents(test_documents, batch_size=1) 162 | query_embeddings_with_prefix = model_with_prefix.embed_queries(test_queries) 163 | 164 | # Check documents 165 | for i in range(len(test_documents)): 166 | for j in range(len(test_documents[i])): 167 | # Embeddings should be different when using prefix 168 | assert not torch.allclose(doc_embeddings_no_prefix[i][j], doc_embeddings_with_prefix[i][j], atol=1e-4), ( 169 | f"Document {i} chunk {j} embeddings should be different with prefix" 170 | ) 171 | 172 | # Check queries (for average pooling) 173 | if model.pooling_mode == "average": 174 | for i in range(len(test_queries)): 175 | assert not torch.allclose(query_embeddings_no_prefix[i], query_embeddings_with_prefix[i], atol=1e-4), ( 176 | f"Query {i} embeddings should be different with prefix" 177 | ) 178 | -------------------------------------------------------------------------------- /tests/models/test_contextual_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from sentence_transformers import ( 4 | SentenceTransformer, 5 | ) 6 | 7 | from contextual_embeddings import LongContextEmbeddingModel 8 | 9 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 10 | N_DOCS = 10 11 | DOC_TOY_STR = "This is an example document." 12 | QUERY_TOY_STR = "toy query" 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def model_str(): 17 | return "nomic-ai/modernbert-embed-base" 18 | 19 | 20 | @pytest.fixture(scope="module") 21 | def base_model(model_str): 22 | model = SentenceTransformer(model_str) 23 | yield model 24 | del model 25 | 26 | 27 | @pytest.fixture(scope="module") 28 | def hidden_dim(base_model: SentenceTransformer): 29 | return base_model.get_sentence_embedding_dimension() 30 | 31 | 32 | @pytest.fixture(scope="module") 33 | def doc_str(base_model): 34 | return base_model.tokenizer.sep_token.join([DOC_TOY_STR for _ in range(N_DOCS)]) 35 | 36 | 37 | @pytest.fixture(scope="module") 38 | def model(base_model): 39 | return LongContextEmbeddingModel(base_model).to(DEVICE) 40 | 41 | 42 | def test_forward_multi_docs(model, doc_str): 43 | doc_inputs = model.tokenize([doc_str]) 44 | query_inputs = model.tokenize([QUERY_TOY_STR]) 45 | qc_indices = [[1]] 46 | batch = { 47 | "docs_inputs": doc_inputs, 48 | "query_inputs": query_inputs, 49 | } 50 | outputs = model(batch, queries_chunk_indices=qc_indices) 51 | 52 | assert "loss" in outputs 53 | 54 | 55 | def test_late_chunking_pooling(model: LongContextEmbeddingModel, doc_str, hidden_dim): 56 | inputs = model.tokenize([doc_str]) 57 | input_ids = inputs["input_ids"] 58 | token_embeddings = torch.randn((input_ids.shape[0], input_ids.shape[1], hidden_dim)) 59 | pooled_chunks, padding_mask = model._late_chunking_pooling(input_ids, token_embeddings) 60 | 61 | assert pooled_chunks.shape == (1, N_DOCS, hidden_dim) 62 | assert padding_mask.shape == (1, N_DOCS) 63 | 64 | # batch test 65 | batch_size = 4 66 | other_str = model.base_model.tokenizer.sep_token.join([DOC_TOY_STR for _ in range(N_DOCS - 1)]) 67 | 68 | inputs = model.tokenize([doc_str, other_str, doc_str, other_str]) 69 | input_ids = inputs["input_ids"] 70 | token_embeddings = torch.randn((input_ids.shape[0], input_ids.shape[1], hidden_dim)) 71 | pooled_chunks, padding_mask = model._late_chunking_pooling(input_ids, token_embeddings) 72 | assert pooled_chunks.shape == (batch_size, N_DOCS, hidden_dim) 73 | assert padding_mask.shape == (batch_size, N_DOCS) 74 | assert padding_mask.sum() == 2 75 | -------------------------------------------------------------------------------- /tests/models/test_li_inference_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pylate.models import ColBERT 4 | 5 | from contextual_embeddings import LongContextEmbeddingModel 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def colbert_model(): 10 | # Use a path to a pretrained ColBERT model 11 | model_path = "lightonai/GTE-ModernColBERT-v1" # replace with your actual model path 12 | return ColBERT( 13 | model_path, 14 | device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu", 15 | ) 16 | 17 | 18 | @pytest.fixture(scope="module") 19 | def model(colbert_model): 20 | model = LongContextEmbeddingModel(base_model=colbert_model, pooling_mode="tokens") 21 | model.batch_size = 2 22 | model.show_progress_bar = False 23 | return model 24 | 25 | 26 | @pytest.fixture 27 | def test_documents(): 28 | return [ 29 | ["This is chunk 1 of doc 1", "This is chunk 2 of doc 1"], # 2 chunks 30 | ["This is chunk 1 of doc 2", "This is chunk 2 of doc 2", "This is chunk 3 of doc 2"], # 3 chunks 31 | ["Single chunk document"], # 1 chunk 32 | ["Chunk 1 of doc 4", "Chunk 2 of doc 4", "Chunk 3 of doc 4", "Chunk 4 of doc 4"], # 4 chunks 33 | ] 34 | 35 | 36 | @pytest.fixture 37 | def test_queries(): 38 | return [ 39 | "Short query", 40 | "This is a longer query with more words", 41 | "Query 3", 42 | "Very long query that contains many words to test the embedding function", 43 | ] 44 | 45 | 46 | def test_embed_documents_token_counts(model, test_documents): 47 | """Test that embedded chunks maintain appropriate token counts""" 48 | # First, get the token counts for each chunk 49 | chunk_token_counts = {} 50 | 51 | for doc_idx, doc in enumerate(test_documents): 52 | chunk_token_counts[doc_idx] = [] 53 | for chunk_idx, chunk in enumerate(doc): 54 | # Tokenize each chunk individually (without special tokens) 55 | tokens = model.base_model.tokenize([chunk]) 56 | # Count actual tokens (excluding padding) 57 | token_count = tokens["attention_mask"].sum().item() 58 | chunk_token_counts[doc_idx].append(token_count) 59 | 60 | # Now get embeddings 61 | embeddings = model.embed_documents(test_documents, batch_size=2) 62 | 63 | # Check the output structure 64 | assert len(embeddings) == len(test_documents) 65 | 66 | # Check that each document has the right number of chunk embeddings 67 | for doc_idx, doc in enumerate(test_documents): 68 | assert len(embeddings[doc_idx]) == len(doc) 69 | 70 | # For token-level embeddings, check the shape of each embedding 71 | embedding_dim = model.base_model.get_sentence_embedding_dimension() 72 | for chunk_idx, chunk_embedding in enumerate(embeddings[doc_idx]): 73 | assert chunk_embedding.shape[-1] == embedding_dim 74 | 75 | 76 | def test_embed_queries_token_dimensions(model, test_queries): 77 | """Test that query embeddings maintain token dimensions""" 78 | # Get token counts for each query 79 | query_token_counts = [] 80 | for query in test_queries: 81 | tokens = model.base_model.tokenize([query], is_query=True) 82 | token_count = tokens["attention_mask"].sum().item() 83 | query_token_counts.append(token_count) 84 | 85 | # Get query embeddings 86 | query_embeddings = model.embed_queries(test_queries) 87 | 88 | # For token pooling with ColBERT, we expect token-level embeddings 89 | # The shape should be [num_queries, max_query_length, embedding_dim] 90 | assert query_embeddings.shape[0] == len(test_queries) 91 | 92 | embedding_dim = model.base_model.get_sentence_embedding_dimension() 93 | assert query_embeddings.shape[2] == embedding_dim 94 | 95 | # For each query, check that we have at least as many token embeddings as the original token count 96 | # (might have more due to padding) 97 | for i, count in enumerate(query_token_counts): 98 | # Count non-zero embeddings 99 | non_zero_embs = torch.sum(torch.sum(query_embeddings[i] != 0, dim=1) > 0).item() 100 | assert non_zero_embs >= count, f"Query {i} should have at least {count} token embeddings" 101 | 102 | 103 | @pytest.mark.parametrize("batch_size", [1, 2]) 104 | def test_embed_documents_batch_processing(model, batch_size): 105 | """Test batch processing maintains token counts""" 106 | test_documents = [ 107 | [f"Doc {i} Chunk {j}" for j in range(i % 2 + 1)] 108 | for i in range(3) # 3 documents with 1, 2, 1 chunks 109 | ] 110 | 111 | # Process with different batch sizes 112 | embeddings = model.embed_documents(test_documents, batch_size=batch_size) 113 | 114 | # Verify the structure 115 | assert len(embeddings) == len(test_documents) 116 | 117 | for doc_idx, doc in enumerate(test_documents): 118 | assert len(embeddings[doc_idx]) == len(doc) 119 | 120 | # Check embedding dimensions 121 | embedding_dim = model.base_model.get_sentence_embedding_dimension() 122 | for chunk_idx, chunk_embedding in enumerate(embeddings[doc_idx]): 123 | assert chunk_embedding.shape[-1] == embedding_dim 124 | --------------------------------------------------------------------------------