├── .env.example ├── .gitignore ├── LICENSE ├── README.md ├── generate_embeddings.py ├── requirements.txt ├── start_web.py └── templates ├── README.md ├── base.html ├── display_image.html ├── index.html └── query_results.html /.env.example: -------------------------------------------------------------------------------- 1 | DATA_DIR="./data/" 2 | DB_FILENAME=images.db 3 | IMAGE_DIRECTORY=./images 4 | ANONYMIZED_TELEMETRY=False 5 | LOG_LEVEL=INFO 6 | CLIP_MODEL="openai/clip-vit-large-patch14" 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | data 3 | chroma 4 | chroma-*.log 5 | *.pyc 6 | *.pyc 7 | mlx_model 8 | .env 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Shi Sheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📸 Embed-Photos 🖼️ 2 | 3 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) 4 | 5 | Welcome to Embed-Photos, a powerful photo similarity search engine built by [@harperreed](https://github.com/harperreed)! 🎉 This project leverages the CLIP (Contrastive Language-Image Pre-training) model to find visually similar images based on textual descriptions. 🔍🖼️ 6 | 7 | ## 🌟 Features 8 | 9 | - 🚀 Fast and efficient image search using the CLIP model 10 | - 💻 Works on Apple Silicon (MLX) only 11 | - 💾 Persistent storage of image embeddings using SQLite and Chroma 12 | - 🌐 Web interface for easy interaction and exploration 13 | - 🔒 Secure image serving and handling 14 | - 📊 Logging and monitoring for performance analysis 15 | - 🔧 Configurable settings using environment variables 16 | 17 | ## Screenshot 18 | 19 | ![image](https://github.com/harperreed/photo-similarity-search/assets/18504/7df51659-84b0-4efb-9647-58a544743ea5) 20 | 21 | 22 | ## 📂 Repository Structure 23 | 24 | ``` 25 | embed-photos/ 26 | ├── README.md 27 | ├── generate_embeddings.py 28 | ├── requirements.txt 29 | ├── start_web.py 30 | └── templates 31 | ├── README.md 32 | ├── base.html 33 | ├── display_image.html 34 | ├── index.html 35 | ├── output.txt 36 | └── query_results.html 37 | ``` 38 | 39 | - `generate_embeddings.py`: Script to generate image embeddings using the CLIP model 40 | - `requirements.txt`: Lists the required Python dependencies 41 | - `start_web.py`: Flask web application for the photo similarity search 42 | - `templates/`: Contains HTML templates for the web interface 43 | 44 | ## 🚀 Getting Started 45 | 46 | 1. Clone the repository: 47 | ``` 48 | git clone https://github.com/harperreed/photo-similarity-search.git 49 | ``` 50 | 51 | 2. Install the required dependencies: 52 | ``` 53 | pip install -r requirements.txt 54 | ``` 55 | 56 | 3. Configure the application by setting the necessary environment variables in a `.env` file. 57 | 58 | 4. Generate image embeddings: 59 | ``` 60 | python generate_embeddings.py 61 | ``` 62 | 63 | 5. Start the web application: 64 | ``` 65 | python start_web.py 66 | ``` 67 | 68 | 6. Open your web browser and navigate to `http://localhost:5000` to explore the photo similarity search! 69 | 70 | ## Todo 71 | 72 | - Use siglip instead of clip 73 | - add a more robust config 74 | - make mlx optional 75 | 76 | ## 🙏 Acknowledgments 77 | 78 | The Embed-Photos project builds upon the work of the Apple (mlx!), the CLIP model and leverages various open-source libraries. We extend our gratitude to the authors and contributors of these projects. 79 | 80 | Happy searching! 🔍✨ 81 | -------------------------------------------------------------------------------- /generate_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import msgpack 3 | import socket 4 | import uuid 5 | import logging 6 | import time 7 | from dotenv import load_dotenv 8 | import sqlite3 9 | import hashlib 10 | import requests 11 | from io import BytesIO 12 | import signal 13 | from concurrent.futures import ThreadPoolExecutor 14 | from logging.handlers import RotatingFileHandler 15 | import chromadb 16 | import json 17 | import numpy as np 18 | 19 | import mlx_clip 20 | 21 | 22 | # Generate unique ID for the machine 23 | host_name = socket.gethostname() 24 | unique_id = uuid.uuid5(uuid.NAMESPACE_DNS, host_name + str(uuid.getnode())) 25 | 26 | # Configure logging 27 | log_app_name = "app" 28 | log_level = os.getenv('LOG_LEVEL', 'INFO') 29 | log_level = getattr(logging, log_level.upper()) 30 | 31 | file_handler = RotatingFileHandler(f"{log_app_name}_{unique_id}.log", maxBytes=10485760, backupCount=10) 32 | file_handler.setLevel(log_level) 33 | file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 34 | file_handler.setFormatter(file_formatter) 35 | 36 | console_handler = logging.StreamHandler() 37 | console_handler.setLevel(log_level) 38 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 39 | console_handler.setFormatter(console_formatter) 40 | 41 | logger = logging.getLogger(log_app_name) 42 | logger.setLevel(log_level) 43 | logger.addHandler(file_handler) 44 | logger.addHandler(console_handler) 45 | 46 | # Load environment variables 47 | load_dotenv() 48 | 49 | 50 | 51 | logger.info(f"Running on machine ID: {unique_id}") 52 | 53 | # Retrieve values from .env 54 | DATA_DIR = os.getenv('DATA_DIR', './') 55 | SQLITE_DB_FILENAME = os.getenv('DB_FILENAME', 'images.db') 56 | FILELIST_CACHE_FILENAME = os.getenv('CACHE_FILENAME', 'filelist_cache.msgpack') 57 | SOURCE_IMAGE_DIRECTORY = os.getenv('IMAGE_DIRECTORY', 'images') 58 | CHROMA_DB_PATH = os.getenv('CHROME_PATH', f"{DATA_DIR}/{unique_id}_chroma") 59 | CHROMA_COLLECTION_NAME = os.getenv('CHROME_COLLECTION', "images") 60 | CLIP_MODEL = os.getenv('CLIP_MODEL', "openai/clip-vit-base-patch32") 61 | 62 | logger.debug("Configuration loaded.") 63 | # Log the configuration for debugging 64 | logger.debug(f"Configuration - DATA_DIR: {DATA_DIR}") 65 | logger.debug(f"Configuration - DB_FILENAME: {SQLITE_DB_FILENAME}") 66 | logger.debug(f"Configuration - CACHE_FILENAME: {FILELIST_CACHE_FILENAME}") 67 | logger.debug(f"Configuration - IMAGE_DIRECTORY: {SOURCE_IMAGE_DIRECTORY}") 68 | logger.debug(f"Configuration - CHROME_PATH: {CHROMA_DB_PATH}") 69 | logger.debug(f"Configuration - CHROME_COLLECTION: {CHROMA_COLLECTION_NAME}") 70 | logger.debug(f"Configuration - CLIP_MODEL: {CLIP_MODEL}") 71 | logger.debug("Configuration loaded.") 72 | 73 | # Append the unique ID to the db file path and cache file path 74 | SQLITE_DB_FILEPATH = f"{DATA_DIR}{str(unique_id)}_{SQLITE_DB_FILENAME}" 75 | FILELIST_CACHE_FILEPATH = os.path.join(DATA_DIR, f"{unique_id}_{FILELIST_CACHE_FILENAME}") 76 | 77 | 78 | 79 | # Graceful shutdown handler 80 | def graceful_shutdown(signum, frame): 81 | logger.info("Caught signal, shutting down gracefully...") 82 | if 'conn_pool' in globals(): 83 | connection.close() 84 | logger.info("Database connection pool closed.") 85 | exit(0) 86 | 87 | # Register the signal handlers for graceful shutdown 88 | signal.signal(signal.SIGINT, graceful_shutdown) 89 | signal.signal(signal.SIGTERM, graceful_shutdown) 90 | 91 | #Instantiate MLX Clip model 92 | clip = mlx_clip.mlx_clip("mlx_model", hf_repo=CLIP_MODEL) 93 | 94 | # Check if data dir exists, if it doesn't - then create it 95 | if not os.path.exists(DATA_DIR): 96 | logger.info("Creating data directory ...") 97 | os.makedirs(DATA_DIR) 98 | 99 | # Create a connection pool for the SQLite database 100 | connection = sqlite3.connect(SQLITE_DB_FILEPATH) 101 | 102 | def create_table(): 103 | """ 104 | Creates the 'images' table in the SQLite database if it doesn't exist. 105 | """ 106 | with connection: 107 | connection.execute(''' 108 | CREATE TABLE IF NOT EXISTS images ( 109 | id INTEGER PRIMARY KEY, 110 | filename TEXT NOT NULL, 111 | file_path TEXT NOT NULL, 112 | file_date TEXT NOT NULL, 113 | file_md5 TEXT NOT NULL, 114 | embeddings BLOB 115 | ) 116 | ''') 117 | connection.execute('CREATE INDEX IF NOT EXISTS idx_filename ON images (filename)') 118 | connection.execute('CREATE INDEX IF NOT EXISTS idx_file_path ON images (file_path)') 119 | logger.info("Table 'images' ensured to exist.") 120 | 121 | 122 | 123 | def file_generator(directory): 124 | """ 125 | Generates file paths for all files in the specified directory and its subdirectories. 126 | 127 | :param directory: The directory path to search for files. 128 | :return: A generator yielding file paths. 129 | """ 130 | logger.debug(f"Generating file paths for directory: {directory}") 131 | for root, _, files in os.walk(directory): 132 | for file in files: 133 | yield os.path.join(root, file) 134 | 135 | def hydrate_cache(directory, cache_file_path): 136 | """ 137 | Loads or generates a cache of file paths for the specified directory. 138 | 139 | :param directory: The directory path to search for files. 140 | :param cache_file_path: The path to the cache file. 141 | :return: A list of cached file paths. 142 | """ 143 | logger.info(f"Hydrating cache for {directory} using {cache_file_path}...") 144 | if os.path.exists(cache_file_path): 145 | try: 146 | with open(cache_file_path, 'rb') as f: 147 | cached_files = msgpack.load(f) 148 | logger.info(f"Loaded cached files from {cache_file_path}") 149 | if len(cached_files) == 0: 150 | logger.warning(f"Cache file {cache_file_path} is empty. Regenerating cache...") 151 | cached_files = list(file_generator(directory)) 152 | with open(cache_file_path, 'wb') as f: 153 | msgpack.dump(cached_files, f) 154 | logger.info(f"Regenerated cache with {len(cached_files)} files and dumped to {cache_file_path}") 155 | except (msgpack.UnpackException, IOError) as e: 156 | logger.error(f"Error loading cache file {cache_file_path}: {e}. Regenerating cache...") 157 | cached_files = list(file_generator(directory)) 158 | with open(cache_file_path, 'wb') as f: 159 | msgpack.dump(cached_files, f) 160 | logger.info(f"Regenerated cache with {len(cached_files)} files and dumped to {cache_file_path}") 161 | else: 162 | logger.info(f"Cache file not found at {cache_file_path}. Creating cache dirlist for {directory}...") 163 | cached_files = list(file_generator(directory)) 164 | try: 165 | with open(cache_file_path, 'wb') as f: 166 | msgpack.dump(cached_files, f) 167 | logger.info(f"Created cache with {len(cached_files)} files and dumped to {cache_file_path}") 168 | except IOError as e: 169 | logger.error(f"Error creating cache file {cache_file_path}: {e}. Proceeding without cache.") 170 | return cached_files 171 | 172 | 173 | def update_db(image): 174 | """ 175 | Updates the database with the image embeddings. 176 | 177 | :param image: A dictionary containing image information. 178 | """ 179 | try: 180 | embeddings_blob = sqlite3.Binary(msgpack.dumps(image.get('embeddings', []))) 181 | with sqlite3.connect(SQLITE_DB_FILEPATH) as conn: 182 | conn.execute("UPDATE images SET embeddings = ? WHERE filename = ?", 183 | (embeddings_blob, image['filename'])) 184 | logger.debug(f"Database updated successfully for image: {image['filename']}") 185 | except sqlite3.Error as e: 186 | logger.error(f"Database update failed for image: {image['filename']}. Error: {e}") 187 | 188 | def process_image(file_path): 189 | """ 190 | Processes an image file by extracting metadata and inserting it into the database. 191 | 192 | :param file_path: The path to the image file. 193 | """ 194 | file = os.path.basename(file_path) 195 | file_date = time.ctime(os.path.getmtime(file_path)) 196 | with open(file_path, 'rb') as f: 197 | file_content = f.read() 198 | file_md5 = hashlib.md5(file_content).hexdigest() 199 | conn = None 200 | try: 201 | conn = sqlite3.connect(SQLITE_DB_FILEPATH) 202 | with conn: 203 | cursor = conn.cursor() 204 | cursor.execute(''' 205 | SELECT EXISTS(SELECT 1 FROM images WHERE filename=? AND file_path=? LIMIT 1) 206 | ''', (file, file_path)) 207 | result = cursor.fetchone() 208 | file_exists = result[0] if result else False 209 | if not file_exists: 210 | cursor.execute(''' 211 | INSERT INTO images (filename, file_path, file_date, file_md5) 212 | VALUES (?, ?, ?, ?) 213 | ''', (file, file_path, file_date, file_md5)) 214 | logger.debug(f'Inserted {file} with metadata into the database.') 215 | else: 216 | logger.debug(f'File {file} already exists in the database. Skipping insertion.') 217 | except sqlite3.Error as e: 218 | logger.error(f'Error processing image {file}: {e}') 219 | finally: 220 | if conn: 221 | conn.close() 222 | 223 | def process_embeddings(photo): 224 | """ 225 | Processes image embeddings by uploading them to the embedding server and updating the database. 226 | 227 | :param photo: A dictionary containing photo information. 228 | """ 229 | logger.debug(f"Processing photo: {photo['filename']}") 230 | if photo['embeddings']: 231 | logger.debug(f"Photo {photo['filename']} already has embeddings. Skipping.") 232 | return 233 | 234 | try: 235 | start_time = time.time() 236 | imemb = clip.image_encoder(photo['file_path']) 237 | photo['embeddings'] = imemb 238 | update_db(photo) 239 | end_time = time.time() 240 | logger.debug(f"Processed embeddings for {photo['filename']} in {end_time - start_time:.5f} seconds") 241 | except Exception as e: 242 | logger.error(f"Error generating embeddings for {photo['filename']}: {e}") 243 | 244 | 245 | def main(): 246 | """ 247 | Main function to process images and embeddings. 248 | """ 249 | cache_start_time = time.time() 250 | cached_files = hydrate_cache(SOURCE_IMAGE_DIRECTORY, FILELIST_CACHE_FILEPATH) 251 | cache_end_time = time.time() 252 | logger.info(f"Cache operation took {cache_end_time - cache_start_time:.2f} seconds") 253 | logger.info(f"Directory has {len(cached_files)} files: {SOURCE_IMAGE_DIRECTORY}") 254 | 255 | create_table() 256 | 257 | with ThreadPoolExecutor() as executor: 258 | futures = [] 259 | for file_path in cached_files: 260 | if file_path.lower().endswith('.jpg'): 261 | future = executor.submit(process_image, file_path) 262 | futures.append(future) 263 | for future in futures: 264 | future.result() 265 | with connection: 266 | cursor = connection.cursor() 267 | cursor.execute("SELECT filename, file_path, file_date, file_md5, embeddings FROM images") 268 | photos = [{'filename': row[0], 'file_path': row[1], 'file_date': row[2], 'file_md5': row[3], 'embeddings': msgpack.loads(row[4]) if row[4] else []} for row in cursor.fetchall()] 269 | # for photo in photos: 270 | # photo['embeddings'] = msgpack.loads(photo['embeddings']) if photo['embeddings'] else [] 271 | 272 | num_photos = len(photos) 273 | 274 | logger.info(f"Loaded {len(photos)} photos from database") 275 | #cant't use ThreadPoolExecutor here because of the MLX memory thing 276 | start_time = time.time() 277 | photo_ite = 0 278 | for photo in photos: 279 | process_embeddings(photo) 280 | photo_ite += 1 281 | if log_level != 'DEBUG': 282 | if photo_ite % 100 == 0: 283 | logger.info(f"Processed {photo_ite}/{num_photos} photos") 284 | end_time = time.time() 285 | logger.info(f"Generated embeddings for {len(photos)} photos in {end_time - start_time:.2f} seconds") 286 | connection.close() 287 | logger.info("Database connection pool closed.") 288 | 289 | 290 | logger.info(f"Initializing Chrome DB: {CHROMA_COLLECTION_NAME}") 291 | client = chromadb.PersistentClient(path=CHROMA_DB_PATH) 292 | collection = client.get_or_create_collection(name=CHROMA_COLLECTION_NAME) 293 | 294 | logger.info(f"Generated embeddings for {len(photos)} photos") 295 | start_time = time.time() 296 | 297 | photo_ite = 0 298 | for photo in photos: 299 | # Skip processing if the photo does not have embeddings 300 | if not photo['embeddings']: 301 | logger.debug(f"[{photo_ite}/{num_photos}] Photo {photo['filename']} has no embeddings. Skipping addition to Chroma.") 302 | continue 303 | 304 | try: 305 | # Add the photo's embeddings to the Chroma collection 306 | item = collection.get(ids=[photo['filename']]) 307 | if item['ids'] !=[]: 308 | continue 309 | collection.add( 310 | embeddings=[photo["embeddings"]], 311 | documents=[photo['filename']], 312 | ids=[photo['filename']] 313 | ) 314 | logger.debug(f"[{photo_ite}/{num_photos}] Added embedding to Chroma for {photo['filename']}") 315 | photo_ite += 1 316 | if log_level != 'DEBUG': 317 | if photo_ite % 100 == 0: 318 | logger.info(f"Processed {photo_ite}/{num_photos} photos") 319 | except Exception as e: 320 | # Log an error if the addition to Chroma fails 321 | logger.error(f"[{photo_ite}/{num_photos}] Failed to add embedding to Chroma for {photo['filename']}: {e}") 322 | end_time = time.time() 323 | logger.info(f"Inserted embeddings {len(photos)} photos into Chroma in {end_time - start_time:.2f} seconds") 324 | 325 | if __name__ == "__main__": 326 | main() 327 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chromadb 2 | open-clip-torch 3 | transformers 4 | python-dotenv 5 | flask 6 | msgpack 7 | mlx 8 | numpy 9 | torch 10 | huggingface_hub 11 | Pillow 12 | git+https://github.com/harperreed/mlx_clip.git 13 | -------------------------------------------------------------------------------- /start_web.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import logging 4 | import os 5 | import requests 6 | import random 7 | import signal 8 | import socket 9 | import sqlite3 10 | import time 11 | import uuid 12 | from concurrent.futures import ThreadPoolExecutor 13 | from dotenv import load_dotenv 14 | from flask import jsonify, g, send_file 15 | from flask import Flask, render_template, request, redirect, url_for 16 | from io import BytesIO 17 | from logging.handlers import RotatingFileHandler 18 | import msgpack 19 | import numpy as np 20 | import chromadb 21 | from PIL import Image, ImageOps 22 | import mlx_clip 23 | 24 | 25 | 26 | # Generate unique ID for the machine 27 | host_name = socket.gethostname() 28 | unique_id = uuid.uuid5(uuid.NAMESPACE_DNS, host_name + str(uuid.getnode())) 29 | 30 | # Configure logging 31 | log_app_name = "web" 32 | log_level = os.getenv('LOG_LEVEL', 'DEBUG') 33 | log_level = getattr(logging, log_level.upper()) 34 | 35 | file_handler = RotatingFileHandler(f"{log_app_name}_{unique_id}.log", maxBytes=10485760, backupCount=10) 36 | file_handler.setLevel(log_level) 37 | file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 38 | file_handler.setFormatter(file_formatter) 39 | 40 | console_handler = logging.StreamHandler() 41 | console_handler.setLevel(log_level) 42 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 43 | console_handler.setFormatter(console_formatter) 44 | 45 | logger = logging.getLogger(log_app_name) 46 | logger.setLevel(log_level) 47 | logger.addHandler(file_handler) 48 | logger.addHandler(console_handler) 49 | 50 | # Load environment variables 51 | load_dotenv() 52 | 53 | 54 | 55 | logger.info(f"Running on machine ID: {unique_id}") 56 | 57 | # Retrieve values from .env 58 | DATA_DIR = os.getenv('DATA_DIR', './') 59 | SQLITE_DB_FILENAME = os.getenv('DB_FILENAME', 'images.db') 60 | FILELIST_CACHE_FILENAME = os.getenv('CACHE_FILENAME', 'filelist_cache.msgpack') 61 | SOURCE_IMAGE_DIRECTORY = os.getenv('IMAGE_DIRECTORY', 'images') 62 | CHROMA_DB_PATH = os.getenv('CHROME_PATH', f"{DATA_DIR}{unique_id}_chroma") 63 | CHROMA_COLLECTION_NAME = os.getenv('CHROME_COLLECTION', "images") 64 | NUM_IMAGE_RESULTS = int(os.getenv('NUM_IMAGE_RESULTS', 52)) 65 | CLIP_MODEL = os.getenv('CLIP_MODEL', "openai/clip-vit-base-patch32") 66 | 67 | logger.debug("Configuration loaded.") 68 | # Log the configuration for debugging 69 | logger.debug(f"Configuration - DATA_DIR: {DATA_DIR}") 70 | logger.debug(f"Configuration - DB_FILENAME: {SQLITE_DB_FILENAME}") 71 | logger.debug(f"Configuration - CACHE_FILENAME: {FILELIST_CACHE_FILENAME}") 72 | logger.debug(f"Configuration - SOURCE_IMAGE_DIRECTORY: {SOURCE_IMAGE_DIRECTORY}") 73 | logger.debug(f"Configuration - CHROME_PATH: {CHROMA_DB_PATH}") 74 | logger.debug(f"Configuration - CHROME_COLLECTION: {CHROMA_COLLECTION_NAME}") 75 | logger.debug(f"Configuration - NUM_IMAGE_RESULTS: {NUM_IMAGE_RESULTS}") 76 | logger.debug(f"Configuration - CLIP_MODEL: {CLIP_MODEL}") 77 | logger.debug("Configuration loaded.") 78 | 79 | # Append the unique ID to the db file path and cache file path 80 | SQLITE_DB_FILEPATH = f"{DATA_DIR}{str(unique_id)}_{SQLITE_DB_FILENAME}" 81 | FILELIST_CACHE_FILEPATH = os.path.join(DATA_DIR, f"{unique_id}_{FILELIST_CACHE_FILENAME}") 82 | 83 | # Create a connection pool for the SQLite database 84 | connection = sqlite3.connect(SQLITE_DB_FILEPATH) 85 | 86 | app = Flask(__name__) 87 | 88 | # Graceful shutdown handler 89 | def graceful_shutdown(signum, frame): 90 | logger.info("Caught signal, shutting down gracefully...") 91 | if 'conn_pool' in globals(): 92 | connection.close() 93 | logger.info("Database connection pool closed.") 94 | exit(0) 95 | 96 | # Register the signal handlers for graceful shutdown 97 | signal.signal(signal.SIGINT, graceful_shutdown) 98 | signal.signal(signal.SIGTERM, graceful_shutdown) 99 | 100 | #Instantiate MLX Clip model 101 | clip = mlx_clip.mlx_clip("mlx_model", hf_repo=CLIP_MODEL) 102 | 103 | logger.info(f"Initializing Chrome DB: {CHROMA_COLLECTION_NAME}") 104 | client = chromadb.PersistentClient(path=CHROMA_DB_PATH) 105 | collection = client.get_or_create_collection(name=CHROMA_COLLECTION_NAME) 106 | items = collection.get()["ids"] 107 | 108 | print(len(items)) 109 | # WEBS 110 | 111 | 112 | @app.teardown_appcontext 113 | def close_connection(exception): 114 | db = getattr(g, "_database", None) 115 | if db is not None: 116 | db.close() 117 | 118 | 119 | @app.route("/") 120 | def index(): 121 | images = collection.get()["ids"] 122 | print(NUM_IMAGE_RESULTS) 123 | print(len(images)) 124 | random_items = random.sample(images, NUM_IMAGE_RESULTS) 125 | print(random_items) 126 | # Display a form or some introduction text 127 | return render_template("index.html", images=random_items) 128 | 129 | 130 | @app.route("/image/") 131 | def serve_specific_image(filename): 132 | # Construct the filepath and check if it exists 133 | print(filename) 134 | 135 | filepath = os.path.join(SOURCE_IMAGE_DIRECTORY, filename) 136 | print(filepath) 137 | if not os.path.exists(filepath): 138 | return "Image not found", 404 139 | 140 | image = collection.get(ids=[filename], include=["embeddings"]) 141 | results = collection.query( 142 | query_embeddings=image["embeddings"], n_results=(NUM_IMAGE_RESULTS + 1) 143 | ) 144 | 145 | images = [] 146 | for ids in results["ids"]: 147 | for id in ids: 148 | # Adjust the path as needed 149 | image_url = url_for("serve_image", filename=id) 150 | images.append({"url": image_url, "id": id}) 151 | 152 | # Use the proxy function to serve the image if it exists 153 | image_url = url_for("serve_image", filename=filename) 154 | 155 | # Render the template with the specific image 156 | return render_template("display_image.html", image=image_url, images=images[1:]) 157 | 158 | 159 | @app.route("/random") 160 | def random_image(): 161 | images = collection.get()["ids"] 162 | image = random.choice(images) if images else None 163 | 164 | if image: 165 | return redirect(url_for("serve_specific_image", filename=image)) 166 | else: 167 | return "No images found", 404 168 | 169 | 170 | @app.route("/text-query", methods=["GET"]) 171 | def text_query(): 172 | 173 | # Assuming there's an input for embeddings; this part is tricky and needs customization 174 | # You might need to adjust how embeddings are received or generated based on user input 175 | text = request.args.get("text") # Adjusted to use GET parameters 176 | 177 | # Use the MLX Clip model to generate embeddings from the text 178 | embeddings = clip.text_encoder(text) 179 | 180 | results = collection.query(query_embeddings=embeddings, n_results=(NUM_IMAGE_RESULTS + 1)) 181 | images = [] 182 | for ids in results["ids"]: 183 | for id in ids: 184 | # Adjust the path as needed 185 | image_url = url_for("serve_image", filename=id) 186 | images.append({"url": image_url, "id": id}) 187 | 188 | return render_template( 189 | "query_results.html", images=images, text=text, title="Text Query Results" 190 | ) 191 | 192 | 193 | @app.route("/img/") 194 | def serve_image(filename): 195 | """ 196 | Serve a resized image directly from the filesystem outside of the static directory. 197 | """ 198 | 199 | 200 | # Construct the full file path. Be careful with security implications. 201 | # Ensure that you validate `filename` to prevent directory traversal attacks. 202 | filepath = os.path.join(SOURCE_IMAGE_DIRECTORY, filename) 203 | if not os.path.exists(filepath): 204 | # You can return a default image or a 404 error if the file does not exist. 205 | return "Image not found", 404 206 | 207 | # Check the file size 208 | file_size = os.path.getsize(filepath) 209 | if file_size > 1 * 1024 * 1024: # File size is greater than 1 megabyte 210 | with Image.open(filepath) as img: 211 | # Resize the image to half the original size 212 | img.thumbnail((img.width // 2, img.height // 2)) 213 | img = ImageOps.exif_transpose(img) 214 | # Save the resized image to a BytesIO object 215 | img_io = BytesIO() 216 | img.save(img_io, 'JPEG', quality=85) 217 | img_io.seek(0) 218 | return send_file(img_io, mimetype='image/jpeg') 219 | 220 | return send_file(filepath) 221 | 222 | 223 | if __name__ == "__main__": 224 | app.run(debug=True, host="0.0.0.0") 225 | -------------------------------------------------------------------------------- /templates/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Features ✨ 3 | 4 | - 🌐 Web-based interface for easy access 5 | - 🔍 Search images using natural language queries 6 | - 🎨 Display similar images based on semantic similarity 7 | - 🔀 Explore random images for serendipitous discoveries 8 | - 📱 Share interesting images with others (when supported by the browser) 9 | 10 | ## Structure 📂 11 | 12 | ``` 13 | templates/ 14 | ├── base.html 15 | ├── display_image.html 16 | ├── index.html 17 | └── query_results.html 18 | ``` 19 | 20 | - `base.html`: The base template that provides the common structure and layout for all pages. 21 | - `display_image.html`: Template for displaying a single image and its similar images. 22 | - `index.html`: The main page template that showcases a grid of random images. 23 | - `query_results.html`: Template for displaying the search results based on a user's query. 24 | -------------------------------------------------------------------------------- /templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {% block title %}{% endblock %} ~ Semantic Image Search 5 | 9 | 10 | 11 | 12 |
13 |
16 |
19 | 34 | 37 | 38 | 39 | Image Search 40 | 41 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
65 |
66 |
67 |
68 |
69 | 70 | {% block header %}{% endblock %} 71 | 72 |
73 | 74 |
{% block content %}{% endblock %}
75 |
76 |
77 |
78 | 79 | {% block js %}{% endblock %} 80 | 81 | 82 | -------------------------------------------------------------------------------- /templates/display_image.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} {% block title %} Display Image {% endblock %} {% 2 | block header %} {% endblock %} {% block content %} 3 | image 9 | 10 | 19 | 20 | 34 |
35 | 36 |

50 Similar Images

37 | 38 |
39 |
40 | {% for image in images %} 41 | 42 |
46 | image 48 |
49 |
50 | {% endfor %} 51 |
52 | {% endblock %} 53 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | 3 | {% block title %} 4 | Image Explorer 5 | {% endblock %} 6 | 7 | {% block header %} 8 |

50 random images

9 | {% endblock %} 10 | 11 | 12 | {% block content %} 13 |
14 | {% for image in images %} 15 | 16 |
20 | image 22 |
23 |
24 | {% endfor %} 25 |
26 | 27 | 28 | {% endblock %} 29 | 30 | {% block js %} 31 | {% endblock %} 32 | -------------------------------------------------------------------------------- /templates/query_results.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | 3 | {% block title %} Results for {{text}} {% endblock %} 4 | 5 | 6 | {% block header %} 7 |

Results for {{text}}

8 | {% endblock %} 9 | 10 | {% block content %} 11 |
12 | {% for image in images %} 13 | 14 |
18 | image 20 |
21 |
22 | {% endfor %} 23 |
24 | {% endblock %} 25 | 26 | {% block js %} {% endblock %} 27 | --------------------------------------------------------------------------------