├── .gitignore
├── .env.example
├── requirements.txt
├── templates
├── query_results.html
├── index.html
├── README.md
├── display_image.html
└── base.html
├── LICENSE
├── README.md
├── start_web.py
└── generate_embeddings.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.log
2 | data
3 | chroma
4 | chroma-*.log
5 | *.pyc
6 | *.pyc
7 | mlx_model
8 | .env
9 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | DATA_DIR="./data/"
2 | DB_FILENAME=images.db
3 | IMAGE_DIRECTORY=./images
4 | ANONYMIZED_TELEMETRY=False
5 | LOG_LEVEL=INFO
6 | CLIP_MODEL="openai/clip-vit-large-patch14"
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | chromadb
2 | open-clip-torch
3 | transformers
4 | python-dotenv
5 | flask
6 | msgpack
7 | mlx
8 | numpy
9 | torch
10 | huggingface_hub
11 | Pillow
12 | git+https://github.com/harperreed/mlx_clip.git
13 |
--------------------------------------------------------------------------------
/templates/query_results.html:
--------------------------------------------------------------------------------
1 | {% extends 'base.html' %}
2 |
3 | {% block title %} Results for {{text}} {% endblock %}
4 |
5 |
6 | {% block header %}
7 |
Results for {{text}}
8 | {% endblock %}
9 |
10 | {% block content %}
11 |
24 | {% endblock %}
25 |
26 | {% block js %} {% endblock %}
27 |
--------------------------------------------------------------------------------
/templates/index.html:
--------------------------------------------------------------------------------
1 | {% extends 'base.html' %}
2 |
3 | {% block title %}
4 | Image Explorer
5 | {% endblock %}
6 |
7 | {% block header %}
8 | 50 random images
9 | {% endblock %}
10 |
11 |
12 | {% block content %}
13 |
26 |
27 |
28 | {% endblock %}
29 |
30 | {% block js %}
31 | {% endblock %}
32 |
--------------------------------------------------------------------------------
/templates/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Features ✨
3 |
4 | - 🌐 Web-based interface for easy access
5 | - 🔍 Search images using natural language queries
6 | - 🎨 Display similar images based on semantic similarity
7 | - 🔀 Explore random images for serendipitous discoveries
8 | - 📱 Share interesting images with others (when supported by the browser)
9 |
10 | ## Structure 📂
11 |
12 | ```
13 | templates/
14 | ├── base.html
15 | ├── display_image.html
16 | ├── index.html
17 | └── query_results.html
18 | ```
19 |
20 | - `base.html`: The base template that provides the common structure and layout for all pages.
21 | - `display_image.html`: Template for displaying a single image and its similar images.
22 | - `index.html`: The main page template that showcases a grid of random images.
23 | - `query_results.html`: Template for displaying the search results based on a user's query.
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Shi Sheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/templates/display_image.html:
--------------------------------------------------------------------------------
1 | {% extends 'base.html' %} {% block title %} Display Image {% endblock %} {%
2 | block header %} {% endblock %} {% block content %}
3 |
9 |
10 |
19 |
20 |
34 |
35 |
36 | 50 Similar Images
37 |
38 |
39 |
52 | {% endblock %}
53 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 📸 Embed-Photos 🖼️
2 |
3 | [](LICENSE)
4 |
5 | Welcome to Embed-Photos, a powerful photo similarity search engine built by [@harperreed](https://github.com/harperreed)! 🎉 This project leverages the CLIP (Contrastive Language-Image Pre-training) model to find visually similar images based on textual descriptions. 🔍🖼️
6 |
7 | ## 🌟 Features
8 |
9 | - 🚀 Fast and efficient image search using the CLIP model
10 | - 💻 Works on Apple Silicon (MLX) only
11 | - 💾 Persistent storage of image embeddings using SQLite and Chroma
12 | - 🌐 Web interface for easy interaction and exploration
13 | - 🔒 Secure image serving and handling
14 | - 📊 Logging and monitoring for performance analysis
15 | - 🔧 Configurable settings using environment variables
16 |
17 | ## Screenshot
18 |
19 | 
20 |
21 |
22 | ## 📂 Repository Structure
23 |
24 | ```
25 | embed-photos/
26 | ├── README.md
27 | ├── generate_embeddings.py
28 | ├── requirements.txt
29 | ├── start_web.py
30 | └── templates
31 | ├── README.md
32 | ├── base.html
33 | ├── display_image.html
34 | ├── index.html
35 | ├── output.txt
36 | └── query_results.html
37 | ```
38 |
39 | - `generate_embeddings.py`: Script to generate image embeddings using the CLIP model
40 | - `requirements.txt`: Lists the required Python dependencies
41 | - `start_web.py`: Flask web application for the photo similarity search
42 | - `templates/`: Contains HTML templates for the web interface
43 |
44 | ## 🚀 Getting Started
45 |
46 | 1. Clone the repository:
47 | ```
48 | git clone https://github.com/harperreed/photo-similarity-search.git
49 | ```
50 |
51 | 2. Install the required dependencies:
52 | ```
53 | pip install -r requirements.txt
54 | ```
55 |
56 | 3. Configure the application by setting the necessary environment variables in a `.env` file.
57 |
58 | 4. Generate image embeddings:
59 | ```
60 | python generate_embeddings.py
61 | ```
62 |
63 | 5. Start the web application:
64 | ```
65 | python start_web.py
66 | ```
67 |
68 | 6. Open your web browser and navigate to `http://localhost:5000` to explore the photo similarity search!
69 |
70 | ## Todo
71 |
72 | - Use siglip instead of clip
73 | - add a more robust config
74 | - make mlx optional
75 |
76 | ## 🙏 Acknowledgments
77 |
78 | The Embed-Photos project builds upon the work of the Apple (mlx!), the CLIP model and leverages various open-source libraries. We extend our gratitude to the authors and contributors of these projects.
79 |
80 | Happy searching! 🔍✨
81 |
--------------------------------------------------------------------------------
/templates/base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | {% block title %}{% endblock %} ~ Semantic Image Search
5 |
9 |
10 |
11 |
12 |
13 |
66 |
67 |
68 |
69 |
70 | {% block header %}{% endblock %}
71 |
72 |
73 |
74 |
{% block content %}{% endblock %}
75 |
76 |
77 |
78 |
79 | {% block js %}{% endblock %}
80 |
81 |
82 |
--------------------------------------------------------------------------------
/start_web.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import logging
4 | import os
5 | import requests
6 | import random
7 | import signal
8 | import socket
9 | import sqlite3
10 | import time
11 | import uuid
12 | from concurrent.futures import ThreadPoolExecutor
13 | from dotenv import load_dotenv
14 | from flask import jsonify, g, send_file
15 | from flask import Flask, render_template, request, redirect, url_for
16 | from io import BytesIO
17 | from logging.handlers import RotatingFileHandler
18 | import msgpack
19 | import numpy as np
20 | import chromadb
21 | from PIL import Image, ImageOps
22 | import mlx_clip
23 |
24 |
25 |
26 | # Generate unique ID for the machine
27 | host_name = socket.gethostname()
28 | unique_id = uuid.uuid5(uuid.NAMESPACE_DNS, host_name + str(uuid.getnode()))
29 |
30 | # Configure logging
31 | log_app_name = "web"
32 | log_level = os.getenv('LOG_LEVEL', 'DEBUG')
33 | log_level = getattr(logging, log_level.upper())
34 |
35 | file_handler = RotatingFileHandler(f"{log_app_name}_{unique_id}.log", maxBytes=10485760, backupCount=10)
36 | file_handler.setLevel(log_level)
37 | file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
38 | file_handler.setFormatter(file_formatter)
39 |
40 | console_handler = logging.StreamHandler()
41 | console_handler.setLevel(log_level)
42 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
43 | console_handler.setFormatter(console_formatter)
44 |
45 | logger = logging.getLogger(log_app_name)
46 | logger.setLevel(log_level)
47 | logger.addHandler(file_handler)
48 | logger.addHandler(console_handler)
49 |
50 | # Load environment variables
51 | load_dotenv()
52 |
53 |
54 |
55 | logger.info(f"Running on machine ID: {unique_id}")
56 |
57 | # Retrieve values from .env
58 | DATA_DIR = os.getenv('DATA_DIR', './')
59 | SQLITE_DB_FILENAME = os.getenv('DB_FILENAME', 'images.db')
60 | FILELIST_CACHE_FILENAME = os.getenv('CACHE_FILENAME', 'filelist_cache.msgpack')
61 | SOURCE_IMAGE_DIRECTORY = os.getenv('IMAGE_DIRECTORY', 'images')
62 | CHROMA_DB_PATH = os.getenv('CHROME_PATH', f"{DATA_DIR}{unique_id}_chroma")
63 | CHROMA_COLLECTION_NAME = os.getenv('CHROME_COLLECTION', "images")
64 | NUM_IMAGE_RESULTS = int(os.getenv('NUM_IMAGE_RESULTS', 52))
65 | CLIP_MODEL = os.getenv('CLIP_MODEL', "openai/clip-vit-base-patch32")
66 |
67 | logger.debug("Configuration loaded.")
68 | # Log the configuration for debugging
69 | logger.debug(f"Configuration - DATA_DIR: {DATA_DIR}")
70 | logger.debug(f"Configuration - DB_FILENAME: {SQLITE_DB_FILENAME}")
71 | logger.debug(f"Configuration - CACHE_FILENAME: {FILELIST_CACHE_FILENAME}")
72 | logger.debug(f"Configuration - SOURCE_IMAGE_DIRECTORY: {SOURCE_IMAGE_DIRECTORY}")
73 | logger.debug(f"Configuration - CHROME_PATH: {CHROMA_DB_PATH}")
74 | logger.debug(f"Configuration - CHROME_COLLECTION: {CHROMA_COLLECTION_NAME}")
75 | logger.debug(f"Configuration - NUM_IMAGE_RESULTS: {NUM_IMAGE_RESULTS}")
76 | logger.debug(f"Configuration - CLIP_MODEL: {CLIP_MODEL}")
77 | logger.debug("Configuration loaded.")
78 |
79 | # Append the unique ID to the db file path and cache file path
80 | SQLITE_DB_FILEPATH = f"{DATA_DIR}{str(unique_id)}_{SQLITE_DB_FILENAME}"
81 | FILELIST_CACHE_FILEPATH = os.path.join(DATA_DIR, f"{unique_id}_{FILELIST_CACHE_FILENAME}")
82 |
83 | # Create a connection pool for the SQLite database
84 | connection = sqlite3.connect(SQLITE_DB_FILEPATH)
85 |
86 | app = Flask(__name__)
87 |
88 | # Graceful shutdown handler
89 | def graceful_shutdown(signum, frame):
90 | logger.info("Caught signal, shutting down gracefully...")
91 | if 'conn_pool' in globals():
92 | connection.close()
93 | logger.info("Database connection pool closed.")
94 | exit(0)
95 |
96 | # Register the signal handlers for graceful shutdown
97 | signal.signal(signal.SIGINT, graceful_shutdown)
98 | signal.signal(signal.SIGTERM, graceful_shutdown)
99 |
100 | #Instantiate MLX Clip model
101 | clip = mlx_clip.mlx_clip("mlx_model", hf_repo=CLIP_MODEL)
102 |
103 | logger.info(f"Initializing Chrome DB: {CHROMA_COLLECTION_NAME}")
104 | client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
105 | collection = client.get_or_create_collection(name=CHROMA_COLLECTION_NAME)
106 | items = collection.get()["ids"]
107 |
108 | print(len(items))
109 | # WEBS
110 |
111 |
112 | @app.teardown_appcontext
113 | def close_connection(exception):
114 | db = getattr(g, "_database", None)
115 | if db is not None:
116 | db.close()
117 |
118 |
119 | @app.route("/")
120 | def index():
121 | images = collection.get()["ids"]
122 | print(NUM_IMAGE_RESULTS)
123 | print(len(images))
124 | random_items = random.sample(images, NUM_IMAGE_RESULTS)
125 | print(random_items)
126 | # Display a form or some introduction text
127 | return render_template("index.html", images=random_items)
128 |
129 |
130 | @app.route("/image/")
131 | def serve_specific_image(filename):
132 | # Construct the filepath and check if it exists
133 | print(filename)
134 |
135 | filepath = os.path.join(SOURCE_IMAGE_DIRECTORY, filename)
136 | print(filepath)
137 | if not os.path.exists(filepath):
138 | return "Image not found", 404
139 |
140 | image = collection.get(ids=[filename], include=["embeddings"])
141 | results = collection.query(
142 | query_embeddings=image["embeddings"], n_results=(NUM_IMAGE_RESULTS + 1)
143 | )
144 |
145 | images = []
146 | for ids in results["ids"]:
147 | for id in ids:
148 | # Adjust the path as needed
149 | image_url = url_for("serve_image", filename=id)
150 | images.append({"url": image_url, "id": id})
151 |
152 | # Use the proxy function to serve the image if it exists
153 | image_url = url_for("serve_image", filename=filename)
154 |
155 | # Render the template with the specific image
156 | return render_template("display_image.html", image=image_url, images=images[1:])
157 |
158 |
159 | @app.route("/random")
160 | def random_image():
161 | images = collection.get()["ids"]
162 | image = random.choice(images) if images else None
163 |
164 | if image:
165 | return redirect(url_for("serve_specific_image", filename=image))
166 | else:
167 | return "No images found", 404
168 |
169 |
170 | @app.route("/text-query", methods=["GET"])
171 | def text_query():
172 |
173 | # Assuming there's an input for embeddings; this part is tricky and needs customization
174 | # You might need to adjust how embeddings are received or generated based on user input
175 | text = request.args.get("text") # Adjusted to use GET parameters
176 |
177 | # Use the MLX Clip model to generate embeddings from the text
178 | embeddings = clip.text_encoder(text)
179 |
180 | results = collection.query(query_embeddings=embeddings, n_results=(NUM_IMAGE_RESULTS + 1))
181 | images = []
182 | for ids in results["ids"]:
183 | for id in ids:
184 | # Adjust the path as needed
185 | image_url = url_for("serve_image", filename=id)
186 | images.append({"url": image_url, "id": id})
187 |
188 | return render_template(
189 | "query_results.html", images=images, text=text, title="Text Query Results"
190 | )
191 |
192 |
193 | @app.route("/img/")
194 | def serve_image(filename):
195 | """
196 | Serve a resized image directly from the filesystem outside of the static directory.
197 | """
198 |
199 |
200 | # Construct the full file path. Be careful with security implications.
201 | # Ensure that you validate `filename` to prevent directory traversal attacks.
202 | filepath = os.path.join(SOURCE_IMAGE_DIRECTORY, filename)
203 | if not os.path.exists(filepath):
204 | # You can return a default image or a 404 error if the file does not exist.
205 | return "Image not found", 404
206 |
207 | # Check the file size
208 | file_size = os.path.getsize(filepath)
209 | if file_size > 1 * 1024 * 1024: # File size is greater than 1 megabyte
210 | with Image.open(filepath) as img:
211 | # Resize the image to half the original size
212 | img.thumbnail((img.width // 2, img.height // 2))
213 | img = ImageOps.exif_transpose(img)
214 | # Save the resized image to a BytesIO object
215 | img_io = BytesIO()
216 | img.save(img_io, 'JPEG', quality=85)
217 | img_io.seek(0)
218 | return send_file(img_io, mimetype='image/jpeg')
219 |
220 | return send_file(filepath)
221 |
222 |
223 | if __name__ == "__main__":
224 | app.run(debug=True, host="0.0.0.0")
225 |
--------------------------------------------------------------------------------
/generate_embeddings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import msgpack
3 | import socket
4 | import uuid
5 | import logging
6 | import time
7 | from dotenv import load_dotenv
8 | import sqlite3
9 | import hashlib
10 | import requests
11 | from io import BytesIO
12 | import signal
13 | from concurrent.futures import ThreadPoolExecutor
14 | from logging.handlers import RotatingFileHandler
15 | import chromadb
16 | import json
17 | import numpy as np
18 |
19 | import mlx_clip
20 |
21 |
22 | # Generate unique ID for the machine
23 | host_name = socket.gethostname()
24 | unique_id = uuid.uuid5(uuid.NAMESPACE_DNS, host_name + str(uuid.getnode()))
25 |
26 | # Configure logging
27 | log_app_name = "app"
28 | log_level = os.getenv('LOG_LEVEL', 'INFO')
29 | log_level = getattr(logging, log_level.upper())
30 |
31 | file_handler = RotatingFileHandler(f"{log_app_name}_{unique_id}.log", maxBytes=10485760, backupCount=10)
32 | file_handler.setLevel(log_level)
33 | file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
34 | file_handler.setFormatter(file_formatter)
35 |
36 | console_handler = logging.StreamHandler()
37 | console_handler.setLevel(log_level)
38 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
39 | console_handler.setFormatter(console_formatter)
40 |
41 | logger = logging.getLogger(log_app_name)
42 | logger.setLevel(log_level)
43 | logger.addHandler(file_handler)
44 | logger.addHandler(console_handler)
45 |
46 | # Load environment variables
47 | load_dotenv()
48 |
49 |
50 |
51 | logger.info(f"Running on machine ID: {unique_id}")
52 |
53 | # Retrieve values from .env
54 | DATA_DIR = os.getenv('DATA_DIR', './')
55 | SQLITE_DB_FILENAME = os.getenv('DB_FILENAME', 'images.db')
56 | FILELIST_CACHE_FILENAME = os.getenv('CACHE_FILENAME', 'filelist_cache.msgpack')
57 | SOURCE_IMAGE_DIRECTORY = os.getenv('IMAGE_DIRECTORY', 'images')
58 | CHROMA_DB_PATH = os.getenv('CHROME_PATH', f"{DATA_DIR}/{unique_id}_chroma")
59 | CHROMA_COLLECTION_NAME = os.getenv('CHROME_COLLECTION', "images")
60 | CLIP_MODEL = os.getenv('CLIP_MODEL', "openai/clip-vit-base-patch32")
61 |
62 | logger.debug("Configuration loaded.")
63 | # Log the configuration for debugging
64 | logger.debug(f"Configuration - DATA_DIR: {DATA_DIR}")
65 | logger.debug(f"Configuration - DB_FILENAME: {SQLITE_DB_FILENAME}")
66 | logger.debug(f"Configuration - CACHE_FILENAME: {FILELIST_CACHE_FILENAME}")
67 | logger.debug(f"Configuration - IMAGE_DIRECTORY: {SOURCE_IMAGE_DIRECTORY}")
68 | logger.debug(f"Configuration - CHROME_PATH: {CHROMA_DB_PATH}")
69 | logger.debug(f"Configuration - CHROME_COLLECTION: {CHROMA_COLLECTION_NAME}")
70 | logger.debug(f"Configuration - CLIP_MODEL: {CLIP_MODEL}")
71 | logger.debug("Configuration loaded.")
72 |
73 | # Append the unique ID to the db file path and cache file path
74 | SQLITE_DB_FILEPATH = f"{DATA_DIR}{str(unique_id)}_{SQLITE_DB_FILENAME}"
75 | FILELIST_CACHE_FILEPATH = os.path.join(DATA_DIR, f"{unique_id}_{FILELIST_CACHE_FILENAME}")
76 |
77 |
78 |
79 | # Graceful shutdown handler
80 | def graceful_shutdown(signum, frame):
81 | logger.info("Caught signal, shutting down gracefully...")
82 | if 'conn_pool' in globals():
83 | connection.close()
84 | logger.info("Database connection pool closed.")
85 | exit(0)
86 |
87 | # Register the signal handlers for graceful shutdown
88 | signal.signal(signal.SIGINT, graceful_shutdown)
89 | signal.signal(signal.SIGTERM, graceful_shutdown)
90 |
91 | #Instantiate MLX Clip model
92 | clip = mlx_clip.mlx_clip("mlx_model", hf_repo=CLIP_MODEL)
93 |
94 | # Check if data dir exists, if it doesn't - then create it
95 | if not os.path.exists(DATA_DIR):
96 | logger.info("Creating data directory ...")
97 | os.makedirs(DATA_DIR)
98 |
99 | # Create a connection pool for the SQLite database
100 | connection = sqlite3.connect(SQLITE_DB_FILEPATH)
101 |
102 | def create_table():
103 | """
104 | Creates the 'images' table in the SQLite database if it doesn't exist.
105 | """
106 | with connection:
107 | connection.execute('''
108 | CREATE TABLE IF NOT EXISTS images (
109 | id INTEGER PRIMARY KEY,
110 | filename TEXT NOT NULL,
111 | file_path TEXT NOT NULL,
112 | file_date TEXT NOT NULL,
113 | file_md5 TEXT NOT NULL,
114 | embeddings BLOB
115 | )
116 | ''')
117 | connection.execute('CREATE INDEX IF NOT EXISTS idx_filename ON images (filename)')
118 | connection.execute('CREATE INDEX IF NOT EXISTS idx_file_path ON images (file_path)')
119 | logger.info("Table 'images' ensured to exist.")
120 |
121 |
122 |
123 | def file_generator(directory):
124 | """
125 | Generates file paths for all files in the specified directory and its subdirectories.
126 |
127 | :param directory: The directory path to search for files.
128 | :return: A generator yielding file paths.
129 | """
130 | logger.debug(f"Generating file paths for directory: {directory}")
131 | for root, _, files in os.walk(directory):
132 | for file in files:
133 | yield os.path.join(root, file)
134 |
135 | def hydrate_cache(directory, cache_file_path):
136 | """
137 | Loads or generates a cache of file paths for the specified directory.
138 |
139 | :param directory: The directory path to search for files.
140 | :param cache_file_path: The path to the cache file.
141 | :return: A list of cached file paths.
142 | """
143 | logger.info(f"Hydrating cache for {directory} using {cache_file_path}...")
144 | if os.path.exists(cache_file_path):
145 | try:
146 | with open(cache_file_path, 'rb') as f:
147 | cached_files = msgpack.load(f)
148 | logger.info(f"Loaded cached files from {cache_file_path}")
149 | if len(cached_files) == 0:
150 | logger.warning(f"Cache file {cache_file_path} is empty. Regenerating cache...")
151 | cached_files = list(file_generator(directory))
152 | with open(cache_file_path, 'wb') as f:
153 | msgpack.dump(cached_files, f)
154 | logger.info(f"Regenerated cache with {len(cached_files)} files and dumped to {cache_file_path}")
155 | except (msgpack.UnpackException, IOError) as e:
156 | logger.error(f"Error loading cache file {cache_file_path}: {e}. Regenerating cache...")
157 | cached_files = list(file_generator(directory))
158 | with open(cache_file_path, 'wb') as f:
159 | msgpack.dump(cached_files, f)
160 | logger.info(f"Regenerated cache with {len(cached_files)} files and dumped to {cache_file_path}")
161 | else:
162 | logger.info(f"Cache file not found at {cache_file_path}. Creating cache dirlist for {directory}...")
163 | cached_files = list(file_generator(directory))
164 | try:
165 | with open(cache_file_path, 'wb') as f:
166 | msgpack.dump(cached_files, f)
167 | logger.info(f"Created cache with {len(cached_files)} files and dumped to {cache_file_path}")
168 | except IOError as e:
169 | logger.error(f"Error creating cache file {cache_file_path}: {e}. Proceeding without cache.")
170 | return cached_files
171 |
172 |
173 | def update_db(image):
174 | """
175 | Updates the database with the image embeddings.
176 |
177 | :param image: A dictionary containing image information.
178 | """
179 | try:
180 | embeddings_blob = sqlite3.Binary(msgpack.dumps(image.get('embeddings', [])))
181 | with sqlite3.connect(SQLITE_DB_FILEPATH) as conn:
182 | conn.execute("UPDATE images SET embeddings = ? WHERE filename = ?",
183 | (embeddings_blob, image['filename']))
184 | logger.debug(f"Database updated successfully for image: {image['filename']}")
185 | except sqlite3.Error as e:
186 | logger.error(f"Database update failed for image: {image['filename']}. Error: {e}")
187 |
188 | def process_image(file_path):
189 | """
190 | Processes an image file by extracting metadata and inserting it into the database.
191 |
192 | :param file_path: The path to the image file.
193 | """
194 | file = os.path.basename(file_path)
195 | file_date = time.ctime(os.path.getmtime(file_path))
196 | with open(file_path, 'rb') as f:
197 | file_content = f.read()
198 | file_md5 = hashlib.md5(file_content).hexdigest()
199 | conn = None
200 | try:
201 | conn = sqlite3.connect(SQLITE_DB_FILEPATH)
202 | with conn:
203 | cursor = conn.cursor()
204 | cursor.execute('''
205 | SELECT EXISTS(SELECT 1 FROM images WHERE filename=? AND file_path=? LIMIT 1)
206 | ''', (file, file_path))
207 | result = cursor.fetchone()
208 | file_exists = result[0] if result else False
209 | if not file_exists:
210 | cursor.execute('''
211 | INSERT INTO images (filename, file_path, file_date, file_md5)
212 | VALUES (?, ?, ?, ?)
213 | ''', (file, file_path, file_date, file_md5))
214 | logger.debug(f'Inserted {file} with metadata into the database.')
215 | else:
216 | logger.debug(f'File {file} already exists in the database. Skipping insertion.')
217 | except sqlite3.Error as e:
218 | logger.error(f'Error processing image {file}: {e}')
219 | finally:
220 | if conn:
221 | conn.close()
222 |
223 | def process_embeddings(photo):
224 | """
225 | Processes image embeddings by uploading them to the embedding server and updating the database.
226 |
227 | :param photo: A dictionary containing photo information.
228 | """
229 | logger.debug(f"Processing photo: {photo['filename']}")
230 | if photo['embeddings']:
231 | logger.debug(f"Photo {photo['filename']} already has embeddings. Skipping.")
232 | return
233 |
234 | try:
235 | start_time = time.time()
236 | imemb = clip.image_encoder(photo['file_path'])
237 | photo['embeddings'] = imemb
238 | update_db(photo)
239 | end_time = time.time()
240 | logger.debug(f"Processed embeddings for {photo['filename']} in {end_time - start_time:.5f} seconds")
241 | except Exception as e:
242 | logger.error(f"Error generating embeddings for {photo['filename']}: {e}")
243 |
244 |
245 | def main():
246 | """
247 | Main function to process images and embeddings.
248 | """
249 | cache_start_time = time.time()
250 | cached_files = hydrate_cache(SOURCE_IMAGE_DIRECTORY, FILELIST_CACHE_FILEPATH)
251 | cache_end_time = time.time()
252 | logger.info(f"Cache operation took {cache_end_time - cache_start_time:.2f} seconds")
253 | logger.info(f"Directory has {len(cached_files)} files: {SOURCE_IMAGE_DIRECTORY}")
254 |
255 | create_table()
256 |
257 | with ThreadPoolExecutor() as executor:
258 | futures = []
259 | for file_path in cached_files:
260 | if file_path.lower().endswith('.jpg'):
261 | future = executor.submit(process_image, file_path)
262 | futures.append(future)
263 | for future in futures:
264 | future.result()
265 | with connection:
266 | cursor = connection.cursor()
267 | cursor.execute("SELECT filename, file_path, file_date, file_md5, embeddings FROM images")
268 | photos = [{'filename': row[0], 'file_path': row[1], 'file_date': row[2], 'file_md5': row[3], 'embeddings': msgpack.loads(row[4]) if row[4] else []} for row in cursor.fetchall()]
269 | # for photo in photos:
270 | # photo['embeddings'] = msgpack.loads(photo['embeddings']) if photo['embeddings'] else []
271 |
272 | num_photos = len(photos)
273 |
274 | logger.info(f"Loaded {len(photos)} photos from database")
275 | #cant't use ThreadPoolExecutor here because of the MLX memory thing
276 | start_time = time.time()
277 | photo_ite = 0
278 | for photo in photos:
279 | process_embeddings(photo)
280 | photo_ite += 1
281 | if log_level != 'DEBUG':
282 | if photo_ite % 100 == 0:
283 | logger.info(f"Processed {photo_ite}/{num_photos} photos")
284 | end_time = time.time()
285 | logger.info(f"Generated embeddings for {len(photos)} photos in {end_time - start_time:.2f} seconds")
286 | connection.close()
287 | logger.info("Database connection pool closed.")
288 |
289 |
290 | logger.info(f"Initializing Chrome DB: {CHROMA_COLLECTION_NAME}")
291 | client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
292 | collection = client.get_or_create_collection(name=CHROMA_COLLECTION_NAME)
293 |
294 | logger.info(f"Generated embeddings for {len(photos)} photos")
295 | start_time = time.time()
296 |
297 | photo_ite = 0
298 | for photo in photos:
299 | # Skip processing if the photo does not have embeddings
300 | if not photo['embeddings']:
301 | logger.debug(f"[{photo_ite}/{num_photos}] Photo {photo['filename']} has no embeddings. Skipping addition to Chroma.")
302 | continue
303 |
304 | try:
305 | # Add the photo's embeddings to the Chroma collection
306 | item = collection.get(ids=[photo['filename']])
307 | if item['ids'] !=[]:
308 | continue
309 | collection.add(
310 | embeddings=[photo["embeddings"]],
311 | documents=[photo['filename']],
312 | ids=[photo['filename']]
313 | )
314 | logger.debug(f"[{photo_ite}/{num_photos}] Added embedding to Chroma for {photo['filename']}")
315 | photo_ite += 1
316 | if log_level != 'DEBUG':
317 | if photo_ite % 100 == 0:
318 | logger.info(f"Processed {photo_ite}/{num_photos} photos")
319 | except Exception as e:
320 | # Log an error if the addition to Chroma fails
321 | logger.error(f"[{photo_ite}/{num_photos}] Failed to add embedding to Chroma for {photo['filename']}: {e}")
322 | end_time = time.time()
323 | logger.info(f"Inserted embeddings {len(photos)} photos into Chroma in {end_time - start_time:.2f} seconds")
324 |
325 | if __name__ == "__main__":
326 | main()
327 |
--------------------------------------------------------------------------------