├── .python-version ├── 2405.19504v1.pdf ├── .gitignore ├── pyproject.toml ├── main.py ├── main_pylate.py ├── README.md └── fde_generator.py /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /2405.19504v1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sionic-ai/muvera-py/HEAD/2405.19504v1.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | models--raphaelsty--neural-cherche-colbert/ 3 | .DS_Store 4 | .idea/ 5 | .locks/ -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fde-playground" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | dependencies = [ 8 | "beir>=2.2.0", 9 | "datasets>=3.6.0", 10 | "fitz>=0.0.1.dev2", 11 | "langchain>=0.3.26", 12 | "neural-cherche>=1.4.3", 13 | "nltk>=3.9.1", 14 | "numpy>=2.3.1", 15 | "pymupdf>=1.26.3", 16 | "sentence-transformers>=5.0.0", 17 | "torch>=2.7.1", 18 | ] 19 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | from dataclasses import replace 3 | 4 | import nltk 5 | import numpy as np 6 | import torch 7 | import neural_cherche.models as neural_cherche_models 8 | import neural_cherche.rank as neural_cherche_rank 9 | from datasets import load_dataset 10 | import logging 11 | 12 | from fde_generator import ( 13 | FixedDimensionalEncodingConfig, 14 | generate_query_fde, 15 | generate_document_fde_batch, 16 | ) 17 | 18 | # --- Configuration --- 19 | DATASET_REPO_ID = "zeta-alpha-ai/NanoFiQA2018" 20 | COLBERT_MODEL_NAME = "raphaelsty/neural-cherche-colbert" 21 | TOP_K = 10 22 | DEVICE = "cuda" if torch.cuda.is_available() else "mps" 23 | 24 | # --- Setup --- 25 | logging.basicConfig( 26 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 27 | ) 28 | logging.info(f"Using device: {DEVICE}") 29 | 30 | 31 | # --- Helper Functions --- 32 | def load_nanobeir_dataset(repo_id: str) -> (dict, dict, dict): 33 | logging.info(f"Loading dataset from Hugging Face Hub: '{repo_id}'...") 34 | corpus_ds = load_dataset(repo_id, "corpus", split="train") 35 | queries_ds = load_dataset(repo_id, "queries", split="train") 36 | qrels_ds = load_dataset(repo_id, "qrels", split="train") 37 | 38 | corpus = { 39 | row["_id"]: {"title": row.get("title", ""), "text": row.get("text", "")} 40 | for row in corpus_ds 41 | } 42 | queries = {row["_id"]: row["text"] for row in queries_ds} 43 | qrels = {str(row["query-id"]): {str(row["corpus-id"]): 1} for row in qrels_ds} 44 | 45 | logging.info(f"Dataset loaded: {len(corpus)} documents, {len(queries)} queries.") 46 | return corpus, queries, qrels 47 | 48 | 49 | def evaluate_recall(results: dict, qrels: dict, k: int) -> float: 50 | hits, total_queries = 0, 0 51 | for query_id, ranked_docs in results.items(): 52 | relevant_docs = set(qrels.get(str(query_id), {}).keys()) 53 | if not relevant_docs: 54 | continue 55 | total_queries += 1 56 | top_k_docs = set(list(ranked_docs.keys())[:k]) 57 | if not relevant_docs.isdisjoint(top_k_docs): 58 | hits += 1 59 | return hits / total_queries if total_queries > 0 else 0.0 60 | 61 | 62 | def to_numpy(tensor_or_array) -> np.ndarray: 63 | """Safely convert a PyTorch Tensor or a NumPy array to a float32 NumPy array.""" 64 | if isinstance(tensor_or_array, torch.Tensor): 65 | return tensor_or_array.cpu().detach().numpy().astype(np.float32) 66 | elif isinstance(tensor_or_array, np.ndarray): 67 | return tensor_or_array.astype(np.float32) 68 | else: 69 | raise TypeError(f"Unsupported type for conversion: {type(tensor_or_array)}") 70 | 71 | 72 | class ColbertNativeRetriever: 73 | """Uses neural-cherche's native ColBERT ranking (non-FDE).""" 74 | 75 | def __init__(self, model_name=COLBERT_MODEL_NAME): 76 | model = neural_cherche_models.ColBERT( 77 | model_name_or_path=model_name, device=DEVICE 78 | ) 79 | self.ranker = neural_cherche_rank.ColBERT( 80 | key="id", on=["title", "text"], model=model 81 | ) 82 | self.doc_embeddings_map = {} 83 | self.documents_for_ranker = [] 84 | 85 | def index(self, corpus: dict): 86 | self.documents_for_ranker = [ 87 | {"id": doc_id, **corpus[doc_id]} for doc_id in corpus.keys() 88 | ] 89 | logging.info( 90 | f"[{self.__class__.__name__}] Generating ColBERT embeddings for all documents..." 91 | ) 92 | self.doc_embeddings_map = self.ranker.encode_documents( 93 | documents=self.documents_for_ranker 94 | ) 95 | 96 | def search(self, query: str) -> dict: 97 | query_embeddings = self.ranker.encode_queries(queries=[query]) 98 | scores = self.ranker( 99 | queries_embeddings=query_embeddings, 100 | documents_embeddings=self.doc_embeddings_map, 101 | documents=[self.documents_for_ranker] * len(query_embeddings), 102 | ) 103 | return {item["id"]: item["similarity"] for item in scores[0]} 104 | 105 | 106 | class ColbertFdeRetriever: 107 | """Uses a real ColBERT model to generate embeddings, then FDE for search.""" 108 | 109 | def __init__(self, model_name=COLBERT_MODEL_NAME): 110 | model = neural_cherche_models.ColBERT( 111 | model_name_or_path=model_name, device=DEVICE 112 | ) 113 | self.ranker = neural_cherche_rank.ColBERT( 114 | key="id", on=["title", "text"], model=model 115 | ) 116 | self.doc_config = FixedDimensionalEncodingConfig( 117 | dimension=128, 118 | num_repetitions=20, 119 | num_simhash_projections=7, 120 | seed=42, 121 | fill_empty_partitions=True, # Config for documents 122 | ) 123 | self.fde_index, self.doc_ids = None, [] 124 | 125 | def index(self, corpus: dict): 126 | self.doc_ids = list(corpus.keys()) 127 | documents_for_ranker = [ 128 | {"id": doc_id, **corpus[doc_id]} for doc_id in self.doc_ids 129 | ] 130 | 131 | logging.info( 132 | f"[{self.__class__.__name__}] Generating native multi-vector embeddings..." 133 | ) 134 | doc_embeddings_map = self.ranker.encode_documents( 135 | documents=documents_for_ranker 136 | ) 137 | doc_embeddings_list = [ 138 | to_numpy(doc_embeddings_map[doc_id]) for doc_id in self.doc_ids 139 | ] 140 | 141 | logging.info( 142 | f"[{self.__class__.__name__}] Generating FDEs from ColBERT embeddings in BATCH mode..." 143 | ) 144 | self.fde_index = generate_document_fde_batch( 145 | doc_embeddings_list, self.doc_config 146 | ) 147 | 148 | def search(self, query: str) -> dict: 149 | query_embeddings_map = self.ranker.encode_queries(queries=[query]) 150 | query_embeddings = to_numpy(list(query_embeddings_map.values())[0]) 151 | 152 | query_config = replace(self.doc_config, fill_empty_partitions=False) 153 | 154 | query_fde = generate_query_fde(query_embeddings, query_config) 155 | scores = self.fde_index @ query_fde 156 | return dict( 157 | sorted(zip(self.doc_ids, scores), key=lambda item: item[1], reverse=True) 158 | ) 159 | 160 | 161 | if __name__ == "__main__": 162 | nltk.data.find("tokenizers/punkt") 163 | 164 | corpus, queries, qrels = load_nanobeir_dataset(DATASET_REPO_ID) 165 | 166 | logging.info("Initializing retrieval models...") 167 | retrievers = { 168 | "1. ColBERT (Native)": ColbertNativeRetriever(), 169 | "2. ColBERT + FDE": ColbertFdeRetriever(), 170 | } 171 | 172 | timings, final_results = {}, {} 173 | 174 | logging.info("--- PHASE 1: INDEXING ---") 175 | for name, retriever in retrievers.items(): 176 | start_time = time.perf_counter() 177 | retriever.index(corpus) 178 | timings[name] = {"indexing_time": time.perf_counter() - start_time} 179 | logging.info( 180 | f"'{name}' indexing finished in {timings[name]['indexing_time']:.2f} seconds." 181 | ) 182 | 183 | logging.info("--- PHASE 2: SEARCH & EVALUATION ---") 184 | for name, retriever in retrievers.items(): 185 | logging.info(f"Running search for '{name}' on {len(queries)} queries...") 186 | query_times = [] 187 | results = {} 188 | for query_id, query_text in queries.items(): 189 | start_time = time.perf_counter() 190 | results[str(query_id)] = retriever.search(query_text) 191 | query_times.append(time.perf_counter() - start_time) 192 | 193 | timings[name]["avg_query_time"] = np.mean(query_times) 194 | final_results[name] = results 195 | logging.info( 196 | f"'{name}' search finished. Avg query time: {timings[name]['avg_query_time'] * 1000:.2f} ms." 197 | ) 198 | 199 | print("\n" + "=" * 85) 200 | print(f"{'FINAL REPORT':^85}") 201 | print(f"(Dataset: {DATASET_REPO_ID})") 202 | print("=" * 85) 203 | print( 204 | f"{'Retriever':<25} | {'Indexing Time (s)':<20} | {'Avg Query Time (ms)':<22} | {'Recall@{k}'.format(k=TOP_K):<10}" 205 | ) 206 | print("-" * 85) 207 | 208 | for name in retrievers.keys(): 209 | recall = evaluate_recall(final_results[name], qrels, k=TOP_K) 210 | idx_time = timings[name]["indexing_time"] 211 | query_time_ms = timings[name]["avg_query_time"] * 1000 212 | 213 | print( 214 | f"{name:<25} | {idx_time:<20.2f} | {query_time_ms:<22.2f} | {recall:<10.4f}" 215 | ) 216 | 217 | print("=" * 85) 218 | -------------------------------------------------------------------------------- /main_pylate.py: -------------------------------------------------------------------------------- 1 | import time 2 | from dataclasses import replace 3 | import numpy as np 4 | import torch 5 | from datasets import load_dataset 6 | import logging 7 | from pylate.models import ColBERT as PylateColBERT 8 | 9 | from fde_generator import ( 10 | FixedDimensionalEncodingConfig, 11 | generate_query_fde, 12 | generate_document_fde_batch, 13 | ) 14 | 15 | DATASET_REPO_ID = "zeta-alpha-ai/NanoFiQA2018" 16 | COLBERT_MODEL_NAME = "ayushexel/colbert-ModernBERT-base-1-neg-1-epoch-gooaq-1995000" # Supports pylate models 17 | TOP_K = 10 18 | DEVICE = "cuda" if torch.cuda.is_available() else "mps" 19 | 20 | logging.basicConfig( 21 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 22 | ) 23 | logging.info(f"Using device: {DEVICE}") 24 | 25 | 26 | # --- Helper Functions --- 27 | def load_nanobeir_dataset(repo_id: str) -> (dict, dict, dict): 28 | logging.info(f"Loading dataset from Hugging Face Hub: '{repo_id}'...") 29 | corpus_ds = load_dataset(repo_id, "corpus", split="train") 30 | queries_ds = load_dataset(repo_id, "queries", split="train") 31 | qrels_ds = load_dataset(repo_id, "qrels", split="train") 32 | 33 | corpus = { 34 | row["_id"]: {"title": row.get("title", ""), "text": row.get("text", "")} 35 | for row in corpus_ds 36 | } 37 | queries = {row["_id"]: row["text"] for row in queries_ds} 38 | qrels = {str(row["query-id"]): {str(row["corpus-id"]): 1} for row in qrels_ds} 39 | 40 | logging.info(f"Dataset loaded: {len(corpus)} documents, {len(queries)} queries.") 41 | return corpus, queries, qrels 42 | 43 | 44 | def evaluate_recall(results: dict, qrels: dict, k: int) -> float: 45 | hits, total_queries = 0, 0 46 | for query_id, ranked_docs in results.items(): 47 | relevant_docs = set(qrels.get(str(query_id), {}).keys()) 48 | if not relevant_docs: 49 | continue 50 | total_queries += 1 51 | top_k_docs = set(list(ranked_docs.keys())[:k]) 52 | if not relevant_docs.isdisjoint(top_k_docs): 53 | hits += 1 54 | return hits / total_queries if total_queries > 0 else 0.0 55 | 56 | 57 | def to_numpy(tensor_or_array) -> np.ndarray: 58 | """Safely convert a PyTorch Tensor or a NumPy array to a float32 NumPy array.""" 59 | if isinstance(tensor_or_array, torch.Tensor): 60 | return tensor_or_array.cpu().detach().numpy().astype(np.float32) 61 | elif isinstance(tensor_or_array, np.ndarray): 62 | return tensor_or_array.astype(np.float32) 63 | else: 64 | raise TypeError(f"Unsupported type for conversion: {type(tensor_or_array)}") 65 | 66 | 67 | class ColbertNativeRetriever: 68 | """Uses pylate's native ColBERT ranking (non-FDE).""" 69 | 70 | def __init__(self, model_name=COLBERT_MODEL_NAME): 71 | self.model = PylateColBERT(model_name_or_path=model_name, device=DEVICE) 72 | if hasattr(self.model[0].tokenizer, "model_max_length"): # For modernbert support 73 | self.model[0].tokenizer.model_input_names = ["input_ids", "attention_mask"] 74 | self.doc_embeddings_map = {} 75 | self.doc_ids = [] 76 | 77 | def index(self, corpus: dict): 78 | self.doc_ids = list(corpus.keys()) 79 | documents_for_ranker = [{"id": doc_id, **corpus[doc_id]} for doc_id in self.doc_ids] 80 | doc_texts = [f"{doc.get('title', '')} {doc.get('text', '')}".strip() for doc in documents_for_ranker] 81 | 82 | logging.info( 83 | f"[{self.__class__.__name__}] Generating ColBERT embeddings for all documents..." 84 | ) 85 | doc_embeddings_list = self.model.encode( 86 | sentences=doc_texts, 87 | is_query=False, 88 | convert_to_tensor=True, 89 | normalize_embeddings=True, 90 | ) 91 | self.doc_embeddings_map = dict(zip(self.doc_ids, doc_embeddings_list)) 92 | 93 | def search(self, query: str) -> dict: 94 | query_embedding = self.model.encode( 95 | sentences=query, 96 | is_query=True, 97 | convert_to_tensor=True, 98 | normalize_embeddings=True, 99 | ) 100 | 101 | scores = {} 102 | with torch.no_grad(): 103 | for doc_id, doc_embedding in self.doc_embeddings_map.items(): 104 | late_interaction = torch.einsum("sh,th->st", query_embedding.to(DEVICE), doc_embedding.to(DEVICE)) 105 | score = late_interaction.max(dim=1).values.sum() 106 | scores[doc_id] = score.item() 107 | 108 | return dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)) 109 | 110 | 111 | class ColbertFdeRetriever: 112 | """Uses a real ColBERT model to generate embeddings, then FDE for search.""" 113 | 114 | def __init__(self, model_name=COLBERT_MODEL_NAME): 115 | self.model = PylateColBERT(model_name_or_path=model_name, device=DEVICE) 116 | if hasattr(self.model[0].tokenizer, "model_max_length"): 117 | self.model[0].tokenizer.model_input_names = ["input_ids", "attention_mask"] 118 | self.doc_config = FixedDimensionalEncodingConfig( 119 | dimension=128, 120 | num_repetitions=20, 121 | num_simhash_projections=7, 122 | seed=42, 123 | fill_empty_partitions=True, # Config for documents 124 | ) 125 | self.fde_index, self.doc_ids = None, [] 126 | 127 | def index(self, corpus: dict): 128 | self.doc_ids = list(corpus.keys()) 129 | documents_for_ranker = [{"id": doc_id, **corpus[doc_id]} for doc_id in self.doc_ids] 130 | doc_texts = [f"{doc.get('title', '')} {doc.get('text', '')}".strip() for doc in documents_for_ranker] 131 | 132 | logging.info(f"[{self.__class__.__name__}] Generating native multi-vector embeddings...") 133 | doc_embeddings_list = self.model.encode( 134 | sentences=doc_texts, 135 | is_query=False, 136 | convert_to_numpy=True, 137 | normalize_embeddings=True, 138 | ) 139 | 140 | logging.info(f"[{self.__class__.__name__}] Generating FDEs from ColBERT embeddings in BATCH mode...") 141 | self.fde_index = generate_document_fde_batch(doc_embeddings_list, self.doc_config) 142 | 143 | def search(self, query: str) -> dict: 144 | query_embeddings = self.model.encode( 145 | sentences=query, 146 | is_query=True, 147 | convert_to_numpy=True, 148 | normalize_embeddings=True, 149 | ) 150 | 151 | query_config = replace(self.doc_config, fill_empty_partitions=False) 152 | query_fde = generate_query_fde(query_embeddings, query_config) 153 | scores = self.fde_index @ query_fde 154 | return dict(sorted(zip(self.doc_ids, scores), key=lambda item: item[1], reverse=True)) 155 | 156 | 157 | if __name__ == "__main__": 158 | corpus, queries, qrels = load_nanobeir_dataset(DATASET_REPO_ID) 159 | 160 | logging.info("Initializing retrieval models...") 161 | retrievers = { 162 | "1. ColBERT (Native)": ColbertNativeRetriever(), 163 | "2. ColBERT + FDE": ColbertFdeRetriever(), 164 | } 165 | 166 | timings, final_results = {}, {} 167 | 168 | logging.info("--- PHASE 1: INDEXING ---") 169 | for name, retriever in retrievers.items(): 170 | start_time = time.perf_counter() 171 | retriever.index(corpus) 172 | timings[name] = {"indexing_time": time.perf_counter() - start_time} 173 | logging.info(f"'{name}' indexing finished in {timings[name]['indexing_time']:.2f} seconds.") 174 | 175 | logging.info("--- PHASE 2: SEARCH & EVALUATION ---") 176 | for name, retriever in retrievers.items(): 177 | logging.info(f"Running search for '{name}' on {len(queries)} queries...") 178 | query_times = [] 179 | results = {} 180 | for query_id, query_text in queries.items(): 181 | start_time = time.perf_counter() 182 | results[str(query_id)] = retriever.search(query_text) 183 | query_times.append(time.perf_counter() - start_time) 184 | 185 | timings[name]["avg_query_time"] = np.mean(query_times) 186 | final_results[name] = results 187 | logging.info(f"'{name}' search finished. Avg query time: {timings[name]['avg_query_time'] * 1000:.2f} ms.") 188 | 189 | print("\n" + "=" * 85) 190 | print(f"{'FINAL REPORT':^85}") 191 | print(f"(Dataset: {DATASET_REPO_ID})") 192 | print("=" * 85) 193 | print( 194 | f"{'Retriever':<25} | {'Indexing Time (s)':<20} | {'Avg Query Time (ms)':<22} | {'Recall@{k}'.format(k=TOP_K):<10}" 195 | ) 196 | print("-" * 85) 197 | 198 | for name in retrievers.keys(): 199 | recall = evaluate_recall(final_results[name], qrels, k=TOP_K) 200 | idx_time = timings[name]["indexing_time"] 201 | query_time_ms = timings[name]["avg_query_time"] * 1000 202 | 203 | print( 204 | f"{name:<25} | {idx_time:<20.2f} | {query_time_ms:<22.2f} | {recall:<10.4f}" 205 | ) 206 | 207 | print("=" * 85) 208 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python Implementation of MUVERA: Multi-Vector Retrieval via Fixed Dimensional Encodings 2 | 3 | This Python implementation was created to make the FDE algorithm more accessible while maintaining complete fidelity to the original C++ implementation. Every function and parameter has been carefully mapped to ensure identical behavior. 4 | 5 | ## What is FDE? 6 | 7 | Fixed-Dimensional Encoding (FDE) solves a fundamental problem in modern search systems: how to efficiently search through billions of documents when each document is represented by hundreds of vectors (as in ColBERT-style models). 8 | 9 | ### The Problem 10 | - **Traditional search**: Document = 1 vector → Fast but inaccurate 11 | - **Modern multi-vector search**: Document = 100s of vectors → Accurate but extremely slow 12 | 13 | ### The FDE Solution 14 | FDE transforms multiple vectors into a single fixed-size vector while preserving the similarity relationships. The magic is that the dot product between two FDE vectors approximates the original Chamfer similarity between the multi-vector sets. 15 | 16 | ### Running Guide 17 | ``` 18 | $ uv run main.py 19 | 20 | 2025-07-06 13:10:09,942 - INFO - Using device: mps 21 | 2025-07-06 13:10:09,942 - INFO - Loading dataset from Hugging Face Hub: 'zeta-alpha-ai/NanoFiQA2018'... 22 | 2025-07-06 13:10:18,583 - INFO - Dataset loaded: 4598 documents, 50 queries. 23 | 2025-07-06 13:10:18,583 - INFO - Initializing retrieval models... 24 | 2025-07-06 13:10:20,095 - INFO - --- PHASE 1: INDEXING --- 25 | 2025-07-06 13:10:20,096 - INFO - [ColbertFdeRetriever] Generating native multi-vector embeddings... 26 | ColBERT documents embeddings: 100%|██████████| 144/144 [01:05<00:00, 2.21it/s] 27 | 2025-07-06 13:11:25,420 - INFO - [ColbertFdeRetriever] Generating FDEs from ColBERT embeddings in BATCH mode... 28 | 2025-07-06 13:11:25,420 - INFO - [FDE Batch] Starting batch FDE generation for 4598 documents 29 | 2025-07-06 13:11:25,420 - INFO - [FDE Batch] Using identity projection (dim=128) 30 | 2025-07-06 13:11:25,420 - INFO - [FDE Batch] Configuration: 20 repetitions, 128 partitions, projection_dim=128 31 | 2025-07-06 13:11:25,422 - INFO - [FDE Batch] Total vectors: 1177088, avg per doc: 256.0 32 | 2025-07-06 13:11:25,627 - INFO - [FDE Batch] Concatenation completed in 0.205s 33 | 2025-07-06 13:11:25,627 - INFO - [FDE Batch] Output FDE dimension: 327680 34 | 2025-07-06 13:11:25,627 - INFO - [FDE Batch] Processing repetition 1/20 35 | 2025-07-06 13:11:33,469 - INFO - [FDE Batch] Repetition timing breakdown: 36 | 2025-07-06 13:11:33,469 - INFO - - SimHash: 0.037s 37 | 2025-07-06 13:11:33,469 - INFO - - Projection: 0.000s 38 | 2025-07-06 13:11:33,469 - INFO - - Partition indices: 0.019s 39 | 2025-07-06 13:11:33,469 - INFO - - Aggregation: 2.101s 40 | 2025-07-06 13:11:33,469 - INFO - - Averaging: 5.655s 41 | 2025-07-06 13:11:33,469 - INFO - - Filled 462482 empty partitions 42 | 2025-07-06 13:12:04,662 - INFO - [FDE Batch] Processing repetition 6/20 43 | 2025-07-06 13:12:43,054 - INFO - [FDE Batch] Processing repetition 11/20 44 | 2025-07-06 13:13:22,420 - INFO - [FDE Batch] Processing repetition 16/20 45 | 2025-07-06 13:14:01,083 - INFO - [FDE Batch] Batch generation completed in 155.663s 46 | 2025-07-06 13:14:01,083 - INFO - [FDE Batch] Average time per document: 33.85ms 47 | 2025-07-06 13:14:01,083 - INFO - [FDE Batch] Throughput: 29.5 docs/sec 48 | 2025-07-06 13:14:01,083 - INFO - [FDE Batch] Output shape: (4598, 327680) 49 | 2025-07-06 13:14:01,188 - INFO - '2. ColBERT + FDE' indexing finished in 221.09 seconds. 50 | 2025-07-06 13:14:01,188 - INFO - --- PHASE 2: SEARCH & EVALUATION --- 51 | 2025-07-06 13:14:01,188 - INFO - Running search for '2. ColBERT + FDE' on 50 queries... 52 | ColBERT queries embeddings: 100%|██████████| 1/1 [00:00<00:00, 2.44it/s] 53 | ColBERT queries embeddings: 100%|██████████| 1/1 [00:00<00:00, 11.13it/s] 54 | ColBERT queries embeddings: 100%|██████████| 1/1 [00:00<00:00, 45.49it/s] 55 | ===================================================================================== 56 | FINAL REPORT 57 | (Dataset: zeta-alpha-ai/NanoFiQA2018) 58 | ===================================================================================== 59 | Retriever | Indexing Time (s) | Avg Query Time (ms) | Recall@10 60 | ------------------------------------------------------------------------------------- 61 | 1. ColBERT (Native) | 82.31 | 1618.29 | 0.7000 62 | ===================================================================================== 63 | Retriever | Indexing Time (s) | Avg Query Time (ms) | Recall@10 64 | ------------------------------------------------------------------------------------- 65 | 2. ColBERT + FDE | 221.09 | 189.97 | 0.6400 66 | ===================================================================================== 67 | 2025-07-06 13:14:10,688 - INFO - '2. ColBERT + FDE' search finished. Avg query time: 189.97 ms. 68 | 69 | Process finished with exit code 0 70 | ``` 71 | 72 | ## Detailed Implementation Guide 73 | 74 | ### 1. Configuration Classes 75 | 76 | #### EncodingType Enum 77 | ```python 78 | class EncodingType(Enum): 79 | DEFAULT_SUM = 0 # For queries: sum vectors in each partition 80 | AVERAGE = 1 # For documents: average vectors in each partition 81 | ``` 82 | **C++ Mapping**: Directly corresponds to `FixedDimensionalEncodingConfig::EncodingType` in the proto file. 83 | 84 | #### ProjectionType Enum 85 | ```python 86 | class ProjectionType(Enum): 87 | DEFAULT_IDENTITY = 0 # No dimensionality reduction 88 | AMS_SKETCH = 1 # Use AMS sketch for reduction 89 | ``` 90 | **C++ Mapping**: Maps to `FixedDimensionalEncodingConfig::ProjectionType`. 91 | 92 | #### FixedDimensionalEncodingConfig 93 | ```python 94 | @dataclass 95 | class FixedDimensionalEncodingConfig: 96 | dimension: int = 128 # Original vector dimension 97 | num_repetitions: int = 10 # Number of independent runs 98 | num_simhash_projections: int = 6 # Controls partition granularity 99 | seed: int = 42 # Random seed 100 | encoding_type: EncodingType = DEFAULT_SUM 101 | projection_type: ProjectionType = DEFAULT_IDENTITY 102 | projection_dimension: Optional[int] = None 103 | fill_empty_partitions: bool = False 104 | final_projection_dimension: Optional[int] = None 105 | ``` 106 | **C++ Mapping**: Direct equivalent of `FixedDimensionalEncodingConfig` message in the proto file. 107 | 108 | ### 2. Internal Helper Functions 109 | 110 | #### Gray Code Functions 111 | ```python 112 | def _append_to_gray_code(gray_code: int, bit: bool) -> int: 113 | return (gray_code << 1) + (int(bit) ^ (gray_code & 1)) 114 | ``` 115 | **C++ Mapping**: Exact implementation of `internal::AppendToGrayCode()`. 116 | 117 | ```python 118 | def _gray_code_to_binary(num: int) -> int: 119 | mask = num >> 1 120 | while mask != 0: 121 | num = num ^ mask 122 | mask >>= 1 123 | return num 124 | ``` 125 | **C++ Mapping**: Equivalent to `internal::GrayCodeToBinary()`. The C++ version uses `num ^ (num >> 1)`, while Python uses a loop for clarity. 126 | 127 | #### Random Matrix Generators 128 | 129 | ```python 130 | def _simhash_matrix_from_seed(dimension: int, num_projections: int, seed: int): 131 | rng = np.random.default_rng(seed) 132 | return rng.normal(loc=0.0, scale=1.0, size=(dimension, num_projections)) 133 | ``` 134 | **C++ Mapping**: Maps to `internal::SimHashMatrixFromSeed()`. Uses Gaussian distribution for LSH. 135 | 136 | ```python 137 | def _ams_projection_matrix_from_seed(dimension: int, projection_dim: int, seed: int): 138 | # Creates sparse random matrix with one non-zero per row 139 | ``` 140 | **C++ Mapping**: Corresponds to `internal::AMSProjectionMatrixFromSeed()`. 141 | 142 | #### Partition Index Calculation 143 | ```python 144 | def _simhash_partition_index_gray(sketch_vector: np.ndarray) -> int: 145 | partition_index = 0 146 | for val in sketch_vector: 147 | partition_index = _append_to_gray_code(partition_index, val > 0) 148 | return partition_index 149 | ``` 150 | **C++ Mapping**: Direct implementation of `internal::SimHashPartitionIndex()`. 151 | 152 | ### 3. Core Algorithm 153 | 154 | The `_generate_fde_internal()` function implements the main FDE generation logic: 155 | 156 | ```python 157 | def _generate_fde_internal(point_cloud: np.ndarray, config: FixedDimensionalEncodingConfig): 158 | # Step 1: Validate inputs (matches C++ parameter validation) 159 | # Step 2: Calculate dimensions 160 | # Step 3: For each repetition: 161 | # - Apply SimHash for space partitioning 162 | # - Apply optional dimensionality reduction 163 | # - Aggregate vectors by partition 164 | # - Apply averaging for document FDE 165 | # Step 4: Optional final projection 166 | ``` 167 | 168 | **C++ Mapping**: This function combines the logic from both `GenerateQueryFixedDimensionalEncoding()` and `GenerateDocumentFixedDimensionalEncoding()` in the C++ code. 169 | 170 | ### 4. Public API 171 | 172 | ```python 173 | def generate_query_fde(point_cloud: np.ndarray, config: FixedDimensionalEncodingConfig): 174 | # Forces encoding_type to DEFAULT_SUM 175 | ``` 176 | **C++ Mapping**: Equivalent to `GenerateQueryFixedDimensionalEncoding()`. 177 | 178 | ```python 179 | def generate_document_fde(point_cloud: np.ndarray, config: FixedDimensionalEncodingConfig): 180 | # Forces encoding_type to AVERAGE 181 | ``` 182 | **C++ Mapping**: Equivalent to `GenerateDocumentFixedDimensionalEncoding()`. 183 | 184 | ```python 185 | def generate_fde(point_cloud: np.ndarray, config: FixedDimensionalEncodingConfig): 186 | # Routes based on config.encoding_type 187 | ``` 188 | **C++ Mapping**: Equivalent to `GenerateFixedDimensionalEncoding()`. 189 | 190 | ## C++ to Python Mapping 191 | 192 | ### Key Differences and Similarities 193 | 194 | | Feature | C++ Implementation | Python Implementation | Notes | 195 | |---------|-------------------|----------------------|--------| 196 | | **Matrix Operations** | Eigen library | NumPy | Functionally equivalent | 197 | | **Memory Management** | Manual with Eigen::Map | Automatic (NumPy) | Python is simpler | 198 | | **Gray Code Conversion** | `num ^ (num >> 1)` | While loop | Both produce same result | 199 | | **Error Handling** | absl::Status | Python exceptions | Different style, same checks | 200 | | **Random Number Generation** | std::mt19937 | np.random.default_rng | Same distributions | 201 | | **Configuration** | Protocol Buffers | dataclass | Same fields and defaults | 202 | 203 | ### Exact Function Mappings 204 | 205 | | C++ Function | Python Function | Purpose | 206 | |--------------|-----------------|---------| 207 | | `GenerateFixedDimensionalEncoding()` | `generate_fde()` | Top-level routing function | 208 | | `GenerateQueryFixedDimensionalEncoding()` | `generate_query_fde()` | Query FDE generation | 209 | | `GenerateDocumentFixedDimensionalEncoding()` | `generate_document_fde()` | Document FDE generation | 210 | | `internal::SimHashPartitionIndex()` | `_simhash_partition_index_gray()` | Partition assignment | 211 | | `internal::DistanceToSimHashPartition()` | `_distance_to_simhash_partition()` | Hamming distance calculation | 212 | | `internal::ApplyCountSketchToVector()` | `_apply_count_sketch_to_vector()` | Final projection | 213 | 214 | ## Usage Examples 215 | 216 | ### Basic Usage 217 | 218 | ```python 219 | import numpy as np 220 | from fde_generator import FixedDimensionalEncodingConfig, generate_query_fde, generate_document_fde 221 | 222 | # 1. Create configuration 223 | config = FixedDimensionalEncodingConfig( 224 | dimension=128, # Vector dimension 225 | num_repetitions=10, # Number of independent partitionings 226 | num_simhash_projections=6, # Creates 2^6 = 64 partitions 227 | seed=42 228 | ) 229 | 230 | # 2. Prepare data 231 | # Query: 32 vectors of 128 dimensions each 232 | query_vectors = np.random.randn(32, 128).astype(np.float32) 233 | 234 | # Document: 80 vectors of 128 dimensions each 235 | doc_vectors = np.random.randn(80, 128).astype(np.float32) 236 | 237 | # 3. Generate FDEs 238 | query_fde = generate_query_fde(query_vectors, config) 239 | doc_fde = generate_document_fde(doc_vectors, config) 240 | 241 | # 4. Compute similarity (approximates Chamfer similarity) 242 | similarity_score = np.dot(query_fde, doc_fde) 243 | print(f"Similarity: {similarity_score}") 244 | ``` 245 | 246 | ### Advanced Usage with Dimensionality Reduction 247 | 248 | ```python 249 | from fde_generator import ProjectionType, replace 250 | 251 | # Use AMS Sketch for internal projection 252 | config_with_projection = replace( 253 | config, 254 | projection_type=ProjectionType.AMS_SKETCH, 255 | projection_dimension=16 # Reduce from 128 to 16 dimensions 256 | ) 257 | 258 | # Use Count Sketch for final projection 259 | config_with_final_projection = replace( 260 | config, 261 | final_projection_dimension=1024 # Final FDE will be 1024 dimensions 262 | ) 263 | ``` 264 | 265 | ## Algorithm Walkthrough 266 | 267 | ### Step-by-Step Process 268 | 269 | 1. **Input**: Multiple vectors representing a document/query 270 | - Example: 32 vectors of 128 dimensions each 271 | 272 | 2. **Space Partitioning** (per repetition): 273 | - Apply SimHash: Multiply by random Gaussian matrix 274 | - Convert to partition indices using Gray Code 275 | - Creates 2^k_sim partitions (e.g., 64 partitions) 276 | 277 | 3. **Vector Aggregation**: 278 | - **For Queries**: Sum all vectors in each partition 279 | - **For Documents**: Average all vectors in each partition 280 | 281 | 4. **Repetition**: 282 | - Repeat steps 2-3 with different random seeds 283 | - Concatenate results from all repetitions 284 | 285 | 5. **Output**: Single FDE vector 286 | - Dimension: `num_repetitions × num_partitions × projection_dim` 287 | 288 | ### Why It Works 289 | 290 | The key insight is that FDE preserves the local structure of the vector space through LSH (Locality Sensitive Hashing). Vectors that are close in the original space are likely to: 291 | 1. End up in the same partition 292 | 2. Contribute to the same parts of the FDE vector 293 | 3. Produce high dot products when their FDEs are compared 294 | 295 | ## Performance Characteristics 296 | 297 | - **FDE Generation Time**: O(n × d × r × k) where: 298 | - n = number of vectors 299 | - d = vector dimension 300 | - r = number of repetitions 301 | - k = number of SimHash projections 302 | 303 | - **Search Time**: O(1) using standard MIPS libraries 304 | - **Memory**: Configurable via projection dimensions 305 | 306 | ## References 307 | 308 | - **Original Paper**: [MUVERA: Multi-Vector Retrieval via Fixed Dimensional Encodings](https://arxiv.org/pdf/2405.19504) 309 | - **C++ Implementation**: [Google Graph Mining Repository](https://github.com/google/graph-mining/tree/main/sketching/point_cloud) 310 | - **Blog Post**: [MUVERA: Making multi-vector retrieval as fast as single-vector search](https://research.google/blog/muvera-making-multi-vector-retrieval-as-fast-as-single-vector-search/) 311 | 312 | ## Contributing 313 | 314 | Contributions are welcome! Please ensure any changes maintain compatibility with the C++ implementation. 315 | 316 | ## License 317 | 318 | This implementation follows the same Apache 2.0 license as the original C++ code. 319 | -------------------------------------------------------------------------------- /fde_generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import numpy as np 5 | from dataclasses import dataclass, replace 6 | from enum import Enum 7 | from typing import Optional, List 8 | 9 | 10 | class EncodingType(Enum): 11 | DEFAULT_SUM = 0 12 | AVERAGE = 1 13 | 14 | 15 | class ProjectionType(Enum): 16 | DEFAULT_IDENTITY = 0 17 | AMS_SKETCH = 1 18 | 19 | 20 | @dataclass 21 | class FixedDimensionalEncodingConfig: 22 | dimension: int = 128 23 | num_repetitions: int = 10 24 | num_simhash_projections: int = 6 25 | seed: int = 42 26 | encoding_type: EncodingType = EncodingType.DEFAULT_SUM 27 | projection_type: ProjectionType = ProjectionType.DEFAULT_IDENTITY 28 | projection_dimension: Optional[int] = None 29 | fill_empty_partitions: bool = False 30 | final_projection_dimension: Optional[int] = None 31 | 32 | 33 | def _append_to_gray_code(gray_code: int, bit: bool) -> int: 34 | return (gray_code << 1) + (int(bit) ^ (gray_code & 1)) 35 | 36 | 37 | def _gray_code_to_binary(num: int) -> int: 38 | mask = num >> 1 39 | while mask != 0: 40 | num = num ^ mask 41 | mask >>= 1 42 | return num 43 | 44 | 45 | def _simhash_matrix_from_seed( 46 | dimension: int, num_projections: int, seed: int 47 | ) -> np.ndarray: 48 | rng = np.random.default_rng(seed) 49 | return rng.normal(loc=0.0, scale=1.0, size=(dimension, num_projections)).astype( 50 | np.float32 51 | ) 52 | 53 | 54 | def _ams_projection_matrix_from_seed( 55 | dimension: int, projection_dim: int, seed: int 56 | ) -> np.ndarray: 57 | rng = np.random.default_rng(seed) 58 | out = np.zeros((dimension, projection_dim), dtype=np.float32) 59 | indices = rng.integers(0, projection_dim, size=dimension) 60 | signs = rng.choice([-1.0, 1.0], size=dimension) 61 | out[np.arange(dimension), indices] = signs 62 | return out 63 | 64 | 65 | def _apply_count_sketch_to_vector( 66 | input_vector: np.ndarray, final_dimension: int, seed: int 67 | ) -> np.ndarray: 68 | rng = np.random.default_rng(seed) 69 | out = np.zeros(final_dimension, dtype=np.float32) 70 | indices = rng.integers(0, final_dimension, size=input_vector.shape[0]) 71 | signs = rng.choice([-1.0, 1.0], size=input_vector.shape[0]) 72 | np.add.at(out, indices, signs * input_vector) 73 | return out 74 | 75 | 76 | def _simhash_partition_index_gray(sketch_vector: np.ndarray) -> int: 77 | partition_index = 0 78 | for val in sketch_vector: 79 | partition_index = _append_to_gray_code(partition_index, val > 0) 80 | return partition_index 81 | 82 | 83 | def _distance_to_simhash_partition( 84 | sketch_vector: np.ndarray, partition_index: int 85 | ) -> int: 86 | num_projections = sketch_vector.size 87 | binary_representation = _gray_code_to_binary(partition_index) 88 | sketch_bits = (sketch_vector > 0).astype(int) 89 | binary_array = (binary_representation >> np.arange(num_projections - 1, -1, -1)) & 1 90 | return int(np.sum(sketch_bits != binary_array)) 91 | 92 | 93 | def _generate_fde_internal( 94 | point_cloud: np.ndarray, config: FixedDimensionalEncodingConfig 95 | ) -> np.ndarray: 96 | if point_cloud.ndim != 2 or point_cloud.shape[1] != config.dimension: 97 | raise ValueError( 98 | f"Input data shape {point_cloud.shape} is inconsistent with config dimension {config.dimension}." 99 | ) 100 | if not (0 <= config.num_simhash_projections < 32): 101 | raise ValueError( 102 | f"num_simhash_projections must be in [0, 31]: {config.num_simhash_projections}" 103 | ) 104 | 105 | num_points, original_dim = point_cloud.shape 106 | num_partitions = 2**config.num_simhash_projections 107 | 108 | use_identity_proj = config.projection_type == ProjectionType.DEFAULT_IDENTITY 109 | projection_dim = original_dim if use_identity_proj else config.projection_dimension 110 | if not use_identity_proj and (not projection_dim or projection_dim <= 0): 111 | raise ValueError( 112 | "A positive projection_dimension is required for non-identity projections." 113 | ) 114 | 115 | final_fde_dim = config.num_repetitions * num_partitions * projection_dim 116 | out_fde = np.zeros(final_fde_dim, dtype=np.float32) 117 | 118 | for rep_num in range(config.num_repetitions): 119 | current_seed = config.seed + rep_num 120 | 121 | sketches = point_cloud @ _simhash_matrix_from_seed( 122 | original_dim, config.num_simhash_projections, current_seed 123 | ) 124 | 125 | if use_identity_proj: 126 | projected_matrix = point_cloud 127 | elif config.projection_type == ProjectionType.AMS_SKETCH: 128 | ams_matrix = _ams_projection_matrix_from_seed( 129 | original_dim, projection_dim, current_seed 130 | ) 131 | projected_matrix = point_cloud @ ams_matrix 132 | 133 | rep_fde_sum = np.zeros(num_partitions * projection_dim, dtype=np.float32) 134 | partition_counts = np.zeros(num_partitions, dtype=np.int32) 135 | partition_indices = np.array( 136 | [_simhash_partition_index_gray(sketches[i]) for i in range(num_points)] 137 | ) 138 | 139 | for i in range(num_points): 140 | start_idx = partition_indices[i] * projection_dim 141 | rep_fde_sum[start_idx : start_idx + projection_dim] += projected_matrix[i] 142 | partition_counts[partition_indices[i]] += 1 143 | 144 | if config.encoding_type == EncodingType.AVERAGE: 145 | for i in range(num_partitions): 146 | start_idx = i * projection_dim 147 | if partition_counts[i] > 0: 148 | rep_fde_sum[start_idx : start_idx + projection_dim] /= ( 149 | partition_counts[i] 150 | ) 151 | elif config.fill_empty_partitions and num_points > 0: 152 | distances = [ 153 | _distance_to_simhash_partition(sketches[j], i) 154 | for j in range(num_points) 155 | ] 156 | nearest_point_idx = np.argmin(distances) 157 | rep_fde_sum[start_idx : start_idx + projection_dim] = ( 158 | projected_matrix[nearest_point_idx] 159 | ) 160 | 161 | rep_start_index = rep_num * num_partitions * projection_dim 162 | out_fde[rep_start_index : rep_start_index + rep_fde_sum.size] = rep_fde_sum 163 | 164 | if config.final_projection_dimension and config.final_projection_dimension > 0: 165 | return _apply_count_sketch_to_vector( 166 | out_fde, config.final_projection_dimension, config.seed 167 | ) 168 | 169 | return out_fde 170 | 171 | 172 | def generate_query_fde( 173 | point_cloud: np.ndarray, config: FixedDimensionalEncodingConfig 174 | ) -> np.ndarray: 175 | """Generates a Fixed Dimensional Encoding for a query point cloud (using SUM).""" 176 | if config.fill_empty_partitions: 177 | raise ValueError( 178 | "Query FDE generation does not support 'fill_empty_partitions'." 179 | ) 180 | query_config = replace(config, encoding_type=EncodingType.DEFAULT_SUM) 181 | return _generate_fde_internal(point_cloud, query_config) 182 | 183 | 184 | def generate_document_fde( 185 | point_cloud: np.ndarray, config: FixedDimensionalEncodingConfig 186 | ) -> np.ndarray: 187 | """Generates a Fixed Dimensional Encoding for a document point cloud (using AVERAGE).""" 188 | doc_config = replace(config, encoding_type=EncodingType.AVERAGE) 189 | return _generate_fde_internal(point_cloud, doc_config) 190 | 191 | 192 | def generate_fde( 193 | point_cloud: np.ndarray, config: FixedDimensionalEncodingConfig 194 | ) -> np.ndarray: 195 | if config.encoding_type == EncodingType.DEFAULT_SUM: 196 | return generate_query_fde(point_cloud, config) 197 | elif config.encoding_type == EncodingType.AVERAGE: 198 | return generate_document_fde(point_cloud, config) 199 | else: 200 | raise ValueError(f"Unsupported encoding type in config: {config.encoding_type}") 201 | 202 | 203 | def generate_document_fde_batch( 204 | doc_embeddings_list: List[np.ndarray], config: FixedDimensionalEncodingConfig 205 | ) -> np.ndarray: 206 | """ 207 | Generates FDEs for a batch of documents using highly optimized NumPy vectorization. 208 | Fully compliant with C++ implementation including all projection types. 209 | """ 210 | batch_start_time = time.perf_counter() 211 | num_docs = len(doc_embeddings_list) 212 | 213 | if num_docs == 0: 214 | logging.warning("[FDE Batch] Empty document list provided") 215 | return np.array([]) 216 | 217 | logging.info(f"[FDE Batch] Starting batch FDE generation for {num_docs} documents") 218 | 219 | # Input validation 220 | valid_docs = [] 221 | for i, doc in enumerate(doc_embeddings_list): 222 | if doc.ndim != 2: 223 | logging.warning( 224 | f"[FDE Batch] Document {i} has invalid shape (ndim={doc.ndim}), skipping" 225 | ) 226 | continue 227 | if doc.shape[1] != config.dimension: 228 | raise ValueError( 229 | f"Document {i} has incorrect dimension: expected {config.dimension}, got {doc.shape[1]}" 230 | ) 231 | if doc.shape[0] == 0: 232 | logging.warning(f"[FDE Batch] Document {i} has no vectors, skipping") 233 | continue 234 | valid_docs.append(doc) 235 | 236 | if len(valid_docs) == 0: 237 | logging.warning("[FDE Batch] No valid documents after filtering") 238 | return np.array([]) 239 | 240 | num_docs = len(valid_docs) 241 | doc_embeddings_list = valid_docs 242 | 243 | # Determine projection dimension (matching C++ logic) 244 | use_identity_proj = config.projection_type == ProjectionType.DEFAULT_IDENTITY 245 | if use_identity_proj: 246 | projection_dim = config.dimension 247 | logging.info(f"[FDE Batch] Using identity projection (dim={projection_dim})") 248 | else: 249 | if not config.projection_dimension or config.projection_dimension <= 0: 250 | raise ValueError( 251 | "A positive projection_dimension must be specified for non-identity projections" 252 | ) 253 | projection_dim = config.projection_dimension 254 | logging.info( 255 | f"[FDE Batch] Using {config.projection_type.name} projection: " 256 | f"{config.dimension} -> {projection_dim}" 257 | ) 258 | 259 | # Configuration summary 260 | num_partitions = 2**config.num_simhash_projections 261 | logging.info( 262 | f"[FDE Batch] Configuration: {config.num_repetitions} repetitions, " 263 | f"{num_partitions} partitions, projection_dim={projection_dim}" 264 | ) 265 | 266 | # Document tracking 267 | doc_lengths = np.array([len(doc) for doc in doc_embeddings_list], dtype=np.int32) 268 | total_vectors = np.sum(doc_lengths) 269 | doc_boundaries = np.insert(np.cumsum(doc_lengths), 0, 0) 270 | doc_indices = np.repeat(np.arange(num_docs), doc_lengths) 271 | 272 | logging.info( 273 | f"[FDE Batch] Total vectors: {total_vectors}, avg per doc: {total_vectors / num_docs:.1f}" 274 | ) 275 | 276 | # Concatenate all embeddings 277 | concat_start = time.perf_counter() 278 | all_points = np.vstack(doc_embeddings_list).astype(np.float32) 279 | concat_time = time.perf_counter() - concat_start 280 | logging.info(f"[FDE Batch] Concatenation completed in {concat_time:.3f}s") 281 | 282 | # Pre-allocate output 283 | final_fde_dim = config.num_repetitions * num_partitions * projection_dim 284 | out_fdes = np.zeros((num_docs, final_fde_dim), dtype=np.float32) 285 | logging.info(f"[FDE Batch] Output FDE dimension: {final_fde_dim}") 286 | 287 | # Process each repetition 288 | for rep_num in range(config.num_repetitions): 289 | # rep_start_time = time.perf_counter() 290 | current_seed = config.seed + rep_num 291 | 292 | if rep_num % 5 == 0: # Log every 5 repetitions 293 | logging.info( 294 | f"[FDE Batch] Processing repetition {rep_num + 1}/{config.num_repetitions}" 295 | ) 296 | 297 | # Step 1: SimHash projection 298 | simhash_start = time.perf_counter() 299 | simhash_matrix = _simhash_matrix_from_seed( 300 | config.dimension, config.num_simhash_projections, current_seed 301 | ) 302 | all_sketches = all_points @ simhash_matrix 303 | simhash_time = time.perf_counter() - simhash_start 304 | 305 | # Step 2: Apply dimensionality reduction if configured 306 | proj_start = time.perf_counter() 307 | if use_identity_proj: 308 | projected_points = all_points 309 | elif config.projection_type == ProjectionType.AMS_SKETCH: 310 | ams_matrix = _ams_projection_matrix_from_seed( 311 | config.dimension, projection_dim, current_seed 312 | ) 313 | projected_points = all_points @ ams_matrix 314 | else: 315 | raise ValueError(f"Unsupported projection type: {config.projection_type}") 316 | proj_time = time.perf_counter() - proj_start 317 | 318 | # Step 3: Vectorized partition index calculation 319 | partition_start = time.perf_counter() 320 | bits = (all_sketches > 0).astype(np.uint32) 321 | partition_indices = np.zeros(total_vectors, dtype=np.uint32) 322 | 323 | # Vectorized Gray Code computation 324 | for bit_idx in range(config.num_simhash_projections): 325 | partition_indices = (partition_indices << 1) + ( 326 | bits[:, bit_idx] ^ (partition_indices & 1) 327 | ) 328 | 329 | partition_time = time.perf_counter() - partition_start 330 | 331 | # Step 4: Vectorized aggregation 332 | agg_start = time.perf_counter() 333 | 334 | # Initialize storage for this repetition 335 | rep_fde_sum = np.zeros( 336 | (num_docs * num_partitions * projection_dim,), dtype=np.float32 337 | ) 338 | partition_counts = np.zeros((num_docs, num_partitions), dtype=np.int32) 339 | 340 | # Count vectors per partition per document 341 | np.add.at(partition_counts, (doc_indices, partition_indices), 1) 342 | 343 | # Aggregate vectors using flattened indexing for efficiency 344 | doc_part_indices = doc_indices * num_partitions + partition_indices 345 | base_indices = doc_part_indices * projection_dim 346 | 347 | for d in range(projection_dim): 348 | flat_indices = base_indices + d 349 | np.add.at(rep_fde_sum, flat_indices, projected_points[:, d]) 350 | 351 | # Reshape for easier manipulation 352 | rep_fde_sum = rep_fde_sum.reshape(num_docs, num_partitions, projection_dim) 353 | 354 | agg_time = time.perf_counter() - agg_start 355 | 356 | # Step 5: Convert sums to averages (for document FDE) 357 | avg_start = time.perf_counter() 358 | 359 | # Vectorized division where counts > 0 360 | non_zero_mask = partition_counts > 0 361 | counts_3d = partition_counts[:, :, np.newaxis] # Broadcasting for division 362 | 363 | # Safe division (avoid divide by zero) 364 | np.divide(rep_fde_sum, counts_3d, out=rep_fde_sum, where=counts_3d > 0) 365 | 366 | # Fill empty partitions if configured 367 | empty_filled = 0 368 | if config.fill_empty_partitions: 369 | empty_mask = ~non_zero_mask 370 | empty_docs, empty_parts = np.where(empty_mask) 371 | 372 | for doc_idx, part_idx in zip(empty_docs, empty_parts): 373 | if doc_lengths[doc_idx] == 0: 374 | continue 375 | 376 | # Get sketches for this document 377 | doc_start = doc_boundaries[doc_idx] 378 | doc_end = doc_boundaries[doc_idx + 1] 379 | doc_sketches = all_sketches[doc_start:doc_end] 380 | 381 | # Vectorized distance calculation 382 | binary_rep = _gray_code_to_binary(part_idx) 383 | target_bits = ( 384 | binary_rep >> np.arange(config.num_simhash_projections - 1, -1, -1) 385 | ) & 1 386 | distances = np.sum( 387 | (doc_sketches > 0).astype(int) != target_bits, axis=1 388 | ) 389 | 390 | nearest_local_idx = np.argmin(distances) 391 | nearest_global_idx = doc_start + nearest_local_idx 392 | 393 | rep_fde_sum[doc_idx, part_idx, :] = projected_points[nearest_global_idx] 394 | empty_filled += 1 395 | 396 | avg_time = time.perf_counter() - avg_start 397 | 398 | # Step 6: Copy results to output array 399 | rep_output_start = rep_num * num_partitions * projection_dim 400 | out_fdes[ 401 | :, rep_output_start : rep_output_start + num_partitions * projection_dim 402 | ] = rep_fde_sum.reshape(num_docs, -1) 403 | 404 | # Log timing for first repetition 405 | if rep_num == 0: 406 | logging.info("[FDE Batch] Repetition timing breakdown:") 407 | logging.info(f" - SimHash: {simhash_time:.3f}s") 408 | logging.info(f" - Projection: {proj_time:.3f}s") 409 | logging.info(f" - Partition indices: {partition_time:.3f}s") 410 | logging.info(f" - Aggregation: {agg_time:.3f}s") 411 | logging.info(f" - Averaging: {avg_time:.3f}s") 412 | if config.fill_empty_partitions: 413 | logging.info(f" - Filled {empty_filled} empty partitions") 414 | 415 | # Step 7: Apply final projection if configured 416 | if config.final_projection_dimension and config.final_projection_dimension > 0: 417 | logging.info( 418 | f"[FDE Batch] Applying final projection: {final_fde_dim} -> " 419 | f"{config.final_projection_dimension}" 420 | ) 421 | final_proj_start = time.perf_counter() 422 | 423 | # Process in chunks to avoid memory issues 424 | chunk_size = min(100, num_docs) 425 | final_fdes = [] 426 | 427 | for i in range(0, num_docs, chunk_size): 428 | chunk_end = min(i + chunk_size, num_docs) 429 | chunk_fdes = np.array( 430 | [ 431 | _apply_count_sketch_to_vector( 432 | out_fdes[j], config.final_projection_dimension, config.seed 433 | ) 434 | for j in range(i, chunk_end) 435 | ] 436 | ) 437 | final_fdes.append(chunk_fdes) 438 | 439 | out_fdes = np.vstack(final_fdes) 440 | final_proj_time = time.perf_counter() - final_proj_start 441 | logging.info( 442 | f"[FDE Batch] Final projection completed in {final_proj_time:.3f}s" 443 | ) 444 | 445 | # Final statistics and validation 446 | total_time = time.perf_counter() - batch_start_time 447 | logging.info(f"[FDE Batch] Batch generation completed in {total_time:.3f}s") 448 | logging.info( 449 | f"[FDE Batch] Average time per document: {total_time / num_docs * 1000:.2f}ms" 450 | ) 451 | logging.info(f"[FDE Batch] Throughput: {num_docs / total_time:.1f} docs/sec") 452 | logging.info(f"[FDE Batch] Output shape: {out_fdes.shape}") 453 | 454 | # Validate output dimensions 455 | expected_dim = ( 456 | final_fde_dim 457 | if not config.final_projection_dimension 458 | else config.final_projection_dimension 459 | ) 460 | assert out_fdes.shape == (num_docs, expected_dim), ( 461 | f"Output shape mismatch: {out_fdes.shape} != ({num_docs}, {expected_dim})" 462 | ) 463 | 464 | # doc_config = replace(config, encoding_type=EncodingType.AVERAGE) 465 | 466 | return out_fdes 467 | 468 | 469 | if __name__ == "__main__": 470 | print(f"\n{'=' * 20} SCENARIO 1: Basic FDE Generation {'=' * 20}") 471 | 472 | base_config = FixedDimensionalEncodingConfig( 473 | dimension=128, num_repetitions=2, num_simhash_projections=4, seed=42 474 | ) 475 | query_data = np.random.randn(32, base_config.dimension).astype(np.float32) 476 | doc_data = np.random.randn(80, base_config.dimension).astype(np.float32) 477 | 478 | query_fde = generate_query_fde(query_data, base_config) 479 | doc_fde = generate_document_fde( 480 | doc_data, replace(base_config, fill_empty_partitions=True) 481 | ) 482 | 483 | expected_dim = ( 484 | base_config.num_repetitions 485 | * (2**base_config.num_simhash_projections) 486 | * base_config.dimension 487 | ) 488 | print(f"Query FDE Shape: {query_fde.shape} (Expected: {expected_dim})") 489 | print(f"Document FDE Shape: {doc_fde.shape} (Expected: {expected_dim})") 490 | print(f"Similarity Score: {np.dot(query_fde, doc_fde):.4f}") 491 | assert query_fde.shape[0] == expected_dim 492 | 493 | print(f"\n{'=' * 20} SCENARIO 2: Inner Projection (AMS Sketch) {'=' * 20}") 494 | 495 | ams_config = replace( 496 | base_config, projection_type=ProjectionType.AMS_SKETCH, projection_dimension=16 497 | ) 498 | query_fde_ams = generate_query_fde(query_data, ams_config) 499 | expected_dim_ams = ( 500 | ams_config.num_repetitions 501 | * (2**ams_config.num_simhash_projections) 502 | * ams_config.projection_dimension 503 | ) 504 | print(f"AMS Sketch FDE Shape: {query_fde_ams.shape} (Expected: {expected_dim_ams})") 505 | assert query_fde_ams.shape[0] == expected_dim_ams 506 | 507 | print(f"\n{'=' * 20} SCENARIO 3: Final Projection (Count Sketch) {'=' * 20}") 508 | 509 | final_proj_config = replace(base_config, final_projection_dimension=1024) 510 | query_fde_final = generate_query_fde(query_data, final_proj_config) 511 | print( 512 | f"Final Projection FDE Shape: {query_fde_final.shape} (Expected: {final_proj_config.final_projection_dimension})" 513 | ) 514 | assert query_fde_final.shape[0] == final_proj_config.final_projection_dimension 515 | 516 | print(f"\n{'=' * 20} SCENARIO 4: Top-level `generate_fde` wrapper {'=' * 20}") 517 | 518 | query_fde_2 = generate_fde( 519 | query_data, replace(base_config, encoding_type=EncodingType.DEFAULT_SUM) 520 | ) 521 | doc_fde_2 = generate_fde( 522 | doc_data, replace(base_config, encoding_type=EncodingType.AVERAGE) 523 | ) 524 | print( 525 | f"Wrapper-generated Query FDE is identical: {np.allclose(query_fde, query_fde_2)}" 526 | ) 527 | print( 528 | f"Wrapper-generated Document FDE is identical: {np.allclose(doc_fde, doc_fde_2)}" 529 | ) 530 | 531 | print("\nAll test scenarios completed successfully.") 532 | --------------------------------------------------------------------------------