├── .gitignore ├── .vscode └── settings.json ├── Justfile ├── LICENSE ├── README.md ├── examples └── vector_search │ ├── Attention_is_all_you_need.pdf │ └── vector_search.ipynb ├── midrasai ├── __init__.py ├── _abc.py ├── _constants.py ├── cli.py ├── client │ ├── __init__.py │ └── main.py ├── local │ ├── __init__.py │ ├── main.py │ └── server.py ├── types.py └── vectordb │ ├── __init__.py │ └── _qdrant.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── assets ├── Attention_is_all_you_need.pdf ├── Colpali-example1.png ├── Colpali-example2.png └── Colpali-example3.png ├── test_local_midras.py ├── test_midras_client.py ├── test_midras_server.py └── vectordb └── test_qdrant.py /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache/ 2 | __pycache__/ 3 | .venv/ 4 | .ruff_cache/ 5 | dist/ 6 | .env 7 | .aider* 8 | .vscode -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "basic" 3 | } -------------------------------------------------------------------------------- /Justfile: -------------------------------------------------------------------------------- 1 | default: 2 | @just -l 3 | 4 | # Install development dependencies 5 | install: 6 | @poetry install --all-extras --no-cache 7 | 8 | # Run formatter & linter 9 | tidy: 10 | @poetry run ruff check 11 | @poetry run ruff format 12 | 13 | # Publish to pypi 14 | publish: 15 | @poetry build 16 | @poetry publish 17 | 18 | # Run all tests 19 | test-all: 20 | @poetry run pytest -q 21 | 22 | # Run tests for local Midras 23 | test-local: 24 | @poetry run pytest -v tests/test_local_midras.py 25 | 26 | # Run tests for Midras server 27 | test-server: 28 | @poetry run pytest -v tests/test_midras_server.py 29 | 30 | # Run tests for Midras client 31 | test-client: 32 | @poetry run pytest -v tests/test_midras_client.py 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Anibal Angulo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MidrasAI 2 | 3 | MidrasAI provides a simple API for using the Colpali model, which is a multi-modal model for text and image retrieval. 4 | It allows for local access to the model, and integrates a vector database for efficient storage and sematic search. 5 | 6 | Setting up the model as a server for remote access is a WIP. 7 | 8 | ## Getting started 9 | 10 | Note: This is an alpha version of MidrasAI. All feedack and suggestions are welcome! 11 | 12 | ### Local Dependencies 13 | 14 | - ColPali access: ColPali is based on PaliGemma, you will need to request access to the model [here](https://huggingface.co/google/paligemma-3b-mix-448). Then you must authenticate through the huggingface-cli to download the model. 15 | - Poppler: Midras uses `pdf2image` to convert pdfs to images. This library requires `poppler` to be installed on your system. Check out the installation instructions [here](https://poppler.freedesktop.org/). 16 | - Hardware: ColPali is a 3B parmeter model, so I recommend using a GPU with at least 8GB of VRAM. 17 | 18 | ### Installation 19 | 20 | If running locally, you can install MidrasAI and its dependencies with pip, poetry, or uv: 21 | ```bash 22 | # pip 23 | pip install 'midrasai[local]' 24 | 25 | # poetry 26 | poetry install 'midrasai[local]' 27 | 28 | # uv 29 | uv install 'midrasai[local]' 30 | ``` 31 | 32 | ### Usage 33 | 34 | #### Starting the ColPali model 35 | 36 | To load the ColPali model locally, you just need to use the `LocalMidras` class: 37 | 38 | ```python3 39 | from midrasai.local import LocalMidras 40 | 41 | midras = LocalMidras() # Make sure you're logged in to HuggingFace so you can download the model 42 | ``` 43 | 44 | #### Creating an index 45 | 46 | To create an index, you can use the `create_index` method with the name of the index you want to create: 47 | ```python3 48 | midras.create_index("my_index") 49 | ``` 50 | 51 | #### Using the model to embed data 52 | 53 | The Midras class provides a couple of convenience methods for embeding data. 54 | You can use the `embed_pdf` method to embed a single pdf, or the `embed_pil_images` method to embed a list of images. Here's how to use them: 55 | 56 | ```python3 57 | # Embed a single pdf 58 | path_to_pdf = "path/to/pdf.pdf" 59 | 60 | pdf_response = midras.embed_pdf(path_to_pdf, include_images=True) 61 | ``` 62 | 63 | ```python3 64 | # Embed a list of images 65 | images = [Image.open("path/to/image.png"), Image.open("path/to/another_image.png")] 66 | 67 | image_response = midras.embed_pil_images(images) 68 | ``` 69 | 70 | #### Inserting data into an index 71 | 72 | Once you have your data embeddings, you can insert a data point into your index with the `add_point` method: 73 | 74 | ```python3 75 | midras.add_point( 76 | index="my_index", # name of the index you want to add to 77 | id=1, # id of this data point, can be any integer or string 78 | embedding=response.embeddings[0], # the embedding you created in the previous step 79 | data={ # any additional data you want to store with this point, can be any dictionary 80 | "something": "hi" 81 | "something_else": 123 82 | } 83 | ) 84 | ``` 85 | 86 | ### Searching an index 87 | 88 | After you've added data to your index, you can start searching for relevant data. You can use the `query` method to do this: 89 | 90 | ```python3 91 | query = "What is the meaing of life?" 92 | 93 | results = midras.query(index_name, query=query) 94 | 95 | # Top 3 relevant data points 96 | for result in results[:3]: 97 | # Each result will have a score, which is a measure of how relevant the data is to the query 98 | print(f"score: {result.score}") 99 | # Each result will also have any additional data you stored with it 100 | print(f"data: {result.data}") 101 | ``` 102 | 103 | If you want a more detailed example including RAG, check out the [example vector search notebook](https://github.com/Midras-AI-Systems/midrasai/blob/main/examples/vector_search/vector_search.ipynb). 104 | -------------------------------------------------------------------------------- /examples/vector_search/Attention_is_all_you_need.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajac-zero/midrasai/9580bd665f9648015561aef89b5592165b9885e3/examples/vector_search/Attention_is_all_you_need.pdf -------------------------------------------------------------------------------- /midrasai/__init__.py: -------------------------------------------------------------------------------- 1 | from midrasai.client import AsyncMidras, Midras 2 | 3 | __all__ = ["Midras", "AsyncMidras"] 4 | -------------------------------------------------------------------------------- /midrasai/_abc.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from base64 import b64encode 3 | from io import BytesIO 4 | from typing import Any, Awaitable 5 | 6 | from midrasai.types import ColBERT, MidrasResponse, Mode, QueryResult 7 | 8 | 9 | class BaseMidras(ABC): 10 | @abstractmethod 11 | def embed_pdf( 12 | self, pdf: str | bytes, batch_size: int = 10, include_images: bool = False 13 | ) -> MidrasResponse: ... 14 | 15 | @abstractmethod 16 | def embed_images( 17 | self, images: list, mode: Mode = Mode.Standard 18 | ) -> MidrasResponse: ... 19 | 20 | @abstractmethod 21 | def embed_queries( 22 | self, queries: list[str], mode: Mode = Mode.Standard 23 | ) -> MidrasResponse: ... 24 | 25 | @abstractmethod 26 | def create_index(self, name: str) -> bool | Awaitable[bool]: ... 27 | 28 | @abstractmethod 29 | def add_point( 30 | self, index: str, id: str | int, embedding: ColBERT, data: dict[str, Any] 31 | ) -> Any: ... 32 | 33 | @abstractmethod 34 | def query(self, index: str, query: str, quantity: int = 5) -> list[QueryResult]: ... 35 | 36 | def base64_encode_image_list(self, pil_images: list) -> list[str]: 37 | base64_images = [] 38 | for image in pil_images: 39 | with BytesIO() as buffer: 40 | image.save(buffer, format=image.format) 41 | base64_image = b64encode(buffer.getvalue()).decode("utf-8") 42 | base64_images.append(base64_image) 43 | return base64_images 44 | 45 | 46 | class AsyncBaseMidras(ABC): 47 | @abstractmethod 48 | async def embed_pdf( 49 | self, pdf: str | bytes, batch_size: int = 10, include_images: bool = False 50 | ) -> MidrasResponse: ... 51 | 52 | @abstractmethod 53 | async def embed_images( 54 | self, images: list, mode: Mode = Mode.Standard 55 | ) -> MidrasResponse: ... 56 | 57 | @abstractmethod 58 | async def embed_queries( 59 | self, queries: list[str], mode: Mode = Mode.Standard 60 | ) -> MidrasResponse: ... 61 | 62 | @abstractmethod 63 | async def create_index(self, name: str) -> bool | Awaitable[bool]: ... 64 | 65 | @abstractmethod 66 | async def add_point( 67 | self, index: str, id: str | int, embedding: ColBERT, data: dict[str, Any] 68 | ) -> Any: ... 69 | 70 | @abstractmethod 71 | async def query( 72 | self, index: str, query: str, quantity: int = 5 73 | ) -> list[QueryResult]: ... 74 | 75 | def base64_encode_image_list(self, pil_images: list) -> list[str]: 76 | base64_images = [] 77 | for image in pil_images: 78 | with BytesIO() as buffer: 79 | image.save(buffer, format=image.format) 80 | base64_image = b64encode(buffer.getvalue()).decode("utf-8") 81 | base64_images.append(base64_image) 82 | return base64_images 83 | 84 | 85 | class VectorDB(ABC): 86 | @abstractmethod 87 | def create_index(self, name: str) -> bool: ... 88 | 89 | @abstractmethod 90 | def create_point( 91 | self, id: int | str, embedding: ColBERT, data: dict[str, Any] 92 | ) -> Any: ... 93 | 94 | @abstractmethod 95 | def save_points(self, index: str, points: list[Any]) -> Any: ... 96 | 97 | @abstractmethod 98 | def delete_index(self, name: str) -> bool: ... 99 | 100 | @abstractmethod 101 | def search( 102 | self, index: str, query_vector: ColBERT, quantity: int 103 | ) -> list[QueryResult]: ... 104 | 105 | 106 | class AsyncVectorDB(ABC): 107 | @abstractmethod 108 | async def create_index(self, name: str) -> bool: ... 109 | 110 | @abstractmethod 111 | async def create_point( 112 | self, id: int | str, embedding: ColBERT, data: dict[str, Any] 113 | ) -> Any: ... 114 | 115 | @abstractmethod 116 | async def save_points(self, index: str, points: list[Any]) -> Any: ... 117 | 118 | @abstractmethod 119 | async def delete_index(self, name: str) -> bool: ... 120 | 121 | @abstractmethod 122 | async def search( 123 | self, index: str, query_vector: ColBERT, quantity: int 124 | ) -> list[QueryResult]: ... 125 | -------------------------------------------------------------------------------- /midrasai/_constants.py: -------------------------------------------------------------------------------- 1 | CLOUD_URL = "https://backend-bold-leaf-5025.fly.dev" 2 | -------------------------------------------------------------------------------- /midrasai/cli.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | def cli(host: str = "127.0.0.1", port: int = 8000): 5 | try: 6 | import uvicorn 7 | 8 | from midrasai.local.server import app 9 | 10 | uvicorn.run(app, host=host, port=port) 11 | 12 | except ImportError: 13 | warnings.warn("Local extra dependencies not installed. Server unavailable.") 14 | 15 | 16 | if __name__ == "__main__": 17 | import typer 18 | 19 | typer.run(cli) 20 | -------------------------------------------------------------------------------- /midrasai/client/__init__.py: -------------------------------------------------------------------------------- 1 | from midrasai.client.main import AsyncMidras, Midras 2 | 3 | __all__ = ["Midras", "AsyncMidras"] 4 | -------------------------------------------------------------------------------- /midrasai/client/main.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import httpx 4 | 5 | from midrasai._abc import ( 6 | AsyncBaseMidras, 7 | BaseMidras, 8 | VectorDB, 9 | ) 10 | from midrasai._constants import CLOUD_URL 11 | from midrasai.types import ColBERT, MidrasResponse, Mode 12 | from midrasai.vectordb import Qdrant 13 | 14 | 15 | class Midras(BaseMidras): 16 | def __init__( 17 | self, 18 | api_key: str, 19 | *, 20 | vector_database: VectorDB | None = None, 21 | base_url: str | None = None, 22 | ): 23 | self.api_key = api_key 24 | self.client = httpx.Client(base_url=CLOUD_URL if base_url is None else base_url) 25 | self.index = vector_database if vector_database else Qdrant(location=":memory:") 26 | 27 | def embed_pdf( 28 | self, pdf: str | bytes, batch_size: int = 10, include_images: bool = False 29 | ) -> MidrasResponse: 30 | if isinstance(pdf, str): 31 | with open(pdf, "rb") as f: 32 | file_data = f.read() 33 | elif isinstance(pdf, bytes): 34 | file_data = pdf 35 | 36 | files = {"file": ("test.pdf", file_data, "application/pdf")} 37 | response = self.client.post( 38 | "/embed/pdf", 39 | files=files, 40 | headers={"Authorization": f"Bearer {self.api_key}"}, 41 | params={"batch_size": batch_size, "include_images": include_images}, 42 | ) 43 | 44 | if response.status_code == 200: 45 | return MidrasResponse.model_validate(response.json()) 46 | else: 47 | raise ValueError("Internal server error") 48 | 49 | def embed_images(self, images: list, mode: Mode = Mode.Standard) -> MidrasResponse: 50 | encoded_images = self.base64_encode_image_list(images) 51 | 52 | response = self.client.post( 53 | "/embed/images", 54 | json={"images": encoded_images}, 55 | headers={"Authorization": f"Bearer {self.api_key}"}, 56 | params={"mode": mode}, 57 | ) 58 | 59 | if response.status_code == 200: 60 | return MidrasResponse.model_validate(response.json()) 61 | else: 62 | raise ValueError("Internal server error") 63 | 64 | def create_index(self, name: str) -> bool: 65 | return self.index.create_index(name) 66 | 67 | def add_point( 68 | self, index: str, id: str | int, embedding: ColBERT, data: dict[str, Any] 69 | ): 70 | point = self.index.create_point(id=id, embedding=embedding, data=data) 71 | return self.index.save_points(index, [point]) 72 | 73 | def embed_queries( 74 | self, queries: list[str], mode: Mode = Mode.Standard 75 | ) -> MidrasResponse: 76 | response = self.client.post( 77 | "/embed/queries", 78 | json={"queries": queries}, 79 | headers={"Authorization": f"Bearer {self.api_key}"}, 80 | params={"mode": mode}, 81 | ) 82 | 83 | if response.status_code == 200: 84 | return MidrasResponse.model_validate(response.json()) 85 | else: 86 | raise ValueError("Internal server error") 87 | 88 | def query(self, index: str, query: str, quantity: int = 5): 89 | query_vector = self.embed_queries([query]).embeddings[0] 90 | return self.index.search(index, query_vector, quantity) 91 | 92 | 93 | class AsyncMidras(AsyncBaseMidras): 94 | def __init__( 95 | self, 96 | api_key: str, 97 | *, 98 | vector_database: VectorDB | None = None, 99 | base_url: str | None = None, 100 | ): 101 | self.api_key = api_key 102 | self.client = httpx.AsyncClient( 103 | base_url=CLOUD_URL if base_url is None else base_url 104 | ) 105 | self.index = vector_database if vector_database else Qdrant(location=":memory:") 106 | 107 | async def embed_pdf( 108 | self, pdf: str | bytes, batch_size: int = 10, include_images: bool = False 109 | ) -> MidrasResponse: 110 | if isinstance(pdf, str): 111 | with open(pdf, "rb") as f: 112 | file_data = f.read() 113 | elif isinstance(pdf, bytes): 114 | file_data = pdf 115 | 116 | files = {"file": ("test.pdf", file_data, "application/pdf")} 117 | response = await self.client.post( 118 | "/embed/pdf", 119 | files=files, 120 | headers={"Authorization": f"Bearer {self.api_key}"}, 121 | params={"batch_size": batch_size, "include_images": include_images}, 122 | ) 123 | 124 | if response.status_code == 200: 125 | return MidrasResponse.model_validate(response.json()) 126 | else: 127 | raise ValueError("Internal server error") 128 | 129 | async def embed_images( 130 | self, images: list, mode: Mode = Mode.Standard 131 | ) -> MidrasResponse: 132 | encoded_images = self.base64_encode_image_list(images) 133 | 134 | response = await self.client.post( 135 | "/embed/images", 136 | json={"images": encoded_images}, 137 | headers={"Authorization": f"Bearer {self.api_key}"}, 138 | params={"mode": mode}, 139 | ) 140 | 141 | if response.status_code == 200: 142 | return MidrasResponse.model_validate(response.json()) 143 | else: 144 | raise ValueError("Internal server error") 145 | 146 | async def create_index(self, name: str) -> bool: 147 | return self.index.create_index(name) 148 | 149 | async def add_point( 150 | self, index: str, id: str | int, embedding: ColBERT, data: dict[str, Any] 151 | ): 152 | point = self.index.create_point(id=id, embedding=embedding, data=data) 153 | return self.index.save_points(index, [point]) 154 | 155 | async def embed_queries( 156 | self, queries: list[str], mode: Mode = Mode.Standard 157 | ) -> MidrasResponse: 158 | response = await self.client.post( 159 | "/embed/queries", 160 | json={"queries": queries}, 161 | headers={"Authorization": f"Bearer {self.api_key}"}, 162 | params={"mode": mode}, 163 | ) 164 | 165 | if response.status_code == 200: 166 | return MidrasResponse.model_validate(response.json()) 167 | else: 168 | raise ValueError("Internal server error") 169 | 170 | async def query(self, index: str, query: str, quantity: int = 5): 171 | query_vector = (await self.embed_queries([query])).embeddings[0] 172 | return self.index.search(index, query_vector, quantity) 173 | -------------------------------------------------------------------------------- /midrasai/local/__init__.py: -------------------------------------------------------------------------------- 1 | from midrasai.local.main import LocalMidras 2 | 3 | __all__ = ["LocalMidras"] 4 | -------------------------------------------------------------------------------- /midrasai/local/main.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import pdf2image 4 | import torch 5 | from colpali_engine import ColPali, ColPaliProcessor 6 | 7 | from midrasai._abc import BaseMidras, VectorDB 8 | from midrasai.types import MidrasResponse 9 | from midrasai.vectordb import Qdrant 10 | 11 | 12 | class LocalMidras(BaseMidras): 13 | def __init__( 14 | self, device_map: str = "cuda:0", vector_database: VectorDB | None = None 15 | ): 16 | model_name = "vidore/colpali-v1.2" 17 | self.model = cast( 18 | ColPali, 19 | ColPali.from_pretrained( 20 | model_name, 21 | torch_dtype=torch.bfloat16, 22 | device_map=device_map, 23 | ), 24 | ) 25 | self.processor = cast( 26 | ColPaliProcessor, 27 | ColPaliProcessor.from_pretrained(model_name), 28 | ) 29 | self.index = vector_database if vector_database else Qdrant(location=":memory:") 30 | 31 | def embed_pdf(self, pdf, batch_size=10, include_images=False) -> MidrasResponse: 32 | if isinstance(pdf, str): 33 | images = pdf2image.convert_from_path(pdf) 34 | elif isinstance(pdf, bytes): 35 | images = pdf2image.convert_from_bytes(pdf) 36 | else: 37 | raise ValueError("Invalid type") 38 | 39 | embeddings = [] 40 | 41 | for i in range(0, len(images), batch_size): 42 | image_batch = images[i : i + batch_size] 43 | response = self.embed_images(image_batch) 44 | embeddings.extend(response.embeddings) 45 | 46 | return MidrasResponse( 47 | embeddings=embeddings, 48 | images=images if include_images else None, 49 | ) 50 | 51 | def embed_images(self, images, mode="local"): 52 | _ = mode 53 | batch_images = self.processor.process_images(images).to(self.model.device) 54 | with torch.no_grad(): 55 | image_embeddings = self.model(**batch_images) 56 | return MidrasResponse(embeddings=image_embeddings.tolist()) 57 | 58 | def embed_queries(self, queries, mode="local"): 59 | _ = mode 60 | batch_queries = self.processor.process_queries(queries).to(self.model.device) 61 | with torch.no_grad(): 62 | query_embeddings = self.model(**batch_queries) 63 | return MidrasResponse(embeddings=query_embeddings.tolist()) 64 | 65 | def create_index(self, name): 66 | return self.index.create_index(name) 67 | 68 | def add_point(self, index, id, embedding, data): 69 | point = self.index.create_point(id=id, embedding=embedding, data=data) 70 | return self.index.save_points(index, [point]) 71 | 72 | def query(self, index, query, quantity=5): 73 | query_vector = self.embed_queries([query]).embeddings[0] 74 | return self.index.search(index, query_vector, quantity) 75 | 76 | def delete_index(self, name): 77 | return self.index.delete_index(name) 78 | -------------------------------------------------------------------------------- /midrasai/local/server.py: -------------------------------------------------------------------------------- 1 | from base64 import b64decode 2 | from contextlib import asynccontextmanager 3 | from io import BytesIO 4 | from typing import cast 5 | 6 | from fastapi import FastAPI, File, UploadFile 7 | from PIL import Image 8 | from pydantic import BaseModel 9 | 10 | from midrasai.local.main import LocalMidras 11 | from midrasai.types import MidrasResponse 12 | 13 | midras = cast(LocalMidras, None) 14 | 15 | 16 | @asynccontextmanager 17 | async def lifespan(_: FastAPI): 18 | global midras 19 | midras = LocalMidras() 20 | yield 21 | 22 | 23 | app = FastAPI(lifespan=lifespan) 24 | 25 | 26 | class ImageInput(BaseModel): 27 | images: list[str] 28 | 29 | @property 30 | def pil_images(self): 31 | return [ 32 | cast(Image.Image, Image.open(BytesIO(b64decode(image)))) 33 | for image in self.images 34 | ] 35 | 36 | 37 | class TextInput(BaseModel): 38 | queries: list[str] 39 | 40 | 41 | @app.post("/embed/queries") 42 | def embed_queries(input: TextInput) -> MidrasResponse: 43 | query_embeddings = midras.embed_queries(input.queries) 44 | return query_embeddings 45 | 46 | 47 | @app.post("/embed/images") 48 | def embed_images(input: ImageInput) -> MidrasResponse: 49 | image_embeddings = midras.embed_images(input.pil_images) 50 | return image_embeddings 51 | 52 | 53 | @app.post("/embed/pdf") 54 | def embed_pdf(file: UploadFile = File(...)) -> MidrasResponse: 55 | image_embeddings = midras.embed_pdf(file.file.read()) 56 | return image_embeddings 57 | -------------------------------------------------------------------------------- /midrasai/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Dict, TypeAlias 3 | 4 | from pydantic import BaseModel, ConfigDict 5 | 6 | Embedding: TypeAlias = list[float] 7 | ColBERT: TypeAlias = list[Embedding] 8 | Base64Image: TypeAlias = str 9 | 10 | 11 | class Mode(str, Enum): 12 | Standard = "standard" 13 | Turbo = "turbo" 14 | Local = "local" 15 | 16 | 17 | class MidrasRequest(BaseModel): 18 | key: str 19 | mode: Mode = Mode.Standard 20 | base64images: list[str] | None = None 21 | queries: list[str] | None = None 22 | 23 | 24 | class MidrasResponse(BaseModel): 25 | embeddings: list[ColBERT] 26 | images: list | None = None 27 | 28 | model_config = ConfigDict(arbitrary_types_allowed=True) 29 | 30 | 31 | class QueryResult(BaseModel): 32 | id: int | str 33 | score: float 34 | data: Dict[str, Any] | None 35 | -------------------------------------------------------------------------------- /midrasai/vectordb/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | 3 | __all__ = [] 4 | 5 | if importlib.util.find_spec("qdrant_client"): 6 | from midrasai.vectordb._qdrant import AsyncQdrant, Qdrant 7 | 8 | _, __ = Qdrant, AsyncQdrant # Redundant alias for F401 9 | 10 | __all__.extend(["Qdrant", "AsyncQdrant"]) 11 | -------------------------------------------------------------------------------- /midrasai/vectordb/_qdrant.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Awaitable, Callable, Dict, Optional, Union 2 | 3 | from qdrant_client import AsyncQdrantClient, QdrantClient, models 4 | 5 | from midrasai._abc import AsyncVectorDB, VectorDB 6 | from midrasai.types import ColBERT, QueryResult 7 | 8 | 9 | class Qdrant(VectorDB): 10 | def __init__( 11 | self, 12 | location: Optional[str] = None, 13 | url: Optional[str] = None, 14 | port: Optional[int] = 6333, 15 | grpc_port: int = 6334, 16 | prefer_grpc: bool = False, 17 | https: Optional[bool] = None, 18 | api_key: Optional[str] = None, 19 | prefix: Optional[str] = None, 20 | timeout: Optional[int] = None, 21 | host: Optional[str] = None, 22 | path: Optional[str] = None, 23 | force_disable_check_same_thread: bool = False, 24 | grpc_options: Optional[Dict[str, Any]] = None, 25 | auth_token_provider: Optional[ 26 | Union[Callable[[], str], Callable[[], Awaitable[str]]] 27 | ] = None, 28 | **kwargs: Any, 29 | ): 30 | self.client = QdrantClient( 31 | location, 32 | url, 33 | port, 34 | grpc_port, 35 | prefer_grpc, 36 | https, 37 | api_key, 38 | prefix, 39 | timeout, 40 | host, 41 | path, 42 | force_disable_check_same_thread, 43 | grpc_options, 44 | auth_token_provider, 45 | **kwargs, 46 | ) 47 | 48 | def create_index(self, name: str) -> bool: 49 | return self.client.create_collection( 50 | collection_name=name, 51 | vectors_config=models.VectorParams( 52 | size=128, 53 | distance=models.Distance.COSINE, 54 | multivector_config=models.MultiVectorConfig( 55 | comparator=models.MultiVectorComparator.MAX_SIM 56 | ), 57 | ), 58 | ) 59 | 60 | def create_point( 61 | self, id: int | str, embedding: ColBERT, data: dict[str, Any] 62 | ) -> models.PointStruct: 63 | return models.PointStruct(id=id, payload=data, vector=embedding) 64 | 65 | def save_points( 66 | self, index: str, points: list[models.PointStruct] 67 | ) -> models.UpdateResult: 68 | return self.client.upsert(index, points) 69 | 70 | def delete_index(self, name: str) -> bool: 71 | return self.client.delete_collection(collection_name=name) 72 | 73 | def search( 74 | self, index: str, query_vector: ColBERT, quantity: int 75 | ) -> list[QueryResult]: 76 | result = self.client.query_points(index, query=query_vector, limit=quantity) 77 | return [ 78 | QueryResult(id=point.id, score=point.score, data=point.payload) 79 | for point in result.points 80 | ] 81 | 82 | 83 | class AsyncQdrant(AsyncVectorDB): 84 | def __init__( 85 | self, 86 | location: Optional[str] = None, 87 | url: Optional[str] = None, 88 | port: Optional[int] = 6333, 89 | grpc_port: int = 6334, 90 | prefer_grpc: bool = False, 91 | https: Optional[bool] = None, 92 | api_key: Optional[str] = None, 93 | prefix: Optional[str] = None, 94 | timeout: Optional[int] = None, 95 | host: Optional[str] = None, 96 | path: Optional[str] = None, 97 | force_disable_check_same_thread: bool = False, 98 | grpc_options: Optional[Dict[str, Any]] = None, 99 | auth_token_provider: Optional[ 100 | Union[Callable[[], str], Callable[[], Awaitable[str]]] 101 | ] = None, 102 | **kwargs: Any, 103 | ): 104 | self.client = AsyncQdrantClient( 105 | location, 106 | url, 107 | port, 108 | grpc_port, 109 | prefer_grpc, 110 | https, 111 | api_key, 112 | prefix, 113 | timeout, 114 | host, 115 | path, 116 | force_disable_check_same_thread, 117 | grpc_options, 118 | auth_token_provider, 119 | **kwargs, 120 | ) 121 | 122 | async def create_index(self, name: str) -> bool: 123 | return await self.client.create_collection( 124 | collection_name=name, 125 | vectors_config=models.VectorParams( 126 | size=128, 127 | distance=models.Distance.COSINE, 128 | multivector_config=models.MultiVectorConfig( 129 | comparator=models.MultiVectorComparator.MAX_SIM 130 | ), 131 | ), 132 | ) 133 | 134 | async def create_point( 135 | self, id: int | str, embedding: ColBERT, data: dict[str, Any] 136 | ) -> models.PointStruct: 137 | return models.PointStruct(id=id, payload=data, vector=embedding) 138 | 139 | async def save_points( 140 | self, index: str, points: list[models.PointStruct] 141 | ) -> models.UpdateResult: 142 | return await self.client.upsert(index, points) 143 | 144 | async def delete_index(self, name: str) -> bool: 145 | return await self.client.delete_collection(collection_name=name) 146 | 147 | async def search( 148 | self, index: str, query_vector: ColBERT, quantity: int 149 | ) -> list[QueryResult]: 150 | result = await self.client.query_points( 151 | index, query=query_vector, quantity=quantity 152 | ) 153 | return [ 154 | QueryResult(id=point.id, score=point.score, data=point.payload) 155 | for point in result.points 156 | ] 157 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "midrasai" 3 | version = "0.1.5" 4 | description = "A simple framework for multimodal RAG" 5 | authors = ["ajac-zero "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.scripts] 9 | midras-server = "midrasai.cli:cli" 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.10" 13 | typer = "^0.12.5" 14 | qdrant-client = "^1.11.1" 15 | pdf2image = "^1.17.0" 16 | pillow = { version = "^10", optional = true } 17 | fastapi = { extras = ["standard"], version = "^0", optional = true } 18 | huggingface-hub = { extras = ["cli"], version = "^0", optional = true } 19 | colpali-engine = { version = "^0.3.1", optional = true } 20 | 21 | [tool.poetry.extras] 22 | local = ["colpali-engine", "pillow", "fastapi", "huggingface-hub"] 23 | 24 | [tool.poetry.group.dev.dependencies] 25 | ruff = "^0.6.3" 26 | pytest = "^8.3.2" 27 | httpx = "^0.27.2" 28 | 29 | [tool.poetry.group.docs.dependencies] 30 | ipykernel = "^6.29.5" 31 | matplotlib = "^3.9.2" 32 | google-generativeai = "^0.7.2" 33 | 34 | [tool.ruff] 35 | fix = true 36 | exclude = ["examples"] 37 | 38 | [tool.ruff.lint] 39 | select = ["E", "F", "I"] 40 | ignore = ["E501"] 41 | fixable = ["ALL"] 42 | 43 | [build-system] 44 | requires = ["poetry-core"] 45 | build-backend = "poetry.core.masonry.api" 46 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajac-zero/midrasai/9580bd665f9648015561aef89b5592165b9885e3/tests/__init__.py -------------------------------------------------------------------------------- /tests/assets/Attention_is_all_you_need.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajac-zero/midrasai/9580bd665f9648015561aef89b5592165b9885e3/tests/assets/Attention_is_all_you_need.pdf -------------------------------------------------------------------------------- /tests/assets/Colpali-example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajac-zero/midrasai/9580bd665f9648015561aef89b5592165b9885e3/tests/assets/Colpali-example1.png -------------------------------------------------------------------------------- /tests/assets/Colpali-example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajac-zero/midrasai/9580bd665f9648015561aef89b5592165b9885e3/tests/assets/Colpali-example2.png -------------------------------------------------------------------------------- /tests/assets/Colpali-example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajac-zero/midrasai/9580bd665f9648015561aef89b5592165b9885e3/tests/assets/Colpali-example3.png -------------------------------------------------------------------------------- /tests/test_local_midras.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | 4 | from midrasai.local import Midras 5 | from midrasai.types import MidrasResponse, QueryResult 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def m(): 10 | return Midras() 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def state(): 15 | return {} 16 | 17 | 18 | def test_create_index(m: Midras): 19 | r = m.create_index("test_index") 20 | assert r is True 21 | 22 | 23 | def test_embed_query(m: Midras, state): 24 | r = m.embed_queries(["hello", "it's me"]) 25 | assert isinstance(r, MidrasResponse) 26 | assert len(r.embeddings) == 2 27 | assert isinstance(r.embeddings[0], list) 28 | assert isinstance(r.embeddings[0][0], list) 29 | 30 | state["query"] = r.embeddings 31 | 32 | 33 | def test_embed_pdf_path(m: Midras, state): 34 | r = m.embed_pdf("./tests/assets/Attention_is_all_you_need.pdf", include_images=True) 35 | assert isinstance(r, MidrasResponse) 36 | assert len(r.images) == 15 # type: ignore 37 | assert len(r.embeddings) == 15 38 | assert isinstance(r.embeddings[0], list) 39 | assert isinstance(r.embeddings[0][0], list) 40 | 41 | state["pdf"] = r.embeddings 42 | 43 | 44 | def test_embed_pdf_bytes(m: Midras, state): 45 | with open("./tests/assets/Attention_is_all_you_need.pdf", "rb") as f: 46 | r = m.embed_pdf(f.read(), include_images=True) 47 | assert isinstance(r, MidrasResponse) 48 | assert len(r.images) == 15 # type: ignore 49 | assert len(r.embeddings) == 15 50 | assert isinstance(r.embeddings[0], list) 51 | assert isinstance(r.embeddings[0][0], list) 52 | 53 | state["pdf"] = r.embeddings 54 | 55 | 56 | def test_embed_images(m: Midras, state): 57 | r = m.embed_images( 58 | [ 59 | Image.open(open("./tests/assets/Colpali-example1.png", "rb")), # type: ignore 60 | Image.open(open("./tests/assets/Colpali-example2.png", "rb")), # type: ignore 61 | Image.open(open("./tests/assets/Colpali-example3.png", "rb")), # type: ignore 62 | ] 63 | ) 64 | assert isinstance(r, MidrasResponse) 65 | assert len(r.embeddings) == 3 66 | assert isinstance(r.embeddings[0], list) 67 | assert isinstance(r.embeddings[0][0], list) 68 | 69 | state["images"] = r.embeddings 70 | 71 | 72 | def test_add_point(m: Midras, state): 73 | for i, e in enumerate(state["pdf"]): 74 | r = m.add_point(index="test_index", id=i, embedding=e, data={"test": "midras"}) 75 | assert r is not None 76 | 77 | 78 | def test_query(m: Midras): 79 | q = m.query("test_index", "whats a transformer") 80 | 81 | assert len(q) == 5 82 | for r in q: 83 | assert isinstance(r, QueryResult) 84 | assert r.data.get("test") == "midras" # type: ignore 85 | -------------------------------------------------------------------------------- /tests/test_midras_client.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | 4 | from midrasai import Midras 5 | from midrasai.types import MidrasResponse, QueryResult 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def m(): 10 | return Midras(api_key="test", base_url="http://localhost:8000") 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def state(): 15 | return {} 16 | 17 | 18 | def test_create_index(m: Midras): 19 | r = m.create_index("test_index") 20 | assert r is True 21 | 22 | 23 | def test_embed_query(m: Midras, state): 24 | r = m.embed_queries(["hello", "it's me"]) 25 | assert isinstance(r, MidrasResponse) 26 | assert len(r.embeddings) == 2 27 | assert isinstance(r.embeddings[0], list) 28 | assert isinstance(r.embeddings[0][0], list) 29 | 30 | state["query"] = r.embeddings 31 | 32 | 33 | def test_embed_pdf_path(m: Midras, state): 34 | r = m.embed_pdf("./tests/assets/Attention_is_all_you_need.pdf") 35 | assert isinstance(r, MidrasResponse) 36 | 37 | assert len(r.embeddings) == 15 38 | assert isinstance(r.embeddings[0], list) 39 | assert isinstance(r.embeddings[0][0], list) 40 | 41 | state["pdf"] = r.embeddings 42 | 43 | 44 | def test_embed_pdf_bytes(m: Midras, state): 45 | with open("./tests/assets/Attention_is_all_you_need.pdf", "rb") as f: 46 | r = m.embed_pdf(f.read()) 47 | assert isinstance(r, MidrasResponse) 48 | 49 | assert len(r.embeddings) == 15 50 | assert isinstance(r.embeddings[0], list) 51 | assert isinstance(r.embeddings[0][0], list) 52 | 53 | state["pdf"] = r.embeddings 54 | 55 | 56 | def test_embed_images(m: Midras, state): 57 | images = [ 58 | Image.open("./tests/assets/Colpali-example1.png"), 59 | Image.open("./tests/assets/Colpali-example2.png"), 60 | Image.open("./tests/assets/Colpali-example3.png"), 61 | ] 62 | 63 | r = m.embed_images(images) # type: ignore 64 | 65 | assert isinstance(r, MidrasResponse) 66 | assert len(r.embeddings) == 3 67 | assert isinstance(r.embeddings[0], list) 68 | assert isinstance(r.embeddings[0][0], list) 69 | 70 | state["images"] = r.embeddings 71 | 72 | 73 | def test_add_point(m: Midras, state): 74 | for i, e in enumerate(state["pdf"]): 75 | r = m.add_point(index="test_index", id=i, embedding=e, data={"test": "midras"}) 76 | assert r is not None 77 | 78 | 79 | def test_query(m: Midras): 80 | q = m.query("test_index", "whats a transformer") 81 | 82 | assert len(q) == 5 83 | for r in q: 84 | assert isinstance(r, QueryResult) 85 | assert r.data.get("test") == "midras" # type: ignore 86 | -------------------------------------------------------------------------------- /tests/test_midras_server.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | 4 | import pytest 5 | from fastapi.testclient import TestClient 6 | from PIL import Image 7 | 8 | from midrasai.local.server import app 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def client(): 13 | with TestClient(app) as client: 14 | yield client 15 | 16 | 17 | def test_embed_queries(client: TestClient): 18 | queries = ["Hello!", "I exist!"] 19 | r = client.post("/embed/queries", json={"queries": queries}) 20 | 21 | assert r.status_code == 200 22 | 23 | data = r.json() 24 | assert data["images"] is None 25 | 26 | embeddings = data["embeddings"] 27 | assert len(embeddings) == 2 28 | 29 | for colbert in embeddings: 30 | assert isinstance(colbert, list) 31 | assert isinstance(colbert[0], list) 32 | 33 | 34 | def test_embed_images(client: TestClient): 35 | images = [ 36 | Image.open(open("./tests/assets/Colpali-example1.png", "rb")), 37 | Image.open(open("./tests/assets/Colpali-example2.png", "rb")), 38 | Image.open(open("./tests/assets/Colpali-example3.png", "rb")), 39 | ] 40 | 41 | encoded_images = [] 42 | for image in images: 43 | with io.BytesIO() as buffer: 44 | image.save(buffer, format="JPEG") 45 | img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") 46 | encoded_images.append(img_str) 47 | 48 | r = client.post("/embed/images", json={"images": encoded_images}) 49 | 50 | assert r.status_code == 200 51 | 52 | data = r.json() 53 | assert data["images"] is None 54 | 55 | embeddings = data["embeddings"] 56 | assert len(embeddings) == 3 57 | 58 | for colbert in embeddings: 59 | assert isinstance(colbert, list) 60 | assert isinstance(colbert[0], list) 61 | 62 | 63 | def test_embed_pdf(client: TestClient): 64 | with open("./tests/assets/Attention_is_all_you_need.pdf", "rb") as f: 65 | files = {"file": ("test.pdf", f, "application/pdf")} 66 | r = client.post("/embed/pdf", files=files) 67 | 68 | assert r.status_code == 200 69 | 70 | data = r.json() 71 | assert data["images"] is None 72 | 73 | embeddings = data["embeddings"] 74 | assert len(embeddings) == 15 75 | 76 | for colbert in embeddings: 77 | assert isinstance(colbert, list) 78 | assert isinstance(colbert[0], list) 79 | -------------------------------------------------------------------------------- /tests/vectordb/test_qdrant.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from midrasai.vectordb import Qdrant 4 | 5 | 6 | @pytest.fixture 7 | def qdrant(): 8 | return Qdrant(":memory:") 9 | 10 | 11 | def test_create_collection(qdrant: Qdrant): 12 | result = qdrant.create_index("test_index") 13 | assert result is True 14 | --------------------------------------------------------------------------------