├── bentocolpali ├── __init__.py ├── interfaces.py ├── utils.py ├── models.py └── service.py ├── requirements-dev.txt ├── requirements.txt ├── assets └── colpali_architecture.webp ├── pytest.ini ├── ruff.toml ├── bentofile.yaml ├── CITATION.cff ├── tests ├── conftest.py ├── test_unit.py └── test_integration.py ├── .gitignore └── README.md /bentocolpali/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-asyncio 3 | ruff 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bentoml>=1.4.8,<1.5 2 | colpali-engine>=0.3.0,<0.4.0 3 | pydantic>=2.8,<3 4 | -------------------------------------------------------------------------------- /assets/colpali_architecture.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentoml/BentoColPali/HEAD/assets/colpali_architecture.webp -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::Warning 4 | markers = 5 | slow: marks test as slow 6 | testpaths = 7 | tests 8 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 120 2 | 3 | [lint] 4 | select = ["E", "F", "W", "I", "N"] 5 | 6 | [lint.per-file-ignores] 7 | "__init__.py" = ["F401"] 8 | -------------------------------------------------------------------------------- /bentofile.yaml: -------------------------------------------------------------------------------- 1 | service: "bentocolpali.service:ColPaliService" 2 | labels: 3 | project: colpali 4 | include: 5 | - "bentocolpali/**/*.py" 6 | python: 7 | requirements_txt: "requirements.txt" 8 | lock_packages: false 9 | docker: 10 | cuda_version: "12.1" 11 | -------------------------------------------------------------------------------- /bentocolpali/interfaces.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class ImageDetail(Enum): 7 | """ 8 | Supported image details for `ImagePayload`. 9 | """ 10 | 11 | LOW = "low" 12 | HIGH = "high" 13 | AUTO = "auto" 14 | 15 | 16 | class ImagePayload(BaseModel): 17 | """ 18 | Image payload for `ColPaliService`. 19 | 20 | Args: 21 | url: The URL of the image or a base 64 encoded image. 22 | detail: The detail of the image. Defaults to `ImageURLDetail.AUTO`. 23 | 24 | NOTE: `detail` is unused for now and only here to be ISO with the OpenAI API. 25 | """ 26 | 27 | url: str 28 | detail: ImageDetail = ImageDetail.AUTO 29 | -------------------------------------------------------------------------------- /bentocolpali/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | 4 | from PIL import Image 5 | 6 | 7 | def is_url(val: str) -> bool: 8 | return val.startswith("http") 9 | 10 | 11 | def convert_b64_to_pil_image(b64_image: str) -> Image.Image: 12 | """ 13 | Convert a base64 image string to a PIL Image. 14 | """ 15 | if b64_image.startswith("data:image"): 16 | b64_image = b64_image.split(",")[1] 17 | 18 | try: 19 | image_data = base64.b64decode(b64_image) 20 | image_bytes = BytesIO(image_data) 21 | image = Image.open(image_bytes) 22 | except Exception as e: 23 | raise ValueError("Failed to convert base64 string to PIL Image") from e 24 | 25 | return image 26 | 27 | 28 | def convert_pil_to_b64_image(image: Image.Image) -> str: 29 | """ 30 | Convert a PIL Image to a base64 string. 31 | """ 32 | image_bytes = BytesIO() 33 | image.save(image_bytes, format="PNG") 34 | image_base64 = base64.b64encode(image_bytes.getvalue()).decode("utf-8") 35 | return f"data:image/png;base64,{image_base64}" 36 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Faysse" 5 | given-names: "Manuel" 6 | email: "manuel.faysse@illuin.tech" 7 | - family-names: "Sibille" 8 | given-names: "Hugues" 9 | email: "hugues.sibille@illuin.tech" 10 | - family-names: "Wu" 11 | given-names: "Tony" 12 | email: "tony.wu@illuin.tech" 13 | title: "Vision Document Retrieval (ViDoRe): Benchmark" 14 | date-released: 2024-06-26 15 | url: "https://github.com/illuin-tech/vidore-benchmark" 16 | preferred-citation: 17 | type: article 18 | authors: 19 | - family-names: "Faysse" 20 | given-names: "Manuel" 21 | - family-names: "Sibille" 22 | given-names: "Hugues" 23 | - family-names: "Wu" 24 | given-names: "Tony" 25 | - family-names: "Omrani" 26 | given-names: "Bilel" 27 | - family-names: "Viaud" 28 | given-names: "Gautier" 29 | - family-names: "Hudelot" 30 | given-names: "Céline" 31 | - family-names: "Colombo" 32 | given-names: "Pierre" 33 | doi: "arXiv.2407.01449" 34 | month: 6 35 | title: "ColPali: Efficient Document Retrieval with Vision Language Models" 36 | year: 2024 37 | url: "https://arxiv.org/abs/2407.01449" 38 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Any, Dict, List 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from bentocolpali.service import ImagePayload 9 | 10 | PAYLOAD_TEST_PATH = "./tests/data/query_test.json" 11 | assert Path(PAYLOAD_TEST_PATH).exists(), f"File not found: {PAYLOAD_TEST_PATH}" 12 | 13 | COLPALI_EMBEDDING_DIM = 128 14 | 15 | 16 | def load_queries_and_images(file_path: str) -> Dict[str, Any]: 17 | with open(file_path, "r") as file: 18 | data = json.load(file) 19 | return data 20 | 21 | 22 | @pytest.fixture(scope="module") 23 | def queries() -> List[str]: 24 | data = load_queries_and_images(PAYLOAD_TEST_PATH) 25 | return data.get("queries", []) 26 | 27 | 28 | @pytest.fixture(scope="module") 29 | def images() -> List[Dict[str, Any]]: 30 | data = load_queries_and_images(PAYLOAD_TEST_PATH) 31 | return data.get("images", []) 32 | 33 | 34 | @pytest.fixture(scope="module") 35 | def image_payloads(images: List[Dict[str, Any]]) -> List[ImagePayload]: 36 | return [ImagePayload(**elt) for elt in images] 37 | 38 | 39 | @pytest.fixture(scope="module") 40 | def random_image_embeddings() -> np.ndarray: 41 | return np.random.rand(5, 32, COLPALI_EMBEDDING_DIM) 42 | 43 | 44 | @pytest.fixture(scope="module") 45 | def random_query_embeddings() -> np.ndarray: 46 | return np.random.rand(3, 16, COLPALI_EMBEDDING_DIM) 47 | -------------------------------------------------------------------------------- /bentocolpali/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use this script to build the models for the ColPali BentoService. 3 | """ 4 | 5 | import argparse 6 | import os 7 | from typing import Optional, cast 8 | 9 | import bentoml 10 | import torch 11 | from colpali_engine.models import ColPali, ColPaliProcessor 12 | from colpali_engine.utils.torch_utils import get_torch_device 13 | from peft import LoraConfig, PeftModel 14 | 15 | 16 | def build_models(model_name: str, hf_token: Optional[str] = None) -> None: 17 | """ 18 | Build the model and the preprocessor for the ColPali BentoService. 19 | """ 20 | if hf_token is None: 21 | hf_token = os.getenv("HF_TOKEN") 22 | if hf_token is None: 23 | raise ValueError("HF token is required.") 24 | 25 | with bentoml.models.create(name="colpali_model") as model_ref: 26 | lora_config = LoraConfig.from_pretrained(model_name) 27 | 28 | cast( 29 | PeftModel, 30 | PeftModel.from_pretrained( 31 | ColPali.from_pretrained( 32 | lora_config.base_model_name_or_path, 33 | torch_dtype=torch.bfloat16, 34 | device_map=get_torch_device("auto"), 35 | token=hf_token, 36 | ), 37 | model_name, 38 | ), 39 | ).merge_and_unload().save_pretrained(model_ref.path) 40 | 41 | print("Model successfully built.") 42 | 43 | cast( 44 | ColPaliProcessor, 45 | ColPaliProcessor.from_pretrained( 46 | pretrained_model_name_or_path=model_name, 47 | token=hf_token, 48 | ), 49 | ).save_pretrained(model_ref.path) 50 | 51 | print("Preprocessor successfully built.") 52 | 53 | return 54 | 55 | 56 | def parse_arguments() -> argparse.Namespace: 57 | parser = argparse.ArgumentParser(description="Build the models for the ColPali BentoService.") 58 | parser.add_argument( 59 | "--model-name", 60 | type=str, 61 | required=True, 62 | help="Name or path of the model to build.", 63 | ) 64 | parser.add_argument( 65 | "--hf-token", 66 | type=str, 67 | required=False, 68 | help="Hugging Face token for model access. Defaults to the HF_TOKEN environment variable.", 69 | ) 70 | return parser.parse_args() 71 | 72 | 73 | def main(): 74 | args = parse_arguments() 75 | build_models(model_name=args.model_name, hf_token=args.hf_token) 76 | print("ColPali models were successfully built.") 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /tests/test_unit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the `ColPaliService` BentoML service. 3 | """ 4 | 5 | from typing import List 6 | from unittest.mock import AsyncMock, Mock, patch 7 | 8 | import numpy as np 9 | import pytest 10 | import torch 11 | 12 | from bentocolpali.service import ColPaliService, ImagePayload 13 | from tests.conftest import COLPALI_EMBEDDING_DIM 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_embed_images(image_payloads: List[ImagePayload]): 18 | service = ColPaliService() 19 | 20 | mock_output = torch.rand(len(image_payloads), 32, COLPALI_EMBEDDING_DIM, dtype=torch.float32, device="cpu") 21 | 22 | with patch.object(service.model, "forward", new=Mock(return_value=mock_output)) as mock_service_method: 23 | image_embeddings = await service.embed_images(image_payloads) 24 | mock_service_method.assert_called_once() 25 | 26 | assert isinstance(image_embeddings, np.ndarray) 27 | assert image_embeddings.ndim == 3 28 | assert image_embeddings.shape == mock_output.shape 29 | 30 | 31 | @pytest.mark.asyncio 32 | async def test_embed_queries(queries: List[str]): 33 | service = ColPaliService() 34 | 35 | mock_output = torch.rand(len(queries), 32, COLPALI_EMBEDDING_DIM, dtype=torch.float32, device="cpu") 36 | 37 | with patch.object(service.model, "forward", new=Mock(return_value=mock_output)) as mock_service_method: 38 | query_embeddings = await service.embed_queries(queries) 39 | mock_service_method.assert_called_once() 40 | 41 | assert isinstance(query_embeddings, np.ndarray) 42 | assert query_embeddings.ndim == 3 43 | assert query_embeddings.shape == mock_output.shape 44 | 45 | 46 | @pytest.mark.asyncio 47 | async def test_score_embeddings( 48 | random_image_embeddings: np.ndarray, 49 | random_query_embeddings: np.ndarray, 50 | ): 51 | service = ColPaliService() 52 | 53 | scores = await service.score_embeddings( 54 | image_embeddings=random_image_embeddings, 55 | query_embeddings=random_query_embeddings, 56 | ) 57 | 58 | assert isinstance(scores, np.ndarray) 59 | assert scores.ndim == 2 60 | assert scores.shape == (len(random_query_embeddings), len(random_image_embeddings)) 61 | 62 | 63 | @pytest.mark.asyncio 64 | async def test_score( 65 | image_payloads: List[ImagePayload], 66 | queries: List[str], 67 | ): 68 | service = ColPaliService() 69 | 70 | mock_embed_images = AsyncMock(return_value=np.random.rand(len(image_payloads), 32, COLPALI_EMBEDDING_DIM)) 71 | mock_embed_queries = AsyncMock(return_value=np.random.rand(len(queries), 16, COLPALI_EMBEDDING_DIM)) 72 | 73 | with patch.object(service, "embed_images", new=mock_embed_images): 74 | with patch.object(service, "embed_queries", new=mock_embed_queries): 75 | scores = await service.score(images=image_payloads, queries=queries) 76 | 77 | assert isinstance(scores, np.ndarray) 78 | assert scores.ndim == 2 79 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from typing import List 3 | 4 | import bentoml 5 | import numpy as np 6 | import pytest 7 | 8 | from bentocolpali.service import ImagePayload 9 | from tests.conftest import COLPALI_EMBEDDING_DIM 10 | 11 | BENTOML_PORT = "50001" 12 | LOCAL_BENTOML_URL = f"http://localhost:{BENTOML_PORT}" 13 | 14 | 15 | class TestColPaliServiceIntegration: 16 | """ 17 | Integration tests for the ColPali. 18 | 19 | To run these tests, first you need to build the BentoService: 20 | 21 | ```bash 22 | bentoml build -f bentofile.yaml 23 | ``` 24 | """ 25 | 26 | @pytest.mark.slow 27 | def test_embed_images_route(self, image_payloads: List[ImagePayload]): 28 | with subprocess.Popen( 29 | ["bentoml", "serve", "bentocolpali.service:ColPaliService", "-p", BENTOML_PORT] 30 | ) as server_proc: 31 | try: 32 | with bentoml.SyncHTTPClient(LOCAL_BENTOML_URL) as client: 33 | embeddings = client.embed_images(image_payloads) 34 | 35 | assert embeddings.ndim == 3 36 | assert embeddings.shape[0] == len(image_payloads) 37 | assert embeddings.shape[2] == COLPALI_EMBEDDING_DIM 38 | 39 | finally: 40 | server_proc.terminate() 41 | 42 | @pytest.mark.slow 43 | def test_embed_queries_route(self, queries: List[str]): 44 | with subprocess.Popen( 45 | ["bentoml", "serve", "bentocolpali.service:ColPaliService", "-p", BENTOML_PORT] 46 | ) as server_proc: 47 | try: 48 | with bentoml.SyncHTTPClient(LOCAL_BENTOML_URL) as client: 49 | embeddings = client.embed_queries(queries) 50 | 51 | assert embeddings.ndim == 3 52 | assert embeddings.shape[0] == len(queries) 53 | assert embeddings.shape[2] == COLPALI_EMBEDDING_DIM 54 | 55 | finally: 56 | server_proc.terminate() 57 | 58 | @pytest.mark.slow 59 | def test_score_embeddings_route( 60 | self, 61 | random_image_embeddings: np.ndarray, 62 | random_query_embeddings: np.ndarray, 63 | ): 64 | with subprocess.Popen( 65 | ["bentoml", "serve", "bentocolpali.service:ColPaliService", "-p", BENTOML_PORT] 66 | ) as server_proc: 67 | try: 68 | with bentoml.SyncHTTPClient(LOCAL_BENTOML_URL) as client: 69 | scores = client.score_embeddings( 70 | image_embeddings=random_image_embeddings, 71 | query_embeddings=random_query_embeddings, 72 | ) 73 | 74 | assert scores.ndim == 2 75 | assert scores.shape[0] == random_query_embeddings.shape[0] 76 | assert scores.shape[1] == random_image_embeddings.shape[0] 77 | 78 | finally: 79 | server_proc.terminate() 80 | 81 | @pytest.mark.slow 82 | def test_score_route( 83 | self, 84 | image_payloads: List[ImagePayload], 85 | queries: List[str], 86 | ): 87 | with subprocess.Popen( 88 | ["bentoml", "serve", "bentocolpali.service:ColPaliService", "-p", BENTOML_PORT] 89 | ) as server_proc: 90 | try: 91 | with bentoml.SyncHTTPClient(LOCAL_BENTOML_URL) as client: 92 | scores = client.score(images=image_payloads, queries=queries) 93 | 94 | assert scores.ndim == 2 95 | assert scores.shape[0] == len(queries) 96 | assert scores.shape[1] == len(image_payloads) 97 | 98 | # Check if the maximum scores per row are in the diagonal of the matrix score 99 | assert (scores.argmax(axis=1) == range(len(queries))).all() 100 | 101 | finally: 102 | server_proc.terminate() 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | .DS_Store 3 | /.vscode/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | notebooks/*.png 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 115 | .pdm.toml 116 | .pdm-python 117 | .pdm-build/ 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /bentocolpali/service.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Annotated, List, cast 4 | 5 | import bentoml 6 | import numpy as np 7 | import torch 8 | from annotated_types import MinLen 9 | from colpali_engine.models import ColPali, ColPaliProcessor 10 | from colpali_engine.utils.torch_utils import get_torch_device 11 | from PIL import Image 12 | 13 | from .interfaces import ImagePayload 14 | from .utils import convert_b64_to_pil_image, is_url 15 | 16 | logger = logging.getLogger("bentoml") 17 | 18 | 19 | def create_model_pipeline(path: str) -> ColPali: 20 | return cast( 21 | ColPali, 22 | ColPali.from_pretrained( 23 | pretrained_model_name_or_path=path, 24 | torch_dtype=torch.bfloat16, 25 | device_map=get_torch_device("auto"), 26 | local_files_only=True, 27 | low_cpu_mem_usage=True, 28 | ), 29 | ).eval() 30 | 31 | 32 | def create_processor_pipeline(path: str) -> ColPaliProcessor: 33 | return cast( 34 | ColPaliProcessor, 35 | ColPaliProcessor.from_pretrained( 36 | pretrained_model_name_or_path=path, 37 | local_files_only=True, 38 | ), 39 | ) 40 | 41 | 42 | @bentoml.service( 43 | name="colpali", 44 | workers=1, 45 | traffic={"concurrency": 64}, 46 | ) 47 | class ColPaliService: 48 | """ 49 | ColPali service for embedding images and queries, and scoring them. 50 | Provides batch processing capabilities. 51 | 52 | NOTE: You need to build the model using `bentoml.models.create(name="colpali_model")` before using this service. 53 | """ 54 | 55 | _model_ref: bentoml.Model = bentoml.models.get("colpali_model") 56 | 57 | def __init__(self) -> None: 58 | self.model: ColPali = create_model_pipeline(path=self._model_ref.path) 59 | self.processor: ColPaliProcessor = create_processor_pipeline(path=self._model_ref.path) 60 | logger.info(f"ColPali loaded on device: {self.model.device}") 61 | 62 | @bentoml.api( 63 | batchable=True, 64 | batch_dim=(0, 0), 65 | max_batch_size=64, 66 | max_latency_ms=30_000, 67 | ) 68 | async def embed_images( 69 | self, 70 | items: List[ImagePayload], 71 | ) -> np.ndarray: 72 | """ 73 | Generate image embeddings of shape (batch_size, sequence_length, embedding_dim). 74 | """ 75 | 76 | images: List[Image.Image] = [] 77 | 78 | for item in items: 79 | if is_url(item.url): 80 | raise NotImplementedError("URLs are not supported.") 81 | images += [convert_b64_to_pil_image(item.url)] 82 | 83 | batch_images = self.processor.process_images(images).to(self.model.device) 84 | 85 | with torch.inference_mode(): 86 | image_embeddings = self.model(**batch_images) 87 | 88 | return image_embeddings.cpu().to(torch.float32).detach().numpy() 89 | 90 | @bentoml.api( 91 | batchable=True, 92 | batch_dim=(0, 0), 93 | max_batch_size=64, 94 | max_latency_ms=30_000, 95 | ) 96 | async def embed_queries( 97 | self, 98 | items: List[str], 99 | ) -> np.ndarray: 100 | """ 101 | Generate query embeddings of shape (batch_size, sequence_length, embedding_dim). 102 | """ 103 | batch_queries = self.processor.process_queries(items).to(self.model.device) 104 | 105 | with torch.inference_mode(): 106 | query_embeddings = self.model(**batch_queries) 107 | 108 | return query_embeddings.cpu().to(torch.float32).detach().numpy() 109 | 110 | @bentoml.api 111 | async def score_embeddings( 112 | self, 113 | image_embeddings: np.ndarray, 114 | query_embeddings: np.ndarray, 115 | ) -> np.ndarray: 116 | """ 117 | Returns the late-interaction/MaxSim scores of shape (num_queries, num_images). 118 | 119 | Args: 120 | image_embeddings: The image embeddings of shape (num_images, sequence_length, embedding_dim). 121 | query_embeddings: The query embeddings of shape (num_queries, sequence_length, embedding_dim). 122 | """ 123 | 124 | image_embeddings_torch: List[torch.Tensor] = [torch.Tensor(x) for x in image_embeddings] 125 | query_embeddings_torch: List[torch.Tensor] = [torch.Tensor(x) for x in query_embeddings] 126 | 127 | return ( 128 | self.processor.score( 129 | qs=query_embeddings_torch, 130 | ps=image_embeddings_torch, 131 | ) 132 | .cpu() 133 | .to(torch.float32) 134 | .numpy() 135 | ) 136 | 137 | @bentoml.api 138 | async def score( 139 | self, 140 | images: Annotated[List[ImagePayload], MinLen(1)], 141 | queries: Annotated[List[str], MinLen(1)], 142 | ) -> np.ndarray: 143 | """ 144 | Returns the late-interaction/MaxSim scores of the queries against the images. 145 | """ 146 | 147 | image_embeddings, query_embeddings = await asyncio.gather( 148 | self.embed_images(images), 149 | self.embed_queries(queries), 150 | ) 151 | 152 | return await self.score_embeddings( 153 | image_embeddings=image_embeddings, 154 | query_embeddings=query_embeddings, 155 | ) 156 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |