├── .github ├── dependabot.yml └── workflows │ ├── publish.yml │ ├── ruff.yml │ └── test.yml ├── .gitignore ├── CHANGELOG.md ├── CITATION.cff ├── LICENSE ├── README.md ├── assets └── colpali_architecture.webp ├── colpali_engine ├── __init__.py ├── collators │ ├── __init__.py │ └── visual_retriever_collator.py ├── compression │ ├── __init__.py │ └── token_pooling │ │ ├── __init__.py │ │ ├── base_token_pooling.py │ │ ├── hierarchical_token_pooling.py │ │ └── lambda_token_pooling.py ├── data │ ├── __init__.py │ ├── dataset.py │ └── sampler.py ├── interpretability │ ├── __init__.py │ ├── similarity_map_utils.py │ └── similarity_maps.py ├── loss │ ├── __init__.py │ ├── bi_encoder_losses.py │ └── late_interaction_losses.py ├── models │ ├── __init__.py │ ├── idefics3 │ │ ├── __init__.py │ │ └── colidefics3 │ │ │ ├── __init__.py │ │ │ ├── modeling_colidefics3.py │ │ │ └── processing_colidefics3.py │ ├── paligemma │ │ ├── __init__.py │ │ ├── bipali │ │ │ ├── __init__.py │ │ │ ├── modeling_bipali.py │ │ │ └── processing_bipali.py │ │ └── colpali │ │ │ ├── __init__.py │ │ │ ├── modeling_colpali.py │ │ │ └── processing_colpali.py │ ├── qwen2 │ │ ├── __init__.py │ │ ├── biqwen2 │ │ │ ├── __init__.py │ │ │ ├── modeling_biqwen2.py │ │ │ └── processing_biqwen2.py │ │ └── colqwen2 │ │ │ ├── __init__.py │ │ │ ├── modeling_colqwen2.py │ │ │ └── processing_colqwen2.py │ └── qwen2_5 │ │ ├── __init__.py │ │ ├── biqwen2_5 │ │ ├── __init__.py │ │ ├── modeling_biqwen2_5.py │ │ └── processing_biqwen2_5.py │ │ └── colqwen2_5 │ │ ├── __init__.py │ │ ├── modeling_colqwen2_5.py │ │ └── processing_colqwen2_5.py ├── trainer │ ├── __init__.py │ ├── colmodel_training.py │ └── contrastive_trainer.py └── utils │ ├── __init__.py │ ├── dataset_transformation.py │ ├── gpu_stats.py │ ├── processing_utils.py │ ├── torch_utils.py │ └── transformers_wrappers.py ├── pyproject.toml ├── scripts ├── compute_hardnegs.py ├── configs │ ├── data │ │ ├── debug_data.yaml │ │ └── test_data.yaml │ ├── idefics │ │ └── train_colsmolvlm_model.yaml │ ├── pali │ │ ├── train_bipali_all_model.yaml │ │ ├── train_bipali_model.yaml │ │ ├── train_bipali_pairwise_256_model.yaml │ │ ├── train_bipali_pairwise_hardneg_model.yaml │ │ ├── train_bipali_pairwise_model.yaml │ │ ├── train_colpali2_pt_model.yaml │ │ ├── train_colpali_all_model.yaml │ │ ├── train_colpali_docmatix_hardneg_model.yaml │ │ ├── train_colpali_docmatix_model.yaml │ │ ├── train_colpali_hardneg_debug_model.yaml │ │ ├── train_colpali_hardneg_model.yaml │ │ ├── train_colpali_model.yaml │ │ └── train_colpali_pt_model.yaml │ ├── qwen2 │ │ ├── deprecated │ │ │ ├── train_biqwen2_docmatix_model.yaml │ │ │ ├── train_biqwen2_warmup_model.yaml │ │ │ ├── train_colqwen2_docmatix_model.yaml │ │ │ ├── train_colqwen2_hardneg_model.yaml │ │ │ └── train_colqwen2_wikiss_model.yaml │ │ ├── train_biqwen2_hardneg_model.py │ │ ├── train_biqwen2_hardneg_model.yaml │ │ ├── train_biqwen2_model.yaml │ │ ├── train_colqwen2_infonce_model.yaml │ │ └── train_colqwen2_model.yaml │ └── tr_args │ │ ├── default_neg_tr_args.yaml │ │ ├── default_tr_args.yaml │ │ ├── eval_tr_args.yaml │ │ └── resume_neg_tr_args.yaml └── train │ └── train_colbert.py └── tests ├── collators └── test_visual_retriever_collator.py ├── compression ├── __init__.py └── token_pooling │ ├── __init__.py │ ├── conftest.py │ ├── example_pooling_functions.py │ ├── test_hierarchical_pooling.py │ └── test_lambda_pooling.py ├── data ├── test_dataset.py └── test_sampler.py ├── interpretability ├── test_similarity_map_utils.py └── test_similarity_maps.py ├── models ├── idefics3 │ └── colidefics3 │ │ ├── test_modeling_colidefics3.py │ │ └── test_processing_colidefics3.py ├── paligemma │ └── colpali │ │ ├── test_modeling_colpali.py │ │ └── test_processing_colpali.py ├── qwen2 │ └── colqwen2 │ │ ├── test_modeling_colqwen2.py │ │ └── test_processing_colqwen2.py └── qwen2_5 │ └── colqwen2_5 │ ├── test_modeling_colqwen2_5.py │ └── test_processing_colqwen2_5.py └── utils ├── test_processing_utils.py └── test_torch_utils.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/dependabot-options-reference#package-ecosystem- 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "monthly" 8 | - package-ecosystem: "pip" 9 | directory: "/" 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | release-build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.x" 20 | 21 | - name: Build release distributions 22 | run: | 23 | # NOTE: put your own distribution build steps here. 24 | python -m pip install build 25 | python -m build 26 | 27 | - name: Upload distributions 28 | uses: actions/upload-artifact@v4 29 | with: 30 | name: release-dists 31 | path: dist/ 32 | 33 | pypi-publish: 34 | runs-on: ubuntu-latest 35 | 36 | needs: 37 | - release-build 38 | 39 | permissions: 40 | # IMPORTANT: this permission is mandatory for trusted publishing 41 | id-token: write 42 | 43 | # Dedicated environments with protections for publishing are strongly recommended. 44 | environment: 45 | name: pypi 46 | # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: 47 | url: https://pypi.org/project/colpali-engine/ 48 | 49 | steps: 50 | - name: Retrieve release distributions 51 | uses: actions/download-artifact@v4 52 | with: 53 | name: release-dists 54 | path: dist/ 55 | 56 | - name: Publish release distributions to PyPI 57 | uses: pypa/gh-action-pypi-publish@release/v1 58 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | jobs: 8 | ruff: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: astral-sh/ruff-action@v3 13 | - run: ruff check --fix 14 | - run: ruff format --check 15 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test (non-slow only) 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install -e ".[all]" 28 | 29 | - name: Run tests with pytest (except "slow" tests) 30 | run: | 31 | pytest -m "not slow" 32 | env: 33 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | !*/configs/data/ 3 | .DS_Store 4 | /.vscode/ 5 | /data/ 6 | /logs/ 7 | /models/ 8 | /outputs/ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | notebooks/*.png 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 120 | .pdm.toml 121 | .pdm-python 122 | .pdm-build/ 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | #.idea/ 173 | *.ipynb 174 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Faysse" 5 | given-names: "Manuel" 6 | email: "manuel.faysse@illuin.tech" 7 | - family-names: "Sibille" 8 | given-names: "Hugues" 9 | email: "hugues.sibille@illuin.tech" 10 | - family-names: "Wu" 11 | given-names: "Tony" 12 | email: "tony.wu@illuin.tech" 13 | title: "Vision Document Retrieval (ViDoRe): Benchmark" 14 | date-released: 2024-06-26 15 | url: "https://github.com/illuin-tech/vidore-benchmark" 16 | preferred-citation: 17 | type: article 18 | authors: 19 | - family-names: "Faysse" 20 | given-names: "Manuel" 21 | - family-names: "Sibille" 22 | given-names: "Hugues" 23 | - family-names: "Wu" 24 | given-names: "Tony" 25 | - family-names: "Omrani" 26 | given-names: "Bilel" 27 | - family-names: "Viaud" 28 | given-names: "Gautier" 29 | - family-names: "Hudelot" 30 | given-names: "Céline" 31 | - family-names: "Colombo" 32 | given-names: "Pierre" 33 | doi: "arXiv.2407.01449" 34 | month: 6 35 | title: "ColPali: Efficient Document Retrieval with Vision Language Models" 36 | year: 2024 37 | url: "https://arxiv.org/abs/2407.01449" 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Manuel Faysse, Hugues Sibille, Tony Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/colpali_architecture.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illuin-tech/colpali/fbf9dcc70ef591dcadd1aa73ab019e97a60f272a/assets/colpali_architecture.webp -------------------------------------------------------------------------------- /colpali_engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ( 2 | BiPali, 3 | BiPaliProj, 4 | BiQwen2, 5 | BiQwen2_5, 6 | BiQwen2_5_Processor, 7 | BiQwen2Processor, 8 | ColIdefics3, 9 | ColIdefics3Processor, 10 | ColPali, 11 | ColPaliProcessor, 12 | ColQwen2, 13 | ColQwen2_5, 14 | ColQwen2_5_Processor, 15 | ColQwen2Processor, 16 | ) 17 | -------------------------------------------------------------------------------- /colpali_engine/collators/__init__.py: -------------------------------------------------------------------------------- 1 | from .visual_retriever_collator import VisualRetrieverCollator 2 | -------------------------------------------------------------------------------- /colpali_engine/collators/visual_retriever_collator.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict, List, Union 3 | 4 | from PIL.Image import Image 5 | 6 | from colpali_engine.data.dataset import ColPaliEngineDataset 7 | from colpali_engine.models.paligemma import ColPaliProcessor 8 | from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor 9 | 10 | 11 | def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: 12 | """ 13 | Prefix all keys in a dictionary with the given prefix. 14 | """ 15 | return {f"{prefix}{k}": v for k, v in data.items()} 16 | 17 | 18 | class VisualRetrieverCollator: 19 | """ 20 | Collator for training vision retrieval models. 21 | """ 22 | 23 | # Prefixes 24 | query_prefix = "query_" 25 | pos_doc_prefix = "doc_" 26 | neg_doc_prefix = "neg_doc_" 27 | 28 | def __init__( 29 | self, 30 | processor: BaseVisualRetrieverProcessor, 31 | max_length: int = 2048, 32 | ): 33 | self.processor = processor 34 | self.max_length = max_length 35 | self.image_token_id = None 36 | 37 | # If processor is one of the supported types, extract the token id. 38 | if isinstance(self.processor, (ColPaliProcessor,)): 39 | image_token = "" 40 | try: 41 | idx = self.processor.tokenizer.additional_special_tokens.index(image_token) 42 | self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[idx] 43 | except ValueError: 44 | self.image_token_id = None 45 | 46 | # Force padding to be on the right for ColPaliProcessor. 47 | if isinstance(self.processor, ColPaliProcessor) and self.processor.tokenizer.padding_side != "right": 48 | print("Setting padding side to right") 49 | self.processor.tokenizer.padding_side = "right" 50 | 51 | def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: 52 | queries: List[Union[None, str, Image]] = [] 53 | pos_targets: List[Union[str, Image]] = [] 54 | neg_targets: List[Union[str, Image]] = [] 55 | 56 | # Parse the examples. 57 | for example in examples: 58 | assert ColPaliEngineDataset.QUERY_KEY in example, f"Missing {ColPaliEngineDataset.QUERY_KEY} in example." 59 | query = example[ColPaliEngineDataset.QUERY_KEY] 60 | sampled_query = random.choice(query) if isinstance(query, list) else query 61 | queries.append(sampled_query) 62 | 63 | assert ColPaliEngineDataset.POS_TARGET_KEY in example, ( 64 | f"Missing {ColPaliEngineDataset.POS_TARGET_KEY} in example." 65 | ) 66 | pos_tgt = example[ColPaliEngineDataset.POS_TARGET_KEY] 67 | sample_pos = random.choice(pos_tgt) if isinstance(pos_tgt, list) else pos_tgt 68 | pos_targets.append(sample_pos) 69 | 70 | neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None) 71 | if neg_tgt is not None: 72 | sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt 73 | neg_targets.append(sampled_neg) 74 | 75 | # Process queries. 76 | if all(q is None for q in queries): 77 | batch_query = None 78 | elif any(q is None for q in queries): 79 | raise ValueError("Some queries are None. This collator does not support None queries yet.") 80 | else: 81 | batch_query = self.auto_collate(queries, prefix=self.query_prefix) 82 | 83 | # Process targets. 84 | batch_pos_target = self.auto_collate(pos_targets, prefix=self.pos_doc_prefix) 85 | batch_neg_target = self.auto_collate(neg_targets, prefix=self.neg_doc_prefix) if neg_targets else {} 86 | 87 | return { 88 | **batch_query, 89 | **batch_pos_target, 90 | **batch_neg_target, 91 | } 92 | 93 | def auto_collate(self, batch: List[Union[str, Image]], prefix: str = "") -> Dict[str, Any]: 94 | """Automatically collate a batch of documents.""" 95 | # Convert Document objects to their underlying data. 96 | if isinstance(batch[0], str): 97 | return self.collate_texts(batch, prefix=prefix) 98 | elif isinstance(batch[0], Image): 99 | return self.collate_images(batch, prefix=prefix) 100 | else: 101 | raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.") 102 | 103 | def collate_images(self, images: List[Image], prefix: str = "") -> Dict[str, Any]: 104 | """Collate images into a batch.""" 105 | # Process images. 106 | batch_im = self.processor.process_images(images=images) 107 | # Prefix keys to avoid collisions. 108 | return prefix_keys(batch_im, prefix) 109 | 110 | def collate_texts(self, texts: List[str], prefix: str = "") -> Dict[str, Any]: 111 | """Collate texts into a batch.""" 112 | # Process texts. 113 | batch_text = self.processor.process_queries( 114 | queries=texts, 115 | max_length=self.max_length, 116 | ) 117 | # Prefix keys to avoid collisions. 118 | return prefix_keys(batch_text, prefix) 119 | -------------------------------------------------------------------------------- /colpali_engine/compression/__init__.py: -------------------------------------------------------------------------------- 1 | from .token_pooling import ( 2 | BaseTokenPooler, 3 | HierarchicalTokenPooler, 4 | LambdaTokenPooler, 5 | TokenPoolingOutput, 6 | ) 7 | -------------------------------------------------------------------------------- /colpali_engine/compression/token_pooling/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_token_pooling import BaseTokenPooler, TokenPoolingOutput 2 | from .hierarchical_token_pooling import HierarchicalTokenPooler 3 | from .lambda_token_pooling import LambdaTokenPooler 4 | -------------------------------------------------------------------------------- /colpali_engine/compression/token_pooling/lambda_token_pooling.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from typing import Callable, Dict, List, Optional, Tuple 3 | 4 | import torch 5 | 6 | from colpali_engine.compression.token_pooling.base_token_pooling import BaseTokenPooler 7 | 8 | 9 | class LambdaTokenPooler(BaseTokenPooler): 10 | """ 11 | Token pooler that applies a user-defined pooling function to multi-vector embeddings. 12 | 13 | This pooler allows users to define custom pooling methods rather than relying on pre-defined pooling strategies. 14 | 15 | Example: 16 | 17 | ```python 18 | # Define a custom pooling function that reduces sequence length by half 19 | def custom_pooling(embedding: torch.Tensor) -> torch.Tensor: 20 | token_length = embedding.size(0) 21 | # Resize to half the original length by averaging pairs of tokens 22 | half_length = token_length // 2 + (token_length % 2) 23 | pooled_embeddings = torch.zeros( 24 | (half_length, embedding.size(1)), 25 | dtype=embedding.dtype, 26 | device=embedding.device, 27 | ) 28 | 29 | for i in range(half_length): 30 | start_idx = i * 2 31 | end_idx = min(start_idx + 2, token_length) 32 | cluster_indices = torch.arange(start_idx, end_idx) 33 | pooled_embeddings[i] = embedding[cluster_indices].mean(dim=0) 34 | pooled_embeddings[i] = torch.nn.functional.normalize(pooled_embeddings[i], p=2, dim=-1) 35 | 36 | return pooled_embeddings 37 | 38 | 39 | # Create a LambdaTokenPooler with the custom function 40 | pooler = LambdaTokenPooler(pool_func=custom_pooling) 41 | outputs = pooler.pool_embeddings(embeddings) 42 | ``` 43 | """ 44 | 45 | def __init__( 46 | self, 47 | pool_func: Callable[[torch.Tensor], torch.Tensor], 48 | ): 49 | """ 50 | Initialize the LambdaTokenPooler with a custom pooling function. 51 | 52 | Args: 53 | pool_func: A function that takes a 2D tensor (token_length, embedding_dim) and returns pooled embeddings, 54 | i.e. a tensor of shape (num_clusters, embedding_dim)). 55 | """ 56 | self.pool_func = pool_func 57 | 58 | def _pool_embeddings_impl( 59 | self, 60 | embeddings: List[torch.Tensor], 61 | num_workers: Optional[int] = None, 62 | ) -> Tuple[ 63 | List[torch.Tensor], 64 | Optional[List[Dict[int, Tuple[torch.Tensor]]]], 65 | ]: 66 | """ 67 | Apply the custom pooling function to each embedding in the list. 68 | 69 | Args: 70 | embeddings: List of 2D tensors to pool 71 | num_workers: Number of workers for parallel processing 72 | 73 | Returns: 74 | Tuple containing: 75 | - List of pooled embeddings 76 | - None (no cluster ID mapping in this implementation) 77 | """ 78 | if num_workers and num_workers > 1: 79 | with ThreadPoolExecutor(num_workers) as executor: 80 | # NOTE: We opted for a thread-based pool because most of the heavy lifting is done in C-level libraries 81 | # (NumPy, Torch, and SciPy) which usually release the GIL. 82 | pooled_embeddings = list(executor.map(self.pool_func, embeddings)) 83 | elif num_workers is None or num_workers == 1: 84 | # Process embeddings sequentially 85 | pooled_embeddings = [self.pool_func(emb) for emb in embeddings] 86 | else: 87 | raise ValueError(f"Invalid number of workers: {num_workers}") 88 | 89 | return pooled_embeddings, None 90 | -------------------------------------------------------------------------------- /colpali_engine/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ColPaliEngineDataset, Corpus 2 | from .sampler import SingleDatasetBatchSampler 3 | -------------------------------------------------------------------------------- /colpali_engine/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | from datasets import Dataset as HFDataset 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | Document = Union[str, Image.Image] 8 | 9 | 10 | class Corpus: 11 | """ 12 | Corpus class for handling retrieving with simple mapping. 13 | This class is meant to be overridden by the user to handle their own corpus. 14 | 15 | Args: 16 | corpus_data (List[Dict[str, Any]]): List of dictionaries containing doc data. 17 | docid_to_idx_mapping (Optional[Dict[str, int]]): Optional mapping from doc IDs to indices. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | corpus_data: List[Dict[str, Any]], 23 | docid_to_idx_mapping: Optional[Dict[str, int]] = None, 24 | doc_column_name: str = "doc", 25 | ): 26 | """ 27 | Initialize the corpus with the provided data. 28 | """ 29 | self.corpus_data = corpus_data 30 | self.docid_to_idx_mapping = docid_to_idx_mapping 31 | self.doc_column_name = doc_column_name 32 | 33 | assert isinstance( 34 | self.corpus_data, 35 | (list, Dataset, HFDataset), 36 | ), "Corpus data must be a map-style dataset" 37 | 38 | assert self.doc_column_name in self.corpus_data[0], f"Corpus data must contain a column {self.doc_column_name}." 39 | 40 | def __len__(self) -> int: 41 | """ 42 | Return the number of docs in the corpus. 43 | 44 | Returns: 45 | int: The number of docs in the corpus. 46 | """ 47 | return len(self.corpus_data) 48 | 49 | def retrieve(self, docid: Any) -> Document: 50 | """ 51 | Get the corpus row from the given Doc ID. 52 | 53 | Args: 54 | docid (str): The id of the document. 55 | 56 | Returns: 57 | Document: The document retrieved from the corpus. 58 | """ 59 | if self.docid_to_idx_mapping is not None: 60 | doc_idx = self.docid_to_idx_mapping[docid] 61 | else: 62 | doc_idx = docid 63 | return self.corpus_data[doc_idx][self.doc_column_name] 64 | 65 | 66 | class ColPaliEngineDataset(Dataset): 67 | # Output keys 68 | QUERY_KEY = "query" 69 | POS_TARGET_KEY = "pos_target" 70 | NEG_TARGET_KEY = "neg_target" 71 | 72 | def __init__( 73 | self, 74 | data: List[Dict[str, Any]], 75 | corpus: Optional[Corpus] = None, 76 | query_column_name: str = "query", 77 | pos_target_column_name: str = "pos_target", 78 | neg_target_column_name: str = None, 79 | ): 80 | """ 81 | Initialize the dataset with the provided data and external document corpus. 82 | 83 | Args: 84 | data (Dict[str, List[Any]]): A dictionary containing the dataset samples. 85 | corpus (Optional[Corpus]): An optional external document corpus to retrieve 86 | documents (images) from. 87 | """ 88 | self.data = data 89 | self.corpus = corpus 90 | 91 | # Column args 92 | self.query_column_name = query_column_name 93 | self.pos_target_column_name = pos_target_column_name 94 | self.neg_target_column_name = neg_target_column_name 95 | 96 | assert isinstance( 97 | self.data, 98 | (list, Dataset, HFDataset), 99 | ), "Data must be a map-style dataset" 100 | 101 | assert self.query_column_name in self.data[0], f"Data must contain the {self.query_column_name} column" 102 | assert self.pos_target_column_name in self.data[0], f"Data must contain a {self.pos_target_column_name} column" 103 | if self.neg_target_column_name is not None: 104 | assert self.neg_target_column_name in self.data[0], ( 105 | f"Data must contain a {self.neg_target_column_name} column" 106 | ) 107 | 108 | def __len__(self) -> int: 109 | """Return the number of samples in the dataset.""" 110 | return len(self.data) 111 | 112 | def __getitem__(self, idx: int) -> Dict[str, Any]: 113 | sample = self.data[idx] 114 | 115 | query = sample[self.query_column_name] 116 | 117 | pos_targets = sample[self.pos_target_column_name] 118 | if not isinstance(pos_targets, list): 119 | pos_targets = [pos_targets] 120 | 121 | if self.neg_target_column_name is not None: 122 | neg_targets = sample[self.neg_target_column_name] 123 | if not isinstance(neg_targets, list): 124 | neg_targets = [neg_targets] 125 | else: 126 | neg_targets = None 127 | 128 | # If an external document corpus is provided, retrieve the documents from it. 129 | if self.corpus is not None: 130 | pos_targets = [self.corpus.retrieve(doc_id) for doc_id in pos_targets] 131 | if neg_targets is not None: 132 | neg_targets = [self.corpus.retrieve(doc_id) for doc_id in neg_targets] 133 | 134 | return { 135 | self.QUERY_KEY: query, 136 | self.POS_TARGET_KEY: pos_targets, 137 | self.NEG_TARGET_KEY: neg_targets, 138 | } 139 | 140 | def take(self, n: int) -> "ColPaliEngineDataset": 141 | """ 142 | Take the first n samples from the dataset. 143 | 144 | Args: 145 | n (int): The number of samples to take. 146 | 147 | Returns: 148 | ColPaliEngineDataset: A new dataset containing the first n samples. 149 | """ 150 | return self.__class__( 151 | self.data.take(n), 152 | self.corpus, 153 | self.query_column_name, 154 | self.pos_target_column_name, 155 | self.neg_target_column_name, 156 | ) 157 | -------------------------------------------------------------------------------- /colpali_engine/data/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import BatchSampler, Dataset 6 | 7 | 8 | class SingleDatasetBatchSampler(BatchSampler): 9 | """ 10 | A batch sampler that samples from a single dataset per batch and handles distribution across GPUs. 11 | 12 | Args: 13 | datasets (List[Dataset]): List of datasets to sample from 14 | batch_size (int): Global batch size (will be divided across GPUs) 15 | drop_last (bool): Whether to drop the last incomplete batch 16 | generator (Optional[torch.Generator]): Random number generator 17 | """ 18 | 19 | def __init__( 20 | self, 21 | datasets: List[Dataset], 22 | global_batch_size: int, 23 | drop_last: bool = True, 24 | generator: Optional[torch.Generator] = None, 25 | ): 26 | # round down each dataset if not divible by global batch size 27 | for i in range(len(datasets)): 28 | if len(datasets[i]) % global_batch_size != 0: 29 | total_samples = (len(datasets[i]) // global_batch_size) * global_batch_size 30 | datasets[i] = datasets[i].take(total_samples) 31 | 32 | self.datasets = datasets 33 | self.global_batch_size = global_batch_size 34 | self.drop_last = drop_last 35 | self.generator = generator or torch.Generator() 36 | self.initial_seed = self.generator.initial_seed() 37 | 38 | # Calculate dataset sizes and create index mappings 39 | self.dataset_sizes = [len(dataset) for dataset in datasets] 40 | #### get start of each dataset ##### 41 | self.cumsum_sizes = np.cumsum([0] + self.dataset_sizes).tolist() 42 | self.total_size = sum(self.dataset_sizes) 43 | 44 | # Create shuffled indices for each dataset 45 | self.indices_per_dataset = [ 46 | torch.randperm(size, generator=self.generator).tolist() for size in self.dataset_sizes 47 | ] 48 | self.current_positions = [0] * len(datasets) 49 | 50 | self.available_datasets = list(range(len(datasets))) 51 | self.max_positions = [(size // self.global_batch_size) * self.global_batch_size for size in self.dataset_sizes] 52 | 53 | def __iter__(self) -> Iterator[List[int]]: 54 | # Reset current positions and available datasets 55 | self.current_positions = [0] * len(self.datasets) 56 | self.available_datasets = list(range(len(self.datasets))) 57 | 58 | while self.available_datasets: 59 | # Randomly select from available datasets 60 | dataset_idx_index = torch.randint(len(self.available_datasets), size=(1,), generator=self.generator).item() 61 | dataset_idx = self.available_datasets[dataset_idx_index] 62 | 63 | # Get indices for the current dataset 64 | dataset_indices = self.indices_per_dataset[dataset_idx] 65 | current_pos = self.current_positions[dataset_idx] 66 | 67 | # Check if we have enough samples for a full batch 68 | end_pos = current_pos + self.global_batch_size 69 | 70 | if end_pos <= self.max_positions[dataset_idx]: 71 | # Get batch indices 72 | batch_indices = [idx + self.cumsum_sizes[dataset_idx] for idx in dataset_indices[current_pos:end_pos]] 73 | 74 | # Update position 75 | self.current_positions[dataset_idx] = end_pos 76 | 77 | # If dataset is exhausted, remove from available datasets 78 | if end_pos >= self.max_positions[dataset_idx]: 79 | self.available_datasets.remove(dataset_idx) 80 | 81 | yield batch_indices 82 | else: 83 | # This dataset doesn't have enough samples for another batch 84 | self.available_datasets.remove(dataset_idx) 85 | 86 | def set_epoch(self, epoch): 87 | """ 88 | Sets the epoch for this sampler. 89 | 90 | Args: 91 | epoch (int): Epoch number 92 | """ 93 | torch_gen = torch.Generator() 94 | 95 | # Set seed based on epoch to ensure different shuffling each epoch 96 | new_seed = self.initial_seed + epoch 97 | torch_gen.manual_seed(new_seed) 98 | self.generator.manual_seed(new_seed) 99 | 100 | # Reshuffle indices for each dataset 101 | self.indices_per_dataset = [torch.randperm(size, generator=torch_gen).tolist() for size in self.dataset_sizes] 102 | 103 | @property 104 | def batch_size(self) -> int: 105 | return self.global_batch_size 106 | 107 | def __len__(self) -> int: 108 | return sum(size // self.global_batch_size for size in self.dataset_sizes) 109 | -------------------------------------------------------------------------------- /colpali_engine/interpretability/__init__.py: -------------------------------------------------------------------------------- 1 | from .similarity_map_utils import ( 2 | get_similarity_maps_from_embeddings, 3 | normalize_similarity_map, 4 | ) 5 | from .similarity_maps import ( 6 | plot_all_similarity_maps, 7 | plot_similarity_map, 8 | ) 9 | -------------------------------------------------------------------------------- /colpali_engine/interpretability/similarity_map_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | EPSILON = 1e-10 7 | 8 | 9 | def get_similarity_maps_from_embeddings( 10 | image_embeddings: torch.Tensor, 11 | query_embeddings: torch.Tensor, 12 | n_patches: Union[Tuple[int, int], List[Tuple[int, int]]], 13 | image_mask: torch.Tensor, 14 | ) -> List[torch.Tensor]: 15 | """ 16 | Get the batched similarity maps between the query embeddings and the image embeddings. 17 | Each element in the returned list is a tensor of shape (query_tokens, n_patches_x, n_patches_y). 18 | 19 | Args: 20 | image_embeddings: tensor of shape (batch_size, image_tokens, dim) 21 | query_embeddings: tensor of shape (batch_size, query_tokens, dim) 22 | n_patches: number of patches per dimension for each image in the batch. If a single tuple is provided, 23 | the same number of patches is used for all images in the batch (broadcasted). 24 | image_mask: tensor of shape (batch_size, image_tokens). Used to filter out the embeddings 25 | that are not related to the image 26 | """ 27 | 28 | if isinstance(n_patches, tuple): 29 | n_patches = [n_patches] * image_embeddings.size(0) 30 | 31 | similarity_maps: List[torch.Tensor] = [] 32 | 33 | for idx in range(image_embeddings.size(0)): 34 | # Sanity check 35 | if image_mask[idx].sum() != n_patches[idx][0] * n_patches[idx][1]: 36 | raise ValueError( 37 | f"The number of patches ({n_patches[idx][0]} x {n_patches[idx][1]} = " 38 | f"{n_patches[idx][0] * n_patches[idx][1]}) " 39 | f"does not match the number of non-padded image tokens ({image_mask[idx].sum()})." 40 | ) 41 | 42 | # Rearrange the output image tensor to explicitly represent the 2D grid of patches 43 | image_embedding_grid = rearrange( 44 | image_embeddings[idx][image_mask[idx]], # (n_patches_x * n_patches_y, dim) 45 | "(h w) c -> w h c", 46 | w=n_patches[idx][0], 47 | h=n_patches[idx][1], 48 | ) # (n_patches_x, n_patches_y, dim) 49 | 50 | similarity_map = torch.einsum( 51 | "nk,ijk->nij", query_embeddings[idx], image_embedding_grid 52 | ) # (batch_size, query_tokens, n_patches_x, n_patches_y) 53 | 54 | similarity_maps.append(similarity_map) 55 | 56 | return similarity_maps 57 | 58 | 59 | def normalize_similarity_map(similarity_map: torch.Tensor) -> torch.Tensor: 60 | """ 61 | Normalize the similarity map to have values in the range [0, 1]. 62 | 63 | Args: 64 | similarity_map: tensor of shape (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y) 65 | """ 66 | if similarity_map.ndim not in [2, 3]: 67 | raise ValueError( 68 | "The input tensor must have 2 dimensions (n_patch_x, n_patch_y) or " 69 | "3 dimensions (batch_size, n_patch_x, n_patch_y)." 70 | ) 71 | 72 | # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y) 73 | min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1) 74 | 75 | # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y) 76 | max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1) 77 | 78 | # Normalize the tensor 79 | # NOTE: Add a small epsilon to avoid division by zero. 80 | similarity_map_normalized = (similarity_map - min_vals) / ( 81 | max_vals - min_vals + EPSILON 82 | ) # (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y) 83 | 84 | return similarity_map_normalized 85 | -------------------------------------------------------------------------------- /colpali_engine/interpretability/similarity_maps.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import seaborn as sns 6 | import torch 7 | from einops import rearrange 8 | from PIL import Image 9 | 10 | from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map 11 | 12 | 13 | def plot_similarity_map( 14 | image: Image.Image, 15 | similarity_map: torch.Tensor, 16 | figsize: Tuple[int, int] = (8, 8), 17 | show_colorbar: bool = False, 18 | ) -> Tuple[plt.Figure, plt.Axes]: 19 | """ 20 | Plot and overlay a similarity map over the input image. 21 | 22 | A similarity map is a 2D tensor where each element (i, j) represents the similarity score between a chosen query 23 | token and the associated image patch at position (i, j). Thus, the higher the similarity score, the brighter the 24 | color of the patch. 25 | 26 | To show the returned similarity map, use: 27 | 28 | ```python 29 | >>> fig, ax = plot_similarity_map(image, similarity_map) 30 | >>> fig.show() 31 | ``` 32 | 33 | Args: 34 | image: PIL image 35 | similarity_map: tensor of shape (n_patches_x, n_patches_y) 36 | figsize: size of the figure 37 | show_colorbar: whether to show a colorbar 38 | """ 39 | 40 | # Convert the image to an array 41 | img_array = np.array(image.convert("RGBA")) # (height, width, channels) 42 | 43 | # Normalize the similarity map and convert it to Pillow image 44 | similarity_map_array = ( 45 | normalize_similarity_map(similarity_map).to(torch.float32).cpu().numpy() 46 | ) # (n_patches_x, n_patches_y) 47 | 48 | # Reshape the similarity map to match the PIL shape convention 49 | similarity_map_array = rearrange(similarity_map_array, "h w -> w h") # (n_patches_y, n_patches_x) 50 | 51 | similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize( 52 | image.size, Image.Resampling.BICUBIC 53 | ) 54 | 55 | # Create the figure 56 | with plt.style.context("dark_background"): 57 | fig, ax = plt.subplots(figsize=figsize) 58 | 59 | ax.imshow(img_array) 60 | im = ax.imshow( 61 | similarity_map_image, 62 | cmap=sns.color_palette("mako", as_cmap=True), 63 | alpha=0.5, 64 | ) 65 | 66 | if show_colorbar: 67 | fig.colorbar(im) 68 | ax.set_axis_off() 69 | fig.tight_layout() 70 | 71 | return fig, ax 72 | 73 | 74 | def plot_all_similarity_maps( 75 | image: Image.Image, 76 | query_tokens: List[str], 77 | similarity_maps: torch.Tensor, 78 | figsize: Tuple[int, int] = (8, 8), 79 | show_colorbar: bool = False, 80 | add_title: bool = True, 81 | ) -> List[Tuple[plt.Figure, plt.Axes]]: 82 | """ 83 | For each token in the query, plot and overlay a similarity map over the input image. 84 | 85 | A similarity map is a 2D tensor where each element (i, j) represents the similarity score between a chosen query 86 | token and the associated image patch at position (i, j). Thus, the higher the similarity score, the brighter the 87 | color of the patch. 88 | 89 | Args: 90 | image: PIL image 91 | query_tokens: list of query tokens 92 | similarity_maps: tensor of shape (query_tokens, n_patches_x, n_patches_y) 93 | figsize: size of the figure 94 | show_colorbar: whether to show a colorbar 95 | add_title: whether to add a title with the token and the max similarity score 96 | 97 | Example usage for one query-image pair: 98 | 99 | ```python 100 | >>> from colpali_engine.interpretability.similarity_map_utils import get_similarity_maps_from_embeddings 101 | 102 | >>> batch_images = processor.process_images([image]).to(device) 103 | >>> batch_queries = processor.process_queries([query]).to(device) 104 | 105 | >>> with torch.no_grad(): 106 | image_embeddings = model.forward(**batch_images) 107 | query_embeddings = model.forward(**batch_queries) 108 | 109 | >>> n_patches = processor.get_n_patches( 110 | image_size=image.size, 111 | patch_size=model.patch_size 112 | ) 113 | >>> image_mask = processor.get_image_mask(batch_images) 114 | 115 | >>> batched_similarity_maps = get_similarity_maps_from_embeddings( 116 | image_embeddings=image_embeddings, 117 | query_embeddings=query_embeddings, 118 | n_patches=n_patches, 119 | image_mask=image_mask, 120 | ) 121 | >>> similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y) 122 | 123 | >>> plots = plot_all_similarity_maps( 124 | image=image, 125 | query_tokens=query_tokens, 126 | similarity_maps=similarity_maps, 127 | ) 128 | 129 | >>> for fig, ax in plots: 130 | fig.show() 131 | ``` 132 | """ 133 | 134 | plots: List[Tuple[plt.Figure, plt.Axes]] = [] 135 | 136 | for idx, token in enumerate(query_tokens): 137 | fig, ax = plot_similarity_map( 138 | image=image, 139 | similarity_map=similarity_maps[idx], 140 | figsize=figsize, 141 | show_colorbar=show_colorbar, 142 | ) 143 | 144 | if add_title: 145 | max_sim_score = similarity_maps[idx].max().item() 146 | ax.set_title(f"Token #{idx}: `{token}`. MaxSim score: {max_sim_score:.2f}", fontsize=14) 147 | 148 | plots.append((fig, ax)) 149 | 150 | return plots 151 | -------------------------------------------------------------------------------- /colpali_engine/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .bi_encoder_losses import ( 2 | BiEncoderLoss, 3 | BiPairwiseCELoss, 4 | BiPairwiseNegativeCELoss, 5 | ) 6 | from .late_interaction_losses import ( 7 | ColbertLoss, 8 | ColbertPairwiseCELoss, 9 | ColbertPairwiseNegativeCELoss, 10 | ) 11 | -------------------------------------------------------------------------------- /colpali_engine/loss/bi_encoder_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F # noqa: N812 3 | from torch.nn import CrossEntropyLoss 4 | 5 | 6 | class BiEncoderLoss(torch.nn.Module): 7 | def __init__(self, temperature: float = 0.02): 8 | """ 9 | InfoNCE loss. 10 | 11 | Args: 12 | temperature: The temperature to use for the loss (`new_scores = scores / temperature`). 13 | """ 14 | super().__init__() 15 | self.ce_loss = CrossEntropyLoss() 16 | self.temperature = temperature 17 | if not temperature > 0: 18 | raise ValueError("Temperature must be strictly positive") 19 | 20 | def forward(self, query_embeddings, doc_embeddings): 21 | """ 22 | query_embeddings: (batch_size, dim) 23 | doc_embeddings: (batch_size, dim) 24 | """ 25 | 26 | scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) 27 | loss_rowwise = self.ce_loss(scores / self.temperature, torch.arange(scores.shape[0], device=scores.device)) 28 | 29 | return loss_rowwise 30 | 31 | 32 | class BiNegativeCELoss(torch.nn.Module): 33 | def __init__(self, temperature: float = 0.02, in_batch_term=False): 34 | """ 35 | InfoNCE loss with negatives. 36 | 37 | Args: 38 | temperature: The temperature to use for the loss (`new_scores = scores / temperature`). 39 | in_batch_term: Whether to include the in-batch term in the loss. 40 | 41 | """ 42 | super().__init__() 43 | self.ce_loss = CrossEntropyLoss() 44 | self.temperature = temperature 45 | if not temperature > 0: 46 | raise ValueError("Temperature must be strictly positive") 47 | self.in_batch_term = in_batch_term 48 | 49 | def forward(self, query_embeddings, doc_embeddings, neg_doc_embeddings): 50 | """ 51 | query_embeddings: (batch_size, dim) 52 | doc_embeddings: (batch_size, dim) 53 | neg_doc_embeddings: (batch_size, dim) 54 | """ 55 | 56 | # Compute the ColBERT scores 57 | pos_scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings).diagonal() / self.temperature 58 | neg_scores = torch.einsum("bd,cd->bc", query_embeddings, neg_doc_embeddings).diagonal() / self.temperature 59 | 60 | loss = F.softplus(neg_scores - pos_scores).mean() 61 | 62 | if self.in_batch_term: 63 | scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) 64 | loss += self.ce_loss(scores / self.temperature, torch.arange(scores.shape[0], device=scores.device)) 65 | 66 | return loss / 2 67 | 68 | 69 | class BiPairwiseCELoss(torch.nn.Module): 70 | def __init__(self): 71 | """ 72 | Pairwise loss. 73 | """ 74 | super().__init__() 75 | self.ce_loss = CrossEntropyLoss() 76 | 77 | def forward(self, query_embeddings, doc_embeddings): 78 | """ 79 | query_embeddings: (batch_size, dim) 80 | doc_embeddings: (batch_size, dim) 81 | """ 82 | 83 | scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) 84 | 85 | pos_scores = scores.diagonal() 86 | neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 87 | neg_scores = neg_scores.max(dim=1)[0] 88 | 89 | loss = F.softplus(neg_scores - pos_scores).mean() 90 | 91 | return loss 92 | 93 | 94 | class BiPairwiseNegativeCELoss(torch.nn.Module): 95 | def __init__(self, in_batch_term=False): 96 | """ 97 | Pairwise loss with negatives. 98 | Args: 99 | in_batch_term: Whether to include the in-batch term in the loss. 100 | """ 101 | super().__init__() 102 | self.ce_loss = CrossEntropyLoss() 103 | self.in_batch_term = in_batch_term 104 | 105 | def forward(self, query_embeddings, doc_embeddings, neg_doc_embeddings): 106 | """ 107 | query_embeddings: (batch_size, dim) 108 | doc_embeddings: (batch_size, dim) 109 | neg_doc_embeddings: (batch_size, dim) 110 | """ 111 | 112 | # Compute the ColBERT scores 113 | pos_scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings).diagonal() 114 | neg_scores = torch.einsum("bd,cd->bc", query_embeddings, neg_doc_embeddings).diagonal() 115 | 116 | loss = F.softplus(neg_scores - pos_scores).mean() 117 | 118 | if self.in_batch_term: 119 | scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) 120 | 121 | # Positive scores are the diagonal of the scores matrix. 122 | pos_scores = scores.diagonal() # (batch_size,) 123 | 124 | neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size) 125 | neg_scores = neg_scores.max(dim=1)[0] # (batch_size,) 126 | 127 | loss += F.softplus(neg_scores - pos_scores).mean() 128 | 129 | return loss / 2 130 | -------------------------------------------------------------------------------- /colpali_engine/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .idefics3 import ColIdefics3, ColIdefics3Processor 2 | from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor 3 | from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor 4 | from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor 5 | -------------------------------------------------------------------------------- /colpali_engine/models/idefics3/__init__.py: -------------------------------------------------------------------------------- 1 | from .colidefics3 import ColIdefics3, ColIdefics3Processor 2 | -------------------------------------------------------------------------------- /colpali_engine/models/idefics3/colidefics3/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_colidefics3 import ColIdefics3 2 | from .processing_colidefics3 import ColIdefics3Processor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/idefics3/colidefics3/modeling_colidefics3.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from transformers import Idefics3Model, Idefics3PreTrainedModel 3 | 4 | 5 | class ColIdefics3(Idefics3PreTrainedModel): 6 | """ 7 | Initializes the ColIdefics3 model. 8 | 9 | Args: 10 | config : The model configuration. 11 | mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings 12 | except those of the image at inference. 13 | Defaults to False --> Do not mask any embeddings during forward pass. 14 | """ 15 | 16 | def __init__(self, config, mask_non_image_embeddings: bool = False): 17 | super(ColIdefics3, self).__init__(config=config) 18 | self.model: Idefics3Model = Idefics3Model(config) 19 | self.dim = 128 20 | self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim) 21 | self.mask_non_image_embeddings = mask_non_image_embeddings 22 | self.main_input_name = "doc_input_ids" 23 | 24 | def forward(self, *args, **kwargs): 25 | """ 26 | Forward pass through Llama and the linear layer for dimensionality reduction 27 | 28 | Args: 29 | - input_ids (torch.LongTensor): The input tokens tensor. 30 | - attention_mask (torch.LongTensor): The attention mask tensor. 31 | 32 | Returns: 33 | - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) 34 | """ 35 | outputs = self.model(*args, **kwargs) 36 | last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) 37 | proj = self.linear(last_hidden_states) 38 | # normalize l2 norm 39 | proj = proj / proj.norm(dim=-1, keepdim=True) 40 | proj = proj * kwargs["attention_mask"].unsqueeze(-1) 41 | 42 | if "pixel_values" in kwargs and self.mask_non_image_embeddings: 43 | # Pools only the image embeddings 44 | image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) 45 | proj = proj * image_mask 46 | return proj 47 | -------------------------------------------------------------------------------- /colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from PIL import Image 5 | from transformers import BatchEncoding, Idefics3Processor 6 | 7 | from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor 8 | 9 | 10 | class ColIdefics3Processor(BaseVisualRetrieverProcessor, Idefics3Processor): 11 | """ 12 | Processor for ColIdefics3. 13 | """ 14 | 15 | query_prefix: ClassVar[str] = "Query: " 16 | query_augmentation_token: ClassVar[str] = "" 17 | image_token: ClassVar[str] = "" 18 | visual_prompt_prefix: ClassVar[str] = "<|im_start|>user\nDescribe the image." 19 | 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | 23 | @property 24 | def image_token_id(self) -> int: 25 | return self.tokenizer.convert_tokens_to_ids(self.image_token) 26 | 27 | def process_images( 28 | self, 29 | images: List[Image.Image], 30 | context_prompts: Optional[List[str]] = None, 31 | ) -> BatchEncoding: 32 | """ 33 | Process images for ColIdefics3. 34 | 35 | Args: 36 | images: List of PIL images. 37 | context_prompts: List of optional context prompts, i.e. some text description of the context of the image. 38 | """ 39 | 40 | texts_doc: List[str] = [] 41 | images = [[image.convert("RGB")] for image in images] 42 | 43 | if context_prompts: 44 | if len(images) != len(context_prompts): 45 | raise ValueError("Length of images and context prompts must match.") 46 | texts_doc = context_prompts 47 | else: 48 | texts_doc = [self.visual_prompt_prefix] * len(images) 49 | 50 | batch_doc = self( 51 | text=texts_doc, 52 | images=images, 53 | return_tensors="pt", 54 | padding="longest", 55 | ) 56 | return batch_doc 57 | 58 | def process_queries( 59 | self, 60 | queries: List[str], 61 | max_length: int = 50, 62 | suffix: Optional[str] = None, 63 | ) -> BatchEncoding: 64 | """ 65 | Process queries for ColIdefics3. 66 | """ 67 | if suffix is None: 68 | suffix = self.query_augmentation_token * 10 69 | texts_query: List[str] = [] 70 | 71 | for query in queries: 72 | query = self.query_prefix + query + suffix + "\n" 73 | texts_query.append(query) 74 | 75 | batch_query = self.tokenizer( 76 | text=texts_query, 77 | return_tensors="pt", 78 | padding="longest", 79 | ) 80 | 81 | return batch_query 82 | 83 | def score( 84 | self, 85 | qs: List[torch.Tensor], 86 | ps: List[torch.Tensor], 87 | device: Optional[Union[str, torch.device]] = None, 88 | **kwargs, 89 | ) -> torch.Tensor: 90 | """ 91 | Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. 92 | """ 93 | return self.score_multi_vector(qs, ps, device=device, **kwargs) 94 | 95 | def get_n_patches( 96 | self, 97 | image_size: Tuple[int, int], 98 | patch_size: int, 99 | ) -> Tuple[int, int]: 100 | raise NotImplementedError("This method is not implemented for ColIdefics3.") 101 | -------------------------------------------------------------------------------- /colpali_engine/models/paligemma/__init__.py: -------------------------------------------------------------------------------- 1 | from .bipali import BiPali, BiPaliProcessor, BiPaliProj 2 | from .colpali import ColPali, ColPaliProcessor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/paligemma/bipali/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_bipali import BiPali, BiPaliProj 2 | from .processing_bipali import BiPaliProcessor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/paligemma/bipali/modeling_bipali.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig 6 | from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel 7 | 8 | 9 | class BiPali(PaliGemmaPreTrainedModel): 10 | """ 11 | BiPali is an implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. 12 | Representations are average pooled to obtain a single vector representation. 13 | """ 14 | 15 | def __init__(self, config: PaliGemmaConfig): 16 | super(BiPali, self).__init__(config=config) 17 | model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config) 18 | if model.language_model._tied_weights_keys is not None: 19 | self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys] 20 | self.model: PaliGemmaForConditionalGeneration = model 21 | self.main_input_name = "doc_input_ids" 22 | self.post_init() 23 | 24 | def get_input_embeddings(self): 25 | return self.model.language_model.get_input_embeddings() 26 | 27 | def set_input_embeddings(self, value): 28 | self.model.language_model.set_input_embeddings(value) 29 | 30 | def get_output_embeddings(self): 31 | return self.model.language_model.get_output_embeddings() 32 | 33 | def set_output_embeddings(self, new_embeddings): 34 | self.model.language_model.set_output_embeddings(new_embeddings) 35 | 36 | def set_decoder(self, decoder): 37 | self.model.language_model.set_decoder(decoder) 38 | 39 | def get_decoder(self): 40 | return self.model.language_model.get_decoder() 41 | 42 | def tie_weights(self): 43 | return self.model.language_model.tie_weights() 44 | 45 | def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: 46 | model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) 47 | # update vocab size 48 | self.config.text_config.vocab_size = model_embeds.num_embeddings 49 | self.config.vocab_size = model_embeds.num_embeddings 50 | self.model.vocab_size = model_embeds.num_embeddings 51 | return model_embeds 52 | 53 | def forward(self, *args, **kwargs): 54 | # delete output_hidden_states from kwargs 55 | kwargs.pop("output_hidden_states", None) 56 | if "pixel_values" in kwargs: 57 | kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype) 58 | 59 | outputs = self.model(*args, output_hidden_states=True, **kwargs) 60 | last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) 61 | # pooling -mean on attention mask==1 62 | proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum( 63 | kwargs["attention_mask"], dim=1, keepdim=True 64 | ) 65 | proj = proj / proj.norm(dim=-1, keepdim=True) 66 | return proj 67 | 68 | 69 | class BiPaliProj(PaliGemmaPreTrainedModel): 70 | """ 71 | BiPaliProj is a BiPali model with a projection layer for dimensionality reduction. 72 | """ 73 | 74 | def __init__(self, config: PaliGemmaConfig): 75 | super(BiPaliProj, self).__init__(config=config) 76 | model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config) 77 | if model.language_model._tied_weights_keys is not None: 78 | self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys] 79 | self.model: PaliGemmaForConditionalGeneration = model 80 | self.main_input_name = "doc_input_ids" 81 | self.dim = 1024 82 | self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) 83 | self.post_init() 84 | 85 | def get_input_embeddings(self): 86 | return self.model.language_model.get_input_embeddings() 87 | 88 | def set_input_embeddings(self, value): 89 | self.model.language_model.set_input_embeddings(value) 90 | 91 | def get_output_embeddings(self): 92 | return self.model.language_model.get_output_embeddings() 93 | 94 | def set_output_embeddings(self, new_embeddings): 95 | self.model.language_model.set_output_embeddings(new_embeddings) 96 | 97 | def set_decoder(self, decoder): 98 | self.model.language_model.set_decoder(decoder) 99 | 100 | def get_decoder(self): 101 | return self.model.language_model.get_decoder() 102 | 103 | def tie_weights(self): 104 | return self.model.language_model.tie_weights() 105 | 106 | def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: 107 | model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) 108 | # update vocab size 109 | self.config.text_config.vocab_size = model_embeds.num_embeddings 110 | self.config.vocab_size = model_embeds.num_embeddings 111 | self.model.vocab_size = model_embeds.num_embeddings 112 | return model_embeds 113 | 114 | def forward(self, *args, **kwargs): 115 | # delete output_hidden_states from kwargs 116 | kwargs.pop("output_hidden_states", None) 117 | if "pixel_values" in kwargs: 118 | kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype) 119 | 120 | outputs = self.model(*args, output_hidden_states=True, **kwargs) 121 | last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) 122 | 123 | # pooling -mean on attention mask==1 124 | proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum( 125 | kwargs["attention_mask"], dim=1, keepdim=True 126 | ) 127 | proj = self.custom_text_proj(proj) 128 | proj = proj / proj.norm(dim=-1, keepdim=True) 129 | return proj 130 | -------------------------------------------------------------------------------- /colpali_engine/models/paligemma/bipali/processing_bipali.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | 5 | from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor 6 | 7 | 8 | class BiPaliProcessor(ColPaliProcessor): 9 | """ 10 | Processor for BiPali. Mirrors the `ColPaliProcessor` class. 11 | """ 12 | 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | 16 | def score( 17 | self, 18 | qs: List[torch.Tensor], 19 | ps: List[torch.Tensor], 20 | device: Optional[Union[str, torch.device]] = None, 21 | **kwargs, 22 | ) -> torch.Tensor: 23 | """ 24 | Compute the dot product score for the given single-vector query and passage embeddings. 25 | """ 26 | return self.score_single_vector(qs, ps, device=device) 27 | -------------------------------------------------------------------------------- /colpali_engine/models/paligemma/colpali/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_colpali import ColPali 2 | from .processing_colpali import ColPaliProcessor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/paligemma/colpali/modeling_colpali.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar, Optional 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.models.paligemma.modeling_paligemma import ( 6 | PaliGemmaConfig, 7 | PaliGemmaForConditionalGeneration, 8 | PaliGemmaPreTrainedModel, 9 | ) 10 | 11 | 12 | class ColPali(PaliGemmaPreTrainedModel): 13 | """ 14 | ColPali model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. 15 | 16 | Args: 17 | config (PaliGemmaConfig): The model configuration. 18 | mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings 19 | except those of the image at inference. 20 | Defaults to False --> Do not mask any embeddings during forward pass. 21 | """ 22 | 23 | main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related 24 | 25 | def __init__(self, config: PaliGemmaConfig, mask_non_image_embeddings: bool = False): 26 | super().__init__(config=config) 27 | 28 | model = PaliGemmaForConditionalGeneration(config=config) 29 | if model.language_model._tied_weights_keys is not None: 30 | self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys] 31 | self.model = model 32 | 33 | # TODO: Wait for ColPali2 to create a ColPaliConfig to allow specifying the embedding dimension. 34 | # We could do it now but it would break all the models trying to load the model from the checkpoint. 35 | self.dim = 128 36 | self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) 37 | 38 | self.mask_non_image_embeddings = mask_non_image_embeddings 39 | 40 | self.post_init() 41 | 42 | def forward(self, *args, **kwargs) -> torch.Tensor: 43 | # Delete output_hidden_states from kwargs 44 | kwargs.pop("output_hidden_states", None) 45 | if "pixel_values" in kwargs: 46 | kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype) 47 | 48 | outputs = self.model(*args, output_hidden_states=True, **kwargs) # (batch_size, sequence_length, hidden_size) 49 | last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) 50 | proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) 51 | 52 | # L2 normalization 53 | proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) 54 | 55 | proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) 56 | 57 | if "pixel_values" in kwargs and self.mask_non_image_embeddings: 58 | # Pools only the image embeddings 59 | image_mask = (kwargs["input_ids"] == self.config.image_token_index).unsqueeze(-1) 60 | proj = proj * image_mask 61 | return proj 62 | 63 | def get_input_embeddings(self): 64 | return self.model.language_model.get_input_embeddings() 65 | 66 | def set_input_embeddings(self, value): 67 | self.model.language_model.set_input_embeddings(value) 68 | 69 | def get_output_embeddings(self): 70 | return self.model.language_model.get_output_embeddings() 71 | 72 | def set_output_embeddings(self, new_embeddings): 73 | self.model.language_model.set_output_embeddings(new_embeddings) 74 | 75 | def set_decoder(self, decoder): 76 | self.model.language_model.set_decoder(decoder) 77 | 78 | def get_decoder(self): 79 | return self.model.language_model.get_decoder() 80 | 81 | def tie_weights(self): 82 | return self.model.language_model.tie_weights() 83 | 84 | def resize_token_embeddings( 85 | self, 86 | new_num_tokens: Optional[int] = None, 87 | pad_to_multiple_of=None, 88 | ) -> nn.Embedding: 89 | model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) 90 | 91 | # Update vocab size 92 | self.config.text_config.vocab_size = model_embeds.num_embeddings 93 | self.config.vocab_size = model_embeds.num_embeddings 94 | self.model.vocab_size = model_embeds.num_embeddings 95 | 96 | return model_embeds 97 | 98 | @property 99 | def patch_size(self) -> int: 100 | return self.model.vision_tower.config.patch_size 101 | -------------------------------------------------------------------------------- /colpali_engine/models/paligemma/colpali/processing_colpali.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from PIL import Image 5 | from transformers import BatchFeature, PaliGemmaProcessor 6 | 7 | from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor 8 | 9 | 10 | class ColPaliProcessor(BaseVisualRetrieverProcessor, PaliGemmaProcessor): 11 | """ 12 | Processor for ColPali. 13 | """ 14 | 15 | visual_prompt_prefix: ClassVar[str] = "Describe the image." 16 | query_prefix: ClassVar[str] = "Query: " 17 | 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | 21 | @property 22 | def query_augmentation_token(self) -> str: 23 | """ 24 | Return the query augmentation token. 25 | Query augmentation buffers are used as reasoning buffers during inference. 26 | """ 27 | return self.tokenizer.pad_token 28 | 29 | def process_images( 30 | self, 31 | images: List[Image.Image], 32 | context_prompts: Optional[List[str]] = None, 33 | ) -> BatchFeature: 34 | """ 35 | Process images for ColPali. 36 | 37 | Args: 38 | images: List of PIL images. 39 | context_prompts: List of optional context prompts, i.e. some text description of the context of the image. 40 | """ 41 | 42 | if context_prompts: 43 | if len(images) != len(context_prompts): 44 | raise ValueError("Length of images and context prompts must match.") 45 | texts_doc = context_prompts 46 | else: 47 | texts_doc = [self.visual_prompt_prefix] * len(images) 48 | 49 | images = [image.convert("RGB") for image in images] 50 | 51 | batch_doc = self( 52 | text=texts_doc, 53 | images=images, 54 | return_tensors="pt", 55 | padding="longest", 56 | ) 57 | return batch_doc 58 | 59 | def process_queries( 60 | self, 61 | queries: List[str], 62 | max_length: int = 50, 63 | suffix: Optional[str] = None, 64 | ) -> BatchFeature: 65 | """ 66 | Process queries for ColPali. 67 | """ 68 | 69 | if suffix is None: 70 | suffix = self.query_augmentation_token * 10 71 | texts_query: List[str] = [] 72 | 73 | for query in queries: 74 | query = self.tokenizer.bos_token + self.query_prefix + query 75 | query += suffix # add suffix (pad tokens) 76 | 77 | # NOTE: Make input ISO to PaliGemma's processor 78 | query += "\n" 79 | 80 | texts_query.append(query) 81 | 82 | batch_query = self.tokenizer( 83 | texts_query, 84 | text_pair=None, 85 | return_token_type_ids=False, 86 | return_tensors="pt", 87 | padding="longest", 88 | max_length=max_length, 89 | ) 90 | 91 | return batch_query 92 | 93 | def score( 94 | self, 95 | qs: List[torch.Tensor], 96 | ps: List[torch.Tensor], 97 | device: Optional[Union[str, torch.device]] = None, 98 | **kwargs, 99 | ) -> torch.Tensor: 100 | """ 101 | Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. 102 | """ 103 | return self.score_multi_vector(qs, ps, device=device, **kwargs) 104 | 105 | def get_n_patches( 106 | self, 107 | image_size: Tuple[int, int], 108 | patch_size: int, 109 | ) -> Tuple[int, int]: 110 | n_patches_x = self.image_processor.size["width"] // patch_size 111 | n_patches_y = self.image_processor.size["height"] // patch_size 112 | 113 | return n_patches_x, n_patches_y 114 | 115 | def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: 116 | return batch_images.input_ids == self.image_token_id 117 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2/__init__.py: -------------------------------------------------------------------------------- 1 | from .biqwen2 import BiQwen2, BiQwen2Processor 2 | from .colqwen2 import ColQwen2, ColQwen2Processor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2/biqwen2/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_biqwen2 import BiQwen2 2 | from .processing_biqwen2 import BiQwen2Processor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2/biqwen2/modeling_biqwen2.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar, List, Literal, Optional 2 | 3 | import torch 4 | from transformers.models.qwen2_vl import Qwen2VLConfig, Qwen2VLForConditionalGeneration 5 | 6 | 7 | class BiQwen2(Qwen2VLForConditionalGeneration): 8 | """ 9 | BiQwen2 is an implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. 10 | Representations are pooled to obtain a single vector representation. Based on the Qwen2.5-VL backbone. 11 | """ 12 | 13 | main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related 14 | 15 | def __init__(self, config: Qwen2VLConfig): 16 | super().__init__(config=config) 17 | self.padding_side = "left" 18 | self.post_init() 19 | 20 | def inner_forward( 21 | self, 22 | input_ids: torch.LongTensor = None, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | position_ids: Optional[torch.LongTensor] = None, 25 | past_key_values: Optional[List[torch.FloatTensor]] = None, 26 | inputs_embeds: Optional[torch.FloatTensor] = None, 27 | use_cache: Optional[bool] = None, 28 | output_attentions: Optional[bool] = None, 29 | output_hidden_states: Optional[bool] = None, 30 | return_dict: Optional[bool] = None, 31 | pixel_values: Optional[torch.Tensor] = None, 32 | pixel_values_videos: Optional[torch.FloatTensor] = None, 33 | image_grid_thw: Optional[torch.LongTensor] = None, 34 | video_grid_thw: Optional[torch.LongTensor] = None, 35 | ) -> torch.Tensor: 36 | if inputs_embeds is None: 37 | inputs_embeds = self.model.embed_tokens(input_ids) 38 | if pixel_values is not None: 39 | pixel_values = pixel_values.type(self.visual.get_dtype()) 40 | image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) 41 | image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) 42 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 43 | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) 44 | 45 | if pixel_values_videos is not None: 46 | pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) 47 | video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) 48 | video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) 49 | video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 50 | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) 51 | 52 | if attention_mask is not None: 53 | attention_mask = attention_mask.to(inputs_embeds.device) 54 | 55 | outputs = self.model( 56 | input_ids=None, 57 | position_ids=position_ids, 58 | attention_mask=attention_mask, 59 | past_key_values=past_key_values, 60 | inputs_embeds=inputs_embeds, 61 | use_cache=use_cache, 62 | output_attentions=output_attentions, 63 | output_hidden_states=output_hidden_states, 64 | return_dict=return_dict, 65 | ) 66 | 67 | hidden_states = outputs[0] 68 | return hidden_states 69 | 70 | def forward( 71 | self, 72 | pooling_strategy: Literal["cls", "last", "mean"] = "last", 73 | *args, 74 | **kwargs, 75 | ) -> torch.Tensor: 76 | """ 77 | Forward pass for BiQwen2.5 model. 78 | 79 | Args: 80 | pooling_strategy: The strategy to use for pooling the hidden states. 81 | *args: Variable length argument list. 82 | **kwargs: Additional keyword arguments. 83 | 84 | Returns: 85 | torch.Tensor: Dense embeddings (batch_size, hidden_size). 86 | """ 87 | kwargs.pop("output_hidden_states", None) 88 | 89 | # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding 90 | if "pixel_values" in kwargs: 91 | offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,) 92 | kwargs["pixel_values"] = torch.cat( 93 | [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], 94 | dim=0, 95 | ) 96 | 97 | position_ids, rope_deltas = self.get_rope_index( 98 | input_ids=kwargs["input_ids"], 99 | image_grid_thw=kwargs.get("image_grid_thw", None), 100 | video_grid_thw=None, 101 | attention_mask=kwargs.get("attention_mask", None), 102 | ) 103 | last_hidden_states = self.inner_forward( 104 | *args, 105 | **kwargs, 106 | position_ids=position_ids, 107 | use_cache=False, 108 | output_hidden_states=True, 109 | ) # (batch_size, sequence_length, hidden_size) 110 | 111 | # Get CLS token embedding, last token, or mean pool over sequence 112 | if pooling_strategy == "cls": 113 | # Use CLS token (first token) embedding 114 | pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) 115 | elif pooling_strategy == "last": 116 | # use last token since we are left padding 117 | pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) 118 | elif pooling_strategy == "mean": 119 | # Mean pooling over sequence length 120 | mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) 121 | pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) 122 | else: 123 | raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") 124 | 125 | # L2 normalization 126 | pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True) 127 | return pooled_output 128 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2/biqwen2/processing_biqwen2.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | from transformers import BatchFeature 5 | 6 | from colpali_engine.models.qwen2.colqwen2 import ColQwen2Processor 7 | 8 | 9 | class BiQwen2Processor(ColQwen2Processor): 10 | """ 11 | Processor for ColQwen2. 12 | """ 13 | 14 | def process_queries( 15 | self, 16 | queries: List[str], 17 | max_length: int = 50, 18 | suffix: Optional[str] = None, 19 | ) -> BatchFeature: 20 | """ 21 | Process queries for ColQwen2. 22 | """ 23 | if suffix is None: 24 | suffix = self.query_augmentation_token # we remove buffer tokens 25 | texts_query: List[str] = [] 26 | 27 | for query in queries: 28 | query = self.query_prefix + query + suffix 29 | texts_query.append(query) 30 | 31 | batch_query = self( 32 | text=texts_query, 33 | return_tensors="pt", 34 | padding="longest", 35 | ) 36 | 37 | return batch_query 38 | 39 | def score( 40 | self, 41 | qs: List[torch.Tensor], 42 | ps: List[torch.Tensor], 43 | device: Optional[Union[str, torch.device]] = None, 44 | **kwargs, 45 | ) -> torch.Tensor: 46 | """ 47 | Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. 48 | """ 49 | return self.score_single_vector(qs, ps, device=device) 50 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2/colqwen2/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_colqwen2 import ColQwen2 2 | from .processing_colqwen2 import ColQwen2Processor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar, List, Optional 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.models.qwen2_vl import Qwen2VLConfig, Qwen2VLForConditionalGeneration 6 | 7 | 8 | class ColQwen2(Qwen2VLForConditionalGeneration): 9 | """ 10 | ColQwen2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. 11 | 12 | Args: 13 | config (Qwen2VLConfig): The model configuration. 14 | mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings 15 | except those of the image at inference. 16 | Defaults to False --> Do not mask any embeddings during forward pass. 17 | """ 18 | 19 | main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related 20 | 21 | def __init__(self, config: Qwen2VLConfig, mask_non_image_embeddings: bool = False): 22 | super().__init__(config=config) 23 | self.dim = 128 24 | self.custom_text_proj = nn.Linear(self.model.config.hidden_size, self.dim) 25 | self.padding_side = "left" 26 | self.mask_non_image_embeddings = mask_non_image_embeddings 27 | self.post_init() 28 | 29 | def inner_forward( 30 | self, 31 | input_ids: torch.LongTensor = None, 32 | attention_mask: Optional[torch.Tensor] = None, 33 | position_ids: Optional[torch.LongTensor] = None, 34 | past_key_values: Optional[List[torch.FloatTensor]] = None, 35 | inputs_embeds: Optional[torch.FloatTensor] = None, 36 | use_cache: Optional[bool] = None, 37 | output_attentions: Optional[bool] = None, 38 | output_hidden_states: Optional[bool] = None, 39 | return_dict: Optional[bool] = None, 40 | pixel_values: Optional[torch.Tensor] = None, 41 | pixel_values_videos: Optional[torch.FloatTensor] = None, 42 | image_grid_thw: Optional[torch.LongTensor] = None, 43 | video_grid_thw: Optional[torch.LongTensor] = None, 44 | ) -> torch.Tensor: 45 | if inputs_embeds is None: 46 | inputs_embeds = self.model.embed_tokens(input_ids) 47 | if pixel_values is not None: 48 | pixel_values = pixel_values.type(self.visual.get_dtype()) 49 | image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) 50 | image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) 51 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 52 | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) 53 | 54 | if pixel_values_videos is not None: 55 | pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) 56 | video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) 57 | video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) 58 | video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 59 | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) 60 | 61 | if attention_mask is not None: 62 | attention_mask = attention_mask.to(inputs_embeds.device) 63 | 64 | outputs = self.model( 65 | input_ids=None, 66 | position_ids=position_ids, 67 | attention_mask=attention_mask, 68 | past_key_values=past_key_values, 69 | inputs_embeds=inputs_embeds, 70 | use_cache=use_cache, 71 | output_attentions=output_attentions, 72 | output_hidden_states=output_hidden_states, 73 | return_dict=return_dict, 74 | ) 75 | 76 | hidden_states = outputs[0] 77 | return hidden_states 78 | 79 | def forward(self, *args, **kwargs) -> torch.Tensor: 80 | kwargs.pop("output_hidden_states", None) 81 | 82 | # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding 83 | if "pixel_values" in kwargs: 84 | offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,) 85 | kwargs["pixel_values"] = torch.cat( 86 | [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], 87 | dim=0, 88 | ) 89 | 90 | position_ids, rope_deltas = self.get_rope_index( 91 | input_ids=kwargs["input_ids"], 92 | image_grid_thw=kwargs.get("image_grid_thw", None), 93 | video_grid_thw=None, 94 | attention_mask=kwargs.get("attention_mask", None), 95 | ) 96 | last_hidden_states = self.inner_forward( 97 | *args, **kwargs, position_ids=position_ids, use_cache=False, output_hidden_states=True 98 | ) # (batch_size, sequence_length, hidden_size) 99 | 100 | proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) 101 | 102 | # L2 normalization 103 | proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) 104 | proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) 105 | 106 | if "pixel_values" in kwargs and self.mask_non_image_embeddings: 107 | # Pools only the image embeddings 108 | image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) 109 | proj = proj * image_mask 110 | return proj 111 | 112 | @property 113 | def patch_size(self) -> int: 114 | return self.visual.config.patch_size 115 | 116 | @property 117 | def spatial_merge_size(self) -> int: 118 | return self.visual.config.spatial_merge_size 119 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2_5/__init__.py: -------------------------------------------------------------------------------- 1 | from .biqwen2_5 import BiQwen2_5, BiQwen2_5_Processor 2 | from .colqwen2_5 import ColQwen2_5, ColQwen2_5_Processor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2_5/biqwen2_5/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_biqwen2_5 import BiQwen2_5 2 | from .processing_biqwen2_5 import BiQwen2_5_Processor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2_5/biqwen2_5/processing_biqwen2_5.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | from transformers import BatchFeature 5 | 6 | from colpali_engine.models.qwen2_5.colqwen2_5 import ColQwen2_5_Processor 7 | 8 | 9 | class BiQwen2_5_Processor(ColQwen2_5_Processor): # noqa: N801 10 | """ 11 | Processor for BiQwen2.5. 12 | """ 13 | 14 | def process_queries( 15 | self, 16 | queries: List[str], 17 | max_length: int = 50, 18 | suffix: Optional[str] = None, 19 | ) -> BatchFeature: 20 | """ 21 | Process queries for BiQwen2.5. 22 | """ 23 | if suffix is None: 24 | suffix = self.query_augmentation_token # we remove buffer tokens 25 | texts_query: List[str] = [] 26 | 27 | for query in queries: 28 | query = self.query_prefix + query + suffix 29 | texts_query.append(query) 30 | 31 | batch_query = self( 32 | text=texts_query, 33 | return_tensors="pt", 34 | padding="longest", 35 | ) 36 | 37 | return batch_query 38 | 39 | def score( 40 | self, 41 | qs: List[torch.Tensor], 42 | ps: List[torch.Tensor], 43 | device: Optional[Union[str, torch.device]] = None, 44 | **kwargs, 45 | ) -> torch.Tensor: 46 | """ 47 | Compute the cosine similarity for the given query and passage embeddings. 48 | """ 49 | return self.score_single_vector(qs, ps, device=device) 50 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2_5/colqwen2_5/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_colqwen2_5 import ColQwen2_5 2 | from .processing_colqwen2_5 import ColQwen2_5_Processor 3 | -------------------------------------------------------------------------------- /colpali_engine/models/qwen2_5/colqwen2_5/modeling_colqwen2_5.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar, List, Optional 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration 6 | 7 | 8 | class ColQwen2_5(Qwen2_5_VLForConditionalGeneration): # noqa: N801 9 | """ 10 | ColQwen2.5 model implementation, following the achitecture from the article "ColPali: Efficient Document Retrieval 11 | with Vision Language Models" paper. Based on the Qwen2.5-VL backbone. 12 | 13 | Args: 14 | config (Qwen2.5VLConfig): The model configuration. 15 | mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings 16 | except those of the image at inference. 17 | Defaults to False --> Do not mask any embeddings during forward pass. 18 | """ 19 | 20 | main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related 21 | 22 | def __init__(self, config: Qwen2_5_VLConfig, mask_non_image_embeddings: bool = False): 23 | super().__init__(config=config) 24 | self.dim = 128 25 | self.custom_text_proj = nn.Linear(self.model.config.hidden_size, self.dim) 26 | self.padding_side = "left" 27 | self.mask_non_image_embeddings = mask_non_image_embeddings 28 | self.post_init() 29 | 30 | def inner_forward( 31 | self, 32 | input_ids: torch.LongTensor = None, 33 | attention_mask: Optional[torch.Tensor] = None, 34 | position_ids: Optional[torch.LongTensor] = None, 35 | past_key_values: Optional[List[torch.FloatTensor]] = None, 36 | inputs_embeds: Optional[torch.FloatTensor] = None, 37 | use_cache: Optional[bool] = None, 38 | output_attentions: Optional[bool] = None, 39 | output_hidden_states: Optional[bool] = None, 40 | return_dict: Optional[bool] = None, 41 | pixel_values: Optional[torch.Tensor] = None, 42 | pixel_values_videos: Optional[torch.FloatTensor] = None, 43 | image_grid_thw: Optional[torch.LongTensor] = None, 44 | video_grid_thw: Optional[torch.LongTensor] = None, 45 | ) -> torch.Tensor: 46 | if inputs_embeds is None: 47 | inputs_embeds = self.model.embed_tokens(input_ids) 48 | if pixel_values is not None: 49 | pixel_values = pixel_values.type(self.visual.dtype) 50 | image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) 51 | image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) 52 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 53 | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) 54 | 55 | if pixel_values_videos is not None: 56 | pixel_values_videos = pixel_values_videos.type(self.visual.dtype) 57 | video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) 58 | video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) 59 | video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 60 | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) 61 | 62 | if attention_mask is not None: 63 | attention_mask = attention_mask.to(inputs_embeds.device) 64 | 65 | outputs = self.model( 66 | input_ids=None, 67 | position_ids=position_ids, 68 | attention_mask=attention_mask, 69 | past_key_values=past_key_values, 70 | inputs_embeds=inputs_embeds, 71 | use_cache=use_cache, 72 | output_attentions=output_attentions, 73 | output_hidden_states=output_hidden_states, 74 | return_dict=return_dict, 75 | ) 76 | 77 | hidden_states = outputs[0] 78 | return hidden_states 79 | 80 | def forward(self, *args, **kwargs) -> torch.Tensor: 81 | kwargs.pop("output_hidden_states", None) 82 | 83 | # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding 84 | if "pixel_values" in kwargs: 85 | offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,) 86 | kwargs["pixel_values"] = torch.cat( 87 | [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], 88 | dim=0, 89 | ) 90 | 91 | position_ids, rope_deltas = self.get_rope_index( 92 | input_ids=kwargs["input_ids"], 93 | image_grid_thw=kwargs.get("image_grid_thw", None), 94 | video_grid_thw=None, 95 | attention_mask=kwargs.get("attention_mask", None), 96 | ) 97 | last_hidden_states = self.inner_forward( 98 | *args, 99 | **kwargs, 100 | position_ids=position_ids, 101 | use_cache=False, 102 | output_hidden_states=True, 103 | ) # (batch_size, sequence_length, hidden_size) 104 | 105 | proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) 106 | 107 | # L2 normalization 108 | proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) 109 | proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) 110 | 111 | if "pixel_values" in kwargs and self.mask_non_image_embeddings: 112 | # Pools only the image embeddings 113 | image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) 114 | proj = proj * image_mask 115 | return proj 116 | 117 | @property 118 | def patch_size(self) -> int: 119 | return self.visual.config.patch_size 120 | 121 | @property 122 | def spatial_merge_size(self) -> int: 123 | return self.visual.config.spatial_merge_size 124 | -------------------------------------------------------------------------------- /colpali_engine/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .colmodel_training import ColModelTraining, ColModelTrainingConfig 2 | from .contrastive_trainer import ContrastiveTrainer 3 | -------------------------------------------------------------------------------- /colpali_engine/trainer/colmodel_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Callable, Dict, List, Optional, Union 4 | 5 | from peft import LoraConfig, PeftModel, get_peft_model 6 | from transformers import ( 7 | PreTrainedModel, 8 | TrainingArguments, 9 | ) 10 | 11 | from colpali_engine.collators import VisualRetrieverCollator 12 | from colpali_engine.data.dataset import ColPaliEngineDataset 13 | from colpali_engine.loss.late_interaction_losses import ( 14 | ColbertLoss, 15 | ) 16 | from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer 17 | from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary 18 | from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor 19 | 20 | 21 | @dataclass 22 | class ColModelTrainingConfig: 23 | model: Union[PreTrainedModel, PeftModel] 24 | processor: BaseVisualRetrieverProcessor 25 | train_dataset: Union[ColPaliEngineDataset, List[ColPaliEngineDataset]] 26 | eval_dataset: Optional[Union[ColPaliEngineDataset, Dict[str, ColPaliEngineDataset]]] = None 27 | tr_args: Optional[TrainingArguments] = None 28 | output_dir: Optional[str] = None 29 | max_length: int = 256 30 | run_eval: bool = True 31 | run_train: bool = True 32 | peft_config: Optional[LoraConfig] = None 33 | loss_func: Optional[Callable] = ColbertLoss() 34 | pretrained_peft_model_name_or_path: Optional[str] = None 35 | """ 36 | Config class used for training a ColVision model. 37 | """ 38 | 39 | def __post_init__(self): 40 | """ 41 | Initialize the model and tokenizer if not provided 42 | """ 43 | if self.output_dir is None: 44 | sanitized_name = str(self.model.name_or_path).replace("/", "_") 45 | self.output_dir = f"./models/{sanitized_name}" 46 | 47 | if self.tr_args is None: 48 | print("No training arguments provided. Using default.") 49 | self.tr_args = TrainingArguments(output_dir=self.output_dir) 50 | elif self.tr_args.output_dir is None or self.tr_args.output_dir == "trainer_output": 51 | self.tr_args.output_dir = self.output_dir 52 | 53 | if isinstance(self.tr_args.learning_rate, str): 54 | print("Casting learning rate to float") 55 | self.tr_args.learning_rate = float(self.tr_args.learning_rate) 56 | 57 | self.tr_args.remove_unused_columns = False 58 | 59 | if self.pretrained_peft_model_name_or_path is not None: 60 | print("Loading pretrained PEFT model") 61 | self.model.load_adapter(self.pretrained_peft_model_name_or_path, is_trainable=True) 62 | 63 | if self.peft_config is not None: 64 | print("Configurating PEFT model") 65 | if self.pretrained_peft_model_name_or_path is None: 66 | self.model = get_peft_model(self.model, self.peft_config) 67 | self.model.print_trainable_parameters() 68 | else: 69 | print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.") 70 | 71 | print_gpu_utilization() 72 | 73 | 74 | class ColModelTraining: 75 | """ 76 | Class that contains the training and evaluation logic for a ColVision model. 77 | """ 78 | 79 | def __init__(self, config: ColModelTrainingConfig) -> None: 80 | self.config = config 81 | self.model = self.config.model 82 | self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() 83 | self.train_dataset = self.config.train_dataset 84 | self.eval_dataset = self.config.eval_dataset 85 | self.collator = VisualRetrieverCollator( 86 | processor=self.config.processor, 87 | max_length=self.config.max_length, 88 | ) 89 | 90 | def train(self) -> None: 91 | trainer = ContrastiveTrainer( 92 | model=self.model, 93 | train_dataset=self.train_dataset, 94 | eval_dataset=self.eval_dataset, 95 | args=self.config.tr_args, 96 | data_collator=self.collator, 97 | loss_func=self.config.loss_func, 98 | is_vision_model=self.config.processor is not None, 99 | ) 100 | 101 | trainer.args.remove_unused_columns = False 102 | 103 | result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint) 104 | print_summary(result) 105 | 106 | def eval(self) -> None: 107 | raise NotImplementedError("Evaluation is not implemented yet.") 108 | 109 | def save(self): 110 | """ 111 | Save the model with its training config, as well as the tokenizer and processor if provided. 112 | """ 113 | self.model.save_pretrained(self.config.output_dir) 114 | self.config.processor.save_pretrained(self.config.output_dir) 115 | 116 | # Save git hash of the commit at beginning of training 117 | with open(f"{self.config.output_dir}/git_hash.txt", "w") as f: 118 | f.write(self.current_git_hash) 119 | -------------------------------------------------------------------------------- /colpali_engine/trainer/contrastive_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import DatasetDict 3 | from torch.utils.data import ConcatDataset, DataLoader, Dataset 4 | from transformers import Trainer, is_datasets_available 5 | from transformers.trainer_utils import seed_worker 6 | 7 | from colpali_engine.data.sampler import SingleDatasetBatchSampler 8 | 9 | 10 | class ContrastiveTrainer(Trainer): 11 | def __init__(self, loss_func, is_vision_model, *args, **kwargs): 12 | if isinstance(kwargs["train_dataset"], DatasetDict): 13 | dataset_list = list(kwargs["train_dataset"].values()) 14 | elif isinstance(kwargs["train_dataset"], list): 15 | dataset_list = kwargs["train_dataset"] 16 | else: 17 | dataset_list = None 18 | 19 | if dataset_list is not None: 20 | kwargs["train_dataset"] = ConcatDataset(dataset_list) 21 | 22 | super().__init__(*args, **kwargs) 23 | self.loss_func = loss_func 24 | self.is_vision_model = is_vision_model # Unused argument, will be removed in 0.4.0 25 | self.args.remove_unused_columns = False # Safety, don't remove dataset columns from dataloader 26 | self.dataset_list = dataset_list 27 | 28 | def get_train_dataloader(self): 29 | ######## adapted from Transformers Trainer (gross) ######## 30 | """ 31 | Returns the training [`~torch.utils.data.DataLoader`]. 32 | 33 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 34 | training if necessary) otherwise. 35 | 36 | Subclass and override this method if you want to inject some custom behavior. 37 | """ 38 | if self.dataset_list is None: 39 | return super().get_train_dataloader() 40 | 41 | train_dataset = self.train_dataset 42 | data_collator = self.data_collator 43 | if is_datasets_available() and isinstance(train_dataset, Dataset): 44 | train_dataset = self._remove_unused_columns(train_dataset, description="training") 45 | else: 46 | data_collator = self._get_collator_with_removed_columns(data_collator, description="training") 47 | 48 | dataloader_params = { 49 | ######### don't set batch size, mutually exclusive from batch sampler ###### 50 | "collate_fn": data_collator, 51 | "num_workers": self.args.dataloader_num_workers, 52 | "pin_memory": self.args.dataloader_pin_memory, 53 | "persistent_workers": self.args.dataloader_persistent_workers, 54 | } 55 | 56 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 57 | ###### batch_sampler set instead of sampler in trainer code ####### 58 | dataloader_params["batch_sampler"] = self._get_train_sampler() 59 | dataloader_params["drop_last"] = self.args.dataloader_drop_last 60 | dataloader_params["worker_init_fn"] = seed_worker 61 | dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor 62 | 63 | dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 64 | return dataloader 65 | 66 | def _get_train_sampler(self): 67 | if self.dataset_list is None: 68 | return super()._get_train_sampler() 69 | 70 | generator = torch.Generator() 71 | generator.manual_seed(self.args.seed) 72 | return SingleDatasetBatchSampler( 73 | self.dataset_list, 74 | self.args.train_batch_size, 75 | drop_last=self.args.dataloader_drop_last, 76 | generator=generator, 77 | ) 78 | 79 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 80 | query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) 81 | # feed only kwargs with 'doc_' prefix 82 | doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) 83 | if "neg_doc_input_ids" in inputs: 84 | neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) 85 | loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) 86 | return (loss, (query_outputs, doc_outputs, neg_doc_outputs)) if return_outputs else loss 87 | 88 | if "labels" in inputs: 89 | loss = self.loss_func(query_outputs, doc_outputs, inputs["labels"]) 90 | else: 91 | loss = self.loss_func(query_outputs, doc_outputs) 92 | return (loss, (query_outputs, doc_outputs)) if return_outputs else loss 93 | 94 | def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True): 95 | """This function is used to generate predictions and return the loss for the given inputs.""" 96 | if not prediction_loss_only: 97 | raise ValueError("prediction_step is only called with prediction_loss_only=True") 98 | 99 | with torch.no_grad(): 100 | # feed only kwargs with 'doc_' prefix 101 | doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) 102 | query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) 103 | if "neg_doc_input_ids" in inputs: 104 | neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) 105 | loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) 106 | return loss, None, None 107 | 108 | if "labels" in inputs: 109 | loss = self.loss_func(query_outputs, doc_outputs, inputs["labels"]) 110 | else: 111 | loss = self.loss_func(query_outputs, doc_outputs) 112 | return loss, None, None 113 | -------------------------------------------------------------------------------- /colpali_engine/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illuin-tech/colpali/fbf9dcc70ef591dcadd1aa73ab019e97a60f272a/colpali_engine/utils/__init__.py -------------------------------------------------------------------------------- /colpali_engine/utils/gpu_stats.py: -------------------------------------------------------------------------------- 1 | # cond import 2 | try: 3 | from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit 4 | 5 | def print_gpu_utilization(): 6 | nvmlInit() 7 | handle = nvmlDeviceGetHandleByIndex(0) 8 | info = nvmlDeviceGetMemoryInfo(handle) 9 | print(f"GPU memory occupied: {info.used // 1024**2} MB.") 10 | 11 | def print_summary(result): 12 | print(f"Time: {result.metrics['train_runtime']:.2f}") 13 | print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}") 14 | print_gpu_utilization() 15 | 16 | except ImportError: 17 | print("pynvml not found. GPU stats will not be printed.") 18 | 19 | def print_summary(result): 20 | print(f"Time: {result.metrics['train_runtime']:.2f}") 21 | print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}") 22 | 23 | def print_gpu_utilization(): 24 | pass 25 | -------------------------------------------------------------------------------- /colpali_engine/utils/processing_utils.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | from PIL import Image 6 | from transformers import BatchEncoding, BatchFeature, ProcessorMixin 7 | 8 | from colpali_engine.utils.torch_utils import get_torch_device 9 | 10 | 11 | class BaseVisualRetrieverProcessor(ABC, ProcessorMixin): 12 | """ 13 | Base class for visual retriever processors. 14 | """ 15 | 16 | @abstractmethod 17 | def process_images( 18 | self, 19 | images: List[Image.Image], 20 | ) -> Union[BatchFeature, BatchEncoding]: 21 | pass 22 | 23 | @abstractmethod 24 | def process_queries( 25 | self, 26 | queries: List[str], 27 | max_length: int = 50, 28 | suffix: Optional[str] = None, 29 | ) -> Union[BatchFeature, BatchEncoding]: 30 | pass 31 | 32 | @abstractmethod 33 | def score( 34 | self, 35 | qs: List[torch.Tensor], 36 | ps: List[torch.Tensor], 37 | device: Optional[Union[str, torch.device]] = None, 38 | **kwargs, 39 | ) -> torch.Tensor: 40 | pass 41 | 42 | @staticmethod 43 | def score_single_vector( 44 | qs: List[torch.Tensor], 45 | ps: List[torch.Tensor], 46 | device: Optional[Union[str, torch.device]] = None, 47 | ) -> torch.Tensor: 48 | """ 49 | Compute the dot product score for the given single-vector query and passage embeddings. 50 | """ 51 | device = device or get_torch_device("auto") 52 | 53 | if len(qs) == 0: 54 | raise ValueError("No queries provided") 55 | if len(ps) == 0: 56 | raise ValueError("No passages provided") 57 | 58 | qs_stacked = torch.stack(qs).to(device) 59 | ps_stacked = torch.stack(ps).to(device) 60 | 61 | scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) 62 | assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" 63 | 64 | scores = scores.to(torch.float32) 65 | return scores 66 | 67 | @staticmethod 68 | def score_multi_vector( 69 | qs: Union[torch.Tensor, List[torch.Tensor]], 70 | ps: Union[torch.Tensor, List[torch.Tensor]], 71 | batch_size: int = 128, 72 | device: Optional[Union[str, torch.device]] = None, 73 | ) -> torch.Tensor: 74 | """ 75 | Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector 76 | query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the 77 | image of a document page. 78 | 79 | Because the embedding tensors are multi-vector and can thus have different shapes, they 80 | should be fed as: 81 | (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) 82 | (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually 83 | obtained by padding the list of tensors. 84 | 85 | Args: 86 | qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings. 87 | ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings. 88 | batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. 89 | device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not 90 | provided, uses `get_torch_device("auto")`. 91 | 92 | Returns: 93 | `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score 94 | tensor is saved on the "cpu" device. 95 | """ 96 | device = device or get_torch_device("auto") 97 | 98 | if len(qs) == 0: 99 | raise ValueError("No queries provided") 100 | if len(ps) == 0: 101 | raise ValueError("No passages provided") 102 | 103 | scores_list: List[torch.Tensor] = [] 104 | 105 | for i in range(0, len(qs), batch_size): 106 | scores_batch = [] 107 | qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to( 108 | device 109 | ) 110 | for j in range(0, len(ps), batch_size): 111 | ps_batch = torch.nn.utils.rnn.pad_sequence( 112 | ps[j : j + batch_size], batch_first=True, padding_value=0 113 | ).to(device) 114 | scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) 115 | scores_batch = torch.cat(scores_batch, dim=1).cpu() 116 | scores_list.append(scores_batch) 117 | 118 | scores = torch.cat(scores_list, dim=0) 119 | assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" 120 | 121 | scores = scores.to(torch.float32) 122 | return scores 123 | 124 | @abstractmethod 125 | def get_n_patches( 126 | self, 127 | image_size: Tuple[int, int], 128 | *args, 129 | **kwargs, 130 | ) -> Tuple[int, int]: 131 | """ 132 | Get the number of patches (n_patches_x, n_patches_y) that will be used to process an 133 | image of size (height, width) with the given patch size. 134 | """ 135 | pass 136 | -------------------------------------------------------------------------------- /colpali_engine/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | from typing import List, TypeVar 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | logger = logging.getLogger(__name__) 9 | T = TypeVar("T") 10 | 11 | 12 | def get_torch_device(device: str = "auto") -> str: 13 | """ 14 | Returns the device (string) to be used by PyTorch. 15 | 16 | `device` arg defaults to "auto" which will use: 17 | - "cuda:0" if available 18 | - else "mps" if available 19 | - else "cpu". 20 | """ 21 | 22 | if device == "auto": 23 | if torch.cuda.is_available(): 24 | device = "cuda:0" 25 | elif torch.backends.mps.is_available(): # for Apple Silicon 26 | device = "mps" 27 | else: 28 | device = "cpu" 29 | logger.info(f"Using device: {device}") 30 | 31 | return device 32 | 33 | 34 | def tear_down_torch(): 35 | """ 36 | Teardown for PyTorch. 37 | Clears GPU cache for both CUDA and MPS. 38 | """ 39 | gc.collect() 40 | if torch.cuda.is_available(): 41 | torch.cuda.empty_cache() 42 | if torch.backends.mps.is_available(): 43 | torch.mps.empty_cache() 44 | 45 | 46 | class ListDataset(Dataset[T]): 47 | def __init__(self, elements: List[T]): 48 | self.elements = elements 49 | 50 | def __len__(self) -> int: 51 | return len(self.elements) 52 | 53 | def __getitem__(self, idx: int) -> T: 54 | return self.elements[idx] 55 | 56 | 57 | def unbind_padded_multivector_embeddings( 58 | embeddings: torch.Tensor, 59 | padding_value: float = 0.0, 60 | padding_side: str = "left", 61 | ) -> List[torch.Tensor]: 62 | """ 63 | Removes padding elements from a batch of multivector embeddings. 64 | 65 | Args: 66 | embeddings (torch.Tensor): A tensor of shape (batch_size, seq_length, dim) with padding. 67 | padding_value (float): The value used for padding. Each padded token is assumed 68 | to be a vector where every element equals this value. 69 | padding_side (str): Either "left" or "right". This indicates whether the padded 70 | elements appear at the beginning (left) or end (right) of the sequence. 71 | 72 | Returns: 73 | List[torch.Tensor]: A list of tensors, one per sequence in the batch, where 74 | each tensor has shape (new_seq_length, dim) and contains only the non-padding elements. 75 | """ 76 | results: List[torch.Tensor] = [] 77 | 78 | for seq in embeddings: 79 | is_padding = torch.all(seq.eq(padding_value), dim=-1) 80 | 81 | if padding_side == "left": 82 | non_padding_indices = (~is_padding).nonzero(as_tuple=False) 83 | if non_padding_indices.numel() == 0: 84 | valid_seq = seq[:0] 85 | else: 86 | first_valid_idx = non_padding_indices[0].item() 87 | valid_seq = seq[first_valid_idx:] 88 | elif padding_side == "right": 89 | non_padding_indices = (~is_padding).nonzero(as_tuple=False) 90 | if non_padding_indices.numel() == 0: 91 | valid_seq = seq[:0] 92 | else: 93 | last_valid_idx = non_padding_indices[-1].item() 94 | valid_seq = seq[: last_valid_idx + 1] 95 | else: 96 | raise ValueError("padding_side must be either 'left' or 'right'.") 97 | results.append(valid_seq) 98 | 99 | return results 100 | -------------------------------------------------------------------------------- /colpali_engine/utils/transformers_wrappers.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | if importlib.util.find_spec("transformers") is not None: 4 | from transformers import AutoProcessor, AutoTokenizer 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | class AllPurposeWrapper: 8 | def __new__(cls, class_to_instanciate, *args, **kwargs): 9 | return class_to_instanciate.from_pretrained(*args, **kwargs) 10 | 11 | class AutoProcessorWrapper: 12 | def __new__(cls, *args, **kwargs): 13 | return AutoProcessor.from_pretrained(*args, **kwargs) 14 | 15 | class AutoTokenizerWrapper(PreTrainedTokenizer): 16 | def __new__(cls, *args, **kwargs): 17 | return AutoTokenizer.from_pretrained(*args, **kwargs) 18 | 19 | else: 20 | raise ModuleNotFoundError("Transformers must be loaded") 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.version] 6 | source = "vcs" 7 | 8 | [tool.hatch.build.targets.wheel] 9 | include = ["colpali_engine"] 10 | 11 | [project] 12 | name = "colpali_engine" 13 | dynamic = ["version"] 14 | description = "The code used to train and run inference with the ColPali architecture." 15 | authors = [ 16 | { name = "Manuel Faysse", email = "manuel.faysse@illuin.tech" }, 17 | { name = "Hugues Sibille", email = "hugues.sibille@illuin.tech" }, 18 | { name = "Tony Wu", email = "tony.wu@illuin.tech" }, 19 | ] 20 | maintainers = [ 21 | { name = "Manuel Faysse", email = "manuel.faysse@illuin.tech" }, 22 | { name = "Tony Wu", email = "tony.wu@illuin.tech" }, 23 | ] 24 | readme = "README.md" 25 | requires-python = ">=3.9" 26 | classifiers = [ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: MIT License", 29 | "Intended Audience :: Science/Research", 30 | "Intended Audience :: Developers", 31 | "Operating System :: OS Independent", 32 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 33 | ] 34 | 35 | dependencies = [ 36 | "numpy", 37 | "peft>=0.14.0,<0.16.0", 38 | "pillow>=10.0.0", 39 | "requests", 40 | "scipy", 41 | "torch>=2.5.0,<2.8.0", 42 | "transformers>=4.51.1,<4.52.0", 43 | ] 44 | 45 | [project.optional-dependencies] 46 | train = [ 47 | "accelerate>=0.34.0,<1.7.0", 48 | "bitsandbytes", 49 | "configue>=5.0.0", 50 | "datasets>=2.19.1", 51 | "mteb>=1.16.3,<2", 52 | "pillow>=10.0.0,<11.3.0", 53 | "typer>=0.15.1", 54 | ] 55 | 56 | interpretability = [ 57 | "einops>=0.8.0,<1.0.0", 58 | "matplotlib>=3.9.0,<4.0.0", 59 | "seaborn>=0.13.2,<1.0.0", 60 | ] 61 | 62 | dev = ["pytest>=8.0.0", "ruff>=0.4.0"] 63 | 64 | all = [ 65 | "colpali-engine[dev]", 66 | "colpali-engine[interpretability]", 67 | "colpali-engine[train]", 68 | ] 69 | 70 | [project.urls] 71 | homepage = "https://github.com/illuin-tech/colpali" 72 | 73 | [tool.pytest.ini_options] 74 | filterwarnings = ["ignore::Warning"] 75 | markers = ["slow: marks test as slow"] 76 | testpaths = ["tests"] 77 | 78 | [tool.ruff] 79 | line-length = 120 80 | 81 | [tool.ruff.lint] 82 | select = ["E", "F", "W", "I", "N"] 83 | 84 | [tool.ruff.lint.per-file-ignores] 85 | "__init__.py" = ["F401"] 86 | -------------------------------------------------------------------------------- /scripts/compute_hardnegs.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import datasets 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | from colpali_engine.models import BiQwen2, BiQwen2Processor 10 | from colpali_engine.utils.dataset_transformation import load_train_set 11 | 12 | train_set = load_train_set() 13 | 14 | 15 | COMPUTE_EMBEDDINGS = False 16 | COMPUTE_HARDNEGS = False 17 | 18 | if COMPUTE_HARDNEGS or COMPUTE_EMBEDDINGS: 19 | print("Loading base model") 20 | model = BiQwen2.from_pretrained( 21 | "./models/biqwen2-warmup-256-newpad-0e", 22 | torch_dtype=torch.bfloat16, 23 | device_map="cuda", 24 | attn_implementation="flash_attention_2" if torch.cuda.is_available() else None, 25 | ).eval() 26 | 27 | print("Loading processor") 28 | processor = BiQwen2Processor.from_pretrained("./models/biqwen2-warmup-256-newpad-0e") 29 | 30 | if COMPUTE_EMBEDDINGS: 31 | print("Loading images") 32 | print("Images loaded") 33 | 34 | document_set = train_set["train"] 35 | print("Filtering dataset") 36 | print(document_set) 37 | initial_list = document_set["image_filename"] 38 | _, unique_indices = np.unique(initial_list, return_index=True, axis=0) 39 | filtered_dataset = document_set.select(unique_indices.tolist()) 40 | filtered_dataset = filtered_dataset.map( 41 | lambda example: {"image": example["image"], "image_filename": example["image_filename"]}, num_proc=16 42 | ) 43 | # keep only column image and image_filename and source if it exists 44 | cols_to_remove = [col for col in filtered_dataset.column_names if col not in ["image", "image_filename"]] 45 | filtered_dataset = filtered_dataset.remove_columns(cols_to_remove) 46 | # save it 47 | print("Saving filtered dataset") 48 | print(filtered_dataset) 49 | filtered_dataset.save_to_disk("data_dir/filtered_dataset", max_shard_size="200MB") 50 | 51 | print("Processing images") 52 | # run inference - docs 53 | dataloader = DataLoader( 54 | filtered_dataset, 55 | batch_size=8, 56 | shuffle=False, 57 | collate_fn=lambda x: processor.process_images([a["image"] for a in x]), 58 | ) 59 | print("Computing embeddings") 60 | 61 | ds = [] 62 | for batch_doc in tqdm(dataloader): 63 | with torch.no_grad(): 64 | batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} 65 | embeddings_doc = model(**batch_doc) 66 | ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) 67 | 68 | ds = torch.stack(ds) 69 | 70 | # save embeddings 71 | torch.save(ds, "data_dir/filtered_dataset_embeddings.pt") 72 | 73 | if not COMPUTE_EMBEDDINGS: 74 | ds = torch.load("data_dir/filtered_dataset_embeddings.pt") 75 | 76 | 77 | if COMPUTE_HARDNEGS: 78 | # compute hard negatives 79 | ds = cast(torch.Tensor, ds).to("cuda") 80 | 81 | # iterate on the train set 82 | mined_hardnegs = [] 83 | 84 | for i in tqdm(range(0, len(train_set["train"]), 8)): 85 | samples = train_set["train"][i : i + 8] 86 | batch_query = processor.process_queries(samples["query"]) 87 | with torch.no_grad(): 88 | batch_query = {k: v.to(model.device) for k, v in batch_query.items()} 89 | embeddings_query = model(**batch_query) 90 | 91 | # compute scores 92 | scores = torch.einsum("bd,cd->bc", embeddings_query, ds) 93 | # get top 100 indexes 94 | top100 = scores.topk(100, dim=1).indices 95 | # indices to list 96 | top100 = top100.tolist() 97 | # append to mined_hardnegs 98 | mined_hardnegs.extend(top100) 99 | 100 | # save mined hardnegs as txt 101 | with open("data_dir/mined_hardnegs_filtered.txt", "w") as f: 102 | for item in mined_hardnegs: 103 | f.write("%s\n" % item) 104 | 105 | 106 | with open("data_dir/mined_hardnegs_filtered.txt") as f: 107 | mined_hardnegs = f.readlines() 108 | 109 | 110 | filtered_dataset = datasets.load_from_disk("data_dir/filtered_dataset") 111 | filenames = list(filtered_dataset["image_filename"]) 112 | 113 | 114 | def mapper_fn(example, idx): 115 | tmp = { 116 | "negative_passages": [int(x) for x in mined_hardnegs[idx][1:-2].strip().split(",")], 117 | "query": example["query"], 118 | "positive_passages": [filenames.index(example["image_filename"])], 119 | } 120 | 121 | tmp["gold_in_top_100"] = tmp["positive_passages"][0] in tmp["negative_passages"] 122 | # remove gold index from negs if it is there 123 | if tmp["gold_in_top_100"]: 124 | tmp["negative_passages"].remove(tmp["positive_passages"][0]) 125 | return tmp 126 | 127 | 128 | final_dataset = train_set["train"].map(mapper_fn, with_indices=True, num_proc=16) 129 | # drop image 130 | final_dataset = final_dataset.remove_columns("image") 131 | final_dataset.save_to_disk("data_dir/final_dataset") 132 | -------------------------------------------------------------------------------- /scripts/configs/data/debug_data.yaml: -------------------------------------------------------------------------------- 1 | syntheticDocQA_energy: 2 | (): colpali_engine.utils.dataset_transformation.load_eval_set 3 | dataset_path: vidore/syntheticDocQA_energy_test -------------------------------------------------------------------------------- /scripts/configs/data/test_data.yaml: -------------------------------------------------------------------------------- 1 | # eval_dataset_loader: 2 | syntheticDocQA_energy: 3 | (): colpali_engine.utils.dataset_transformation.load_eval_set 4 | dataset_path: !path ../../../data_dir/syntheticDocQA_energy_test 5 | syntheticDocQA_healthcare_industry: 6 | (): colpali_engine.utils.dataset_transformation.load_eval_set 7 | dataset_path: !path ../../../data_dir/syntheticDocQA_healthcare_industry_test 8 | syntheticDocQA_artificial_intelligence_test: 9 | (): colpali_engine.utils.dataset_transformation.load_eval_set 10 | dataset_path: !path ../../../data_dir/syntheticDocQA_artificial_intelligence_test 11 | syntheticDocQA_government_reports: 12 | (): colpali_engine.utils.dataset_transformation.load_eval_set 13 | dataset_path: !path ../../../data_dir/syntheticDocQA_government_reports_test 14 | infovqa_subsampled: 15 | (): colpali_engine.utils.dataset_transformation.load_eval_set 16 | dataset_path: !path ../../../data_dir/infovqa_test_subsampled 17 | docvqa_subsampled: 18 | (): colpali_engine.utils.dataset_transformation.load_eval_set 19 | dataset_path: !path ../../../data_dir/docvqa_test_subsampled 20 | arxivqa_subsampled: 21 | (): colpali_engine.utils.dataset_transformation.load_eval_set 22 | dataset_path: !path ../../../data_dir/arxivqa_test_subsampled 23 | tabfquad_subsampled: 24 | (): colpali_engine.utils.dataset_transformation.load_eval_set 25 | dataset_path: !path ../../../data_dir/tabfquad_test_subsampled 26 | tatdqa: 27 | (): colpali_engine.utils.dataset_transformation.load_eval_set 28 | dataset_path: !path ../../../data_dir/tatdqa_test 29 | shift_project: 30 | (): colpali_engine.utils.dataset_transformation.load_eval_set 31 | dataset_path: !path ../../../data_dir/shiftproject_test -------------------------------------------------------------------------------- /scripts/configs/idefics/train_colsmolvlm_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/colsmolvlm 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.ColIdefics3Processor 7 | pretrained_model_name_or_path: "./models/ColSmolVLM-base" 8 | # num_image_tokens: 2048 9 | # max_length: 50 10 | 11 | model: 12 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 13 | class_to_instanciate: !ext colpali_engine.models.ColIdefics3 14 | pretrained_model_name_or_path: "./models/ColSmolVLM-base" 15 | torch_dtype: !ext torch.bfloat16 16 | # use_cache: false 17 | attn_implementation: "flash_attention_2" 18 | # device_map: "auto" 19 | # quantization_config: 20 | # (): transformers.BitsAndBytesConfig 21 | # load_in_4bit: true 22 | # bnb_4bit_quant_type: "nf4" 23 | # bnb_4bit_compute_dtype: "bfloat16" 24 | # bnb_4bit_use_double_quant: true 25 | 26 | train_dataset: 27 | (): colpali_engine.utils.dataset_transformation.load_train_set 28 | eval_dataset: !import ../data/test_data.yaml 29 | 30 | # max_length: 50 31 | run_eval: true 32 | loss_func: 33 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss 34 | tr_args: 35 | (): transformers.training_args.TrainingArguments 36 | output_dir: null 37 | overwrite_output_dir: true 38 | num_train_epochs: 3 39 | per_device_train_batch_size: 32 40 | gradient_checkpointing: true 41 | gradient_checkpointing_kwargs: { "use_reentrant": false } 42 | # gradient_checkpointing: true 43 | # 6 x 8 gpus = 48 batch size 44 | # gradient_accumulation_steps: 4 45 | per_device_eval_batch_size: 32 46 | eval_strategy: "steps" 47 | dataloader_num_workers: 4 48 | # bf16: true 49 | save_steps: 500 50 | logging_steps: 10 51 | eval_steps: 100 52 | warmup_steps: 100 53 | learning_rate: 5e-4 54 | save_total_limit: 1 55 | # resume_from_checkpoint: true 56 | # optim: "paged_adamw_8bit" 57 | # wandb logging 58 | # wandb_project: "colqwen2" 59 | # run_name: "colqwen2-ba32-nolora" 60 | report_to: "wandb" 61 | 62 | 63 | peft_config: 64 | (): peft.LoraConfig 65 | r: 32 66 | lora_alpha: 32 67 | lora_dropout: 0.1 68 | init_lora_weights: "gaussian" 69 | bias: "none" 70 | task_type: "FEATURE_EXTRACTION" 71 | target_modules: '(.*(model.text_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 72 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 73 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_bipali_all_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/without_tabfquad_no_pairwise/train_bipali_all_mean-3b-mix-448 4 | processor: 5 | () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper 6 | pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" 7 | max_length: 50 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.BiPali 11 | pretrained_model_name_or_path: "./models/paligemma-3b-mix-448-base" 12 | torch_dtype: !ext torch.bfloat16 13 | # device_map: "auto" 14 | # quantization_config: 15 | # (): transformers.BitsAndBytesConfig 16 | # load_in_4bit: true 17 | # bnb_4bit_quant_type: "nf4" 18 | # bnb_4bit_compute_dtype: "bfloat16" 19 | # bnb_4bit_use_double_quant: true 20 | 21 | train_dataset: 22 | (): colpali_engine.utils.dataset_transformation.load_train_set 23 | eval_dataset: !import ../data/test_data.yaml 24 | 25 | max_length: 50 26 | run_eval: true 27 | 28 | loss_func: 29 | (): colpali_engine.loss.bi_encoder_losses.BiEncoderLoss 30 | tr_args: !import ../tr_args/default_tr_args.yaml 31 | peft_config: 32 | (): peft.LoraConfig 33 | r: 32 34 | lora_alpha: 32 35 | lora_dropout: 0.1 36 | init_lora_weights: "gaussian" 37 | bias: "none" 38 | task_type: "FEATURE_EXTRACTION" 39 | target_modules: '(.*(language_model|vision_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(multi_modal_projector\.linear).*$)' 40 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$' 41 | 42 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_bipali_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/right_pad/train_bipali_reproduction 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.BiPaliProcessor 7 | pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.BiPali 11 | pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" 12 | torch_dtype: !ext torch.bfloat16 13 | # device_map: "auto" 14 | # quantization_config: 15 | # (): transformers.BitsAndBytesConfig 16 | # load_in_4bit: true 17 | # bnb_4bit_quant_type: "nf4" 18 | # bnb_4bit_compute_dtype: "bfloat16" 19 | # bnb_4bit_use_double_quant: true 20 | 21 | train_dataset: 22 | (): colpali_engine.utils.dataset_transformation.load_train_set_detailed 23 | eval_dataset: !import ../data/test_data.yaml 24 | 25 | max_length: 50 26 | run_eval: true 27 | 28 | loss_func: 29 | (): colpali_engine.loss.bi_encoder_losses.BiEncoderLoss 30 | tr_args: !import ../tr_args/default_tr_args.yaml 31 | peft_config: 32 | (): peft.LoraConfig 33 | r: 32 34 | lora_alpha: 32 35 | lora_dropout: 0.1 36 | init_lora_weights: "gaussian" 37 | bias: "none" 38 | task_type: "FEATURE_EXTRACTION" 39 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' 40 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$' 41 | 42 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_bipali_pairwise_256_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/right_pad/train_bipali_pairwise_256_pt 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.BiPaliProcessor 7 | pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base" 8 | max_length: 50 9 | model: 10 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 11 | class_to_instanciate: !ext colpali_engine.models.BiPali 12 | pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base" 13 | torch_dtype: !ext torch.bfloat16 14 | attn_implementation: "flash_attention_2" 15 | # use_cache: false 16 | # device_map: "auto" 17 | # quantization_config: 18 | # (): transformers.BitsAndBytesConfig 19 | # load_in_4bit: true 20 | # bnb_4bit_quant_type: "nf4" 21 | # bnb_4bit_compute_dtype: "bfloat16" 22 | # bnb_4bit_use_double_quant: true 23 | 24 | train_dataset: 25 | (): colpali_engine.utils.dataset_transformation.load_train_set 26 | eval_dataset: !import ../data/test_data.yaml 27 | 28 | max_length: 50 29 | run_eval: true 30 | 31 | loss_func: 32 | (): colpali_engine.loss.bi_encoder_losses.BiPairwiseCELoss 33 | tr_args: 34 | (): transformers.training_args.TrainingArguments 35 | output_dir: null 36 | overwrite_output_dir: true 37 | num_train_epochs: 3 38 | per_device_train_batch_size: 64 39 | gradient_checkpointing: true 40 | gradient_checkpointing_kwargs: { "use_reentrant": false } 41 | # 6 x 8 gpus = 48 batch size 42 | # gradient_accumulation_steps: 4 43 | per_device_eval_batch_size: 64 44 | eval_strategy: "steps" 45 | dataloader_num_workers: 8 46 | # bf16: true 47 | save_steps: 500 48 | logging_steps: 10 49 | eval_steps: 100 50 | warmup_steps: 100 51 | learning_rate: 5e-4 52 | save_total_limit: 1 53 | resume_from_checkpoint: false 54 | report_to: "wandb" 55 | peft_config: 56 | (): peft.LoraConfig 57 | r: 32 58 | lora_alpha: 32 59 | lora_dropout: 0.1 60 | init_lora_weights: "gaussian" 61 | bias: "none" 62 | task_type: "FEATURE_EXTRACTION" 63 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' 64 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$' 65 | 66 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_bipali_pairwise_hardneg_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/right_pad/train_bipali_pairwise_hardneg_proj 4 | processor: 5 | () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper 6 | pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" 7 | max_length: 50 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.BiPaliProj 11 | pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" 12 | torch_dtype: !ext torch.bfloat16 13 | # device_map: "auto" 14 | # quantization_config: 15 | # (): transformers.BitsAndBytesConfig 16 | # load_in_4bit: true 17 | # bnb_4bit_quant_type: "nf4" 18 | # bnb_4bit_compute_dtype: "bfloat16" 19 | # bnb_4bit_use_double_quant: true 20 | 21 | train_dataset: 22 | (): colpali_engine.utils.dataset_transformation.load_train_set_ir_negs 23 | eval_dataset: !import ../data/test_data.yaml 24 | 25 | max_length: 50 26 | run_eval: true 27 | 28 | loss_func: 29 | (): colpali_engine.loss.bi_encoder_losses.BiPairwiseNegativeCELoss 30 | in_batch_term: true 31 | tr_args: !import ../tr_args/default_neg_tr_args.yaml 32 | peft_config: 33 | (): peft.LoraConfig 34 | r: 32 35 | lora_alpha: 32 36 | lora_dropout: 0.1 37 | init_lora_weights: "gaussian" 38 | bias: "none" 39 | task_type: "FEATURE_EXTRACTION" 40 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' 41 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$' 42 | 43 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_bipali_pairwise_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/right_pad/train_bipali_pairwise_reproduction 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.BiPaliProcessor 7 | pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" 8 | max_length: 50 9 | model: 10 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 11 | class_to_instanciate: !ext colpali_engine.models.BiPali 12 | pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" 13 | torch_dtype: !ext torch.bfloat16 14 | # device_map: "auto" 15 | # quantization_config: 16 | # (): transformers.BitsAndBytesConfig 17 | # load_in_4bit: true 18 | # bnb_4bit_quant_type: "nf4" 19 | # bnb_4bit_compute_dtype: "bfloat16" 20 | # bnb_4bit_use_double_quant: true 21 | 22 | train_dataset: 23 | (): colpali_engine.utils.dataset_transformation.load_train_set_detailed 24 | eval_dataset: !import ../data/test_data.yaml 25 | 26 | max_length: 50 27 | run_eval: true 28 | 29 | loss_func: 30 | (): colpali_engine.loss.bi_encoder_losses.BiPairwiseCELoss 31 | tr_args: !import ../tr_args/default_tr_args.yaml 32 | peft_config: 33 | (): peft.LoraConfig 34 | r: 32 35 | lora_alpha: 32 36 | lora_dropout: 0.1 37 | init_lora_weights: "gaussian" 38 | bias: "none" 39 | task_type: "FEATURE_EXTRACTION" 40 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' 41 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$' 42 | 43 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_colpali2_pt_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/train_colpali2-3b-pt-448-5e5 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.ColPaliProcessor 7 | pretrained_model_name_or_path: "./models/colpaligemma2-3b-pt-448-base" # "./models/paligemma-3b-mix-448" 8 | max_length: 50 9 | model: 10 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 11 | class_to_instanciate: !ext colpali_engine.models.ColPali 12 | pretrained_model_name_or_path: "./models/colpaligemma2-3b-pt-448-base" 13 | torch_dtype: !ext torch.bfloat16 14 | attn_implementation: "flash_attention_2" 15 | # device_map: "auto" 16 | # quantization_config: 17 | # (): transformers.BitsAndBytesConfig 18 | # load_in_4bit: true 19 | # bnb_4bit_quant_type: "nf4" 20 | # bnb_4bit_compute_dtype: "bfloat16" 21 | # bnb_4bit_use_double_quant: true 22 | 23 | train_dataset: 24 | (): colpali_engine.utils.dataset_transformation.load_train_set 25 | eval_dataset: !import ../data/test_data.yaml 26 | 27 | max_length: 50 28 | run_eval: true 29 | 30 | loss_func: 31 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss 32 | tr_args: 33 | (): transformers.training_args.TrainingArguments 34 | output_dir: null 35 | overwrite_output_dir: true 36 | num_train_epochs: 5 37 | per_device_train_batch_size: 32 38 | gradient_checkpointing: true 39 | gradient_checkpointing_kwargs: { "use_reentrant": false } 40 | # 6 x 8 gpus = 48 batch size 41 | # gradient_accumulation_steps: 4 42 | per_device_eval_batch_size: 32 43 | eval_strategy: "steps" 44 | dataloader_num_workers: 16 45 | # bf16: true 46 | save_steps: 500 47 | logging_steps: 10 48 | eval_steps: 100 49 | warmup_steps: 100 50 | learning_rate: 5e-5 51 | save_total_limit: 1 52 | resume_from_checkpoint: false 53 | report_to: "wandb" 54 | 55 | peft_config: 56 | (): peft.LoraConfig 57 | r: 32 58 | lora_alpha: 32 59 | lora_dropout: 0.1 60 | init_lora_weights: "gaussian" 61 | bias: "none" 62 | task_type: "FEATURE_EXTRACTION" 63 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 64 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 65 | 66 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_colpali_all_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/without_tabfquad_no_pairwise/train_colpali_all-3b-mix-448 4 | processor: 5 | () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper 6 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448" 7 | max_length: 50 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.ColPali 11 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" 12 | torch_dtype: !ext torch.bfloat16 13 | # device_map: "auto" 14 | # quantization_config: 15 | # (): transformers.BitsAndBytesConfig 16 | # load_in_4bit: true 17 | # bnb_4bit_quant_type: "nf4" 18 | # bnb_4bit_compute_dtype: "bfloat16" 19 | # bnb_4bit_use_double_quant: true 20 | 21 | train_dataset: 22 | (): colpali_engine.utils.dataset_transformation.load_train_set 23 | eval_dataset: !import ../data/test_data.yaml 24 | 25 | max_length: 50 26 | run_eval: true 27 | loss_func: 28 | (): colpali_engine.loss.late_interaction_losses.ColbertLoss 29 | tr_args: !import ../tr_args/default_tr_args.yaml 30 | peft_config: 31 | (): peft.LoraConfig 32 | r: 32 33 | lora_alpha: 32 34 | lora_dropout: 0.1 35 | init_lora_weights: "gaussian" 36 | bias: "none" 37 | task_type: "FEATURE_EXTRACTION" 38 | target_modules: '(.*(language_model|vision_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(multi_modal_projector\.linear).*$|.*(custom_text_proj).*$)' 39 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 40 | 41 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_colpali_docmatix_hardneg_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/train_colpali_docmatix_hardneg_ib_3b-mix-448 4 | processor: 5 | () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper 6 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448" 7 | max_length: 50 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.ColPali 11 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" 12 | torch_dtype: !ext torch.bfloat16 13 | # device_map: "auto" 14 | # quantization_config: 15 | # (): transformers.BitsAndBytesConfig 16 | # load_in_4bit: true 17 | # bnb_4bit_quant_type: "nf4" 18 | # bnb_4bit_compute_dtype: "bfloat16" 19 | # bnb_4bit_use_double_quant: true 20 | 21 | train_dataset: 22 | (): colpali_engine.utils.dataset_transformation.load_docmatix_ir_negs 23 | eval_dataset: !import ../data/test_data.yaml 24 | 25 | max_length: 50 26 | run_eval: true 27 | 28 | loss_func: 29 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss 30 | in_batch_term: true 31 | tr_args: !import ../tr_args/default_neg_tr_args.yaml 32 | peft_config: 33 | (): peft.LoraConfig 34 | r: 32 35 | lora_alpha: 32 36 | lora_dropout: 0.1 37 | init_lora_weights: "gaussian" 38 | bias: "none" 39 | task_type: "FEATURE_EXTRACTION" 40 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 41 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 42 | 43 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_colpali_docmatix_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/train_colpali-docmatix-3b-mix-448 4 | processor: 5 | () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper 6 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448" 7 | max_length: 50 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.ColPali 11 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" 12 | torch_dtype: !ext torch.bfloat16 13 | # quantization_config: 14 | # (): transformers.BitsAndBytesConfig 15 | # load_in_4bit: true 16 | # bnb_4bit_quant_type: "nf4" 17 | # bnb_4bit_compute_dtype: "bfloat16" 18 | # bnb_4bit_use_double_quant: true 19 | 20 | train_dataset: 21 | (): colpali_engine.utils.dataset_transformation.load_train_set_with_docmatix 22 | eval_dataset: !import ../data/test_data.yaml 23 | 24 | max_length: 50 25 | run_eval: true 26 | loss_func: 27 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss 28 | tr_args: !import ../tr_args/default_tr_args.yaml 29 | peft_config: 30 | (): peft.LoraConfig 31 | r: 32 32 | lora_alpha: 32 33 | lora_dropout: 0.1 34 | init_lora_weights: "gaussian" 35 | bias: "none" 36 | task_type: "FEATURE_EXTRACTION" 37 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 38 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 39 | 40 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_colpali_hardneg_debug_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/without_tabfquad/train_colpali-3b-mix-448-debug 4 | processor: 5 | () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper 6 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448" 7 | max_length: 50 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.ColPali 11 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" 12 | torch_dtype: !ext torch.bfloat16 13 | # device_map: "auto" 14 | # quantization_config: 15 | # (): transformers.BitsAndBytesConfig 16 | # load_in_4bit: true 17 | # bnb_4bit_quant_type: "nf4" 18 | # bnb_4bit_compute_dtype: "bfloat16" 19 | # bnb_4bit_use_double_quant: true 20 | 21 | train_dataset: 22 | (): colpali_engine.utils.dataset_transformation.load_train_set_ir_negs 23 | eval_dataset: !import ../data/test_data.yaml 24 | 25 | max_length: 50 26 | run_eval: true 27 | 28 | loss_func: 29 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss 30 | in_batch_term: true 31 | tr_args: !import ../tr_args/resume_neg_tr_args.yaml 32 | peft_config: 33 | (): peft.LoraConfig 34 | r: 32 35 | lora_alpha: 32 36 | lora_dropout: 0.1 37 | init_lora_weights: "gaussian" 38 | bias: "none" 39 | task_type: "FEATURE_EXTRACTION" 40 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 41 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 42 | 43 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_colpali_hardneg_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/right_pad/train_colpali_hardneg_long 4 | processor: 5 | () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper 6 | pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base" # "./models/paligemma-3b-mix-448" 7 | max_length: 50 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.ColPali 11 | pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base" 12 | torch_dtype: !ext torch.bfloat16 13 | # device_map: "auto" 14 | # quantization_config: 15 | # (): transformers.BitsAndBytesConfig 16 | # load_in_4bit: true 17 | # bnb_4bit_quant_type: "nf4" 18 | # bnb_4bit_compute_dtype: "bfloat16" 19 | # bnb_4bit_use_double_quant: true 20 | 21 | train_dataset: 22 | (): colpali_engine.utils.dataset_transformation.load_train_set_ir_negs 23 | eval_dataset: !import ../data/test_data.yaml 24 | 25 | max_length: 50 26 | run_eval: true 27 | 28 | loss_func: 29 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss 30 | in_batch_term: true 31 | tr_args: 32 | (): transformers.training_args.TrainingArguments 33 | output_dir: null 34 | overwrite_output_dir: true 35 | num_train_epochs: 5 36 | per_device_train_batch_size: 4 37 | # 6 x 8 gpus = 48 batch size 38 | # gradient_accumulation_steps: 4 39 | per_device_eval_batch_size: 4 40 | eval_strategy: "steps" 41 | # dataloader_num_workers: 8 42 | # bf16: true 43 | save_steps: 500 44 | logging_steps: 10 45 | eval_steps: 50 46 | warmup_steps: 1000 47 | learning_rate: 5e-5 48 | save_total_limit: 1 49 | resume_from_checkpoint: true 50 | # optim: "paged_adamw_8bit" 51 | 52 | peft_config: 53 | (): peft.LoraConfig 54 | r: 32 55 | lora_alpha: 32 56 | lora_dropout: 0.1 57 | init_lora_weights: "gaussian" 58 | bias: "none" 59 | task_type: "FEATURE_EXTRACTION" 60 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 61 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 62 | 63 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_colpali_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/right_pad/train_colpali-3b-mix-448 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.ColPaliProcessor 7 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448" 8 | max_length: 50 9 | model: 10 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 11 | class_to_instanciate: !ext colpali_engine.models.ColPali 12 | pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" 13 | torch_dtype: !ext torch.bfloat16 14 | # device_map: "auto" 15 | # quantization_config: 16 | # (): transformers.BitsAndBytesConfig 17 | # load_in_4bit: true 18 | # bnb_4bit_quant_type: "nf4" 19 | # bnb_4bit_compute_dtype: "bfloat16" 20 | # bnb_4bit_use_double_quant: true 21 | 22 | train_dataset: 23 | (): colpali_engine.utils.dataset_transformation.load_train_set 24 | eval_dataset: !import ../data/test_data.yaml 25 | 26 | max_length: 50 27 | run_eval: true 28 | loss_func: 29 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss 30 | tr_args: !import ../tr_args/default_tr_args.yaml 31 | peft_config: 32 | (): peft.LoraConfig 33 | r: 32 34 | lora_alpha: 32 35 | lora_dropout: 0.1 36 | init_lora_weights: "gaussian" 37 | bias: "none" 38 | task_type: "FEATURE_EXTRACTION" 39 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 40 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 41 | 42 | -------------------------------------------------------------------------------- /scripts/configs/pali/train_colpali_pt_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/right_pad/train_colpali-3b-pt-448 4 | processor: 5 | () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper 6 | pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base" # "./models/paligemma-3b-mix-448" 7 | max_length: 50 8 | model: 9 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 10 | class_to_instanciate: !ext colpali_engine.models.ColPali 11 | pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base" 12 | torch_dtype: !ext torch.bfloat16 13 | # device_map: "auto" 14 | # quantization_config: 15 | # (): transformers.BitsAndBytesConfig 16 | # load_in_4bit: true 17 | # bnb_4bit_quant_type: "nf4" 18 | # bnb_4bit_compute_dtype: "bfloat16" 19 | # bnb_4bit_use_double_quant: true 20 | 21 | train_dataset: 22 | (): colpali_engine.utils.dataset_transformation.load_train_set 23 | eval_dataset: !import ../data/test_data.yaml 24 | 25 | max_length: 50 26 | run_eval: true 27 | 28 | loss_func: 29 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss 30 | tr_args: !import ../tr_args/default_tr_args.yaml 31 | peft_config: 32 | (): peft.LoraConfig 33 | r: 32 34 | lora_alpha: 32 35 | lora_dropout: 0.1 36 | init_lora_weights: "gaussian" 37 | bias: "none" 38 | task_type: "FEATURE_EXTRACTION" 39 | target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 40 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 41 | 42 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/deprecated/train_biqwen2_docmatix_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/biqwen2-docmatix-256 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor 7 | pretrained_model_name_or_path: "./models/colqwen2_base" # "./models/paligemma-3b-mix-448" 8 | # max_length: 50 9 | 10 | model: 11 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 12 | class_to_instanciate: !ext colpali_engine.models.BiQwen2 13 | pretrained_model_name_or_path: "./models/colqwen2_base" 14 | torch_dtype: !ext torch.bfloat16 15 | use_cache: false 16 | attn_implementation: "flash_attention_2" 17 | # device_map: "auto" 18 | # quantization_config: 19 | # (): transformers.BitsAndBytesConfig 20 | # load_in_4bit: true 21 | # bnb_4bit_quant_type: "nf4" 22 | # bnb_4bit_compute_dtype: "bfloat16" 23 | # bnb_4bit_use_double_quant: true 24 | 25 | dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_docmatix_ir_negs 26 | eval_dataset_loader: !import ../data/test_data.yaml 27 | 28 | # max_length: 50 29 | run_eval: true 30 | 31 | loss_func: 32 | (): colpali_engine.loss.bi_encoder_losses.BiPairwiseNegativeCELoss 33 | tr_args: 34 | (): transformers.training_args.TrainingArguments 35 | output_dir: null 36 | overwrite_output_dir: true 37 | num_train_epochs: 1 38 | per_device_train_batch_size: 64 39 | gradient_checkpointing: true 40 | gradient_checkpointing_kwargs: { "use_reentrant": false } 41 | # 6 x 8 gpus = 48 batch size 42 | # gradient_accumulation_steps: 4 43 | per_device_eval_batch_size: 64 44 | eval_strategy: "steps" 45 | dataloader_num_workers: 8 46 | # bf16: true 47 | save_steps: 500 48 | logging_steps: 10 49 | eval_steps: 100 50 | warmup_steps: 100 51 | max_steps: 2000 52 | learning_rate: 5e-5 53 | save_total_limit: 1 54 | # optim: "paged_adamw_8bit" 55 | peft_config: 56 | (): peft.LoraConfig 57 | r: 32 58 | lora_alpha: 32 59 | lora_dropout: 0.1 60 | init_lora_weights: "gaussian" 61 | bias: "none" 62 | task_type: "FEATURE_EXTRACTION" 63 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' 64 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 65 | 66 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/deprecated/train_biqwen2_warmup_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/biqwen2-warmup-256-newpad-0e 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor 7 | pretrained_model_name_or_path: "./models/colqwen2_base" # "./models/paligemma-3b-mix-448" 8 | # max_length: 50 9 | 10 | model: 11 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 12 | class_to_instanciate: !ext colpali_engine.models.BiQwen2 13 | pretrained_model_name_or_path: "./models/qwen2-warmup" 14 | torch_dtype: !ext torch.bfloat16 15 | use_cache: false 16 | attn_implementation: "flash_attention_2" 17 | # device_map: "auto" 18 | # quantization_config: 19 | # (): transformers.BitsAndBytesConfig 20 | # load_in_4bit: true 21 | # bnb_4bit_quant_type: "nf4" 22 | # bnb_4bit_compute_dtype: "bfloat16" 23 | # bnb_4bit_use_double_quant: true 24 | 25 | dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set 26 | eval_dataset_loader: !import ../data/test_data.yaml 27 | 28 | # max_length: 50 29 | run_eval: true 30 | 31 | loss_func: 32 | (): colpali_engine.loss.bi_encoder_losses.BiPairwiseCELoss 33 | tr_args: 34 | (): transformers.training_args.TrainingArguments 35 | output_dir: null 36 | overwrite_output_dir: true 37 | num_train_epochs: 1 38 | per_device_train_batch_size: 64 39 | gradient_checkpointing: true 40 | gradient_checkpointing_kwargs: { "use_reentrant": false } 41 | # 6 x 8 gpus = 48 batch size 42 | # gradient_accumulation_steps: 4 43 | per_device_eval_batch_size: 64 44 | eval_strategy: "steps" 45 | dataloader_num_workers: 8 46 | # bf16: true 47 | save_steps: 500 48 | logging_steps: 10 49 | max_steps: 1 50 | eval_steps: 100 51 | warmup_steps: 100 52 | learning_rate: 5e-5 53 | save_total_limit: 1 54 | resume_from_checkpoint: false 55 | # optim: "paged_adamw_8bit" 56 | peft_config: 57 | (): peft.LoraConfig 58 | r: 32 59 | lora_alpha: 32 60 | lora_dropout: 0.1 61 | init_lora_weights: "gaussian" 62 | bias: "none" 63 | task_type: "FEATURE_EXTRACTION" 64 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' 65 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 66 | 67 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/deprecated/train_colqwen2_docmatix_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/colqwen2-docmatix-ba256-ckpt-1000s 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor 7 | pretrained_model_name_or_path: "./models/colqwen2_base" # "./models/paligemma-3b-mix-448" 8 | # num_image_tokens: 2048 9 | # max_length: 50 10 | 11 | model: 12 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 13 | class_to_instanciate: !ext colpali_engine.models.ColQwen2 14 | pretrained_model_name_or_path: "./models/colqwen2_base" 15 | torch_dtype: !ext torch.bfloat16 16 | use_cache: false 17 | attn_implementation: "flash_attention_2" 18 | # device_map: "auto" 19 | # quantization_config: 20 | # (): transformers.BitsAndBytesConfig 21 | # load_in_4bit: true 22 | # bnb_4bit_quant_type: "nf4" 23 | # bnb_4bit_compute_dtype: "bfloat16" 24 | # bnb_4bit_use_double_quant: true 25 | 26 | dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_docmatix_ir_negs 27 | eval_dataset_loader: !import ../data/test_data.yaml 28 | 29 | # max_length: 50 30 | run_eval: true 31 | 32 | loss_func: 33 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss 34 | tr_args: 35 | (): transformers.training_args.TrainingArguments 36 | output_dir: null 37 | overwrite_output_dir: true 38 | num_train_epochs: 1 39 | per_device_train_batch_size: 64 40 | gradient_checkpointing: true 41 | gradient_checkpointing_kwargs: {"use_reentrant": false} 42 | # 6 x 8 gpus = 48 batch size 43 | # gradient_accumulation_steps: 4 44 | per_device_eval_batch_size: 64 45 | eval_strategy: "steps" 46 | dataloader_num_workers: 8 47 | # bf16: true 48 | save_steps: 500 49 | logging_steps: 10 50 | eval_steps: 500 51 | warmup_steps: 100 52 | max_steps: 1000 53 | learning_rate: 5e-5 54 | save_total_limit: 1 55 | # resume_from_checkpoint: true 56 | # optim: "paged_adamw_8bit" 57 | peft_config: 58 | (): peft.LoraConfig 59 | r: 32 60 | lora_alpha: 32 61 | lora_dropout: 0.1 62 | init_lora_weights: "gaussian" 63 | bias: "none" 64 | task_type: "FEATURE_EXTRACTION" 65 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 66 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 67 | 68 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/deprecated/train_colqwen2_hardneg_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/colqwen2-hardneg-256-5e 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor 7 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 8 | # max_length: 50 9 | 10 | model: 11 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 12 | class_to_instanciate: !ext colpali_engine.models.ColQwen2 13 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 14 | torch_dtype: !ext torch.bfloat16 15 | use_cache: false 16 | attn_implementation: "flash_attention_2" 17 | # device_map: "auto" 18 | # quantization_config: 19 | # (): transformers.BitsAndBytesConfig 20 | # load_in_4bit: true 21 | # bnb_4bit_quant_type: "nf4" 22 | # bnb_4bit_compute_dtype: "bfloat16" 23 | # bnb_4bit_use_double_quant: true 24 | 25 | dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set_ir_negs 26 | eval_dataset_loader: !import ../data/test_data.yaml 27 | 28 | # max_length: 50 29 | run_eval: true 30 | 31 | loss_func: 32 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss 33 | in_batch_term: true 34 | tr_args: 35 | (): transformers.training_args.TrainingArguments 36 | output_dir: null 37 | overwrite_output_dir: true 38 | num_train_epochs: 5 39 | per_device_train_batch_size: 64 40 | gradient_checkpointing: true 41 | gradient_checkpointing_kwargs: {"use_reentrant": false} 42 | # 6 x 8 gpus = 48 batch size 43 | # gradient_accumulation_steps: 4 44 | per_device_eval_batch_size: 64 45 | eval_strategy: "steps" 46 | dataloader_num_workers: 8 47 | # bf16: true 48 | save_steps: 500 49 | logging_steps: 10 50 | eval_steps: 100 51 | warmup_steps: 100 52 | learning_rate: 2e-4 53 | save_total_limit: 1 54 | # resume_from_checkpoint: true 55 | # optim: "paged_adamw_8bit" peft_config: 56 | peft_config: 57 | (): peft.LoraConfig 58 | r: 32 59 | lora_alpha: 32 60 | lora_dropout: 0.1 61 | init_lora_weights: "gaussian" 62 | bias: "none" 63 | task_type: "FEATURE_EXTRACTION" 64 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 65 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 66 | 67 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/deprecated/train_colqwen2_wikiss_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/colqwen2-wikiss-ba64 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor 7 | pretrained_model_name_or_path: "./models/colqwen2_base" # "./models/paligemma-3b-mix-448" 8 | # num_image_tokens: 2048 9 | # max_length: 50 10 | 11 | model: 12 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 13 | class_to_instanciate: !ext colpali_engine.models.ColQwen2 14 | pretrained_model_name_or_path: "./models/colqwen2_base" 15 | torch_dtype: !ext torch.bfloat16 16 | use_cache: false 17 | attn_implementation: "flash_attention_2" 18 | # device_map: "auto" 19 | # quantization_config: 20 | # (): transformers.BitsAndBytesConfig 21 | # load_in_4bit: true 22 | # bnb_4bit_quant_type: "nf4" 23 | # bnb_4bit_compute_dtype: "bfloat16" 24 | # bnb_4bit_use_double_quant: true 25 | 26 | dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_wikiss 27 | eval_dataset_loader: !import ../data/test_data.yaml 28 | 29 | # max_length: 50 30 | run_eval: true 31 | 32 | loss_func: 33 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss 34 | tr_args: 35 | (): transformers.training_args.TrainingArguments 36 | output_dir: null 37 | overwrite_output_dir: true 38 | num_train_epochs: 1 39 | per_device_train_batch_size: 16 40 | gradient_checkpointing: true 41 | # 6 x 8 gpus = 48 batch size 42 | # gradient_accumulation_steps: 4 43 | per_device_eval_batch_size: 16 44 | eval_strategy: "steps" 45 | dataloader_num_workers: 16 46 | # bf16: true 47 | save_steps: 500 48 | logging_steps: 10 49 | eval_steps: 500 50 | warmup_steps: 1000 51 | learning_rate: 5e-4 52 | save_total_limit: 1 53 | # resume_from_checkpoint: true 54 | # optim: "paged_adamw_8bit" 55 | peft_config: 56 | (): peft.LoraConfig 57 | r: 32 58 | lora_alpha: 32 59 | lora_dropout: 0.1 60 | init_lora_weights: "gaussian" 61 | bias: "none" 62 | task_type: "FEATURE_EXTRACTION" 63 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 64 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 65 | 66 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/train_biqwen2_hardneg_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | import torch 6 | from peft import LoraConfig 7 | from transformers import TrainingArguments 8 | 9 | from colpali_engine.loss.bi_encoder_losses import BiNegativeCELoss 10 | from colpali_engine.models import BiQwen2, BiQwen2Processor 11 | from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig 12 | from colpali_engine.utils.dataset_transformation import load_train_set_ir_negs 13 | 14 | config = ColModelTrainingConfig( 15 | output_dir="./models/biqwen2-hardneg-5e-0304", 16 | processor=BiQwen2Processor.from_pretrained( 17 | pretrained_model_name_or_path="./models/base_models/colqwen2-base", 18 | ), 19 | model=BiQwen2.from_pretrained( 20 | pretrained_model_name_or_path="./models/base_models/colqwen2-base", 21 | torch_dtype=torch.bfloat16, 22 | use_cache=False, 23 | attn_implementation="flash_attention_2", 24 | ), 25 | dataset_loading_func=load_train_set_ir_negs, 26 | eval_dataset_loader=None, 27 | run_eval=True, 28 | loss_func=BiNegativeCELoss(in_batch_term=True), 29 | tr_args=TrainingArguments( 30 | output_dir=None, 31 | overwrite_output_dir=True, 32 | num_train_epochs=5, 33 | per_device_train_batch_size=64, 34 | gradient_checkpointing=True, 35 | gradient_checkpointing_kwargs={"use_reentrant": False}, 36 | per_device_eval_batch_size=64, 37 | eval_strategy="steps", 38 | dataloader_num_workers=8, 39 | save_steps=500, 40 | logging_steps=10, 41 | eval_steps=100, 42 | warmup_steps=100, 43 | learning_rate=2e-4, 44 | save_total_limit=1, 45 | ), 46 | peft_config=LoraConfig( 47 | r=32, 48 | lora_alpha=32, 49 | lora_dropout=0.1, 50 | init_lora_weights="gaussian", 51 | bias="none", 52 | task_type="FEATURE_EXTRACTION", 53 | target_modules="(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)", 54 | ), 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | # ensure output_dir exists 60 | os.makedirs(config.output_dir, exist_ok=True) 61 | # version this script by copying it into the output dir 62 | current_script = Path(__file__) 63 | shutil.copy(current_script, Path(config.output_dir) / current_script.name) 64 | 65 | training_app = ColModelTraining(config) 66 | 67 | training_app.train() 68 | training_app.save() 69 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/train_biqwen2_hardneg_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/biqwen2-hardneg-5e-0304 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor 7 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 8 | # max_length: 50 9 | 10 | model: 11 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 12 | class_to_instanciate: !ext colpali_engine.models.BiQwen2 13 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 14 | torch_dtype: !ext torch.bfloat16 15 | use_cache: false 16 | attn_implementation: "flash_attention_2" 17 | # device_map: "auto" 18 | # quantization_config: 19 | # (): transformers.BitsAndBytesConfig 20 | # load_in_4bit: true 21 | # bnb_4bit_quant_type: "nf4" 22 | # bnb_4bit_compute_dtype: "bfloat16" 23 | # bnb_4bit_use_double_quant: true 24 | 25 | train_dataset: 26 | (): colpali_engine.utils.dataset_transformation.load_train_set_ir_negs 27 | eval_dataset: !import ../data/test_data.yaml 28 | 29 | # max_length: 50 30 | run_eval: true 31 | 32 | loss_func: 33 | (): colpali_engine.loss.bi_encoder_losses.BiNegativeCELoss 34 | in_batch_term: true 35 | tr_args: 36 | (): transformers.training_args.TrainingArguments 37 | output_dir: null 38 | overwrite_output_dir: true 39 | num_train_epochs: 5 40 | per_device_train_batch_size: 64 41 | gradient_checkpointing: true 42 | gradient_checkpointing_kwargs: { "use_reentrant": false } 43 | # 6 x 8 gpus = 48 batch size 44 | # gradient_accumulation_steps: 4 45 | per_device_eval_batch_size: 64 46 | eval_strategy: "steps" 47 | dataloader_num_workers: 8 48 | # bf16: true 49 | save_steps: 500 50 | logging_steps: 10 51 | eval_steps: 100 52 | warmup_steps: 100 53 | learning_rate: 2e-4 54 | save_total_limit: 1 55 | # optim: "paged_adamw_8bit" 56 | peft_config: 57 | (): peft.LoraConfig 58 | r: 32 59 | lora_alpha: 32 60 | lora_dropout: 0.1 61 | init_lora_weights: "gaussian" 62 | bias: "none" 63 | task_type: "FEATURE_EXTRACTION" 64 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' 65 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 66 | 67 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/train_biqwen2_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/biqwen2-ba256-5e-0304 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor 7 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 8 | max_num_visual_tokens: 1024 9 | 10 | model: 11 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 12 | class_to_instanciate: !ext colpali_engine.models.BiQwen2 13 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 14 | torch_dtype: !ext torch.bfloat16 15 | use_cache: false 16 | attn_implementation: "flash_attention_2" 17 | 18 | 19 | train_dataset: 20 | (): colpali_engine.utils.dataset_transformation.load_train_set 21 | eval_dataset: !import ../data/test_data.yaml 22 | 23 | # max_length: 50 24 | run_eval: true 25 | 26 | loss_func: 27 | (): colpali_engine.loss.bi_encoder_losses.BiEncoderLoss 28 | tr_args: 29 | (): transformers.training_args.TrainingArguments 30 | output_dir: null 31 | overwrite_output_dir: true 32 | num_train_epochs: 5 33 | per_device_train_batch_size: 64 34 | gradient_checkpointing: true 35 | gradient_checkpointing_kwargs: { "use_reentrant": false } 36 | # 6 x 8 gpus = 48 batch size 37 | # gradient_accumulation_steps: 4 38 | per_device_eval_batch_size: 8 39 | eval_strategy: "steps" 40 | dataloader_num_workers: 4 41 | # bf16: true 42 | save_steps: 500 43 | logging_steps: 10 44 | eval_steps: 100 45 | warmup_steps: 100 46 | learning_rate: 2e-4 47 | save_total_limit: 1 48 | resume_from_checkpoint: false 49 | report_to: "wandb" 50 | 51 | # optim: "paged_adamw_8bit" 52 | peft_config: 53 | (): peft.LoraConfig 54 | r: 32 55 | lora_alpha: 32 56 | lora_dropout: 0.1 57 | init_lora_weights: "gaussian" 58 | bias: "none" 59 | task_type: "FEATURE_EXTRACTION" 60 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' 61 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 62 | 63 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/train_colqwen2_infonce_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/colqwen2-ba256-infonce-5e-0304 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor 7 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 8 | max_num_visual_tokens: 1024 9 | 10 | model: 11 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 12 | class_to_instanciate: !ext colpali_engine.models.ColQwen2 13 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 14 | torch_dtype: !ext torch.bfloat16 15 | use_cache: false 16 | attn_implementation: "flash_attention_2" 17 | # device_map: "auto" 18 | # quantization_config: 19 | # (): transformers.BitsAndBytesConfig 20 | # load_in_4bit: true 21 | # bnb_4bit_quant_type: "nf4" 22 | # bnb_4bit_compute_dtype: "bfloat16" 23 | # bnb_4bit_use_double_quant: true 24 | 25 | train_dataset: 26 | (): colpali_engine.utils.dataset_transformation.load_train_set 27 | eval_dataset: !import ../data/test_data.yaml 28 | 29 | # max_length: 50 30 | run_eval: true 31 | loss_func: 32 | (): colpali_engine.loss.late_interaction_losses.ColbertLoss 33 | tr_args: 34 | (): transformers.training_args.TrainingArguments 35 | output_dir: null 36 | overwrite_output_dir: true 37 | num_train_epochs: 5 38 | per_device_train_batch_size: 64 39 | gradient_checkpointing: true 40 | gradient_checkpointing_kwargs: { "use_reentrant": false } 41 | # gradient_checkpointing: true 42 | # 6 x 8 gpus = 48 batch size 43 | # gradient_accumulation_steps: 4 44 | per_device_eval_batch_size: 8 45 | eval_strategy: "steps" 46 | dataloader_num_workers: 8 47 | # bf16: true 48 | save_steps: 500 49 | logging_steps: 10 50 | eval_steps: 100 51 | warmup_steps: 100 52 | learning_rate: 2e-4 53 | save_total_limit: 1 54 | # resume_from_checkpoint: true 55 | # optim: "paged_adamw_8bit" 56 | # wandb logging 57 | # wandb_project: "colqwen2" 58 | # run_name: "colqwen2-ba32-nolora" 59 | report_to: "wandb" 60 | 61 | 62 | peft_config: 63 | (): peft.LoraConfig 64 | r: 32 65 | lora_alpha: 32 66 | lora_dropout: 0.1 67 | init_lora_weights: "gaussian" 68 | bias: "none" 69 | task_type: "FEATURE_EXTRACTION" 70 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 71 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 72 | 73 | -------------------------------------------------------------------------------- /scripts/configs/qwen2/train_colqwen2_model.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig 3 | output_dir: !path ../../../models/colqwen2-ba256-5e-0304 4 | processor: 5 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 6 | class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor 7 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 8 | max_num_visual_tokens: 1024 9 | 10 | model: 11 | (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper 12 | class_to_instanciate: !ext colpali_engine.models.ColQwen2 13 | pretrained_model_name_or_path: "./models/base_models/colqwen2-base" 14 | torch_dtype: !ext torch.bfloat16 15 | use_cache: false 16 | attn_implementation: "flash_attention_2" 17 | 18 | train_dataset: 19 | (): colpali_engine.utils.dataset_transformation.load_train_set 20 | eval_dataset: !import ../data/test_data.yaml 21 | 22 | # max_length: 50 23 | run_eval: true 24 | loss_func: 25 | (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss 26 | tr_args: 27 | (): transformers.training_args.TrainingArguments 28 | output_dir: null 29 | overwrite_output_dir: true 30 | num_train_epochs: 5 31 | per_device_train_batch_size: 64 32 | gradient_checkpointing: true 33 | gradient_checkpointing_kwargs: { "use_reentrant": false } 34 | # 6 x 8 gpus = 48 batch size 35 | # gradient_accumulation_steps: 4 36 | per_device_eval_batch_size: 8 37 | eval_strategy: "steps" 38 | dataloader_num_workers: 8 39 | # bf16: true 40 | save_steps: 500 41 | logging_steps: 10 42 | eval_steps: 100 43 | warmup_steps: 100 44 | learning_rate: 2e-4 45 | save_total_limit: 1 46 | # resume_from_checkpoint: true 47 | # optim: "paged_adamw_8bit" 48 | # wandb logging 49 | # wandb_project: "colqwen2" 50 | # run_name: "colqwen2-ba32-nolora" 51 | report_to: "wandb" 52 | 53 | 54 | peft_config: 55 | (): peft.LoraConfig 56 | r: 32 57 | lora_alpha: 32 58 | lora_dropout: 0.1 59 | init_lora_weights: "gaussian" 60 | bias: "none" 61 | task_type: "FEATURE_EXTRACTION" 62 | target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 63 | # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' 64 | 65 | -------------------------------------------------------------------------------- /scripts/configs/tr_args/default_neg_tr_args.yaml: -------------------------------------------------------------------------------- 1 | (): transformers.training_args.TrainingArguments 2 | output_dir: null 3 | overwrite_output_dir: true 4 | num_train_epochs: 1 5 | per_device_train_batch_size: 4 6 | # 6 x 8 gpus = 48 batch size 7 | # gradient_accumulation_steps: 4 8 | per_device_eval_batch_size: 4 9 | eval_strategy: "steps" 10 | dataloader_num_workers: 8 11 | # bf16: true 12 | save_steps: 500 13 | logging_steps: 10 14 | eval_steps: 50 15 | warmup_steps: 100 16 | learning_rate: 5e-5 17 | save_total_limit: 1 18 | # optim: "paged_adamw_8bit" 19 | -------------------------------------------------------------------------------- /scripts/configs/tr_args/default_tr_args.yaml: -------------------------------------------------------------------------------- 1 | (): transformers.training_args.TrainingArguments 2 | output_dir: null 3 | overwrite_output_dir: true 4 | num_train_epochs: 1 5 | per_device_train_batch_size: 4 6 | # 6 x 8 gpus = 48 batch size 7 | # gradient_accumulation_steps: 4 8 | per_device_eval_batch_size: 4 9 | eval_strategy: "steps" 10 | # dataloader_num_workers: 8 11 | # bf16: true 12 | save_steps: 500 13 | logging_steps: 10 14 | eval_steps: 100 15 | warmup_steps: 500 16 | learning_rate: 5e-5 17 | save_total_limit: 1 18 | # optim: "paged_adamw_8bit" 19 | -------------------------------------------------------------------------------- /scripts/configs/tr_args/eval_tr_args.yaml: -------------------------------------------------------------------------------- 1 | (): transformers.training_args.TrainingArguments 2 | output_dir: null 3 | overwrite_output_dir: true 4 | num_train_epochs: 1 5 | per_device_train_batch_size: 4 6 | # 6 x 8 gpus = 48 batch size 7 | # gradient_accumulation_steps: 4 8 | per_device_eval_batch_size: 4 9 | max_steps: 10 10 | eval_strategy: "steps" 11 | # dataloader_num_workers: 8 12 | # bf16: true 13 | save_steps: 500 14 | logging_steps: 10 15 | eval_steps: 50 16 | warmup_steps: 100 17 | learning_rate: 5e-5 18 | save_total_limit: 1 19 | # optim: "paged_adamw_8bit" -------------------------------------------------------------------------------- /scripts/configs/tr_args/resume_neg_tr_args.yaml: -------------------------------------------------------------------------------- 1 | (): transformers.training_args.TrainingArguments 2 | output_dir: null 3 | overwrite_output_dir: true 4 | num_train_epochs: 3 5 | per_device_train_batch_size: 4 6 | # 6 x 8 gpus = 48 batch size 7 | # gradient_accumulation_steps: 4 8 | per_device_eval_batch_size: 4 9 | eval_strategy: "steps" 10 | # dataloader_num_workers: 8 11 | # bf16: true 12 | save_steps: 500 13 | logging_steps: 10 14 | eval_steps: 50 15 | warmup_steps: 100 16 | learning_rate: 5e-5 17 | save_total_limit: 1 18 | resume_from_checkpoint: true 19 | # optim: "paged_adamw_8bit" 20 | -------------------------------------------------------------------------------- /scripts/train/train_colbert.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import configue 5 | import typer 6 | 7 | from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig 8 | from colpali_engine.utils.gpu_stats import print_gpu_utilization 9 | 10 | app = typer.Typer(pretty_exceptions_enable=False) 11 | 12 | 13 | @app.command() 14 | def main(config_file: Path) -> None: 15 | """ 16 | Training script for ColVision models. 17 | 18 | Args: 19 | config_file (Path): Path to the configuration file. 20 | """ 21 | print_gpu_utilization() 22 | 23 | print("Loading config") 24 | config = configue.load(config_file, sub_path="config") 25 | 26 | print("Creating Setup") 27 | if isinstance(config, ColModelTrainingConfig): 28 | training_app = ColModelTraining(config) 29 | else: 30 | raise ValueError("Config must be of type ColModelTrainingConfig") 31 | 32 | if config.run_train: 33 | print("Training model") 34 | training_app.train() 35 | training_app.save() 36 | os.system(f"cp {config_file} {training_app.config.output_dir}/training_config.yml") 37 | 38 | print("Done!") 39 | 40 | 41 | if __name__ == "__main__": 42 | app() 43 | -------------------------------------------------------------------------------- /tests/collators/test_visual_retriever_collator.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, cast 2 | 3 | import pytest 4 | from PIL import Image 5 | 6 | from colpali_engine.collators.visual_retriever_collator import VisualRetrieverCollator 7 | from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor 8 | 9 | 10 | class TestColPaliCollator: 11 | @pytest.fixture(scope="class") 12 | def colpali_processor_path(self) -> str: 13 | return "vidore/colpali-v1.2" 14 | 15 | @pytest.fixture(scope="class") 16 | def processor_from_pretrained(self, colpali_processor_path: str) -> Generator[ColPaliProcessor, None, None]: 17 | yield cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(colpali_processor_path)) 18 | 19 | @pytest.fixture(scope="class") 20 | def colpali_collator( 21 | self, processor_from_pretrained: ColPaliProcessor 22 | ) -> Generator[VisualRetrieverCollator, None, None]: 23 | yield VisualRetrieverCollator(processor=processor_from_pretrained) 24 | 25 | def test_colpali_collator_call(self, colpali_collator: VisualRetrieverCollator): 26 | example_image = Image.new("RGB", (16, 16), color="red") 27 | examples = [ 28 | {"query": "What is this?", "pos_target": example_image}, 29 | ] 30 | 31 | result = colpali_collator(examples) 32 | 33 | assert isinstance(result, dict) 34 | assert "doc_input_ids" in result 35 | assert "doc_attention_mask" in result 36 | assert "doc_pixel_values" in result 37 | assert "query_input_ids" in result 38 | assert "query_attention_mask" in result 39 | 40 | def test_colpali_collator_call_with_neg_images(self, colpali_collator: VisualRetrieverCollator): 41 | example_image = Image.new("RGB", (16, 16), color="red") 42 | neg_image = Image.new("RGB", (16, 16), color="blue") 43 | examples = [ 44 | { 45 | "query": "What is this?", 46 | "pos_target": example_image, 47 | "neg_target": neg_image, 48 | }, 49 | ] 50 | 51 | result = colpali_collator(examples) 52 | 53 | assert isinstance(result, dict) 54 | assert "doc_input_ids" in result 55 | assert "doc_attention_mask" in result 56 | assert "doc_pixel_values" in result 57 | assert "query_input_ids" in result 58 | assert "query_attention_mask" in result 59 | assert "neg_doc_input_ids" in result 60 | assert "neg_doc_attention_mask" in result 61 | assert "neg_doc_pixel_values" in result 62 | -------------------------------------------------------------------------------- /tests/compression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illuin-tech/colpali/fbf9dcc70ef591dcadd1aa73ab019e97a60f272a/tests/compression/__init__.py -------------------------------------------------------------------------------- /tests/compression/token_pooling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illuin-tech/colpali/fbf9dcc70ef591dcadd1aa73ab019e97a60f272a/tests/compression/token_pooling/__init__.py -------------------------------------------------------------------------------- /tests/compression/token_pooling/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | 5 | @pytest.fixture 6 | def sample_embedding() -> torch.Tensor: 7 | return torch.tensor( 8 | [ 9 | [1.0, 0.0, 0.0], 10 | [0.0, 1.0, 0.0], 11 | [0.0, 0.0, 1.0], 12 | [1.0, 0.0, 0.0], 13 | [1.0, 0.0, 0.0], 14 | [1.0, 0.0, 0.0], 15 | ], 16 | dtype=torch.float32, 17 | ) # (6, 3) 18 | -------------------------------------------------------------------------------- /tests/compression/token_pooling/example_pooling_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def identity_pool_func(embedding: torch.Tensor) -> torch.Tensor: 5 | """ 6 | Returns the unchanged embeddings. 7 | 8 | Args: 9 | embedding (torch.Tensor): The embeddings to pool (token_length, embedding_dim). 10 | 11 | Returns: 12 | torch.Tensor: The unchanged embeddings (token_length, embedding_dim). 13 | """ 14 | return embedding # (token_length, embedding_dim) 15 | 16 | 17 | def halve_pool_func(embedding: torch.Tensor) -> torch.Tensor: 18 | """ 19 | Pools the embeddings by averaging all pairs of consecutive vectors. 20 | Resulting embedding dimension is half the original dimension. 21 | 22 | Args: 23 | embedding (torch.Tensor): The embeddings to pool (token_length, embedding_dim). 24 | 25 | Returns: 26 | torch.Tensor: The pooled embeddings (half_length, embedding_dim). 27 | """ 28 | token_length = embedding.size(0) 29 | half_length = token_length // 2 + (token_length % 2) 30 | 31 | pooled_embeddings = torch.zeros( 32 | (half_length, embedding.size(1)), 33 | dtype=embedding.dtype, 34 | device=embedding.device, 35 | ) 36 | 37 | for i in range(half_length): 38 | start_idx = i * 2 39 | end_idx = min(start_idx + 2, token_length) 40 | cluster_indices = torch.arange(start_idx, end_idx) # (2) 41 | pooled_embeddings[i] = embedding[cluster_indices].mean(dim=0) # (embedding_dim) 42 | 43 | return pooled_embeddings # (half_length, embedding_dim) 44 | 45 | 46 | def mean_pool_func(embedding: torch.Tensor) -> torch.Tensor: 47 | """ 48 | Pools the embeddings by averaging all the embeddings in the cluster. 49 | 50 | Args: 51 | embedding (torch.Tensor): The embeddings to pool (token_length, embedding_dim). 52 | 53 | Returns: 54 | torch.Tensor: The pooled embeddings (1, embedding_dim). 55 | """ 56 | pooled_embedding = torch.mean(embedding, dim=0, keepdim=True) # (1, embedding_dim) 57 | pooled_embedding = torch.nn.functional.normalize(pooled_embedding, p=2, dim=-1) # (1, embedding_dim) 58 | return pooled_embedding # (1, embedding_dim) 59 | -------------------------------------------------------------------------------- /tests/compression/token_pooling/test_lambda_pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from colpali_engine.compression.token_pooling import LambdaTokenPooler, TokenPoolingOutput 4 | 5 | from .example_pooling_functions import ( 6 | halve_pool_func, 7 | identity_pool_func, 8 | mean_pool_func, 9 | ) 10 | 11 | 12 | def test_lambda_token_pooler_initialization(): 13 | pooler = LambdaTokenPooler(pool_func=identity_pool_func) 14 | assert pooler.pool_func == identity_pool_func 15 | 16 | 17 | def test_lambda_token_pooler_with_identity_function(sample_embedding: torch.Tensor): 18 | pooler = LambdaTokenPooler(pool_func=identity_pool_func) 19 | outputs = pooler.pool_embeddings([sample_embedding], return_dict=True) 20 | 21 | assert isinstance(outputs, TokenPoolingOutput) 22 | assert torch.allclose(outputs.pooled_embeddings[0], sample_embedding) 23 | 24 | 25 | def test_lambda_token_pooler_with_mean_pooling(sample_embedding: torch.Tensor): 26 | pooler = LambdaTokenPooler(pool_func=mean_pool_func) 27 | outputs = pooler.pool_embeddings([sample_embedding], return_dict=True) 28 | 29 | assert isinstance(outputs, TokenPoolingOutput) 30 | assert outputs.pooled_embeddings[0].shape[0] == 1 31 | assert outputs.pooled_embeddings[0].shape[1] == sample_embedding.shape[1] 32 | 33 | 34 | def test_lambda_token_pooler_with_batched_input(): 35 | batch_embeddings = torch.rand(3, 10, 768) 36 | 37 | pooler = LambdaTokenPooler(pool_func=halve_pool_func) 38 | outputs = pooler.pool_embeddings(batch_embeddings, return_dict=True) 39 | 40 | assert isinstance(outputs, TokenPoolingOutput) 41 | assert isinstance(outputs.pooled_embeddings, torch.Tensor) 42 | assert outputs.pooled_embeddings.shape == (3, 5, 768) 43 | 44 | 45 | def test_lambda_token_pooler_with_list_input(): 46 | list_embeddings = [ 47 | torch.rand(10, 768), 48 | torch.rand(15, 768), 49 | ] 50 | 51 | pooler = LambdaTokenPooler(pool_func=halve_pool_func) 52 | outputs = pooler.pool_embeddings(list_embeddings, return_dict=True) 53 | 54 | assert isinstance(outputs, TokenPoolingOutput) 55 | assert len(outputs.pooled_embeddings) == len(list_embeddings) 56 | 57 | # First embedding should be pooled to half size 58 | assert outputs.pooled_embeddings[0].shape[0] == 5 59 | assert outputs.pooled_embeddings[0].shape[1] == 768 60 | 61 | # Second embedding should be pooled to half size (rounded up) 62 | assert outputs.pooled_embeddings[1].shape[0] == 8 63 | assert outputs.pooled_embeddings[1].shape[1] == 768 64 | -------------------------------------------------------------------------------- /tests/data/test_dataset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch.utils.data import Dataset 3 | 4 | from colpali_engine.data import ColPaliEngineDataset, Corpus 5 | 6 | 7 | # --------------------------------------------------------------------------- # 8 | # Helper utilities # 9 | # --------------------------------------------------------------------------- # 10 | class DummyMapDataset(Dataset): 11 | """ 12 | Minimal map‑style dataset that includes a `.take()` method so we can 13 | exercise ColPaliEngineDataset.take() without depending on HF datasets. 14 | """ 15 | 16 | def __init__(self, samples): 17 | self._samples = list(samples) 18 | 19 | def __len__(self): 20 | return len(self._samples) 21 | 22 | def __getitem__(self, idx): 23 | return self._samples[idx] 24 | 25 | def take(self, n): 26 | return DummyMapDataset(self._samples[:n]) 27 | 28 | 29 | # --------------------------------------------------------------------------- # 30 | # Fixtures & samples # 31 | # --------------------------------------------------------------------------- # 32 | @pytest.fixture 33 | def corpus(): 34 | data = [{"doc": f"doc_{i}"} for i in range(3)] 35 | return Corpus(corpus_data=data) 36 | 37 | 38 | @pytest.fixture 39 | def data_no_neg(): 40 | """3 samples – *no* neg_target column at all.""" 41 | return [ 42 | {"query": "q0", "pos_target": 0}, 43 | {"query": "q1", "pos_target": [1]}, 44 | {"query": "q2", "pos_target": [2]}, 45 | ] 46 | 47 | 48 | @pytest.fixture 49 | def data_with_neg(): 50 | """2 samples – every sample has a neg_target column.""" 51 | return [ 52 | {"query": "q0", "pos_target": 1, "neg_target": 0}, 53 | {"query": "q1", "pos_target": [2], "neg_target": [0, 1]}, 54 | ] 55 | 56 | 57 | # --------------------------------------------------------------------------- # 58 | # Tests – NO negatives case # 59 | # --------------------------------------------------------------------------- # 60 | def test_no_negatives_basic(data_no_neg): 61 | ds = ColPaliEngineDataset(data_no_neg) # neg_target_column_name defaults to None 62 | assert len(ds) == 3 63 | 64 | sample = ds[0] 65 | assert sample[ColPaliEngineDataset.QUERY_KEY] == "q0" 66 | assert sample[ColPaliEngineDataset.POS_TARGET_KEY] == [0] 67 | # NEG_TARGET_KEY should be None 68 | assert sample[ColPaliEngineDataset.NEG_TARGET_KEY] is None 69 | 70 | 71 | def test_no_negatives_with_corpus_resolution(data_no_neg, corpus): 72 | ds = ColPaliEngineDataset(data_no_neg, corpus=corpus) 73 | s1 = ds[1] 74 | # pos_target indices 1 should be resolved to the actual doc string 75 | assert s1[ColPaliEngineDataset.POS_TARGET_KEY] == ["doc_1"] 76 | # still no negatives 77 | assert s1[ColPaliEngineDataset.NEG_TARGET_KEY] is None 78 | 79 | 80 | # --------------------------------------------------------------------------- # 81 | # Tests – WITH negatives case # 82 | # --------------------------------------------------------------------------- # 83 | def test_with_negatives_basic(data_with_neg): 84 | ds = ColPaliEngineDataset( 85 | data_with_neg, 86 | neg_target_column_name="neg_target", 87 | ) 88 | assert len(ds) == 2 89 | 90 | s0 = ds[0] 91 | assert s0[ColPaliEngineDataset.POS_TARGET_KEY] == [1] 92 | assert s0[ColPaliEngineDataset.NEG_TARGET_KEY] == [0] 93 | 94 | 95 | def test_with_negatives_and_corpus(data_with_neg, corpus): 96 | ds = ColPaliEngineDataset( 97 | data_with_neg, 98 | corpus=corpus, 99 | neg_target_column_name="neg_target", 100 | ) 101 | s1 = ds[1] 102 | # pos 2 -> "doc_2", negs 0,1 -> "doc_0", "doc_1" 103 | assert s1[ColPaliEngineDataset.POS_TARGET_KEY] == ["doc_2"] 104 | assert s1[ColPaliEngineDataset.NEG_TARGET_KEY] == ["doc_0", "doc_1"] 105 | 106 | 107 | # --------------------------------------------------------------------------- # 108 | # Tests for mixed / inconsistent scenarios # 109 | # --------------------------------------------------------------------------- # 110 | def test_error_if_neg_column_specified_but_missing(data_no_neg): 111 | """All samples must include the column when neg_target_column_name is given.""" 112 | with pytest.raises(AssertionError): 113 | ds = ColPaliEngineDataset( # noqa: F841 114 | data_no_neg, 115 | neg_target_column_name="neg_target", 116 | ) 117 | _ = ds[0] # force __getitem__ 118 | 119 | 120 | def test_error_if_data_mix_neg_and_non_neg(data_with_neg, data_no_neg): 121 | """A mixed dataset (some samples without neg_target) should fail.""" 122 | mixed = data_with_neg + data_no_neg 123 | # The first sample *does* have neg_target, so __init__ succeeds. 124 | ds = ColPaliEngineDataset( 125 | mixed, 126 | neg_target_column_name="neg_target", 127 | ) 128 | # Accessing a sample lacking the column should raise. 129 | with pytest.raises(KeyError): 130 | _ = ds[len(data_with_neg)] # first sample from the 'no_neg' part 131 | 132 | 133 | # --------------------------------------------------------------------------- # 134 | # .take() works in both modes # 135 | # --------------------------------------------------------------------------- # 136 | def test_take_returns_subset(data_no_neg): 137 | wrapped = DummyMapDataset(data_no_neg) 138 | ds = ColPaliEngineDataset(wrapped) 139 | 140 | sub_ds = ds.take(1) 141 | 142 | assert isinstance(sub_ds, ColPaliEngineDataset) 143 | assert len(sub_ds) == 1 144 | # Make sure we can still index 145 | _ = sub_ds[0] 146 | -------------------------------------------------------------------------------- /tests/data/test_sampler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | from colpali_engine.data.sampler import SingleDatasetBatchSampler 6 | 7 | 8 | class DummyDataset(Dataset): 9 | """ 10 | Minimal PyTorch dataset that also supports `.take()`. 11 | The values it returns are irrelevant to the sampler; we only care about length. 12 | """ 13 | 14 | def __init__(self, size: int, start: int = 0): 15 | self._data = list(range(start, start + size)) 16 | 17 | def __len__(self): 18 | return len(self._data) 19 | 20 | def __getitem__(self, idx): 21 | return self._data[idx] 22 | 23 | # Simulate Arrow / HF dataset API used by the sampler 24 | def take(self, total_samples: int): 25 | # Keep the same starting offset so global indices stay monotonic 26 | return DummyDataset(total_samples, start=self._data[0]) 27 | 28 | 29 | # --------------------------------------------------------------------------- # 30 | # Test helpers # 31 | # --------------------------------------------------------------------------- # 32 | def dataset_boundaries(sampler): 33 | """Return a list of (lo, hi) index ranges, one per dataset, in global space.""" 34 | cs = sampler.cumsum_sizes # cumsum has an extra leading 0 35 | return [(cs[i], cs[i + 1]) for i in range(len(cs) - 1)] 36 | 37 | 38 | def which_dataset(idx, boundaries): 39 | """Given a global idx, tell which dataset it belongs to (0‑based).""" 40 | for d, (lo, hi) in enumerate(boundaries): 41 | if lo <= idx < hi: 42 | return d 43 | raise ValueError(f"idx {idx} out of bounds") 44 | 45 | 46 | # --------------------------------------------------------------------------- # 47 | # Tests # 48 | # --------------------------------------------------------------------------- # 49 | def test_basic_iteration_and_len(): 50 | """ 51 | Two datasets, lengths 10 and 6, global batch size 4. 52 | 53 | Both datasets should be truncated (10→8, 6→4). Expect 3 batches. 54 | """ 55 | ds = [DummyDataset(10), DummyDataset(6)] 56 | gen = torch.Generator().manual_seed(123) 57 | sampler = SingleDatasetBatchSampler(ds, global_batch_size=4, generator=gen) 58 | 59 | batches = list(iter(sampler)) 60 | 61 | # 1) __len__ matches actual number of batches 62 | assert len(batches) == len(sampler) == 3 63 | 64 | # 2) All samples are unique and count equals truncated total 65 | flat = [i for b in batches for i in b] 66 | assert len(flat) == len(set(flat)) == 12 # 8 + 4 67 | 68 | # 3) Every batch is exactly global_batch_size long 69 | assert all(len(b) == 4 for b in batches) 70 | 71 | 72 | def test_single_dataset_per_batch(): 73 | """ 74 | Ensure that every yielded batch contains indices drawn from 75 | *one—and only one—dataset*. 76 | """ 77 | ds = [DummyDataset(8), DummyDataset(8), DummyDataset(16)] 78 | sampler = SingleDatasetBatchSampler(ds, global_batch_size=4, generator=torch.Generator()) 79 | 80 | boundaries = dataset_boundaries(sampler) 81 | 82 | for batch in sampler: 83 | d0 = which_dataset(batch[0], boundaries) 84 | # All indices in the batch must map to the same dataset ID 85 | assert all(which_dataset(i, boundaries) == d0 for i in batch) 86 | 87 | 88 | def test_epoch_based_reshuffle_changes_order(): 89 | """ 90 | Calling set_epoch should reshuffle the internal order so that 91 | consecutive epochs produce different batch orderings. 92 | """ 93 | ds = [DummyDataset(8), DummyDataset(8)] 94 | gen = torch.Generator().manual_seed(999) 95 | sampler = SingleDatasetBatchSampler(ds, global_batch_size=4, generator=gen) 96 | 97 | first_epoch = list(iter(sampler)) 98 | 99 | sampler.set_epoch(1) 100 | second_epoch = list(iter(sampler)) 101 | 102 | # Pure order comparison; contents are the same but order should differ 103 | assert first_epoch != second_epoch 104 | 105 | # Same epoch again → deterministic repeat 106 | sampler.set_epoch(1) 107 | repeat_epoch = list(iter(sampler)) 108 | assert second_epoch == repeat_epoch 109 | 110 | 111 | @pytest.mark.parametrize( 112 | "lengths,batch_size,expected_batches", 113 | [ 114 | ([12], 4, 3), # single dataset, perfect fit 115 | ([13], 4, 3), # single dataset, truncated down 116 | ([7, 9], 4, 3), # truncates both 117 | ([4, 4, 4], 4, 3), # multiple, exact fit 118 | ], 119 | ) 120 | def test_len_property_various_lengths(lengths, batch_size, expected_batches): 121 | datasets = [DummyDataset(n) for n in lengths] 122 | sampler = SingleDatasetBatchSampler(datasets, global_batch_size=batch_size) 123 | assert len(sampler) == expected_batches 124 | -------------------------------------------------------------------------------- /tests/interpretability/test_similarity_maps.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import torch 5 | from matplotlib import pyplot as plt 6 | from PIL import Image 7 | 8 | from colpali_engine.interpretability.similarity_maps import plot_all_similarity_maps, plot_similarity_map 9 | 10 | 11 | @pytest.fixture 12 | def sample_image() -> Image.Image: 13 | return Image.new("RGB", (100, 100), color="white") 14 | 15 | 16 | @pytest.fixture 17 | def sample_query_tokens() -> List[str]: 18 | return ["token1", "token2"] 19 | 20 | 21 | def test_plot_similarity_map(sample_image): 22 | similarity_map = torch.rand(32, 32) 23 | fig, ax = plot_similarity_map( 24 | image=sample_image, 25 | similarity_map=similarity_map, 26 | figsize=(5, 5), 27 | show_colorbar=True, 28 | ) 29 | 30 | assert isinstance(fig, plt.Figure) 31 | assert isinstance(ax, plt.Axes) 32 | 33 | 34 | def test_plot_all_similarity_maps(sample_image, sample_query_tokens): 35 | similarity_maps = torch.rand(len(sample_query_tokens), 32, 32) 36 | 37 | plots = plot_all_similarity_maps( 38 | image=sample_image, 39 | query_tokens=sample_query_tokens, 40 | similarity_maps=similarity_maps, 41 | figsize=(5, 5), 42 | show_colorbar=False, 43 | add_title=True, 44 | ) 45 | 46 | assert len(plots) == len(sample_query_tokens) 47 | for fig, ax in plots: 48 | assert isinstance(fig, plt.Figure) 49 | assert isinstance(ax, plt.Axes) 50 | -------------------------------------------------------------------------------- /tests/models/idefics3/colidefics3/test_processing_colidefics3.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, cast 2 | 3 | import pytest 4 | import torch 5 | from PIL import Image 6 | 7 | from colpali_engine.models import ColIdefics3Processor 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def model_name() -> str: 12 | return "vidore/colSmol-256M" 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def processor_from_pretrained(model_name: str) -> Generator[ColIdefics3Processor, None, None]: 17 | yield cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name)) 18 | 19 | 20 | def test_load_processor_from_pretrained(processor_from_pretrained: ColIdefics3Processor): 21 | assert isinstance(processor_from_pretrained, ColIdefics3Processor) 22 | 23 | 24 | def test_process_images(processor_from_pretrained: ColIdefics3Processor): 25 | # Create a dummy image 26 | image_size = (16, 32) 27 | image = Image.new("RGB", image_size, color="black") 28 | images = [image] 29 | 30 | # Process the image 31 | batch_feature = processor_from_pretrained.process_images(images) 32 | 33 | # Assertions 34 | assert "pixel_values" in batch_feature 35 | assert batch_feature["pixel_values"].shape == torch.Size([1, 9, 3, 512, 512]) 36 | 37 | 38 | def test_process_images_with_context(processor_from_pretrained: ColIdefics3Processor): 39 | # Create a dummy image 40 | image_size = (16, 32) 41 | image = Image.new("RGB", image_size, color="black") 42 | contexts = ["Open source is the best!"] 43 | images = [image] 44 | 45 | # Process the image 46 | batch_feature = processor_from_pretrained.process_images(images, context_prompts=contexts) 47 | 48 | # Assertions 49 | assert "pixel_values" in batch_feature 50 | assert batch_feature["pixel_values"].shape == torch.Size([1, 9, 3, 512, 512]) 51 | 52 | 53 | def test_process_queries(processor_from_pretrained: ColIdefics3Processor): 54 | queries = [ 55 | "Is attention really all you need?", 56 | "Are Benjamin, Antoine, Merve, and Jo best friends?", 57 | ] 58 | 59 | # Process the queries 60 | batch_encoding = processor_from_pretrained.process_queries(queries) 61 | 62 | # Assertions 63 | assert "input_ids" in batch_encoding 64 | assert isinstance(batch_encoding["input_ids"], torch.Tensor) 65 | assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) 66 | -------------------------------------------------------------------------------- /tests/models/paligemma/colpali/test_processing_colpali.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, cast 2 | 3 | import pytest 4 | import torch 5 | from PIL import Image 6 | 7 | from colpali_engine.models import ColPaliProcessor 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def model_name() -> str: 12 | return "vidore/colpali-v1.2" 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def processor_from_pretrained(model_name: str) -> Generator[ColPaliProcessor, None, None]: 17 | yield cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) 18 | 19 | 20 | def test_load_processor_from_pretrained(processor_from_pretrained: ColPaliProcessor): 21 | assert isinstance(processor_from_pretrained, ColPaliProcessor) 22 | 23 | 24 | def test_process_images(processor_from_pretrained: ColPaliProcessor): 25 | # Create a dummy image 26 | image_size = (16, 32) 27 | image = Image.new("RGB", image_size, color="black") 28 | images = [image] 29 | 30 | # Process the image 31 | batch_feature = processor_from_pretrained.process_images(images) 32 | 33 | # Assertions 34 | assert "pixel_values" in batch_feature 35 | assert batch_feature["pixel_values"].shape == torch.Size([1, 3, 448, 448]) 36 | 37 | 38 | def test_process_images_with_context(processor_from_pretrained: ColPaliProcessor): 39 | # Create a dummy image 40 | image_size = (16, 32) 41 | image = Image.new("RGB", image_size, color="black") 42 | contexts = ["Open source is the best!"] 43 | images = [image] 44 | 45 | # Process the image 46 | batch_feature = processor_from_pretrained.process_images(images, context_prompts=contexts) 47 | 48 | # Assertions 49 | assert "pixel_values" in batch_feature 50 | assert batch_feature["pixel_values"].shape == torch.Size([1, 3, 448, 448]) 51 | 52 | 53 | def test_process_queries(processor_from_pretrained: ColPaliProcessor): 54 | queries = [ 55 | "Is attention really all you need?", 56 | "Are Benjamin, Antoine, Merve, and Jo best friends?", 57 | ] 58 | 59 | # Process the queries 60 | batch_encoding = processor_from_pretrained.process_queries(queries) 61 | 62 | # Assertions 63 | assert "input_ids" in batch_encoding 64 | assert isinstance(batch_encoding["input_ids"], torch.Tensor) 65 | assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) 66 | -------------------------------------------------------------------------------- /tests/models/qwen2/colqwen2/test_modeling_colqwen2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Generator, cast 3 | 4 | import pytest 5 | import torch 6 | from datasets import load_dataset 7 | from PIL import Image 8 | from transformers.utils.import_utils import is_flash_attn_2_available 9 | 10 | from colpali_engine.models import ColQwen2, ColQwen2Processor 11 | from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def model_name() -> str: 18 | return "vidore/colqwen2-v1.0" 19 | 20 | 21 | @pytest.fixture(scope="module") 22 | def model(model_name: str) -> Generator[ColQwen2, None, None]: 23 | device = get_torch_device("auto") 24 | logger.info(f"Device used: {device}") 25 | 26 | yield cast( 27 | ColQwen2, 28 | ColQwen2.from_pretrained( 29 | model_name, 30 | torch_dtype=torch.bfloat16, 31 | device_map=device, 32 | attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, 33 | ).eval(), 34 | ) 35 | tear_down_torch() 36 | 37 | 38 | @pytest.fixture(scope="module") 39 | def processor(model_name: str) -> Generator[ColQwen2Processor, None, None]: 40 | yield cast(ColQwen2Processor, ColQwen2Processor.from_pretrained(model_name)) 41 | 42 | 43 | class TestColQwen2Model: 44 | @pytest.mark.slow 45 | def test_load_model_from_pretrained(self, model: ColQwen2): 46 | assert isinstance(model, ColQwen2) 47 | 48 | 49 | class TestColQwen2ModelIntegration: 50 | @pytest.mark.slow 51 | def test_forward_images_integration( 52 | self, 53 | model: ColQwen2, 54 | processor: ColQwen2Processor, 55 | ): 56 | # Create a batch of dummy images 57 | images = [ 58 | Image.new("RGB", (64, 64), color="white"), 59 | Image.new("RGB", (32, 32), color="black"), 60 | ] 61 | 62 | # Process the image 63 | batch_images = processor.process_images(images).to(model.device) 64 | 65 | # Forward pass 66 | model.mask_non_image_embeddings = False 67 | with torch.no_grad(): 68 | outputs = model(**batch_images) 69 | 70 | # Assertions 71 | assert isinstance(outputs, torch.Tensor) 72 | assert outputs.dim() == 3 73 | batch_size, n_visual_tokens, emb_dim = outputs.shape 74 | assert batch_size == len(images) 75 | assert emb_dim == model.dim 76 | 77 | @pytest.mark.slow 78 | def test_forward_images_with_context_integration( 79 | self, 80 | model: ColQwen2, 81 | processor: ColQwen2Processor, 82 | ): 83 | # Create a batch of dummy images 84 | images = [ 85 | Image.new("RGB", (64, 64), color="white"), 86 | Image.new("RGB", (32, 32), color="black"), 87 | ] 88 | 89 | contexts = [ 90 | "Open source is the best!", 91 | "I love to code!", 92 | ] 93 | 94 | # Process the image 95 | batch_images = processor.process_images(images, context_prompts=contexts).to(model.device) 96 | 97 | # Forward pass 98 | model.mask_non_image_embeddings = True 99 | with torch.no_grad(): 100 | outputs = model(**batch_images) 101 | 102 | # Assertions 103 | assert isinstance(outputs, torch.Tensor) 104 | assert outputs.dim() == 3 105 | batch_size, n_visual_tokens, emb_dim = outputs.shape 106 | assert batch_size == len(images) 107 | assert emb_dim == model.dim 108 | 109 | @pytest.mark.slow 110 | def test_forward_queries_integration( 111 | self, 112 | model: ColQwen2, 113 | processor: ColQwen2Processor, 114 | ): 115 | queries = [ 116 | "Is attention really all you need?", 117 | "Are Benjamin, Antoine, Merve, and Jo best friends?", 118 | ] 119 | 120 | # Process the queries 121 | batch_queries = processor.process_queries(queries).to(model.device) 122 | 123 | # Forward pass 124 | with torch.no_grad(): 125 | outputs = model(**batch_queries) 126 | 127 | # Assertions 128 | assert isinstance(outputs, torch.Tensor) 129 | assert outputs.dim() == 3 130 | batch_size, n_query_tokens, emb_dim = outputs.shape 131 | assert batch_size == len(queries) 132 | assert emb_dim == model.dim 133 | 134 | @pytest.mark.slow 135 | def test_retrieval_integration( 136 | self, 137 | model: ColQwen2, 138 | processor: ColQwen2Processor, 139 | ): 140 | # Load the test dataset 141 | ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") 142 | 143 | # Preprocess the examples 144 | batch_images = processor.process_images(images=ds["image"]).to(model.device) 145 | batch_queries = processor.process_queries(queries=ds["query"]).to(model.device) 146 | 147 | # Run inference 148 | with torch.inference_mode(): 149 | image_embeddings = model(**batch_images) 150 | query_embeddings = model(**batch_queries) 151 | 152 | # Compute retrieval scores 153 | scores = processor.score_multi_vector( 154 | qs=query_embeddings, 155 | ps=image_embeddings, 156 | ) # (len(qs), len(ps)) 157 | 158 | assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}" 159 | assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" 160 | 161 | # Check if the maximum scores per row are in the diagonal of the matrix score 162 | assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all() 163 | -------------------------------------------------------------------------------- /tests/models/qwen2/colqwen2/test_processing_colqwen2.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, cast 2 | 3 | import pytest 4 | import torch 5 | from PIL import Image 6 | 7 | from colpali_engine.models import ColQwen2Processor 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def model_name() -> str: 12 | return "vidore/colqwen2-v1.0" 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def processor_from_pretrained(model_name: str) -> Generator[ColQwen2Processor, None, None]: 17 | yield cast(ColQwen2Processor, ColQwen2Processor.from_pretrained(model_name)) 18 | 19 | 20 | def test_load_processor_from_pretrained(processor_from_pretrained: ColQwen2Processor): 21 | assert isinstance(processor_from_pretrained, ColQwen2Processor) 22 | 23 | 24 | def test_process_images(processor_from_pretrained: ColQwen2Processor): 25 | # Create a dummy image 26 | image_size = (64, 32) 27 | image = Image.new("RGB", image_size, color="black") 28 | images = [image] 29 | 30 | # Process the image 31 | batch_feature = processor_from_pretrained.process_images(images) 32 | 33 | # Assertions 34 | assert "pixel_values" in batch_feature 35 | assert isinstance(batch_feature["pixel_values"], torch.Tensor) 36 | assert batch_feature["pixel_values"].shape[0] == 1 37 | assert batch_feature["pixel_values"].shape[-1] == 1176 38 | 39 | 40 | def test_process_images_with_context(processor_from_pretrained: ColQwen2Processor): 41 | # Create a dummy image 42 | image_size = (64, 32) 43 | image = Image.new("RGB", image_size, color="black") 44 | images = [image] 45 | contexts = ["Open source is the best!"] 46 | 47 | # Process the image 48 | batch_feature = processor_from_pretrained.process_images(images, context_prompts=contexts) 49 | 50 | # Assertions 51 | assert "pixel_values" in batch_feature 52 | assert isinstance(batch_feature["pixel_values"], torch.Tensor) 53 | assert batch_feature["pixel_values"].shape[0] == 1 54 | assert batch_feature["pixel_values"].shape[-1] == 1176 55 | 56 | 57 | def test_process_queries(processor_from_pretrained: ColQwen2Processor): 58 | queries = [ 59 | "Is attention really all you need?", 60 | "Are Benjamin, Antoine, Merve, and Jo best friends?", 61 | ] 62 | 63 | # Process the queries 64 | batch_encoding = processor_from_pretrained.process_queries(queries) 65 | 66 | # Assertions 67 | assert "input_ids" in batch_encoding 68 | assert isinstance(batch_encoding["input_ids"], torch.Tensor) 69 | assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) 70 | -------------------------------------------------------------------------------- /tests/models/qwen2_5/colqwen2_5/test_processing_colqwen2_5.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, cast 2 | 3 | import pytest 4 | import torch 5 | from PIL import Image 6 | 7 | from colpali_engine.models import ColQwen2_5_Processor 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def model_name() -> str: 12 | return "vidore/colqwen2.5-v0.2" 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def processor_from_pretrained(model_name: str) -> Generator[ColQwen2_5_Processor, None, None]: 17 | yield cast(ColQwen2_5_Processor, ColQwen2_5_Processor.from_pretrained(model_name)) 18 | 19 | 20 | def test_load_processor_from_pretrained(processor_from_pretrained: ColQwen2_5_Processor): 21 | assert isinstance(processor_from_pretrained, ColQwen2_5_Processor) 22 | 23 | 24 | def test_process_images(processor_from_pretrained: ColQwen2_5_Processor): 25 | # Create a dummy image 26 | image_size = (64, 32) 27 | image = Image.new("RGB", image_size, color="black") 28 | images = [image] 29 | 30 | # Process the image 31 | batch_feature = processor_from_pretrained.process_images(images) 32 | 33 | # Assertions 34 | assert "pixel_values" in batch_feature 35 | assert isinstance(batch_feature["pixel_values"], torch.Tensor) 36 | assert batch_feature["pixel_values"].shape[0] == 1 37 | assert batch_feature["pixel_values"].shape[-1] == 1176 38 | 39 | 40 | def test_process_images_with_context(processor_from_pretrained: ColQwen2_5_Processor): 41 | # Create a dummy image 42 | image_size = (64, 32) 43 | image = Image.new("RGB", image_size, color="black") 44 | images = [image] 45 | contexts = ["Open source is the best!"] 46 | 47 | # Process the image 48 | batch_feature = processor_from_pretrained.process_images(images, context_prompts=contexts) 49 | 50 | # Assertions 51 | assert "pixel_values" in batch_feature 52 | assert isinstance(batch_feature["pixel_values"], torch.Tensor) 53 | assert batch_feature["pixel_values"].shape[0] == 1 54 | assert batch_feature["pixel_values"].shape[-1] == 1176 55 | 56 | 57 | def test_process_queries(processor_from_pretrained: ColQwen2_5_Processor): 58 | queries = [ 59 | "Is attention really all you need?", 60 | "Are Benjamin, Antoine, Merve, and Jo best friends?", 61 | ] 62 | 63 | # Process the queries 64 | batch_encoding = processor_from_pretrained.process_queries(queries) 65 | 66 | # Assertions 67 | assert "input_ids" in batch_encoding 68 | assert isinstance(batch_encoding["input_ids"], torch.Tensor) 69 | assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) 70 | -------------------------------------------------------------------------------- /tests/utils/test_processing_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor 4 | 5 | EMBEDDING_DIM = 32 6 | 7 | 8 | def test_score_single_vector_embeddings(): 9 | qs = [torch.randn(EMBEDDING_DIM) for _ in range(4)] 10 | ps = [torch.randn(EMBEDDING_DIM) for _ in range(8)] 11 | scores = BaseVisualRetrieverProcessor.score_single_vector(qs, ps) 12 | assert scores.shape == (len(qs), len(ps)) 13 | 14 | 15 | def test_score_multi_vector_embeddings(): 16 | # Score from list input 17 | qs = [ 18 | torch.randn(2, EMBEDDING_DIM), 19 | torch.randn(4, EMBEDDING_DIM), 20 | ] 21 | ps = [ 22 | torch.randn(8, EMBEDDING_DIM), 23 | torch.randn(4, EMBEDDING_DIM), 24 | torch.randn(16, EMBEDDING_DIM), 25 | ] 26 | scores_from_list_input = BaseVisualRetrieverProcessor.score_multi_vector(qs, ps) 27 | assert scores_from_list_input.shape == (len(qs), len(ps)) 28 | 29 | # Score from tensor input 30 | qs_padded = torch.nn.utils.rnn.pad_sequence(qs, batch_first=True) 31 | ps_padded = torch.nn.utils.rnn.pad_sequence(ps, batch_first=True) 32 | scores_from_tensor = BaseVisualRetrieverProcessor.score_multi_vector(qs_padded, ps_padded) 33 | assert scores_from_tensor.shape == (len(qs), len(ps)) 34 | 35 | assert torch.allclose(scores_from_list_input, scores_from_tensor), "Scores from list and tensor inputs should match" 36 | -------------------------------------------------------------------------------- /tests/utils/test_torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from colpali_engine.utils.torch_utils import unbind_padded_multivector_embeddings 4 | 5 | 6 | def test_unbind_padded_multivector_embedding_left(): 7 | # Inputs: 8 | # Sequence 1: 4 tokens with 2 left-padding rows. 9 | seq1 = torch.tensor( 10 | [ 11 | [0.0, 0.0], # padding 12 | [0.0, 0.0], # padding 13 | [1.0, 2.0], 14 | [3.0, 4.0], 15 | ], 16 | dtype=torch.float32, 17 | ) 18 | # Sequence 2: 4 tokens with 1 left-padding row. 19 | seq2 = torch.tensor( 20 | [ 21 | [0.0, 0.0], # padding 22 | [5.0, 6.0], 23 | [7.0, 8.0], 24 | [9.0, 10.0], 25 | ], 26 | dtype=torch.float32, 27 | ) 28 | padded_tensor = torch.stack([seq1, seq2], dim=0) # shape: (2, 4, 2) 29 | 30 | # Call the unbind function. 31 | unbound = unbind_padded_multivector_embeddings(padded_tensor, padding_value=0.0, padding_side="left") 32 | 33 | # Expected outputs: 34 | expected_seq1 = torch.tensor( 35 | [ 36 | [1.0, 2.0], 37 | [3.0, 4.0], 38 | ], 39 | dtype=torch.float32, 40 | ) 41 | expected_seq2 = torch.tensor( 42 | [ 43 | [5.0, 6.0], 44 | [7.0, 8.0], 45 | [9.0, 10.0], 46 | ], 47 | dtype=torch.float32, 48 | ) 49 | 50 | assert len(unbound) == 2 51 | assert torch.allclose(unbound[0], expected_seq1) 52 | assert torch.allclose(unbound[1], expected_seq2) 53 | 54 | 55 | def test_unbind_padded_multivector_embedding_right(): 56 | # Inputs: 57 | # Sequence 1: valid tokens then padding. 58 | seq1 = torch.tensor( 59 | [ 60 | [1.0, 2.0], 61 | [3.0, 4.0], 62 | [5.0, 6.0], 63 | [0.0, 0.0], # padding 64 | ], 65 | dtype=torch.float32, 66 | ) 67 | # Sequence 2: valid tokens then padding. 68 | seq2 = torch.tensor( 69 | [ 70 | [7.0, 8.0], 71 | [9.0, 10.0], 72 | [11.0, 12.0], 73 | [0.0, 0.0], # padding 74 | ], 75 | dtype=torch.float32, 76 | ) 77 | padded_tensor = torch.stack([seq1, seq2], dim=0) # shape: (2, 4, 2) 78 | 79 | # Call the unbind function. 80 | unbound = unbind_padded_multivector_embeddings(padded_tensor, padding_value=0.0, padding_side="right") 81 | 82 | # Expected outputs: 83 | expected_seq1 = torch.tensor( 84 | [ 85 | [1.0, 2.0], 86 | [3.0, 4.0], 87 | [5.0, 6.0], 88 | ], 89 | dtype=torch.float32, 90 | ) 91 | expected_seq2 = torch.tensor( 92 | [ 93 | [7.0, 8.0], 94 | [9.0, 10.0], 95 | [11.0, 12.0], 96 | ], 97 | dtype=torch.float32, 98 | ) 99 | 100 | assert len(unbound) == 2 101 | assert torch.allclose(unbound[0], expected_seq1) 102 | assert torch.allclose(unbound[1], expected_seq2) 103 | --------------------------------------------------------------------------------