├── requirements.txt ├── README.md ├── .gitignore └── app.py /requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements.txt 2 | git+https://github.com/illuin-tech/colpali 3 | streamlit 4 | torch 5 | Pillow 6 | numpy 7 | pdf2image 8 | setuptools 9 | ollama 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🖼️ Image RAG (Colpali + LLaMA Vision) 2 | 3 | A powerful Retrieval-Augmented Generation (RAG) system combining Colpali's ColQwen image embeddings with LLaMA Vision via Ollama. 4 | 5 | ## 🌟 Key Features 6 | 7 | - 🧬 ColQwen model for generating powerful image embeddings via Colpali 8 | - 🤖 LLaMA Vision integration through Ollama for image understanding 9 | - 📥 Intelligent image indexing with duplicate detection 10 | - 💬 Natural language image queries 11 | - 📄 PDF document support 12 | - 🔍 Semantic similarity search 13 | - 📊 Efficient SQLite storage 14 | 15 | ## 🛠️ Technical Stack 16 | 17 | - **Embedding Model**: ColQwen via Colpali 18 | - **Vision Model**: LLaMA Vision via Ollama 19 | - **Frontend**: Streamlit 20 | - **Database**: SQLite 21 | - **Image Processing**: Pillow, pdf2image 22 | - **ML Framework**: PyTorch 23 | 24 | 25 | ## ⚡ Quick Start 26 | 27 | 1. Install Poppler (required for PDF support): 28 | 29 | **Mac:** 30 | ```bash 31 | brew install poppler 32 | ``` 33 | 34 | **Windows:** 35 | 1. Download the latest poppler package from: https://github.com/oschwartz10612/poppler-windows/releases/ 36 | 2. Extract the downloaded zip to a location (e.g., `C:\Program Files\poppler`) 37 | 3. Add bin directory to PATH: 38 | - Open System Properties > Advanced > Environment Variables 39 | - Under System Variables, find and select "Path" 40 | - Click "Edit" > "New" 41 | - Add the bin path (e.g., `C:\Program Files\poppler\bin`) 42 | 4. Verify installation: 43 | ```bash 44 | pdftoppm -h 45 | ``` 46 | 47 | 2. Clone and setup environment: 48 | ```bash 49 | git clone https://github.com/kturung/colpali-llama-vision-rag.git 50 | python -m venv venv 51 | source venv/bin/activate # For Mac/Linux 52 | # or 53 | .\venv\Scripts\activate # For Windows 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | 3. Install Ollama from https://ollama.com 58 | 59 | 4. Launch application: 60 | ```bash 61 | streamlit run app.py 62 | ``` 63 | 64 | > Note: Restart your terminal/IDE after modifying PATH variables 65 | 66 | 67 | ## 💡 Usage 68 | 69 | ### 📤 Adding Images 70 | 1. Navigate to "➕ Add to Index" 71 | 2. Upload images/PDFs 72 | 3. System automatically: 73 | - Generates ColQwen embeddings 74 | - Checks for duplicates 75 | - Stores in SQLite 76 | 77 | ### 🔎 Querying 78 | 1. Go to "🔍 Query Index" 79 | 2. Enter natural language query 80 | 3. View similar images 81 | 4. Get LLaMA Vision analysis 82 | 83 | 84 | ## 💾 Database Schema 85 | 86 | ```sql 87 | CREATE TABLE embeddings ( 88 | id INTEGER PRIMARY KEY AUTOINCREMENT, 89 | image_base64 TEXT, 90 | image_hash TEXT UNIQUE, 91 | embedding BLOB 92 | ) 93 | ``` 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | image_embeddings.db 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import torch 3 | from PIL import Image 4 | import sqlite3 5 | import numpy as np 6 | import pickle 7 | import base64 8 | import io 9 | from colpali_engine.models import ColQwen2, ColQwen2Processor 10 | import gc 11 | from pdf2image import convert_from_bytes 12 | from io import BytesIO 13 | import hashlib # Import hashlib for hashing 14 | 15 | def get_device(): 16 | if torch.cuda.is_available(): 17 | return "cuda" 18 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 19 | return "mps" 20 | else: 21 | return "cpu" 22 | 23 | device_map = get_device() 24 | 25 | # Function to load the model and processor 26 | @st.cache_resource 27 | def load_model(): 28 | model = ColQwen2.from_pretrained( 29 | "vidore/colqwen2-v0.1", 30 | torch_dtype=torch.bfloat16, 31 | device_map=device_map # Use "mps" if on Apple Silicon; otherwise, use "cpu" or "cuda" 32 | ) 33 | 34 | processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v0.1") 35 | 36 | return model, processor 37 | 38 | # Function to get a database connection 39 | def get_db_connection(): 40 | conn = sqlite3.connect('image_embeddings.db') 41 | return conn 42 | 43 | def process_and_index_image(image, img_str, image_hash, processor, model): 44 | # Store in database 45 | conn = get_db_connection() 46 | c = conn.cursor() 47 | c.execute(''' 48 | CREATE TABLE IF NOT EXISTS embeddings ( 49 | id INTEGER PRIMARY KEY AUTOINCREMENT, 50 | image_base64 TEXT, 51 | image_hash TEXT UNIQUE, 52 | embedding BLOB 53 | ) 54 | ''') 55 | # Check if the image hash already exists 56 | c.execute('SELECT id FROM embeddings WHERE image_hash = ?', (image_hash,)) 57 | result = c.fetchone() 58 | if result: 59 | # Image already indexed 60 | conn.close() 61 | return 62 | # Process image to get embedding 63 | batch_images = processor.process_images([image]).to(model.device) 64 | with torch.no_grad(): 65 | image_embeddings = model(**batch_images) 66 | image_embedding = image_embeddings[0].cpu().to(torch.float32).numpy() 67 | # Serialize the embedding 68 | embedding_bytes = pickle.dumps(image_embedding) 69 | c.execute('INSERT INTO embeddings (image_base64, image_hash, embedding) VALUES (?, ?, ?)', (img_str, image_hash, embedding_bytes)) 70 | conn.commit() 71 | conn.close() 72 | 73 | def clear_cache(): 74 | """Clear GPU memory cache for different platforms.""" 75 | try: 76 | if torch.cuda.is_available(): 77 | torch.cuda.empty_cache() 78 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 79 | torch.mps.empty_cache() 80 | # CPU doesn't need explicit cache clearing 81 | except Exception as e: 82 | print(f"Warning: Could not clear cache: {str(e)}") 83 | 84 | def main(): 85 | st.title("📷 Image RAG(Colpali + Llama Vision)") 86 | 87 | model, processor = load_model() 88 | 89 | # Initialize session state for image hashes 90 | if 'image_hashes' not in st.session_state: 91 | st.session_state.image_hashes = set() 92 | 93 | # Use st.radio for tab selection 94 | tab = st.radio("Navigation", ["➕ Add to Index", "🔍 Query Index"]) 95 | 96 | if tab == "➕ Add to Index": 97 | st.header("Add Images to Index") 98 | # File uploader 99 | uploaded_files = st.file_uploader("Upload Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg', 'pdf']) 100 | if uploaded_files: 101 | # Process the uploaded images 102 | for uploaded_file in uploaded_files: 103 | if uploaded_file.type == 'application/pdf': 104 | images = convert_from_bytes(uploaded_file.read()) 105 | for image in images: 106 | buffer = BytesIO() 107 | image.save(buffer, format="PNG") 108 | byte_data = buffer.getvalue() 109 | img_str = base64.b64encode(byte_data).decode('utf-8') 110 | # Compute image hash 111 | image_hash = hashlib.sha256(byte_data).hexdigest() 112 | if image_hash in st.session_state.image_hashes: 113 | continue # Skip if already processed in this session 114 | process_and_index_image(image, img_str, image_hash, processor, model) 115 | st.session_state.image_hashes.add(image_hash) 116 | else: 117 | # Read image data 118 | image_data = uploaded_file.read() 119 | # Compute image hash 120 | image_hash = hashlib.sha256(image_data).hexdigest() 121 | if image_hash in st.session_state.image_hashes: 122 | continue # Skip if already processed in this session 123 | # Convert to PIL Image 124 | image = Image.open(io.BytesIO(image_data)).convert('RGB') 125 | # Encode image to base64 126 | buffered = io.BytesIO() 127 | image.save(buffered, format="PNG") 128 | img_str = base64.b64encode(buffered.getvalue()).decode() 129 | process_and_index_image(image, img_str, image_hash, processor, model) 130 | st.session_state.image_hashes.add(image_hash) 131 | st.success("Images added to index.") 132 | 133 | elif tab == "🔍 Query Index": 134 | st.header("Query Index") 135 | query = st.text_input("Enter your query") 136 | if query: 137 | # Process query 138 | with torch.no_grad(): 139 | batch_query = processor.process_queries([query]).to(model.device) 140 | query_embedding = model(**batch_query) 141 | query_embedding_cpu = query_embedding.cpu().to(torch.float32).numpy()[0] 142 | 143 | # Retrieve image embeddings from database 144 | conn = get_db_connection() 145 | c = conn.cursor() 146 | c.execute('SELECT image_base64, embedding FROM embeddings') 147 | rows = c.fetchall() 148 | conn.close() 149 | 150 | if not rows: 151 | st.warning("No images found in the index. Please add images first.") 152 | return 153 | 154 | # Set fixed sequence length 155 | fixed_seq_len = 620 # Adjust based on your embeddings 156 | 157 | image_embeddings_list = [] 158 | image_base64_list = [] 159 | 160 | for row in rows: 161 | image_base64, embedding_bytes = row 162 | embedding = pickle.loads(embedding_bytes) 163 | seq_len, embedding_dim = embedding.shape 164 | 165 | # Adjust to fixed sequence length 166 | if seq_len < fixed_seq_len: 167 | padding = np.zeros((fixed_seq_len - seq_len, embedding_dim), dtype=embedding.dtype) 168 | embedding_fixed = np.concatenate([embedding, padding], axis=0) 169 | elif seq_len > fixed_seq_len: 170 | embedding_fixed = embedding[:fixed_seq_len, :] 171 | else: 172 | embedding_fixed = embedding # No adjustment needed 173 | 174 | image_embeddings_list.append(embedding_fixed) 175 | image_base64_list.append(image_base64) 176 | 177 | # Stack embeddings 178 | retrieved_image_embeddings = np.stack(image_embeddings_list) 179 | 180 | # Adjust query embedding 181 | seq_len_q, embedding_dim_q = query_embedding_cpu.shape 182 | 183 | if seq_len_q < fixed_seq_len: 184 | padding = np.zeros((fixed_seq_len - seq_len_q, embedding_dim_q), dtype=query_embedding_cpu.dtype) 185 | query_embedding_fixed = np.concatenate([query_embedding_cpu, padding], axis=0) 186 | elif seq_len_q > fixed_seq_len: 187 | query_embedding_fixed = query_embedding_cpu[:fixed_seq_len, :] 188 | else: 189 | query_embedding_fixed = query_embedding_cpu 190 | 191 | # Convert to tensors 192 | query_embedding_tensor = torch.from_numpy(query_embedding_fixed).to(model.device).unsqueeze(0) 193 | retrieved_image_embeddings_tensor = torch.from_numpy(retrieved_image_embeddings).to(model.device) 194 | 195 | # Compute similarity scores 196 | with torch.no_grad(): 197 | scores = processor.score_multi_vector(query_embedding_tensor, retrieved_image_embeddings_tensor) 198 | scores_np = scores.cpu().numpy().flatten() 199 | del query_embedding_tensor, retrieved_image_embeddings_tensor, scores # Free up memory 200 | clear_cache() 201 | 202 | # Combine images and scores 203 | similarities = list(zip(image_base64_list, scores_np)) 204 | 205 | # Sort by similarity 206 | similarities.sort(key=lambda x: x[1], reverse=True) 207 | 208 | if similarities: 209 | st.write("Most similar image:") 210 | img_str, score = similarities[0] 211 | st.write(f"Similarity Score: {score:.4f}") 212 | # Decode image from base64 213 | img_data = base64.b64decode(img_str) 214 | image = Image.open(io.BytesIO(img_data)) 215 | st.image(image) 216 | else: 217 | st.write("No similar images found.") 218 | 219 | st.write("AI Response:") 220 | 221 | import ollama 222 | 223 | response_container = st.empty() 224 | 225 | 226 | # Spinner only for the initial API call 227 | 228 | 229 | stream = ollama.chat( 230 | model="llama3.2-vision", 231 | messages=[ 232 | { 233 | 'role': 'user', 234 | 'content': "Please answer the following question using only the information visible in the provided image" 235 | " Do not use any of your own knowledge, training data, or external sources." 236 | " Base your response solely on the content depicted within the image." 237 | " If there is no relation with question and image," 238 | f" you can respond with 'Question is not related to image'.\nHere is the question: {query}", 239 | 'images': [img_data] 240 | } 241 | ], 242 | stream=True 243 | ) 244 | 245 | 246 | 247 | collected_chunks = [] 248 | stream_iter = iter(stream) 249 | 250 | with st.spinner('⏳ Generating Response...'): 251 | try: 252 | # Get the first chunk 253 | first_chunk = next(stream_iter) 254 | chunk_content = first_chunk['message']['content'] 255 | collected_chunks.append(chunk_content) 256 | # Display the initial response 257 | complete_response = ''.join(collected_chunks) 258 | response_container.markdown(complete_response) 259 | except StopIteration: 260 | # Handle if no chunks are received 261 | pass 262 | 263 | # Continue streaming the rest of the response 264 | for chunk in stream_iter: 265 | chunk_content = chunk['message']['content'] 266 | collected_chunks.append(chunk_content) 267 | complete_response = ''.join(collected_chunks) 268 | response_container.markdown(complete_response) 269 | 270 | 271 | clear_cache() 272 | gc.collect() 273 | 274 | if __name__ == "__main__": 275 | main() 276 | --------------------------------------------------------------------------------